dacapo.experiments.trainers.trainer

Classes

Trainer

Trainer Abstract Base Class

Module Contents

class dacapo.experiments.trainers.trainer.Trainer

Trainer Abstract Base Class

This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.

iteration

The number of training iterations.

Type:

int

batch_size

The size of the training batch.

Type:

int

learning_rate

The learning rate for the optimizer.

Type:

float

create_optimizer(model

Model) -> torch.optim.Optimizer: Creates an optimizer for the model.

iterate(num_iterations

int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations.

can_train(datasets

List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets.

build_batch_provider(datasets

List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components.

Note

The Trainer class is an abstract class that cannot be instantiated directly. It is meant to be subclassed.

iteration: int
batch_size: int
learning_rate: float
abstract create_optimizer(model: dacapo.experiments.model.Model) torch.optim.Optimizer

Creates an optimizer for the model.

Parameters:

model (Model) – The model for which the optimizer will be created.

Returns:

The optimizer created for the model.

Return type:

torch.optim.Optimizer

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Examples

>>> optimizer = trainer.create_optimizer(model)

Note

This method must be implemented by the subclass.

abstract iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats]

Performs a number of training iterations.

Parameters:
  • num_iterations (int) – Number of training iterations.

  • model (Model) – The model to be trained.

  • optimizer (torch.optim.Optimizer) – The optimizer for the model.

  • device (torch.device) – The device (GPU/CPU) where the model will be trained.

Returns:

An iterator of the training statistics.

Return type:

Iterator[TrainingIterationStats]

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Examples

>>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device):
>>>     print(iteration_stats)

Note

This method must be implemented by the subclass.

abstract can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) bool

Checks if the trainer can train with a specific set of datasets.

Some trainers may have specific requirements for their training datasets.

Parameters:

datasets (List[Dataset]) – The training datasets.

Returns:

True if the trainer can train on the given datasets, False otherwise.

Return type:

bool

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Examples

>>> can_train = trainer.can_train(datasets)

Note

This method must be implemented by the subclass.

abstract build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) None

Initializes the training pipeline using various components.

This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.

Parameters:
  • datasets (List[Dataset]) – The datasets to pull data from.

  • model (Model) – The model to inform the pipeline of required input/output sizes.

  • task (Task) – The task to transform ground truth into target.

  • snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.

Raises:

NotImplementedError – If the method is not implemented by the subclass.

Examples

>>> trainer.build_batch_provider(datasets, model, task, snapshot_container)

Note

This method must be implemented by the subclass.