Shortcuts

Source code for connectomics.model.loss.criterion

from __future__ import print_function, division
from typing import Optional, List, Union, Tuple
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from .loss import *
from .regularization import *
from ..utils import get_functional_act, SplitActivation


[docs]class Criterion(object): """Calculating losses and regularizations given the prediction, target and weight mask. Args: device (torch.device): model running device. GPUs are recommended for model training and inference. target_opt (List[str], optional): target options. Defaults to ['1']. loss_opt (List[List[str]], optional): loss options for the specified targets. Defaults to [['WeightedBCE']]. output_act (List[List[str]], optional): activation functions for each loss option. Defaults to [['none']]. loss_weight (List[List[float]], optional): the scalar weight of each loss. Defaults to [[1.]]. regu_opt (Optional[List[str]], optional): regularization options. Defaults to None. regu_target (Optional[List[List[int]]], optional): indicies of predictions for applying regularization. Defaults to None. regu_weight (Optional[List[float]], optional): the scalar weight of each regularization. Defaults to None. do_2d (bool, optional): whether to conduct 2D training. Defaults to False. """ loss_dict = { 'WeightedMSE': WeightedMSE, 'WeightedMAE': WeightedMAE, 'WeightedBCE': WeightedBCE, 'DiceLoss': DiceLoss, 'WeightedCE': WeightedCE, 'WeightedBCEWithLogitsLoss': WeightedBCEWithLogitsLoss, } regu_dict = { 'Binary': BinaryReg, 'FgContour': FgContourConsistency, 'ContourDT': ContourDTConsistency, 'FgDT': ForegroundDTConsistency, } def __init__(self, device: torch.device, target_opt: List[str] = ['1'], loss_opt: List[List[str]] = [['WeightedBCE']], output_act: List[List[str]] = [['none']], loss_weight: List[List[float]] = [[1.]], regu_opt: Optional[List[str]] = None, regu_target: Optional[List[List[int]]] = None, regu_weight: Optional[List[float]] = None, do_2d: bool = False): self.device = device self.target_opt = target_opt self.splitter = SplitActivation( target_opt, split_only=True, do_2d=do_2d) self.num_target = len(target_opt) self.num_regu = 0 if regu_opt is None else len(regu_opt) self.loss_opt = loss_opt self.loss_fn = self.get_loss(loss_opt) self.loss_w = loss_weight self.regu_opt = regu_opt self.regu_fn = self.get_regu(regu_opt) self.regu_t = regu_target self.regu_w = regu_weight self.act = self.get_act(output_act) def get_regu(self, regu_opt): regu = None if regu_opt is not None: regu = [None]*len(regu_opt) for i, ropt in enumerate(regu_opt): assert ropt in self.regu_dict regu[i] = self.regu_dict[ropt]() return regu def get_loss(self, loss_opt): out = [None]*self.num_target for i in range(self.num_target): out[i] = [None]*len(loss_opt[i]) for j, lopt in enumerate(loss_opt[i]): assert lopt in self.loss_dict out[i][j] = self.loss_dict[lopt]() return out def get_act(self, output_act): out = [None]*self.num_target for i in range(self.num_target): out[i] = [None]*len(output_act[i]) for j, act in enumerate(output_act[i]): out[i][j] = get_functional_act(act) return out def to_torch(self, data): if type(data) == torch.Tensor: return data.to(self.device, non_blocking=True) return torch.from_numpy(data).to(self.device) def evaluate(self, pred: Tensor, target: Union[List[Tensor], List[np.ndarray]], weight: Union[List[Tensor], List[np.ndarray]], key: Optional[str] = None, losses_vis: dict = {}, # visualizing individual losses ) -> Tuple[Tensor, dict]: # split the prediction for each target x = self.splitter(pred) loss = 0.0 for i in range(self.num_target): target_t = self.to_torch(target[i]) for j in range(len(self.loss_fn[i])): w_mask = None if weight[i][j].shape[-1] == 1 else self.to_torch( weight[i][j]) loss_temp = self.loss_w[i][j] * self.loss_fn[i][j]( self.act[i][j](x[i]), target=target_t, weight_mask=w_mask) loss += loss_temp loss_tag = self.target_opt[i] + '_' + \ self.loss_opt[i][j] + '_' + str(i) if key is not None: loss_tag += '_' + key assert loss_tag not in losses_vis.keys() losses_vis[loss_tag] = loss_temp for i in range(self.num_regu): targets = [x[j] for j in self.regu_t[i]] regu_temp = self.regu_w[i]*self.regu_fn[i](*targets) loss += regu_temp targets_name = [self.target_opt[j] for j in self.regu_t[i]] regu_tag = '_'.join(targets_name) + '_' + \ self.regu_opt[i] + '_' + str(i) if key is not None: regu_tag += '_' + key assert regu_tag not in losses_vis.keys() losses_vis[regu_tag] = regu_temp return loss, losses_vis def __call__(self, pred: Union[Tensor, OrderedDict], target: Union[List[Tensor], List[np.ndarray]], weight: Union[List[Tensor], List[np.ndarray]], ) -> Tuple[Tensor, dict]: losses_vis = {} if isinstance(pred, Tensor): # Python’s default arguments are evaluated once when the function is defined, not each time the function is # called (like it is in say, Ruby). This means that if you use a mutable default argument and mutate it, you # will and have mutated that object for all future calls to the function as well. # (According to https://docs.python-guide.org/writing/gotchas/) return self.evaluate(pred, target, weight, losses_vis=losses_vis) # evaluate OrderedDict predicted by DeepLab loss = 0.0 for key in pred.keys(): temp_loss, losses_vis = self.evaluate( pred[key], target, weight, key, losses_vis) loss += temp_loss return loss, losses_vis
[docs] @classmethod def build_from_cfg(cls, cfg, device): """Build a Criterion class based on the config options. Args: cfg (yacs.config.CfgNode): YACS configuration options. device (torch.device): model running device type. GPUs are recommended for model training and inference. """ return cls(device, cfg.MODEL.TARGET_OPT, cfg.MODEL.LOSS_OPTION, cfg.MODEL.OUTPUT_ACT, cfg.MODEL.LOSS_WEIGHT, cfg.MODEL.REGU_OPT, cfg.MODEL.REGU_TARGET, cfg.MODEL.REGU_WEIGHT, do_2d=cfg.DATASET.DO_2D)

© Copyright 2019-2021, Zudi Lin and Donglai Wei. Revision 4bb1d5dc.

Built with Sphinx using a theme provided by Read the Docs.