|
- # Copyright (c) Open-MMLab. All rights reserved.
- import io
- import os
- import os.path as osp
- import pkgutil
- import time
- import warnings
- from collections import OrderedDict
- from importlib import import_module
- from tempfile import TemporaryDirectory
-
- import torch
- import torchvision
- from torch.optim import Optimizer
- from torch.utils import model_zoo
- from torch.nn import functional as F
-
- import mmcv
- from mmcv.fileio import FileClient
- from mmcv.fileio import load as load_file
- from mmcv.parallel import is_module_wrapper
- from mmcv.utils import mkdir_or_exist
- from mmcv.runner import get_dist_info
-
- ENV_MMCV_HOME = 'MMCV_HOME'
- ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
- DEFAULT_CACHE_DIR = '~/.cache'
-
-
- def _get_mmcv_home():
- mmcv_home = os.path.expanduser(
- os.getenv(
- ENV_MMCV_HOME,
- os.path.join(
- os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
-
- mkdir_or_exist(mmcv_home)
- return mmcv_home
-
-
- def load_state_dict(module, state_dict, strict=False, logger=None):
- """Load state_dict to a module.
-
- This method is modified from :meth:`torch.nn.Module.load_state_dict`.
- Default value for ``strict`` is set to ``False`` and the message for
- param mismatch will be shown even if strict is False.
-
- Args:
- module (Module): Module that receives the state_dict.
- state_dict (OrderedDict): Weights.
- strict (bool): whether to strictly enforce that the keys
- in :attr:`state_dict` match the keys returned by this module's
- :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
- logger (:obj:`logging.Logger`, optional): Logger to log the error
- message. If not specified, print function will be used.
- """
- unexpected_keys = []
- all_missing_keys = []
- err_msg = []
-
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
-
- # use _load_from_state_dict to enable checkpoint version control
- def load(module, prefix=''):
- # recursively check parallel module in case that the model has a
- # complicated structure, e.g., nn.Module(nn.Module(DDP))
- if is_module_wrapper(module):
- module = module.module
- local_metadata = {} if metadata is None else metadata.get(
- prefix[:-1], {})
- module._load_from_state_dict(state_dict, prefix, local_metadata, True,
- all_missing_keys, unexpected_keys,
- err_msg)
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + '.')
-
- load(module)
- load = None # break load->load reference cycle
-
- # ignore "num_batches_tracked" of BN layers
- missing_keys = [
- key for key in all_missing_keys if 'num_batches_tracked' not in key
- ]
-
- if unexpected_keys:
- err_msg.append('unexpected key in source '
- f'state_dict: {", ".join(unexpected_keys)}\n')
- if missing_keys:
- err_msg.append(
- f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
-
- rank, _ = get_dist_info()
- if len(err_msg) > 0 and rank == 0:
- err_msg.insert(
- 0, 'The model and loaded state dict do not match exactly\n')
- err_msg = '\n'.join(err_msg)
- if strict:
- raise RuntimeError(err_msg)
- elif logger is not None:
- logger.warning(err_msg)
- else:
- print(err_msg)
- print("finish load")
-
-
- def load_url_dist(url, model_dir=None):
- """In distributed setting, this function only download checkpoint at local
- rank 0."""
- rank, world_size = get_dist_info()
- rank = int(os.environ.get('LOCAL_RANK', rank))
- if rank == 0:
- checkpoint = model_zoo.load_url(url, model_dir=model_dir)
- if world_size > 1:
- torch.distributed.barrier()
- if rank > 0:
- checkpoint = model_zoo.load_url(url, model_dir=model_dir)
- return checkpoint
-
-
- def load_pavimodel_dist(model_path, map_location=None):
- """In distributed setting, this function only download checkpoint at local
- rank 0."""
- try:
- from pavi import modelcloud
- except ImportError:
- raise ImportError(
- 'Please install pavi to load checkpoint from modelcloud.')
- rank, world_size = get_dist_info()
- rank = int(os.environ.get('LOCAL_RANK', rank))
- if rank == 0:
- model = modelcloud.get(model_path)
- with TemporaryDirectory() as tmp_dir:
- downloaded_file = osp.join(tmp_dir, model.name)
- model.download(downloaded_file)
- checkpoint = torch.load(downloaded_file, map_location=map_location)
- if world_size > 1:
- torch.distributed.barrier()
- if rank > 0:
- model = modelcloud.get(model_path)
- with TemporaryDirectory() as tmp_dir:
- downloaded_file = osp.join(tmp_dir, model.name)
- model.download(downloaded_file)
- checkpoint = torch.load(
- downloaded_file, map_location=map_location)
- return checkpoint
-
-
- def load_fileclient_dist(filename, backend, map_location):
- """In distributed setting, this function only download checkpoint at local
- rank 0."""
- rank, world_size = get_dist_info()
- rank = int(os.environ.get('LOCAL_RANK', rank))
- allowed_backends = ['ceph']
- if backend not in allowed_backends:
- raise ValueError(f'Load from Backend {backend} is not supported.')
- if rank == 0:
- fileclient = FileClient(backend=backend)
- buffer = io.BytesIO(fileclient.get(filename))
- checkpoint = torch.load(buffer, map_location=map_location)
- if world_size > 1:
- torch.distributed.barrier()
- if rank > 0:
- fileclient = FileClient(backend=backend)
- buffer = io.BytesIO(fileclient.get(filename))
- checkpoint = torch.load(buffer, map_location=map_location)
- return checkpoint
-
-
- def get_torchvision_models():
- model_urls = dict()
- for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
- if ispkg:
- continue
- _zoo = import_module(f'torchvision.models.{name}')
- if hasattr(_zoo, 'model_urls'):
- _urls = getattr(_zoo, 'model_urls')
- model_urls.update(_urls)
- return model_urls
-
-
- def get_external_models():
- mmcv_home = _get_mmcv_home()
- default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
- default_urls = load_file(default_json_path)
- assert isinstance(default_urls, dict)
- external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
- if osp.exists(external_json_path):
- external_urls = load_file(external_json_path)
- assert isinstance(external_urls, dict)
- default_urls.update(external_urls)
-
- return default_urls
-
-
- def get_mmcls_models():
- mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
- mmcls_urls = load_file(mmcls_json_path)
-
- return mmcls_urls
-
-
- def get_deprecated_model_names():
- deprecate_json_path = osp.join(mmcv.__path__[0],
- 'model_zoo/deprecated.json')
- deprecate_urls = load_file(deprecate_json_path)
- assert isinstance(deprecate_urls, dict)
-
- return deprecate_urls
-
-
- def _process_mmcls_checkpoint(checkpoint):
- state_dict = checkpoint['state_dict']
- new_state_dict = OrderedDict()
- for k, v in state_dict.items():
- if k.startswith('backbone.'):
- new_state_dict[k[9:]] = v
- new_checkpoint = dict(state_dict=new_state_dict)
-
- return new_checkpoint
-
-
- def _load_checkpoint(filename, map_location=None):
- """Load checkpoint from somewhere (modelzoo, file, url).
-
- Args:
- filename (str): Accept local filepath, URL, ``torchvision://xxx``,
- ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
- details.
- map_location (str | None): Same as :func:`torch.load`. Default: None.
-
- Returns:
- dict | OrderedDict: The loaded checkpoint. It can be either an
- OrderedDict storing model weights or a dict containing other
- information, which depends on the checkpoint.
- """
- if filename.startswith('modelzoo://'):
- warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
- 'use "torchvision://" instead')
- model_urls = get_torchvision_models()
- model_name = filename[11:]
- checkpoint = load_url_dist(model_urls[model_name])
- elif filename.startswith('torchvision://'):
- model_urls = get_torchvision_models()
- model_name = filename[14:]
- checkpoint = load_url_dist(model_urls[model_name])
- elif filename.startswith('open-mmlab://'):
- model_urls = get_external_models()
- model_name = filename[13:]
- deprecated_urls = get_deprecated_model_names()
- if model_name in deprecated_urls:
- warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
- f'of open-mmlab://{deprecated_urls[model_name]}')
- model_name = deprecated_urls[model_name]
- model_url = model_urls[model_name]
- # check if is url
- if model_url.startswith(('http://', 'https://')):
- checkpoint = load_url_dist(model_url)
- else:
- filename = osp.join(_get_mmcv_home(), model_url)
- if not osp.isfile(filename):
- raise IOError(f'{filename} is not a checkpoint file')
- checkpoint = torch.load(filename, map_location=map_location)
- elif filename.startswith('mmcls://'):
- model_urls = get_mmcls_models()
- model_name = filename[8:]
- checkpoint = load_url_dist(model_urls[model_name])
- checkpoint = _process_mmcls_checkpoint(checkpoint)
- elif filename.startswith(('http://', 'https://')):
- checkpoint = load_url_dist(filename)
- elif filename.startswith('pavi://'):
- model_path = filename[7:]
- checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
- elif filename.startswith('s3://'):
- checkpoint = load_fileclient_dist(
- filename, backend='ceph', map_location=map_location)
- else:
- if not osp.isfile(filename):
- raise IOError(f'{filename} is not a checkpoint file')
- checkpoint = torch.load(filename, map_location=map_location)
- return checkpoint
-
-
- def load_checkpoint(model,
- filename,
- map_location='cpu',
- strict=False,
- logger=None,
- load_ema=True):
- """Load checkpoint from a file or URI.
-
- Args:
- model (Module): Module to load checkpoint.
- filename (str): Accept local filepath, URL, ``torchvision://xxx``,
- ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
- details.
- map_location (str): Same as :func:`torch.load`.
- strict (bool): Whether to allow different params for the model and
- checkpoint.
- logger (:mod:`logging.Logger` or None): The logger for error message.
-
- Returns:
- dict or OrderedDict: The loaded checkpoint.
- """
- checkpoint = _load_checkpoint(filename, map_location)
- # OrderedDict is a subclass of dict
- if not isinstance(checkpoint, dict):
- raise RuntimeError(
- f'No state_dict found in checkpoint file {filename}')
- # get state_dict from checkpoint
- if load_ema and 'state_dict_ema' in checkpoint:
- state_dict = checkpoint['state_dict_ema']
- logger.info(f'loading from state_dict_ema')
- elif 'state_dict' in checkpoint:
- state_dict = checkpoint['state_dict']
- logger.info(f'loading from state_dict')
- elif 'model' in checkpoint:
- state_dict = checkpoint['model']
- logger.info(f'loading from model')
- print(f'loading from model')
- else:
- state_dict = checkpoint
- # strip prefix of state_dict
- if list(state_dict.keys())[0].startswith('module.'):
- state_dict = {k[7:]: v for k, v in state_dict.items()}
-
- # for MoBY, load model of online branch
- if sorted(list(state_dict.keys()))[0].startswith('encoder'):
- state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
-
- # reshape absolute position embedding
- if state_dict.get('absolute_pos_embed') is not None:
- absolute_pos_embed = state_dict['absolute_pos_embed']
- N1, L, C1 = absolute_pos_embed.size()
- N2, C2, H, W = model.absolute_pos_embed.size()
- if N1 != N2 or C1 != C2 or L != H*W:
- logger.warning("Error in loading absolute_pos_embed, pass")
- else:
- state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
-
- all_keys = list(state_dict.keys())
- for key in all_keys:
- if "relative_position_index" in key:
- state_dict.pop(key)
-
- if "relative_position_bias_table" in key:
- state_dict.pop(key)
-
- if '.q_bias' in key:
- q_bias = state_dict[key]
- v_bias = state_dict[key.replace('q_bias', 'v_bias')]
- qkv_bias = torch.cat([q_bias, torch.zeros_like(q_bias), v_bias], 0)
- state_dict[key.replace('q_bias', 'qkv.bias')] = qkv_bias
-
- if '.v.bias' in key:
- continue
-
- all_keys = list(state_dict.keys())
- new_state_dict = {}
- for key in all_keys:
- if "qkv.bias" in key:
- value = state_dict[key]
- dim = value.shape[0]
- selected_dim = (dim * 2) // 3
- new_state_dict[key.replace("qkv.bias", "pos_bias")] = state_dict[key][:selected_dim]
-
-
- # interpolate position bias table if needed
- relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
- for table_key in relative_position_bias_table_keys:
- table_pretrained = state_dict[table_key]
- if table_key not in model.state_dict().keys():
- logger.warning("relative_position_bias_table exits in pretrained model but not in current one, pass")
- continue
- table_current = model.state_dict()[table_key]
- L1, nH1 = table_pretrained.size()
- L2, nH2 = table_current.size()
- if nH1 != nH2:
- logger.warning(f"Error in loading {table_key}, pass")
- else:
- if L1 != L2:
- S1 = int(L1 ** 0.5)
- S2 = int(L2 ** 0.5)
- table_pretrained_resized = F.interpolate(
- table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
- size=(S2, S2), mode='bicubic')
- state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
- rank, _ = get_dist_info()
- if 'pos_embed' in state_dict:
- pos_embed_checkpoint = state_dict['pos_embed']
- embedding_size = pos_embed_checkpoint.shape[-1]
- H, W = model.patch_embed.patch_shape
- num_patches = model.patch_embed.num_patches
- num_extra_tokens = 1
- # height (== width) for the checkpoint position embedding
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
- # height (== width) for the new position embedding
- new_size = int(num_patches ** 0.5)
- # class_token and dist_token are kept unchanged
- if orig_size != new_size:
- if rank == 0:
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, H, W))
- # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
- # only the position tokens are interpolated
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
- pos_tokens = torch.nn.functional.interpolate(
- pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
- new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
- # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
- state_dict['pos_embed'] = new_pos_embed
-
- # load state_dict
- load_state_dict(model, state_dict, strict, logger)
- return checkpoint
-
-
- def weights_to_cpu(state_dict):
- """Copy a model state_dict to cpu.
-
- Args:
- state_dict (OrderedDict): Model weights on GPU.
-
- Returns:
- OrderedDict: Model weights on GPU.
- """
- state_dict_cpu = OrderedDict()
- for key, val in state_dict.items():
- state_dict_cpu[key] = val.cpu()
- return state_dict_cpu
-
-
- def _save_to_state_dict(module, destination, prefix, keep_vars):
- """Saves module state to `destination` dictionary.
-
- This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
-
- Args:
- module (nn.Module): The module to generate state_dict.
- destination (dict): A dict where state will be stored.
- prefix (str): The prefix for parameters and buffers used in this
- module.
- """
- for name, param in module._parameters.items():
- if param is not None:
- destination[prefix + name] = param if keep_vars else param.detach()
- for name, buf in module._buffers.items():
- # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
- if buf is not None:
- destination[prefix + name] = buf if keep_vars else buf.detach()
-
-
- def get_state_dict(module, destination=None, prefix='', keep_vars=False):
- """Returns a dictionary containing a whole state of the module.
-
- Both parameters and persistent buffers (e.g. running averages) are
- included. Keys are corresponding parameter and buffer names.
-
- This method is modified from :meth:`torch.nn.Module.state_dict` to
- recursively check parallel module in case that the model has a complicated
- structure, e.g., nn.Module(nn.Module(DDP)).
-
- Args:
- module (nn.Module): The module to generate state_dict.
- destination (OrderedDict): Returned dict for the state of the
- module.
- prefix (str): Prefix of the key.
- keep_vars (bool): Whether to keep the variable property of the
- parameters. Default: False.
-
- Returns:
- dict: A dictionary containing a whole state of the module.
- """
- # recursively check parallel module in case that the model has a
- # complicated structure, e.g., nn.Module(nn.Module(DDP))
- if is_module_wrapper(module):
- module = module.module
-
- # below is the same as torch.nn.Module.state_dict()
- if destination is None:
- destination = OrderedDict()
- destination._metadata = OrderedDict()
- destination._metadata[prefix[:-1]] = local_metadata = dict(
- version=module._version)
- _save_to_state_dict(module, destination, prefix, keep_vars)
- for name, child in module._modules.items():
- if child is not None:
- get_state_dict(
- child, destination, prefix + name + '.', keep_vars=keep_vars)
- for hook in module._state_dict_hooks.values():
- hook_result = hook(module, destination, prefix, local_metadata)
- if hook_result is not None:
- destination = hook_result
- return destination
-
-
- def save_checkpoint(model, filename, optimizer=None, meta=None):
- """Save checkpoint to file.
-
- The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
- ``optimizer``. By default ``meta`` will contain version and time info.
-
- Args:
- model (Module): Module whose params are to be saved.
- filename (str): Checkpoint filename.
- optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
- meta (dict, optional): Metadata to be saved in checkpoint.
- """
- if meta is None:
- meta = {}
- elif not isinstance(meta, dict):
- raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
- meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
-
- if is_module_wrapper(model):
- model = model.module
-
- if hasattr(model, 'CLASSES') and model.CLASSES is not None:
- # save class name to the meta
- meta.update(CLASSES=model.CLASSES)
-
- checkpoint = {
- 'meta': meta,
- 'state_dict': weights_to_cpu(get_state_dict(model))
- }
- # save optimizer state dict in the checkpoint
- if isinstance(optimizer, Optimizer):
- checkpoint['optimizer'] = optimizer.state_dict()
- elif isinstance(optimizer, dict):
- checkpoint['optimizer'] = {}
- for name, optim in optimizer.items():
- checkpoint['optimizer'][name] = optim.state_dict()
-
- if filename.startswith('pavi://'):
- try:
- from pavi import modelcloud
- from pavi.exception import NodeNotFoundError
- except ImportError:
- raise ImportError(
- 'Please install pavi to load checkpoint from modelcloud.')
- model_path = filename[7:]
- root = modelcloud.Folder()
- model_dir, model_name = osp.split(model_path)
- try:
- model = modelcloud.get(model_dir)
- except NodeNotFoundError:
- model = root.create_training_model(model_dir)
- with TemporaryDirectory() as tmp_dir:
- checkpoint_file = osp.join(tmp_dir, model_name)
- with open(checkpoint_file, 'wb') as f:
- torch.save(checkpoint, f)
- f.flush()
- model.create_file(checkpoint_file, name=model_name)
- else:
- mmcv.mkdir_or_exist(osp.dirname(filename))
- # immediately flush buffer
- with open(filename, 'wb') as f:
- torch.save(checkpoint, f)
- f.flush()
|