class flamo.optimize.trainer.Trainer(net, max_epochs=10, lr=0.001, patience=5, patience_delta=0.01, step_size=50, step_factor=0.1, train_dir=None, device='cpu')

Trainer class for training differenitbale system with multiple loss functions. It handles the training step, validation steps, results logging, and the early stopping criterion. By default, it uses torch.optim.Adam() as the optimizer, and torch.optim.lr_scheduler.StepLR() as the learning rate scheduler. Each loss (criterion) can be registered using the register_criterion() method. The training process can be started using the train() method with the training and validation datasets. To each loss it is possible to assign a weigth \(\alpha\) and a flag indicating whether the loss function requires the model as an input, which might be needed when the loss depends on the model’s parameters.

Args:
  • net (nn.Module): The differentiable system to be trained.

  • max_epochs (int): Maximum number of training epochs. Default: 10.

  • lr (float): Learning rate for the optimizer. Default: 1e-3.

  • patience (int): Number of epochs to wait for improvement in validation loss before early stopping. Default: 5.

  • patience_delta (float): Minimum improvement in validation loss to be considered as an improvement. Default: 0.01.

  • step_size (int): Period of learning rate decay. Default: 50.

  • step_factor (float): Multiplicative factor of learning rate decay. Default: 0.1.

  • train_dir (str): The directory for saving training outputs. Default: None.

  • device (str): Device to use for training. Default: ‘cpu’.

Attributes:
  • device (str): Device to use for training.

  • net (nn.Module): The ifferentiable system.

  • max_epochs (int): Maximum number of training epochs.

  • lr (float): Learning rate for the optimizer.

  • patience (int): Number of epochs to wait for improvement in validation loss before early stopping.

  • patience_delta (float): Minimum improvement in validation loss to be considered as an improvement.

  • min_val_loss (float): Minimum validation loss to be updated by the early stopper.

  • optimizer (torch.optim.Optimizer): The optimizer.

  • train_dir (str): The directory for saving training outputs.

  • criterion (list): List of loss functions.

  • alpha (list): List of weights for the loss functions.

  • requires_model (list): List of flags indicating whether the loss functions require the model as an input.

  • scheduler (torch.optim.lr_scheduler.StepLR): The learning rate scheduler.

Methods:
  • register_criterion(criterion, alpha, requires_model=False): Register a loss function and its weight.

  • train(train_dataset, valid_dataset): Train the neural network model.

  • train_step(data): Perform a single training step.

  • valid_step(data): Perform a single validation step.

  • print_results(epoch, time): Print the training results for an epoch.

  • get_train_dir(): Get the directory path for saving training outputs.

  • save_model(epoch): Save the model parameters to a file.

Examples:

>>> trainer = Trainer(net)  # initialize the trainer with a trainable nn.Module net 
>>> alpha_1, alpha_2 = 1, 0.1
>>> loss_1, loss_2 = torch.nn.MSELoss(), torch.nn.L1Loss()
>>> trainer.register_criterion(loss_1, alpha_1)  # register the first loss function with weight 1
>>> trainer.register_criterion(loss_2, alpha_2)  # register the second loss function with weight 0.1
>>> trainer.train(train_dataset, valid_dataset)
early_stop()

Early stopping criterion.

get_train_dir()

Get the directory path for saving training outputs.

print_results(e, e_time)

Print the training results for an epoch.

register_criterion(criterion, alpha=1, requires_model=False)

Register a loss function and its weight in the loss function.

Args:
  • criterion (nn.Module): The loss function.

  • alpha (float): The weight of the loss function. Default: 1.

  • requires_model (bool): Whether the loss function requires the model as an input. Default: False.

save_model(e)

Save the model parameters to a file.

e (int): The epoch number.

train(train_dataset, valid_dataset)

Train the neural network model.

Args:
  • train_dataset (torch.utils.data.Dataset): The training dataset.

  • valid_dataset (torch.utils.data.Dataset): The validation dataset.

train_step(data)

Perform a single training step.

Args:
  • data (tuple): A tuple containing the input data and the target data (inputs, targets).

Returns:
  • float: The loss value of the training step.

valid_step(data)

Perform a single validation step.

Args:
  • data (tuple): A tuple containing the input data and the target data.

Returns:
  • float: The loss value for the validation step.