dacapo.experiments.architectures.architecture

Classes

Architecture

An abstract base class for defining the architecture of a neural network model.

Module Contents

class dacapo.experiments.architectures.architecture.Architecture(*args, **kwargs)

An abstract base class for defining the architecture of a neural network model. It is inherited from PyTorch’s Module and built-in class ABC (Abstract Base Classes). Other classes can inherit this class to define their own specific variations of architecture. It requires to implement several property methods, and also includes additional methods related to the architecture design.

input_shape

The spatial input shape for the neural network architecture.

Type:

Coordinate

eval_shape_increase

The amount to increase the input shape during prediction.

Type:

Coordinate

num_in_channels

The number of input channels required by the architecture.

Type:

int

num_out_channels

The number of output channels provided by the architecture.

Type:

int

dims()

Returns the number of dimensions of the input shape.

scale()

Scales the input voxel size as required by the architecture.

Note

The class is abstract and requires to implement the abstract methods.

property input_shape: funlib.geometry.Coordinate
Abstractmethod:

Abstract method to define the spatial input shape for the neural network architecture. The shape should not account for the channels and batch dimensions.

Returns:

The spatial input shape.

Return type:

Coordinate

Raises:

NotImplementedError – If the method is not implemented in the derived class.

Examples

>>> input_shape = Coordinate((128, 128, 128))
>>> model = MyModel(input_shape)

Note

The method should be implemented in the derived class.

property eval_shape_increase: funlib.geometry.Coordinate

Provides information about how much to increase the input shape during prediction.

Returns:

An instance representing the amount to increase in each dimension of the input shape.

Return type:

Coordinate

Raises:

NotImplementedError – If the method is not implemented in the derived class.

Examples

>>> eval_shape_increase = Coordinate((0, 0, 0))
>>> model = MyModel(input_shape, eval_shape_increase)

Note

The method is optional and can be overridden in the derived class.

property num_in_channels: int
Abstractmethod:

Abstract method to return number of input channels required by the architecture.

Returns:

Required number of input channels.

Return type:

int

Raises:

NotImplementedError – If the method is not implemented in the derived class.

Examples

>>> num_in_channels = 1
>>> model = MyModel(input_shape, num_in_channels)

Note

The method should be implemented in the derived class.

property num_out_channels: int
Abstractmethod:

Abstract method to return the number of output channels provided by the architecture.

Returns:

Number of output channels.

Return type:

int

Raises:

NotImplementedError – If the method is not implemented in the derived class.

Examples

>>> num_out_channels = 1
>>> model = MyModel(input_shape, num_out_channels)

Note

The method should be implemented in the derived class.

property dims: int

Returns the number of dimensions of the input shape.

Returns:

The number of dimensions.

Return type:

int

Raises:

NotImplementedError – If the method is not implemented in the derived class.

Examples

>>> input_shape = Coordinate((128, 128, 128))
>>> model = MyModel(input_shape)
>>> model.dims
3

Note

The method is optional and can be overridden in the derived class.

scale(input_voxel_size: funlib.geometry.Coordinate) funlib.geometry.Coordinate

Method to scale the input voxel size as required by the architecture.

Parameters:

input_voxel_size (Coordinate) – The original size of the input voxel.

Returns:

The scaled voxel size.

Return type:

Coordinate

Raises:

NotImplementedError – If the method is not implemented in the derived class.

Examples

>>> input_voxel_size = Coordinate((1, 1, 1))
>>> model = MyModel(input_shape)
>>> model.scale(input_voxel_size)
Coordinate((1, 1, 1))

Note

The method is optional and can be overridden in the derived class.