dacapo.utils.balance_weights

Functions

balance_weights(label_data, num_classes[, masks, ...])

Balances the weights based on the label data and other parameters.

Module Contents

dacapo.utils.balance_weights.balance_weights(label_data: numpy.ndarray, num_classes: int, masks: List[numpy.ndarray] = list(), slab=None, clipmin: float = 0.05, clipmax: float = 0.95, moving_counts: List[Dict[int, Tuple[int, int]]] | None = None)

Balances the weights based on the label data and other parameters.

Parameters:
  • label_data (np.ndarray) – The label data.

  • num_classes (int) – The number of classes.

  • masks (List[np.ndarray], optional) – List of masks. Defaults to an empty list.

  • slab (optional) – The slab parameter. Defaults to None.

  • clipmin (float, optional) – The minimum clipping value. Defaults to 0.05.

  • clipmax (float, optional) – The maximum clipping value. Defaults to 0.95.

  • moving_counts (Optional[List[Dict[int, Tuple[int, int]]]], optional) – List of moving counts. Defaults to None.

Returns:

The balanced error scale and moving counts.

Return type:

Tuple[np.ndarray, List[Dict[int, Tuple[int, int]]]]

Raises:
  • AssertionError – If the number of unique labels is greater than the number of classes.

  • AssertionError – If the minimum label is less than 0 or the maximum label is greater than the number of classes.

Examples

>>> label_data = np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]])
>>> num_classes = 3
>>> masks = [np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])]
>>> balance_weights(label_data, num_classes, masks)
(array([[0.33333334, 0.33333334, 0.33333334],
        [0.33333334, 0.33333334, 0.33333334],
        [0.33333334, 0.33333334, 0.33333334]], dtype=float32),
 [{0: (3, 9), 1: (3, 9), 2: (3, 9)}])

Notes

The balanced error scale is computed as: error_scale = np.ones(label_data.shape, dtype=np.float32) for mask in masks:

error_scale = error_scale * mask

slab_ranges = (range(0, m, s) for m, s in zip(error_scale.shape, slab)) for ind, start in enumerate(itertools.product(*slab_ranges)):

slab_counts = moving_counts[ind] slices = tuple(slice(start[d], start[d] + slab[d]) for d in range(len(slab))) scale_slab = error_scale[slices] labels_slab = label_data[slices] masked_in = scale_slab.sum() classes, counts = np.unique(labels_slab[np.nonzero(scale_slab)], return_counts=True) updated_fracs = [] for key, (num, den) in slab_counts.items():

slab_counts[key] = (num, den + masked_in)

for class_id, num in zip(classes, counts):

(old_num, den) = slab_counts[class_id] slab_counts[class_id] = (num + old_num, den) updated_fracs.append(slab_counts[class_id][0] / slab_counts[class_id][1])

fracs = np.array(updated_fracs) if clipmin is not None or clipmax is not None:

np.clip(fracs, clipmin, clipmax, fracs)

total_frac = 1.0 w_sparse = total_frac / float(num_classes) / fracs w = np.zeros(num_classes) w[classes] = w_sparse labels_slab = labels_slab.astype(np.int64) scale_slab *= np.take(w, labels_slab)