dacapo.experiments.trainers.trainer
Classes
Trainer Abstract Base Class |
Module Contents
- class dacapo.experiments.trainers.trainer.Trainer
Trainer Abstract Base Class
This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.
- iteration
The number of training iterations.
- Type:
int
- batch_size
The size of the training batch.
- Type:
int
- learning_rate
The learning rate for the optimizer.
- Type:
float
- create_optimizer(model
Model) -> torch.optim.Optimizer: Creates an optimizer for the model.
- iterate(num_iterations
int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations.
- can_train(datasets
List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets.
- build_batch_provider(datasets
List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components.
Note
The Trainer class is an abstract class that cannot be instantiated directly. It is meant to be subclassed.
- iteration: int
- batch_size: int
- learning_rate: float
- abstract create_optimizer(model: dacapo.experiments.model.Model) torch.optim.Optimizer
Creates an optimizer for the model.
- Parameters:
model (Model) – The model for which the optimizer will be created.
- Returns:
The optimizer created for the model.
- Return type:
torch.optim.Optimizer
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> optimizer = trainer.create_optimizer(model)
Note
This method must be implemented by the subclass.
- abstract iterate(num_iterations: int, model: dacapo.experiments.model.Model, optimizer: torch.optim.Optimizer, device: torch.device) Iterator[dacapo.experiments.training_iteration_stats.TrainingIterationStats]
Performs a number of training iterations.
- Parameters:
num_iterations (int) – Number of training iterations.
model (Model) – The model to be trained.
optimizer (torch.optim.Optimizer) – The optimizer for the model.
device (torch.device) – The device (GPU/CPU) where the model will be trained.
- Returns:
An iterator of the training statistics.
- Return type:
Iterator[TrainingIterationStats]
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats)
Note
This method must be implemented by the subclass.
- abstract can_train(datasets: List[dacapo.experiments.datasplits.datasets.Dataset]) bool
Checks if the trainer can train with a specific set of datasets.
Some trainers may have specific requirements for their training datasets.
- Parameters:
datasets (List[Dataset]) – The training datasets.
- Returns:
True if the trainer can train on the given datasets, False otherwise.
- Return type:
bool
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> can_train = trainer.can_train(datasets)
Note
This method must be implemented by the subclass.
- abstract build_batch_provider(datasets: List[dacapo.experiments.datasplits.datasets.Dataset], model: dacapo.experiments.model.Model, task: dacapo.experiments.tasks.task.Task, snapshot_container: dacapo.store.array_store.LocalContainerIdentifier) None
Initializes the training pipeline using various components.
This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.
- Parameters:
datasets (List[Dataset]) – The datasets to pull data from.
model (Model) – The model to inform the pipeline of required input/output sizes.
task (Task) – The task to transform ground truth into target.
snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> trainer.build_batch_provider(datasets, model, task, snapshot_container)
Note
This method must be implemented by the subclass.