class flamo.optimize.trainer.Trainer(net: Module, max_epochs: int = 10, lr: float = 0.001, patience: int = 5, patience_delta: float = 0.01, step_size: int = 50, step_factor: float = 0.1, train_dir: str = None, device: str = 'cpu')

Trainer class for training differentiable system with multiple loss functions. It handles the training step, validation steps, results logging, and the early stopping criterion. By default, it uses torch.optim.Adam() as the optimizer, and torch.optim.lr_scheduler.StepLR() as the learning rate scheduler. Each loss (criterion) can be registered using the register_criterion() method. The training process can be started using the train() method with the training and validation datasets. To each loss it is possible to assign a weight \(\alpha\) and a flag indicating whether the loss function requires the model as an input, which might be needed when the loss depends on the model’s parameters.

Arguments / Attributes:
  • net (nn.Module): The differentiable system to be trained.

  • max_epochs (int): Maximum number of training epochs. Default: 10.

  • lr (float): Learning rate for the optimizer. Default: 1e-3.

  • patience (int): Number of epochs to wait for improvement in validation loss before early stopping. Default: 5.

  • patience_delta (float): Minimum improvement in validation loss to be considered as an improvement. Default: 0.01.

  • step_size (int): Period of learning rate decay. Default: 50.

  • step_factor (float): Multiplicative factor of learning rate decay. Default: 0.1.

  • train_dir (str): The directory for saving training outputs. Default: None.

  • device (str): Device to use for training. Default: ‘cpu’.

Attributes:
  • min_val_loss (float): Minimum validation loss to be updated by the early stopper.

  • optimizer (torch.optim.Optimizer): The optimizer.

  • criterion (list): List of loss functions.

  • alpha (list): List of weights for the loss functions.

  • requires_model (list): List of flags indicating whether the loss functions require the model as an input.

  • scheduler (torch.optim.lr_scheduler.StepLR): The learning rate scheduler.

Examples:

>>> trainer = Trainer(net)  # initialize the trainer with a trainable nn.Module net
>>> alpha_1, alpha_2 = 1, 0.1
>>> loss_1, loss_2 = torch.nn.MSELoss(), torch.nn.L1Loss()
>>> trainer.register_criterion(loss_1, alpha_1)  # register the first loss function with weight 1
>>> trainer.register_criterion(loss_2, alpha_2)  # register the second loss function with weight 0.1
>>> trainer.train(train_dataset, valid_dataset)
early_stop()

Early stopping criterion.

get_train_dir()

Get the directory path where to save the training outputs.

print_results(e: int, e_time: float)

Print a string with the training results for an epoch.

register_criterion(criterion: Module, alpha: int = 1, requires_model: bool = False)

Register in the class a loss function (criterion) and its weight.

Arguments:
  • criterion (nn.Module): The loss function.

  • alpha (float): The weight of the loss function. Default: 1.

  • requires_model (bool): Whether the loss function requires the model as an input. Default: False.

save_model(e: int)

Save the model parameters to a file.

Arguments:

e (int): The epoch number.

train(train_dataset: Dataset, valid_dataset: Dataset)

Train the neural network model.

Arguments:
  • train_dataset (torch.utils.data.Dataset): The training dataset.

  • valid_dataset (torch.utils.data.Dataset): The validation dataset.

train_step(data: tuple)

Perform a single training step.

Arguments:
  • data (tuple): A tuple containing the input data and the target data (inputs, targets).

Returns:
  • float: The loss value of the training step.

valid_step(data: tuple)

Perform a single validation step.

Arguments:
  • data (tuple): A tuple containing the input data and the target data (inputs, targets).

Returns:
  • float: The loss value for the validation step.

flamo.optimize.trainer.get_str_results(epoch: int | None = None, train_loss: list | None = None, valid_loss: list | None = None, time: int | None = None)

Construct the string that has to be printed at the end of the epoch containing information relative to the training performance.

Arguments:
  • epoch (int): The epoch number. Default: None.

  • train_loss (list): List of training loss values. Default: None.

  • valid_loss (list): List of validation loss values. Default: None.

  • time (float): The time taken for the epoch. Default: None.

Returns:
  • str: The formatted string to be printed.