dacapo.experiments.trainers.gunpowder_trainer
Attributes
Classes
GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It |
Module Contents
- dacapo.experiments.trainers.gunpowder_trainer.logger
- class dacapo.experiments.trainers.gunpowder_trainer.GunpowderTrainer(trainer_config)
GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It implements the abstract methods defined in the Trainer class. The GunpowderTrainer class is used to train a model using gunpowder, a data loading and augmentation library. It is used to train a model on a dataset using a specific task.
- learning_rate
The learning rate for the optimizer.
- Type:
float
- batch_size
The size of the training batch.
- Type:
int
- num_data_fetchers
The number of data fetchers.
- Type:
int
- print_profiling
The number of iterations after which to print profiling stats.
- Type:
int
- snapshot_iteration
The number of iterations after which to save a snapshot.
- Type:
int
- min_masked
The minimum value of the mask.
- Type:
float
- augments
The list of augmentations to apply to the data.
- Type:
List[Augment]
- mask_integral_downsample_factor
The downsample factor for the mask integral.
- Type:
int
- clip_raw
Whether to clip the raw data.
- Type:
bool
- scheduler
The learning rate scheduler.
- Type:
torch.optim.lr_scheduler.LinearLR
- create_optimizer(model
Model) -> torch.optim.Optimizer: Creates an optimizer for the model.
- build_batch_provider(datasets
List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components.
- iterate(num_iterations
int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations.
- next() Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray]
Fetches the next batch of data.
- __enter__() GunpowderTrainer
Enters the context manager.
- can_train(datasets
List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets.
Note
The GunpowderTrainer class is a subclass of the Trainer class. It is used to train a model using gunpowder.
- iteration = 0
- learning_rate
- batch_size
- num_data_fetchers
- print_profiling = 100
- snapshot_iteration
- min_masked
- augments
- mask_integral_downsample_factor = 4
- clip_raw
- gt_min_reject
- scheduler = None
- create_optimizer(model)
Creates an optimizer for the model.
- Parameters:
model (Model) – The model for which the optimizer will be created.
- Returns:
The optimizer created for the model.
- Return type:
torch.optim.Optimizer
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> optimizer = trainer.create_optimizer(model)
- build_batch_provider(datasets, model, task, snapshot_container=None)
Initializes the training pipeline using various components.
- Parameters:
datasets (List[Dataset]) – The list of datasets.
model (Model) – The model to be trained.
task (Task) – The task to be performed.
snapshot_container (LocalContainerIdentifier) – The snapshot container.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> trainer.build_batch_provider(datasets, model, task, snapshot_container)
- iterate(num_iterations, model, optimizer, device)
Performs a number of training iterations.
- Parameters:
num_iterations (int) – The number of training iterations.
model (Model) – The model to be trained.
optimizer (torch.optim.Optimizer) – The optimizer for the model.
device (torch.device) – The device (GPU/CPU) where the model will be trained.
- Returns:
An iterator of the training statistics.
- Return type:
Iterator[TrainingIterationStats]
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats)
- next()
Fetches the next batch of data.
- Returns:
A tuple containing the raw data, ground truth data, target data, weight data, and mask data.
- Return type:
Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray]
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> raw, gt, target, weight, mask = trainer.next()
- can_train(datasets) bool
Checks if the trainer can train with a specific set of datasets.
- Parameters:
datasets (List[Dataset]) – The list of datasets.
- Returns:
True if the trainer can train with the datasets, False otherwise.
- Return type:
bool
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Examples
>>> can_train = trainer.can_train(datasets)
- visualize_pipeline(bind_address='0.0.0.0', bind_port=0)
Visualizes the pipeline for the run, including all produced arrays.
- Parameters:
bind_address – str Bind address for Neuroglancer webserver
bind_port – int Bind port for Neuroglancer webserver