dacapo.experiments.trainers.gunpowder_trainer
Module Contents
Classes
Trainer Abstract Base Class |
Attributes
- dacapo.experiments.trainers.gunpowder_trainer.logger
- class dacapo.experiments.trainers.gunpowder_trainer.GunpowderTrainer(trainer_config)
Trainer Abstract Base Class
This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model.
- iteration = 0
- create_optimizer(model)
Creates an optimizer for the model.
- Parameters:
model (Model) – The model for which the optimizer will be created.
- Returns:
The optimizer created for the model.
- Return type:
torch.optim.Optimizer
- build_batch_provider(datasets, model, task, snapshot_container=None)
Initializes the training pipeline using various components.
This method uses the datasets, model, task, and snapshot_container to set up the training pipeline.
- Parameters:
datasets (List[Dataset]) – The datasets to pull data from.
model (Model) – The model to inform the pipeline of required input/output sizes.
task (Task) – The task to transform ground truth into target.
snapshot_container (LocalContainerIdentifier) – Defines where snapshots will be saved.
- iterate(num_iterations, model, optimizer, device)
Performs a number of training iterations.
- Parameters:
num_iterations (int) – Number of training iterations.
model (Model) – The model to be trained.
optimizer (torch.optim.Optimizer) – The optimizer for the model.
device (torch.device) – The device (GPU/CPU) where the model will be trained.
- Returns:
An iterator of the training statistics.
- Return type:
Iterator[TrainingIterationStats]
- next()
- can_train(datasets) bool
Checks if the trainer can train with a specific set of datasets.
Some trainers may have specific requirements for their training datasets.
- Parameters:
datasets (List[Dataset]) – The training datasets.
- Returns:
True if the trainer can train on the given datasets, False otherwise.
- Return type:
bool