Shortcuts

Source code for connectomics.engine.trainer

from __future__ import print_function, division
from typing import Optional

import os
import time
import math
import GPUtil
import numpy as np
from yacs.config import CfgNode

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

from .solver import *
from ..model import *
from ..utils.monitor import build_monitor
from ..data.augmentation import build_train_augmentor, TestAugmentor
from ..data.dataset import build_dataloader, get_dataset
from ..data.dataset.build import _get_file_list, _make_path_list
from ..data.utils import build_blending_matrix, writeh5
from ..data.utils import get_padsize, array_unpad


[docs]class Trainer(object): r"""Trainer class for supervised learning. Args: cfg (yacs.config.CfgNode): YACS configuration options. device (torch.device): model running device. GPUs are recommended for model training and inference. mode (str): running mode of the trainer (``'train'`` or ``'test'``). Default: ``'train'`` rank (int, optional): node rank for distributed training. Default: `None` checkpoint (str, optional): the checkpoint file to be loaded. Default: `None` """ def __init__(self, cfg: CfgNode, device: torch.device, mode: str = 'train', rank: Optional[int] = None, checkpoint: Optional[str] = None): assert mode in ['train', 'test'] self.cfg = cfg self.device = device self.output_dir = cfg.DATASET.OUTPUT_PATH self.mode = mode self.rank = rank self.is_main_process = rank is None or rank == 0 self.inference_singly = (mode == 'test') and cfg.INFERENCE.DO_SINGLY self.model = build_model(self.cfg, self.device, rank) if self.mode == 'train': self.optimizer = build_optimizer(self.cfg, self.model) self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer) self.scaler = GradScaler() if cfg.MODEL.MIXED_PRECESION else None self.start_iter = self.cfg.MODEL.PRE_MODEL_ITER self.update_checkpoint(checkpoint) # stochastic weight averaging if self.cfg.SOLVER.SWA.ENABLED: self.swa_model, self.swa_scheduler = build_swa_model( self.cfg, self.model, self.optimizer) self.augmentor = build_train_augmentor(self.cfg) self.criterion = Criterion.build_from_cfg(self.cfg, self.device) if self.is_main_process: self.monitor = build_monitor(self.cfg) self.monitor.load_info(self.cfg, self.model) self.total_iter_nums = self.cfg.SOLVER.ITERATION_TOTAL - self.start_iter self.total_time = 0 else: self.update_checkpoint(checkpoint) # build test-time augmentor and update output filename self.augmentor = TestAugmentor.build_from_cfg(cfg, activation=True) if not self.cfg.DATASET.DO_CHUNK_TITLE and not self.inference_singly: self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME self.test_filename = self.augmentor.update_name( self.test_filename) self.dataset, self.dataloader = None, None if not self.cfg.DATASET.DO_CHUNK_TITLE and not self.inference_singly: self.dataloader = build_dataloader( self.cfg, self.augmentor, self.mode, rank=rank) self.dataloader = iter(self.dataloader) if self.mode == 'train' and cfg.DATASET.VAL_IMAGE_NAME is not None: self.val_loader = build_dataloader( self.cfg, None, mode='val', rank=rank)
[docs] def train(self): r"""Training function of the trainer class. """ self.model.train() for i in range(self.total_iter_nums): iter_total = self.start_iter + i self.start_time = time.perf_counter() self.optimizer.zero_grad() # load data sample = next(self.dataloader) volume = sample.out_input target, weight = sample.out_target_l, sample.out_weight_l self.data_time = time.perf_counter() - self.start_time # prediction volume = volume.to(self.device, non_blocking=True) with autocast(enabled=self.cfg.MODEL.MIXED_PRECESION): pred = self.model(volume) loss, losses_vis = self.criterion(pred, target, weight) self._train_misc(loss, pred, volume, target, weight, iter_total, losses_vis) self.maybe_save_swa_model()
def _train_misc(self, loss, pred, volume, target, weight, iter_total, losses_vis): self.backward_pass(loss) # backward pass # logging and update record if hasattr(self, 'monitor'): do_vis = self.monitor.update(iter_total, loss, losses_vis, self.optimizer.param_groups[0]['lr']) if do_vis: self.monitor.visualize( volume, target, pred, weight, iter_total) if torch.cuda.is_available(): GPUtil.showUtilization(all=True) # Save model if (iter_total+1) % self.cfg.SOLVER.ITERATION_SAVE == 0: self.save_checkpoint(iter_total) if (iter_total+1) % self.cfg.SOLVER.ITERATION_VAL == 0: self.validate(iter_total) # update learning rate self.maybe_update_swa_model(iter_total) self.scheduler_step(iter_total, loss) if self.is_main_process: self.iter_time = time.perf_counter() - self.start_time self.total_time += self.iter_time avg_iter_time = self.total_time / (iter_total+1-self.start_iter) est_time_left = avg_iter_time * \ (self.total_iter_nums+self.start_iter-iter_total-1) / 3600.0 print('[Iteration %05d] Data time: %.4fs, Iter time: %.4fs, Avg iter time: %.4fs, Time Left %.2fh.' % ( iter_total, self.data_time, self.iter_time, avg_iter_time, est_time_left)) # Release some GPU memory and ensure same GPU usage in the consecutive iterations according to # https://discuss.pytorch.org/t/gpu-memory-consumption-increases-while-training/2770 del volume, target, pred, weight, loss, losses_vis
[docs] def validate(self, iter_total): r"""Validation function of the trainer class. """ if not hasattr(self, 'val_loader'): return self.model.eval() with torch.no_grad(): val_loss = 0.0 for i, sample in enumerate(self.val_loader): volume = sample.out_input target, weight = sample.out_target_l, sample.out_weight_l # prediction volume = volume.to(self.device, non_blocking=True) with autocast(enabled=self.cfg.MODEL.MIXED_PRECESION): pred = self.model(volume) loss, _ = self.criterion(pred, target, weight) val_loss += loss.data if hasattr(self, 'monitor'): self.monitor.logger.log_tb.add_scalar( 'Validation_Loss', val_loss, iter_total) self.monitor.visualize(volume, target, pred, weight, iter_total, suffix='Val') if not hasattr(self, 'best_val_loss'): self.best_val_loss = val_loss if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.save_checkpoint(iter_total, is_best=True) # Release some GPU memory and ensure same GPU usage in the consecutive iterations according to # https://discuss.pytorch.org/t/gpu-memory-consumption-increases-while-training/2770 del pred, loss, val_loss # model.train() only called at the beginning of Trainer.train(). self.model.train()
[docs] def test(self): r"""Inference function of the trainer class. """ self.model.eval() if self.cfg.INFERENCE.DO_EVAL else self.model.train() output_scale = self.cfg.INFERENCE.OUTPUT_SCALE spatial_size = list(np.ceil( np.array(self.cfg.MODEL.OUTPUT_SIZE) * np.array(output_scale)).astype(int)) channel_size = self.cfg.MODEL.OUT_PLANES sz = tuple([channel_size] + spatial_size) ww = build_blending_matrix(spatial_size, self.cfg.INFERENCE.BLENDING) output_size = [tuple(np.ceil(np.array(x) * np.array(output_scale)).astype(int)) for x in self.dataloader._dataset.volume_size] result = [np.stack([np.zeros(x, dtype=np.float32) for _ in range(channel_size)]) for x in output_size] weight = [np.zeros(x, dtype=np.float32) for x in output_size] print("Total number of batches: ", len(self.dataloader)) start = time.perf_counter() with torch.no_grad(): for i, sample in enumerate(self.dataloader): print('progress: %d/%d batches, total time %.2fs' % (i+1, len(self.dataloader), time.perf_counter()-start)) pos, volume = sample.pos, sample.out_input volume = volume.to(self.device, non_blocking=True) output = self.augmentor(self.model, volume) if torch.cuda.is_available() and i % 50 == 0: GPUtil.showUtilization(all=True) for idx in range(output.shape[0]): st = pos[idx] st = (np.array(st) * np.array([1]+output_scale)).astype(int).tolist() out_block = output[idx] if result[st[0]].ndim - out_block.ndim == 1: # 2d model out_block = out_block[:, np.newaxis, :] result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], st[3]:st[3]+sz[3]] += out_block * ww[np.newaxis, :] weight[st[0]][st[1]:st[1]+sz[1], st[2]:st[2]+sz[2], st[3]:st[3]+sz[3]] += ww end = time.perf_counter() print("Prediction time: %.2fs" % (end-start)) for vol_id in range(len(result)): if result[vol_id].ndim > weight[vol_id].ndim: weight[vol_id] = np.expand_dims(weight[vol_id], axis=0) result[vol_id] /= weight[vol_id] # in-place to save memory result[vol_id] *= 255 result[vol_id] = result[vol_id].astype(np.uint8) if self.cfg.INFERENCE.UNPAD: pad_size = (np.array(self.cfg.DATASET.PAD_SIZE) * np.array(output_scale)).astype(int).tolist() pad_size = get_padsize(pad_size) result[vol_id] = array_unpad(result[vol_id], pad_size) if self.output_dir is None: return result else: print('Final prediction shapes are:') for k in range(len(result)): print(result[k].shape) writeh5(os.path.join(self.output_dir, self.test_filename), result, ['vol%d' % (x) for x in range(len(result))]) print('Prediction saved as: ', self.test_filename)
def test_singly(self): dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH) img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME) assert len(dir_name) == 1 num_file = len(img_name) start_idx = self.cfg.INFERENCE.DO_SINGLY_START_INDEX for i in range(start_idx, num_file): dataset = get_dataset( self.cfg, self.augmentor, self.mode, self.rank, dir_name_init=dir_name, img_name_init=[img_name[i]]) self.dataloader = build_dataloader( self.cfg, self.augmentor, self.mode, dataset, self.rank) self.dataloader = iter(self.dataloader) digits = int(math.log10(num_file))+1 self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME + \ '_' + str(i).zfill(digits) + '.h5' self.test_filename = self.augmentor.update_name( self.test_filename) self.test() # ----------------------------------------------------------------------------- # Misc functions # ----------------------------------------------------------------------------- def backward_pass(self, loss): if self.cfg.MODEL.MIXED_PRECESION: # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose for corresponding forward ops. self.scaler.scale(loss).backward() # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. self.scaler.step(self.optimizer) # Updates the scale for next iteration. self.scaler.update() else: # standard backward pass loss.backward() self.optimizer.step()
[docs] def save_checkpoint(self, iteration: int, is_best: bool = False): r"""Save the model checkpoint. """ if self.is_main_process: print("Save model checkpoint at iteration ", iteration) state = {'iteration': iteration + 1, # Saving DataParallel or DistributedDataParallel models 'state_dict': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict()} # Saves checkpoint to experiment directory filename = 'checkpoint_%05d.pth.tar' % (iteration + 1) if is_best: filename = 'checkpoint_best.pth.tar' filename = os.path.join(self.output_dir, filename) torch.save(state, filename)
[docs] def update_checkpoint(self, checkpoint: Optional[str] = None): r"""Update the model with the specified checkpoint file path. """ if checkpoint is None: return # load pre-trained model print('Load pretrained checkpoint: ', checkpoint) checkpoint = torch.load(checkpoint) print('checkpoints: ', checkpoint.keys()) # update model weights if 'state_dict' in checkpoint.keys(): pretrained_dict = checkpoint['state_dict'] pretrained_dict = update_state_dict( self.cfg, pretrained_dict, mode=self.mode) model_dict = self.model.module.state_dict() # nn.DataParallel # 1. filter out unnecessary keys by name pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict (if size match) for param_tensor in pretrained_dict: if model_dict[param_tensor].size() == pretrained_dict[param_tensor].size(): model_dict[param_tensor] = pretrained_dict[param_tensor] # 3. load the new state dict self.model.module.load_state_dict(model_dict) # nn.DataParallel if self.mode == 'train' and not self.cfg.SOLVER.ITERATION_RESTART: if hasattr(self, 'optimizer') and 'optimizer' in checkpoint.keys(): self.optimizer.load_state_dict(checkpoint['optimizer']) if hasattr(self, 'lr_scheduler') and 'lr_scheduler' in checkpoint.keys(): self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) if hasattr(self, 'start_iter') and 'iteration' in checkpoint.keys(): self.start_iter = checkpoint['iteration']
def maybe_save_swa_model(self): if not hasattr(self, 'swa_model'): return if self.cfg.MODEL.NORM_MODE in ['bn', 'sync_bn']: # update bn statistics for _ in range(self.cfg.SOLVER.SWA.BN_UPDATE_ITER): sample = next(self.dataloader) volume = sample.out_input volume = volume.to(self.device, non_blocking=True) with autocast(enabled=self.cfg.MODEL.MIXED_PRECESION): pred = self.swa_model(volume) # save swa model if self.is_main_process: print("Save SWA model checkpoint.") state = {'state_dict': self.swa_model.module.state_dict()} filename = 'checkpoint_swa.pth.tar' filename = os.path.join(self.output_dir, filename) torch.save(state, filename) def maybe_update_swa_model(self, iter_total): if not hasattr(self, 'swa_model'): return swa_start = self.cfg.SOLVER.SWA.START_ITER swa_merge = self.cfg.SOLVER.SWA.MERGE_ITER if iter_total >= swa_start and iter_total % swa_merge == 0: self.swa_model.update_parameters(self.model) def scheduler_step(self, iter_total, loss): if hasattr(self, 'swa_scheduler') and iter_total >= self.cfg.SOLVER.SWA.START_ITER: self.swa_scheduler.step() return if self.cfg.SOLVER.LR_SCHEDULER_NAME == 'ReduceLROnPlateau': self.lr_scheduler.step(loss) else: self.lr_scheduler.step() # ----------------------------------------------------------------------------- # Chunk processing for TileDataset # -----------------------------------------------------------------------------
[docs] def run_chunk(self, mode: str): r"""Run chunk-based training and inference for large-scale datasets. """ self.dataset = get_dataset(self.cfg, self.augmentor, mode) if mode == 'train': num_chunk = self.total_iter_nums // self.cfg.DATASET.DATA_CHUNK_ITER self.total_iter_nums = self.cfg.DATASET.DATA_CHUNK_ITER for chunk in range(num_chunk): self.dataset.updatechunk() self.dataloader = build_dataloader(self.cfg, self.augmentor, mode, dataset=self.dataset.dataset) self.dataloader = iter(self.dataloader) print('start train for chunk %d' % chunk) self.train() print('finished train for chunk %d' % chunk) self.start_iter += self.cfg.DATASET.DATA_CHUNK_ITER del self.dataloader return # inference mode num_chunk = len(self.dataset.chunk_ind) print("Total number of chunks: ", num_chunk) for chunk in range(num_chunk): self.dataset.updatechunk(do_load=False) self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME + \ '_' + self.dataset.get_coord_name() + '.h5' self.test_filename = self.augmentor.update_name( self.test_filename) if not os.path.exists(os.path.join(self.output_dir, self.test_filename)): self.dataset.loadchunk() self.dataloader = build_dataloader(self.cfg, self.augmentor, mode, dataset=self.dataset.dataset) self.dataloader = iter(self.dataloader) self.test()

© Copyright 2019-2021, Zudi Lin and Donglai Wei. Revision 4bb1d5dc.

Built with Sphinx using a theme provided by Read the Docs.