#672 [Document and Code] Fix topk bugs, add hub and model_zoo, modify supportedlist

Merged
zoulq merged 14 commits from ziqi/MSAdapter:master into master 8 months ago
  1. +17
    -0
      SupportedList.md
  2. +16
    -0
      SupportedList_en.md
  3. +1
    -0
      msadapter/pytorch/__init__.py
  4. +4
    -1
      msadapter/pytorch/functional.py
  5. +164
    -1
      msadapter/pytorch/hub.py
  6. +4
    -1
      msadapter/pytorch/tensor.py
  7. +3
    -0
      msadapter/pytorch/utils/model_zoo.py
  8. +24
    -0
      testing/ut/pytorch/torch/test_hub.py

+ 17
- 0
SupportedList.md View File

@@ -10,6 +10,8 @@
- [torch.optim](#jump7)
- [torch.utils.data](#jump9)
- [torch.distributed](#jump10)
- [torch.utils.model\_zoo](#jump11)
- [torch.hub](#jump12)


### <span id="jump8">通用限制</span>
@@ -1231,6 +1233,7 @@
| lr_scheduler.CyclicLR | 支持 | |
| lr_scheduler.OneCycleLR | 支持 | |
| lr_scheduler.CosineAnnealingWarmRestarts | 支持 | |

### <span id="jump9">torch.utils.data</span>
| MSAdapter接口 | 状态 | 约束 |
| --------------- | ---- |------------------------------|
@@ -1271,3 +1274,17 @@
| is_nccl_available | 支持 | |
| get_backend | 支持 | |
| get_process_group_ranks | 支持 | |

### <span id="jump11">torch.utils.model_zoo</span>
| MSAdapter接口 | 状态 | 约束 |
| --------------- | ---- |------------------------------|
| load_url | 支持 | |


### <span id="jump12">torch.hub</span>
| MSAdapter接口 | 状态 | 约束 |
| --------------- | ---- |------------------------------|
| download_url_to_file | 支持 | |
| get_dir | 支持 | |
| load_state_dict_from_url | 支持 | |
| set_dir | 支持 | |

+ 16
- 0
SupportedList_en.md View File

@@ -9,6 +9,8 @@ English | [简体中文](SupportedList.md)
- [torch.optim](#jump7)
- [torch.utils.data](#jump9)
- [torch.distributed](#jump10)
- [torch.utils.model\_zoo](#jump11)
- [torch.hub](#jump12)

### <span id="jump8">General Constraint</span>
- Not support the function of configuration `layout`, `device`, `requires_grad`, `memory_format`.
@@ -1273,3 +1275,17 @@ English | [简体中文](SupportedList.md)
| is_nccl_available | Supported | |
| get_backend | Supported | |
| get_process_group_ranks | Supported | |

### <span id="jump11">torch.utils.model_zoo</span>
| MSAdapter APIs | Status | Restrictions |
| --------------- | ---- |------------------------------|
| load_url | supported | |


### <span id="jump12">torch.hub</span>
| MSAdapter APIs | Status | Restrictions |
| --------------- | ---- |------------------------------|
| download_url_to_file | supported | |
| get_dir | supported | |
| load_state_dict_from_url | supported | |
| set_dir | supported | |

+ 1
- 0
msadapter/pytorch/__init__.py View File

@@ -22,6 +22,7 @@ from msadapter.pytorch.serialization import *
import msadapter.pytorch.linalg as linalg
from msadapter.pytorch.common.dtype import ms_dtype as dtype
import msadapter.pytorch.amp as amp
from msadapter.pytorch import hub

def _assert(condition, message):
assert condition, message


+ 4
- 1
msadapter/pytorch/functional.py View File

@@ -2090,7 +2090,10 @@ def scatter(input, dim, index, src):

def topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
input_x = cast_to_ms_tensor(input)
output = ms.ops.topk(input_x, k, dim, largest, sorted)
if k == 0:
output = (ms.ops.zeros([], dtype=input.dtype), ms.ops.zeros([], dtype=ms.int32))
else:
output = ms.ops.topk(input_x, k, dim, largest, sorted)
return _out_inplace_assign(out, output, "topk")

def addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):


+ 164
- 1
msadapter/pytorch/hub.py View File

@@ -1,9 +1,18 @@
# Part of the code was borrowed from https://github.com/pytorch/pytorch/blob/v1.12.1/torch/hub.py

import errno
import hashlib
import os
import re
import shutil
import sys
import tempfile
import warnings
import zipfile
from typing import Callable, Dict, Optional, Union, Any
from urllib.request import urlopen, Request
from urllib.parse import urlparse
from .serialization import load as torch_load # noqa: F401

try:
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
@@ -32,7 +41,8 @@ except ImportError:
if self.total is None:
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
else:
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
sys.stderr.write("\r{0:.1f}%".format(
100 * self.n / float(self.total)))
sys.stderr.flush()

def close(self):
@@ -47,6 +57,64 @@ except ImportError:

sys.stderr.write('\n')

__all__ = [
'download_url_to_file',
'get_dir',
'load_state_dict_from_url',
'set_dir',
]

# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')


ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
READ_DATA_CHUNK = 8192
_hub_dir = None

def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
DEFAULT_CACHE_DIR), 'torch')))
return torch_home


def get_dir():
r"""
Get the Torch Hub cache directory used for storing downloaded models & weights.

If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
filesystem layout, with a default value ``~/.cache`` if the environment
variable is not set.
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_HUB'):
warnings.warn(
'TORCH_HUB is deprecated, please use env TORCH_HOME instead')

if _hub_dir is not None:
return _hub_dir
return os.path.join(_get_torch_home(), 'hub')


def set_dir(d):
r"""
Optionally set the Torch Hub directory used to save downloaded models & weights.

Args:
d (string): path to a local folder to save downloaded models & weights.
"""
global _hub_dir
_hub_dir = os.path.expanduser(d)


def download_url_to_file(url, dst, hash_prefix=None, progress=True):
r"""Download object at the given URL to a local path.
@@ -58,6 +126,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
Default: None
progress (bool, optional): whether or not to display a progress bar to stderr
Default: True

"""
file_size = None
req = Request(url, headers={"User-Agent": "torch.hub"})
@@ -102,3 +171,97 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
f.close()
if os.path.exists(f.name):
os.remove(f.name)


def _is_legacy_zip_format(filename):
if zipfile.is_zipfile(filename):
zoulq commented 8 months ago
Review
这些注释根据实际含义确定是否保留
ziqi commented 8 months ago
Review
已删除
infolist = zipfile.ZipFile(filename).infolist()
return len(infolist) == 1 and not infolist[0].is_dir()
return False


def _legacy_zip_load(filename, model_dir, map_location):
warnings.warn('Falling back to the old format < 1.6. This support will be '
'deprecated in favor of default zipfile format introduced in 1.6. '
'Please redo torch.save() to save it in the new zipfile format.')
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
with zipfile.ZipFile(filename) as f:
members = f.infolist()
if len(members) != 1:
raise RuntimeError(
'Only one file(not dir) is allowed in the zipfile')
f.extractall(model_dir)
extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name)
return torch_load(extracted_file, map_location=map_location)


def load_state_dict_from_url(
url: str,
model_dir: Optional[str] = None,
map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None,
progress: bool = True,
check_hash: bool = False,
file_name: Optional[str] = None
) -> Dict[str, Any]:
r"""Loads the Torch serialized object at the given URL.

If downloaded file is a zip file, it will be automatically
decompressed.

If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.

Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch_load)
progress (bool, optional): whether or not to display a progress bar to stderr.
Default: True
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
Default: False
file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set.

"""
zoulq commented 8 months ago
Review
同上
ziqi commented 8 months ago
Review
属于对入参的解释,建议保留
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn(
'TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')

try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise

parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)

if _is_legacy_zip_format(cached_file):
return _legacy_zip_load(cached_file, model_dir, map_location)
return torch_load(cached_file, map_location=map_location)

+ 4
- 1
msadapter/pytorch/tensor.py View File

@@ -2056,7 +2056,10 @@ class Tensor(StubTensor, metaclass=_TensorMeta):

def topk(self, k, dim=None, largest=True, sorted=True):
input = cast_to_ms_tensor(self)
output = ms.ops.topk(input, k, dim, largest, sorted)
if k == 0:
output = (ms.ops.zeros([], dtype=self.dtype), ms.ops.zeros([], dtype=ms.int32))
else:
output = ms.ops.topk(input, k, dim, largest, sorted)
return cast_to_adapter_tensor(output)

def maximum(self, other):


+ 3
- 0
msadapter/pytorch/utils/model_zoo.py View File

@@ -0,0 +1,3 @@
from msadapter.pytorch.hub import tqdm, load_state_dict_from_url as load_url # noqa: F401
zoulq commented 8 months ago
Review
hub新增接口要看是否增加基本用例。至少能够看护接口的正确性
ziqi commented 8 months ago
Review
已新增 test_hub.py 用例

__all__ = ['tqdm', 'load_url']

+ 24
- 0
testing/ut/pytorch/torch/test_hub.py View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import mindspore as ms
import msadapter.pytorch as ms_torch
import torch
import numpy as np
from mindspore import context

from ...utils import set_mode_by_env_config
set_mode_by_env_config()

def test_get_dir():
ms_hub_dir = ms_torch.hub.get_dir()
torch_hub_dir = torch.hub.get_dir()
assert ms_hub_dir == torch_hub_dir

def test_load_state_dict_from_url():
target_url = 'https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth'
ms_torch_file_name = 'ms_torch_mobilenet.pth'
ms_torch_state_dict = ms_torch.hub.load_state_dict_from_url(target_url, model_dir='.', file_name=ms_torch_file_name)
assert len(ms_torch_state_dict) == 244
os.remove(os.path.join('.', ms_torch_file_name))

Loading…
Cancel
Save