dacapo.experiments.starts.start =============================== .. py:module:: dacapo.experiments.starts.start Attributes ---------- .. autoapisummary:: dacapo.experiments.starts.start.logger dacapo.experiments.starts.start.head_keys Classes ------- .. autoapisummary:: dacapo.experiments.starts.start.Start Functions --------- .. autoapisummary:: dacapo.experiments.starts.start.match_heads Module Contents --------------- .. py:data:: logger .. py:data:: head_keys :value: ['prediction_head.weight', 'prediction_head.bias', 'chain.1.weight', 'chain.1.bias'] .. py:function:: match_heads(model, head_weights, old_head, new_head) Matches the head of the model to the new head by copying the weights of the old head to the new head. The weights of the old head are copied to the new head by matching the labels of the old head to the labels of the new head. :param model: obj The model to which the weights are to be loaded. :param head_weights: dict The weights of the old head. :param old_head: list The labels of the old head. :param new_head: list The labels of the new head. :returns: obj The model with the weights of the old head copied to the new head. :rtype: model :raises RuntimeError: If the old head is not found in the new head, a RuntimeError exception is thrown which is logged and handled by loading only the common layers from weights. .. rubric:: Examples >>> model = match_heads(model, head_weights, old_head, new_head) .. rubric:: Notes This function is called by the Start class to match the head of the model to the new head by copying the weights of the old head to the new head. .. py:class:: Start(start_config) This class interfaces with the dacapo store to retrieve and load the weights of the starter model used for finetuning. .. attribute:: run str The specified run to retrieve weights for the model. .. attribute:: criterion str The policy that was used to decide when to store the weights. .. attribute:: channels int The number of channels in the input data. .. method:: __init__(start_config) Initializes the Start class with specified config to run the initialization of weights for a model associated with a specific criterion. .. method:: initialize_weights(model, new_head=None) Retrieves the weights from the dacapo store and load them into the model. .. rubric:: Notes This class is used to retrieve and load the weights of the starter model used for finetuning from the dacapo store. .. py:attribute:: channels :value: None .. py:method:: initialize_weights(model, new_head=None) Retrieves the weights from the dacapo store and load them into the model. :param model: obj The model to which the weights are to be loaded. :param new_head: list The labels of the new head. :returns: obj The model with the weights loaded from the dacapo store. :rtype: model :raises RuntimeError: If weights of a non-existing or mismatched layer are being loaded, a RuntimeError exception is thrown which is logged and handled by loading only the common layers from weights. .. rubric:: Examples >>> model = start.initialize_weights(model, new_head) .. rubric:: Notes This function is called by the Start class to retrieve the weights from the dacapo store and load them into the model.