dacapo.experiments.tasks.losses.mse_loss

Module Contents

Classes

MSELoss

A class used to represent the Mean Square Error Loss function (MSELoss).

class dacapo.experiments.tasks.losses.mse_loss.MSELoss

A class used to represent the Mean Square Error Loss function (MSELoss).

None
compute(prediction, target, weight):

Computes the MSELoss with the given weight for the predictiom and target.

compute(prediction, target, weight)

Function to compute the MSELoss for the provided prediction and target, with respect to the weight.

Parameters:

predictiontorch.Tensor

The prediction tensor for which loss needs to be calculated.

targettorch.Tensor

The target tensor with respect to which loss is calculated.

weighttorch.Tensor

The weight tensor used to weigh the prediction in the loss calculation.

Returns:

: torch.Tensor

The computed MSELoss tensor.