dacapo.experiments.trainers.trainer

Module Contents

Classes

Trainer

Trainer Abstract Base Class

class dacapo.experiments.trainers.trainer.Trainer

Trainer Abstract Base Class

This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.

iteration: int
batch_size: int
learning_rate: float
abstract create_optimizer(model: dacapo.experiments.model.Model) torch.optim.Optimizer

Creates an optimizer for the model.

Parameters:

model (Model) – The model for which the optimizer will be created.

Returns:

The optimizer created for the model.

Return type:

torch.optim.Optimizer

abstract iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats]

Performs a number of training iterations.

Parameters:
  • num_iterations (int) – Number of training iterations.

  • model (Model) – The model to be trained.

  • optimizer (torch.optim.Optimizer) – The optimizer for the model.

  • device (torch.device) – The device (GPU/CPU) where the model will be trained.

Returns:

An iterator of the training statistics.

Return type:

Iterator[TrainingIterationStats]

abstract can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) bool

Checks if the trainer can train with a specific set of datasets.

Some trainers may have specific requirements for their training datasets.

Parameters:

datasets (List[Dataset]) – The training datasets.

Returns:

True if the trainer can train on the given datasets, False otherwise.

Return type:

bool

abstract build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) None

Initializes the training pipeline using various components.

This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.

Parameters:
  • datasets (List[Dataset]) – The datasets to pull data from.

  • model (Model) – The model to inform the pipeline of required input/output sizes.

  • task (Task) – The task to transform ground truth into target.

  • snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.