zoulq commented 8 months ago
Review
同上
ziqi/MSAdapter:master
into master
8 months ago
@@ -10,6 +10,8 @@ | |||
- [torch.optim](#jump7) | |||
- [torch.utils.data](#jump9) | |||
- [torch.distributed](#jump10) | |||
- [torch.utils.model\_zoo](#jump11) | |||
- [torch.hub](#jump12) | |||
### <span id="jump8">通用限制</span> | |||
@@ -1231,6 +1233,7 @@ | |||
| lr_scheduler.CyclicLR | 支持 | | | |||
| lr_scheduler.OneCycleLR | 支持 | | | |||
| lr_scheduler.CosineAnnealingWarmRestarts | 支持 | | | |||
### <span id="jump9">torch.utils.data</span> | |||
| MSAdapter接口 | 状态 | 约束 | | |||
| --------------- | ---- |------------------------------| | |||
@@ -1271,3 +1274,17 @@ | |||
| is_nccl_available | 支持 | | | |||
| get_backend | 支持 | | | |||
| get_process_group_ranks | 支持 | | | |||
### <span id="jump11">torch.utils.model_zoo</span> | |||
| MSAdapter接口 | 状态 | 约束 | | |||
| --------------- | ---- |------------------------------| | |||
| load_url | 支持 | | | |||
### <span id="jump12">torch.hub</span> | |||
| MSAdapter接口 | 状态 | 约束 | | |||
| --------------- | ---- |------------------------------| | |||
| download_url_to_file | 支持 | | | |||
| get_dir | 支持 | | | |||
| load_state_dict_from_url | 支持 | | | |||
| set_dir | 支持 | | |
@@ -9,6 +9,8 @@ English | [简体中文](SupportedList.md) | |||
- [torch.optim](#jump7) | |||
- [torch.utils.data](#jump9) | |||
- [torch.distributed](#jump10) | |||
- [torch.utils.model\_zoo](#jump11) | |||
- [torch.hub](#jump12) | |||
### <span id="jump8">General Constraint</span> | |||
- Not support the function of configuration `layout`, `device`, `requires_grad`, `memory_format`. | |||
@@ -1273,3 +1275,17 @@ English | [简体中文](SupportedList.md) | |||
| is_nccl_available | Supported | | | |||
| get_backend | Supported | | | |||
| get_process_group_ranks | Supported | | | |||
### <span id="jump11">torch.utils.model_zoo</span> | |||
| MSAdapter APIs | Status | Restrictions | | |||
| --------------- | ---- |------------------------------| | |||
| load_url | supported | | | |||
### <span id="jump12">torch.hub</span> | |||
| MSAdapter APIs | Status | Restrictions | | |||
| --------------- | ---- |------------------------------| | |||
| download_url_to_file | supported | | | |||
| get_dir | supported | | | |||
| load_state_dict_from_url | supported | | | |||
| set_dir | supported | | |
@@ -22,6 +22,7 @@ from msadapter.pytorch.serialization import * | |||
import msadapter.pytorch.linalg as linalg | |||
from msadapter.pytorch.common.dtype import ms_dtype as dtype | |||
import msadapter.pytorch.amp as amp | |||
from msadapter.pytorch import hub | |||
def _assert(condition, message): | |||
assert condition, message | |||
@@ -2090,7 +2090,10 @@ def scatter(input, dim, index, src): | |||
def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): | |||
input_x = cast_to_ms_tensor(input) | |||
output = ms.ops.topk(input_x, k, dim, largest, sorted) | |||
if k == 0: | |||
output = (ms.ops.zeros([], dtype=input.dtype), ms.ops.zeros([], dtype=ms.int32)) | |||
else: | |||
output = ms.ops.topk(input_x, k, dim, largest, sorted) | |||
return _out_inplace_assign(out, output, "topk") | |||
def addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None): | |||
@@ -1,9 +1,18 @@ | |||
# Part of the code was borrowed from https://github.com/pytorch/pytorch/blob/v1.12.1/torch/hub.py | |||
import errno | |||
import hashlib | |||
import os | |||
import re | |||
import shutil | |||
import sys | |||
import tempfile | |||
import warnings | |||
import zipfile | |||
from typing import Callable, Dict, Optional, Union, Any | |||
from urllib.request import urlopen, Request | |||
from urllib.parse import urlparse | |||
from .serialization import load as torch_load # noqa: F401 | |||
try: | |||
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available | |||
@@ -32,7 +41,8 @@ except ImportError: | |||
if self.total is None: | |||
sys.stderr.write("\r{0:.1f} bytes".format(self.n)) | |||
else: | |||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||
sys.stderr.write("\r{0:.1f}%".format( | |||
100 * self.n / float(self.total))) | |||
sys.stderr.flush() | |||
def close(self): | |||
@@ -47,6 +57,64 @@ except ImportError: | |||
sys.stderr.write('\n') | |||
__all__ = [ | |||
'download_url_to_file', | |||
'get_dir', | |||
'load_state_dict_from_url', | |||
'set_dir', | |||
] | |||
# matches bfd8deac from resnet18-bfd8deac.pth | |||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN' | |||
ENV_TORCH_HOME = 'TORCH_HOME' | |||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' | |||
DEFAULT_CACHE_DIR = '~/.cache' | |||
VAR_DEPENDENCY = 'dependencies' | |||
MODULE_HUBCONF = 'hubconf.py' | |||
READ_DATA_CHUNK = 8192 | |||
_hub_dir = None | |||
def _get_torch_home(): | |||
torch_home = os.path.expanduser( | |||
os.getenv(ENV_TORCH_HOME, | |||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, | |||
DEFAULT_CACHE_DIR), 'torch'))) | |||
return torch_home | |||
def get_dir(): | |||
r""" | |||
Get the Torch Hub cache directory used for storing downloaded models & weights. | |||
If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where | |||
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. | |||
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux | |||
filesystem layout, with a default value ``~/.cache`` if the environment | |||
variable is not set. | |||
""" | |||
# Issue warning to move data if old env is set | |||
if os.getenv('TORCH_HUB'): | |||
warnings.warn( | |||
'TORCH_HUB is deprecated, please use env TORCH_HOME instead') | |||
if _hub_dir is not None: | |||
return _hub_dir | |||
return os.path.join(_get_torch_home(), 'hub') | |||
def set_dir(d): | |||
r""" | |||
Optionally set the Torch Hub directory used to save downloaded models & weights. | |||
Args: | |||
d (string): path to a local folder to save downloaded models & weights. | |||
""" | |||
global _hub_dir | |||
_hub_dir = os.path.expanduser(d) | |||
def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
r"""Download object at the given URL to a local path. | |||
@@ -58,6 +126,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
Default: None | |||
progress (bool, optional): whether or not to display a progress bar to stderr | |||
Default: True | |||
""" | |||
file_size = None | |||
req = Request(url, headers={"User-Agent": "torch.hub"}) | |||
@@ -102,3 +171,97 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |||
f.close() | |||
if os.path.exists(f.name): | |||
os.remove(f.name) | |||
def _is_legacy_zip_format(filename): | |||
if zipfile.is_zipfile(filename): | |||
|
|||
infolist = zipfile.ZipFile(filename).infolist() | |||
return len(infolist) == 1 and not infolist[0].is_dir() | |||
return False | |||
def _legacy_zip_load(filename, model_dir, map_location): | |||
warnings.warn('Falling back to the old format < 1.6. This support will be ' | |||
'deprecated in favor of default zipfile format introduced in 1.6. ' | |||
'Please redo torch.save() to save it in the new zipfile format.') | |||
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. | |||
# We deliberately don't handle tarfile here since our legacy serialization format was in tar. | |||
# E.g. resnet18-5c106cde.pth which is widely used. | |||
with zipfile.ZipFile(filename) as f: | |||
members = f.infolist() | |||
if len(members) != 1: | |||
raise RuntimeError( | |||
'Only one file(not dir) is allowed in the zipfile') | |||
f.extractall(model_dir) | |||
extraced_name = members[0].filename | |||
extracted_file = os.path.join(model_dir, extraced_name) | |||
return torch_load(extracted_file, map_location=map_location) | |||
def load_state_dict_from_url( | |||
url: str, | |||
model_dir: Optional[str] = None, | |||
map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None, | |||
progress: bool = True, | |||
check_hash: bool = False, | |||
file_name: Optional[str] = None | |||
) -> Dict[str, Any]: | |||
r"""Loads the Torch serialized object at the given URL. | |||
If downloaded file is a zip file, it will be automatically | |||
decompressed. | |||
If the object is already present in `model_dir`, it's deserialized and | |||
returned. | |||
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where | |||
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. | |||
Args: | |||
url (string): URL of the object to download | |||
model_dir (string, optional): directory in which to save the object | |||
map_location (optional): a function or a dict specifying how to remap storage locations (see torch_load) | |||
progress (bool, optional): whether or not to display a progress bar to stderr. | |||
Default: True | |||
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention | |||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||
digits of the SHA256 hash of the contents of the file. The hash is used to | |||
ensure unique names and to verify the contents of the file. | |||
Default: False | |||
file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set. | |||
""" | |||
zoulq commented 8 months ago
Review
同上 同上
ziqi commented 8 months ago
Review
属于对入参的解释,建议保留 属于对入参的解释,建议保留
|
|||
# Issue warning to move data if old env is set | |||
if os.getenv('TORCH_MODEL_ZOO'): | |||
warnings.warn( | |||
'TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') | |||
if model_dir is None: | |||
hub_dir = get_dir() | |||
model_dir = os.path.join(hub_dir, 'checkpoints') | |||
try: | |||
os.makedirs(model_dir) | |||
except OSError as e: | |||
if e.errno == errno.EEXIST: | |||
# Directory already exists, ignore. | |||
pass | |||
else: | |||
# Unexpected OSError, re-raise. | |||
raise | |||
parts = urlparse(url) | |||
filename = os.path.basename(parts.path) | |||
if file_name is not None: | |||
filename = file_name | |||
cached_file = os.path.join(model_dir, filename) | |||
if not os.path.exists(cached_file): | |||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||
hash_prefix = None | |||
if check_hash: | |||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]] | |||
hash_prefix = r.group(1) if r else None | |||
download_url_to_file(url, cached_file, hash_prefix, progress=progress) | |||
if _is_legacy_zip_format(cached_file): | |||
return _legacy_zip_load(cached_file, model_dir, map_location) | |||
return torch_load(cached_file, map_location=map_location) |
@@ -2056,7 +2056,10 @@ class Tensor(StubTensor, metaclass=_TensorMeta): | |||
def topk(self, k, dim=None, largest=True, sorted=True): | |||
input = cast_to_ms_tensor(self) | |||
output = ms.ops.topk(input, k, dim, largest, sorted) | |||
if k == 0: | |||
output = (ms.ops.zeros([], dtype=self.dtype), ms.ops.zeros([], dtype=ms.int32)) | |||
else: | |||
output = ms.ops.topk(input, k, dim, largest, sorted) | |||
return cast_to_adapter_tensor(output) | |||
def maximum(self, other): | |||
@@ -0,0 +1,3 @@ | |||
from msadapter.pytorch.hub import tqdm, load_state_dict_from_url as load_url # noqa: F401 | |||
zoulq commented 8 months ago
Review
hub新增接口要看是否增加基本用例。至少能够看护接口的正确性 hub新增接口要看是否增加基本用例。至少能够看护接口的正确性
ziqi commented 8 months ago
Review
已新增 test_hub.py 用例 已新增 test_hub.py 用例
|
|||
__all__ = ['tqdm', 'load_url'] |
@@ -0,0 +1,24 @@ | |||
#!/usr/bin/env python | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import mindspore as ms | |||
import msadapter.pytorch as ms_torch | |||
import torch | |||
import numpy as np | |||
from mindspore import context | |||
from ...utils import set_mode_by_env_config | |||
set_mode_by_env_config() | |||
def test_get_dir(): | |||
ms_hub_dir = ms_torch.hub.get_dir() | |||
torch_hub_dir = torch.hub.get_dir() | |||
assert ms_hub_dir == torch_hub_dir | |||
def test_load_state_dict_from_url(): | |||
target_url = 'https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth' | |||
ms_torch_file_name = 'ms_torch_mobilenet.pth' | |||
ms_torch_state_dict = ms_torch.hub.load_state_dict_from_url(target_url, model_dir='.', file_name=ms_torch_file_name) | |||
assert len(ms_torch_state_dict) == 244 | |||
os.remove(os.path.join('.', ms_torch_file_name)) |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》
这些注释根据实际含义确定是否保留
已删除