dacapo.experiments.trainers.gunpowder_trainer ============================================= .. py:module:: dacapo.experiments.trainers.gunpowder_trainer Attributes ---------- .. autoapisummary:: dacapo.experiments.trainers.gunpowder_trainer.logger Classes ------- .. autoapisummary:: dacapo.experiments.trainers.gunpowder_trainer.GunpowderTrainer Module Contents --------------- .. py:data:: logger .. py:class:: GunpowderTrainer(trainer_config) GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It implements the abstract methods defined in the Trainer class. The GunpowderTrainer class is used to train a model using gunpowder, a data loading and augmentation library. It is used to train a model on a dataset using a specific task. .. attribute:: learning_rate The learning rate for the optimizer. :type: float .. attribute:: batch_size The size of the training batch. :type: int .. attribute:: num_data_fetchers The number of data fetchers. :type: int .. attribute:: print_profiling The number of iterations after which to print profiling stats. :type: int .. attribute:: snapshot_iteration The number of iterations after which to save a snapshot. :type: int .. attribute:: min_masked The minimum value of the mask. :type: float .. attribute:: augments The list of augmentations to apply to the data. :type: List[Augment] .. attribute:: mask_integral_downsample_factor The downsample factor for the mask integral. :type: int .. attribute:: clip_raw Whether to clip the raw data. :type: bool .. attribute:: scheduler The learning rate scheduler. :type: torch.optim.lr_scheduler.LinearLR .. method:: create_optimizer(model Model) -> torch.optim.Optimizer: Creates an optimizer for the model. .. method:: build_batch_provider(datasets List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: Initializes the training pipeline using various components. .. method:: iterate(num_iterations int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: Performs a number of training iterations. .. method:: __iter__() -> Iterator[None] Initializes the training pipeline. .. method:: next() -> Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray] Fetches the next batch of data. .. method:: __enter__() -> GunpowderTrainer Enters the context manager. .. method:: __exit__(exc_type, exc_val, exc_tb) -> None Exits the context manager. .. method:: can_train(datasets List[Dataset]) -> bool: Checks if the trainer can train with a specific set of datasets. .. note:: The GunpowderTrainer class is a subclass of the Trainer class. It is used to train a model using gunpowder. .. py:attribute:: iteration :value: 0 .. py:attribute:: learning_rate .. py:attribute:: batch_size .. py:attribute:: num_data_fetchers .. py:attribute:: print_profiling :value: 100 .. py:attribute:: snapshot_iteration .. py:attribute:: min_masked .. py:attribute:: augments .. py:attribute:: mask_integral_downsample_factor :value: 4 .. py:attribute:: clip_raw .. py:attribute:: gt_min_reject .. py:attribute:: scheduler :value: None .. py:method:: create_optimizer(model) Creates an optimizer for the model. :param model: The model for which the optimizer will be created. :type model: Model :returns: The optimizer created for the model. :rtype: torch.optim.Optimizer :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> optimizer = trainer.create_optimizer(model) .. py:method:: build_batch_provider(datasets, model, task, snapshot_container=None) Initializes the training pipeline using various components. :param datasets: The list of datasets. :type datasets: List[Dataset] :param model: The model to be trained. :type model: Model :param task: The task to be performed. :type task: Task :param snapshot_container: The snapshot container. :type snapshot_container: LocalContainerIdentifier :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> trainer.build_batch_provider(datasets, model, task, snapshot_container) .. py:method:: iterate(num_iterations, model, optimizer, device) Performs a number of training iterations. :param num_iterations: The number of training iterations. :type num_iterations: int :param model: The model to be trained. :type model: Model :param optimizer: The optimizer for the model. :type optimizer: torch.optim.Optimizer :param device: The device (GPU/CPU) where the model will be trained. :type device: torch.device :returns: An iterator of the training statistics. :rtype: Iterator[TrainingIterationStats] :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): >>> print(iteration_stats) .. py:method:: next() Fetches the next batch of data. :returns: A tuple containing the raw data, ground truth data, target data, weight data, and mask data. :rtype: Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray] :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> raw, gt, target, weight, mask = trainer.next() .. py:method:: can_train(datasets) -> bool Checks if the trainer can train with a specific set of datasets. :param datasets: The list of datasets. :type datasets: List[Dataset] :returns: True if the trainer can train with the datasets, False otherwise. :rtype: bool :raises NotImplementedError: If the method is not implemented by the subclass. .. rubric:: Examples >>> can_train = trainer.can_train(datasets) .. py:method:: visualize_pipeline(bind_address='0.0.0.0', bind_port=0) Visualizes the pipeline for the run, including all produced arrays. :param bind_address: str Bind address for Neuroglancer webserver :param bind_port: int Bind port for Neuroglancer webserver