#918 load/save fix

Merged
Erpim merged 22 commits from load_fix into master 3 weeks ago
  1. +20
    -0
      mindtorch/torch/_tensor.py
  2. +87
    -0
      mindtorch/torch/_utils.py
  3. +5
    -1
      mindtorch/torch/functional.py
  4. +4
    -1
      mindtorch/torch/linalg/linalg.py
  5. +33
    -25
      mindtorch/torch/nn/parameter.py
  6. +251
    -208
      mindtorch/torch/serialization.py
  7. +24
    -6
      mindtorch/torch/storage.py
  8. +47
    -16
      mindtorch/torch/tensor.py
  9. +7
    -2
      mindtorch/torch/utils/data/dataloader.py
  10. +2
    -2
      testing/ut/pytorch/torch/test_hub.py
  11. +1
    -0
      testing/ut/pytorch/torch/test_import.py
  12. +240
    -3
      testing/ut/pytorch/torch/test_serialization.py

+ 20
- 0
mindtorch/torch/_tensor.py View File

@@ -0,0 +1,20 @@
from ._utils import _set_obj_state
def _rebuild_from_type(func, type, args, dict):
from mindtorch.torch.tensor import Tensor # pylint: disable=R0401, C0415
if type is Tensor:
return func(*args)

ret = func(*args).as_subclass(type)
hanjr marked this conversation as resolved
Erpim commented 4 weeks ago
Review
当前有应该不支持as_subclass?什么场景会进这个函数?
hanjr commented 4 weeks ago
Review
def as_subclass(self, cls): return cls(self) 补充了简单实现,浩宇的那个样例会进这个函数
hanjr commented 4 weeks ago
Review
已删除,当前不需要
ret.__dict__ = dict
return ret

def _rebuild_from_type_v2(func, new_type, args, state):
from mindtorch.torch.tensor import Tensor # pylint: disable=R0401, C0415
ret = func(*args)
if not isinstance(ret, new_type):
ret = ret.as_subclass(new_type)
if getattr(ret.__class__, "__setstate__", Tensor.__setstate__) is not Tensor.__setstate__:
ret.__setstate__(state)
else:
ret = _set_obj_state(ret, state)
return ret

+ 87
- 0
mindtorch/torch/_utils.py View File

@@ -1,5 +1,6 @@
import sys
import traceback
import copyreg
import mindtorch.torch.common.dtype as _dtype
from mindtorch.torch.common.dtype import finfo, iinfo
from mindtorch.utils import unsupported_attr
@@ -51,6 +52,53 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
warning("'async' is deprecated; use 'non_blocking'")
return kwargs['async']

def _rebuild_tensor(storage, storage_offset, size, stride):
unsupported_attr(stride)
from mindtorch.torch.tensor import tensor # pylint: disable=R0401, C0415
t = tensor([], dtype=storage.dtype, device=storage._untyped().device)
return t.set_(storage._untyped(), storage_offset, size)
zoulq commented 3 weeks ago
Review
直接创建一个tensor的方法快一点吧?
Erpim commented 3 weeks ago
Review
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新
Erpim commented 3 weeks ago
Review
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新


def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
unsupported_attr(backward_hooks)
unsupported_attr(metadata)
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
return tensor

def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
from mindtorch.torch.functional import from_numpy # pylint: disable=R0401, C0415
tensor = from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor


def _rebuild_parameter(data, requires_grad, backward_hooks):
unsupported_attr(backward_hooks)
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415
param = Parameter(data, requires_grad)
param.set_(data.storage()._untyped(), 0, data.size())
return param

def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
unsupported_attr(backward_hooks)
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415
param = Parameter(data, requires_grad)
param._backward_hooks = backward_hooks
param = _set_obj_state(param, state)
return param

def _rebuild_mindtorch_parameter(data, requires_grad, name, layerwise_parallel):
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415
param = Parameter(data, requires_grad, name, layerwise_parallel)
return param

def _rebuild_mindtorch_parameter_with_state(data, requires_grad, name, layerwise_parallel, state):
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415
param = Parameter(data, requires_grad, name, layerwise_parallel)
param = _set_obj_state(param, state)
return param

def _import_dotted_name(name):
components = name.split('.')
obj = __import__(components[0])
@@ -127,3 +175,42 @@ def _unflatten_dense_tensors(flat, tensors):
unsupported_attr(flat)
unsupported_attr(tensors)
raise NotImplementedError("`_unflatten_dense_tensors` is not implemented now.")

def _set_obj_state(obj, state):
if isinstance(state, tuple):
if not len(state) == 2:
raise RuntimeError(f"Invalid serialized state: {state}")
dict_state = state[0]
slots_state = state[1]
else:
dict_state = state
slots_state = None

if dict_state:
for k, v in dict_state.items():
setattr(obj, k, v)

if slots_state:
for k, v in slots_state.items():
setattr(obj, k, v)
return obj

def _get_obj_state(obj):
getstate_fn = getattr(obj, "__getstate__", None)
if getstate_fn is not None:
state = getstate_fn()
else:
slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined]
if slots_to_save:
state = (
obj.__dict__,
{
name: getattr(obj, name)
for name in slots_to_save
if hasattr(obj, name)
},
)
else:
state = obj.__dict__.copy()

return state

+ 5
- 1
mindtorch/torch/functional.py View File

@@ -12,7 +12,11 @@ except ImportError:
import mindspore as ms
from mindspore import ops
from mindspore.common import dtype as mstype
from mindspore.scipy.ops import SolveTriangular
try:
from mindspore.scipy.ops import SolveTriangular # not support on win cpu
except ImportError:
# do nothings here.
...
from mindspore.ops.primitive import _primexpr
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore._c_expression import Tensor as ms_Tensor_


+ 4
- 1
mindtorch/torch/linalg/linalg.py View File

@@ -3,7 +3,10 @@

import mindspore as ms
from mindspore.ops.primitive import _primexpr
from mindspore.scipy.ops import SolveTriangular
try:
from mindspore.scipy.ops import SolveTriangular# not support on win cpu
except ImportError:
...
from mindtorch.torch.common._inner import _out_inplace_assign
from mindtorch.utils import unsupported_attr, pynative_mode_condition, \
is_under_gpu_context, is_under_ascend_context, set_multiple_name_tuple


+ 33
- 25
mindtorch/torch/nn/parameter.py View File

@@ -15,8 +15,7 @@ from mindspore.parallel._ps_context import _insert_accumu_init_info
from mindtorch.torch.tensor import Tensor, cast_to_ms_tensor, cast_to_adapter_tensor
from mindtorch.torch.common.dtype import _msdtype2typeDict
from mindtorch.torch.functional import empty as torch_empty
from mindtorch.utils import graph_mode_condition

from mindtorch.torch import _utils
__all__ = ['Parameter', 'ParameterTuple', 'UninitializedParameter', 'UninitializedBuffer']

def init_to_value(init):
@@ -41,7 +40,11 @@ def init_to_value(init):

class Parameter(ms.Parameter):
_base_type = {}
def __new__(cls, data, *args, **kwargs):

def __new__(cls, data=None, requires_grad=True, name=None, layerwise_parallel=False, # pylint: disable = W0613
parallel_optimizer=True): # pylint: disable = W0613
if data is None:
data = 1
init_data_flag = bool(isinstance(data, ms.Tensor) and data.has_init)
rc = sys.getrefcount(data)
input_class, *class_init_args = Parameter._get_parameter_new_args(data, rc)
@@ -55,16 +58,22 @@ class Parameter(ms.Parameter):
return obj

def __reduce_ex__(self, _):
data = self
state = _utils._get_obj_state(self)
if self.init_mode is not None:
data = self.init_mode
else:
# cast to break deep infinite loop while deepcopy
data = ms.Tensor(self)
return (
Parameter, (data, self.requires_grad, self.name, self.layerwise_parallel))
if not state:

return (_utils._rebuild_mindtorch_parameter, (data, self.requires_grad, self.name,
self.layerwise_parallel))

return (_utils._rebuild_mindtorch_parameter_with_state, (data, self.requires_grad, self.name,
self.layerwise_parallel, state))

def __init__(self, data, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True):
def __init__(self, data=None, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True):
if data is None:
data = 1
self.adapter_flag = True
super().__init__(default_input=data, name=name, requires_grad=requires_grad,
layerwise_parallel=layerwise_parallel, parallel_optimizer=parallel_optimizer)
@@ -185,23 +194,22 @@ class Parameter(ms.Parameter):
def shape(self):
return self._shape

def set_(self, source=None, storage_offset=0, size=None, stride=None):
if storage_offset or size or stride:
raise ValueError("Currently, `Parameter.set_` specifying `storage_offset`, "
"`size` or `stride` are not supported.")

if source is None:
raise ValueError("Currently, `Parameter.set_` only supported specify the `source`, " \
"please ensure that it is not None.")

if graph_mode_condition():
raise RuntimeError('`Parameter.set_` is an in-place operation and "x.set_()" is not supported to use '
'in MindSpore static graph mode.')

source = cast_to_ms_tensor(source)
self.set_data(source, True)
return self

def __setstate__(self, state):
if isinstance(state, tuple):
if len(state) == 4:
self.set_(*state)
return
elif len(state) == 5:
data = state[0]
Parameter.__init__(self, data, requires_grad=state[3])
self.set_dtype(data.dtype)
self.set_data(data=data, slice_shape=True)
self._requires_grad = state[3]
return

def __getstate__(self):
state = {key: value for key, value in self.__dict__.items() if key not in Parameter().__dict__}
return state

def _init_parameter_api():
param_func = dir(Parameter)


+ 251
- 208
mindtorch/torch/serialization.py View File

@@ -2,6 +2,7 @@
# pylint: disable=unused-argument
# pylint: disable=eval-used
# pylint: disable=broad-except
import difflib
import os
import io
import struct
@@ -11,23 +12,21 @@ import pathlib
import shutil
import zipfile
import tarfile
import warnings
import tempfile
import operator
import inspect
from functools import reduce
from dataclasses import dataclass
from enum import Enum
from contextlib import closing, contextmanager
from collections.abc import Mapping, Sequence
from typing import Any, BinaryIO, Union, IO, Optional, Type, Dict, Tuple
from typing_extensions import TypeAlias
from ml_dtypes import bfloat16
import numpy as np
from mindtorch.module_hooker import torch_disable, torch_pop
from mindtorch.torch import _utils
from mindtorch.torch.storage import _UntypedStorage, _TypedStorage
from mindtorch.torch.tensor import tensor, Tensor
from mindtorch.torch.nn.modules.module import Module, Parameter
from mindtorch.torch.nn.modules.module import Module
from mindtorch.torch.logging import warning
import mindtorch.torch.common.dtype as _dtype
from mindtorch.torch.storage import _get_dtype_from_pickle_storage_type

DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
@@ -46,41 +45,29 @@ __all__ = [
'load',
]

dtype_map = {
"HalfStorage": np.float16,
"FloatStorage": np.float32,
'BFloat16Storage': bfloat16,
'LongStorage': np.int64,
'ByteStorage': np.uint8,
'BoolStorage': np.bool_,
'IntStorage': np.int32,
'ShortStorage': np.int16,
'CharStorage': np.int8,
'DoubleStorage': np.float64,
}

_storage_classes_dict = {_dtype.double: "DoubleStorage",
_dtype.float: "FloatStorage",
_dtype.half: "HalfStorage",
_dtype.long: "LongStorage",
_dtype.int: "IntStorage",
_dtype.int16: "ShortStorage",
_dtype.int8: "CharStorage",
_dtype.uint8: "ByteStorage",
_dtype.bool: "BoolStorage",
_dtype.bfloat16: "BFloat16Storage",
_dtype.cdouble: "ComplexDoubleStorage",
_dtype.cfloat: "ComplexFloatStorage",
}

element_size_map = {
"HalfStorage": 2,
"FloatStorage": 3,
'BFloat16Storage': 2,
'LongStorage': 4,
'ByteStorage': 1,
'BoolStorage': 1
}

def typename(o):
if isinstance(o, Tensor):
return o.type()

module = ''
class_name = ''
if hasattr(o, '__module__') and o.__module__ != 'builtins' \
and o.__module__ != '__builtin__' and o.__module__ is not None:
module = o.__module__ + '.'

if hasattr(o, '__qualname__'):
class_name = o.__qualname__
elif hasattr(o, '__name__'):
class_name = o.__name__
else:
class_name = o.__class__.__name__

return module + class_name


class SourceChangeWarning(Warning):
pass

def get_source_lines_and_file(obj, error_msg = None) :
try:
@@ -106,6 +93,12 @@ def mkdtemp():
finally:
shutil.rmtree(path)

class _HasStorage:
def __init__(self, storage):
self._storage = storage

def storage(self):
return self._storage

class PyTorchFileReader:

@@ -141,6 +134,11 @@ class PyTorchFileReader:
return self.file.getinfo(filename).header_offset
return None

def get_storage_from_record(self, name, numel, dtype):
filename = f"{self.directory}/{name}"
storage = _UntypedStorage
return _HasStorage(storage.from_buffer(self.read_record(name)))


class PyTorchFileWriter:
def __init__(self, file):
@@ -338,44 +336,9 @@ def _should_read_directly(f):
except AttributeError:
return False

def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
if size == ():
size = ()
stride = (1,)
num_elemets = 1
else:
num_elemets = reduce(operator.mul, size)
array = storage[storage_offset: storage_offset + num_elemets]
origin_dtype = None
if array.dtype == bfloat16:
origin_dtype = 'bfloat16'
array = array.astype(np.float32)

if stride is not None and len(stride) > 1 and stride[0] == 1 and stride[1] > 1:
stride = tuple((s * 4 for s in stride))
array = np.lib.stride_tricks.as_strided(array, size, stride)
else:
order = "C"
array = array.reshape(size, order=order)
if origin_dtype == 'bfloat16':
return tensor(array, dtype=_dtype.bfloat16)
param = tensor(array)
return param

def _rebuild_parameter(data, requires_grad, backward_hooks):
param = Parameter(data, requires_grad)
param._backward_hooks = backward_hooks
return param

@dataclass
class FakeParameter:
storage: np.ndarray = None
storage_offset: int = None
size: tuple = None
requires_grad: bool = None

def _rebuild_tensor_legacy(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
return FakeParameter(storage, storage_offset, size, requires_grad)
def normalize_storage_type(storage_type):
import mindtorch.torch as ms_torch # pylint: disable=R0401, C0415
return getattr(ms_torch, storage_type.__name__)

def _maybe_decode_ascii(bytes_str: Union[bytes, str]):
if isinstance(bytes_str, bytes):
@@ -430,21 +393,34 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
# source_lines, _, source_file = get_source_lines_and_file(obj)
# source = ''.join(source_lines)
# except Exception:
# warnings.warn("Couldn't retrieve source code for container of "
# warning("Couldn't retrieve source code for container of "
# "type " + obj.__name__ + ". It won't be checked "
# "for correctness upon loading.")
# return ('module', obj, source_file, source)
raise NotImplementedError("Do not support save module now. Please use torch.save to save model parameters."
"If you want to save model parameters, "
"please use 'torch.save(net.state_dict(), filename)'")
from mindtorch.torch import is_storage # pylint: disable=R0401, C0415
if isinstance(obj, _TypedStorage) or is_storage(obj):
storage = None
if isinstance(obj, _TypedStorage):
import mindtorch.torch as ms_torch # pylint: disable=R0401, C0415
storage = obj._storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(ms_torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj.size()
elif isinstance(obj, _UntypedStorage):
storage = obj
storage_dtype = _dtype.uint8
storage_type = normalize_storage_type(type(obj))
dtype = _dtype.uint8
storage_numel = storage.nbytes()
else:
raise TypeError(f'type not recognized: {type(obj)}')

if isinstance(obj, (Parameter, Tensor)):
storage = obj
storage_dtype = obj.dtype
storage_type = _storage_classes_dict[obj.dtype]
storage_numel = obj.numel()

storage_dataptr = id(storage)
storage_dataptr = storage.data_ptr()
if storage_dataptr != 0:
if storage_dataptr in storage_dtypes:
if storage_dtype != storage_dtypes[storage_dataptr]:
@@ -455,20 +431,19 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
storage_dtypes[storage_dataptr] = storage_dtype

view_metadata: Optional[Tuple[str, int, int]]
storage_key = id_map.setdefault(storage_dataptr, str(len(id_map)))
offset = 0
storage_key = str(id(storage))
location = 'cpu'
if storage_key not in serialized_storages:
serialized_storages[storage_key] = (storage, obj.dtype)
serialized_storages[storage_key] = (storage, dtype)
view_metadata = None
mindtorch_info = storage.shape

res = ('storage',
storage_type,
storage_key,
location,
storage_numel,
view_metadata,
mindtorch_info)
view_metadata)
return res
return None

@@ -494,8 +469,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
f.flush()
for key in serialized_storage_keys:
storage, dtype = serialized_storages[key]
f.write(np.array(storage.numel(), dtype=np.uint64).tobytes())
f.write(storage.get_bytes())
storage._write_file(f, _should_read_directly(f), True, _utils._element_size(dtype))

def _save(obj, zip_file, pickle_module, pickle_protocol):
serialized_storages = {}
@@ -507,14 +481,24 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
raise NotImplementedError("Do not support save module now. Please use torch.save to save model parameters."
"If you want to save model parameters, "
"please use 'torch.save(net.state_dict(), filename)'")
if isinstance(obj, (Parameter, Tensor)):
storage = obj
storage_dtype = obj.dtype
storage_type = _storage_classes_dict[obj.dtype]
storage_numel = obj.numel()
storage_shape = storage.shape

storage_dataptr = id(storage)
from mindtorch.torch import is_storage # pylint: disable=R0401, C0415
if isinstance(obj, _TypedStorage) or is_storage(obj):

if isinstance(obj, _TypedStorage):
import mindtorch.torch as ms_torch # pylint: disable=R0401, C0415
storage = obj._storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(ms_torch, storage_type_str)
storage_numel = obj.size()

else:
storage = obj
storage_dtype = _dtype.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()

storage_dataptr = storage.data_ptr()
if storage_dataptr != 0:
if storage_dataptr in storage_dtypes:
if storage_dtype != storage_dtypes[storage_dataptr]:
@@ -524,19 +508,16 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
else:
storage_dtypes[storage_dataptr] = storage_dtype

view_metadata: Optional[Tuple[str, int, int]]
storage_key = id_map.setdefault(storage_dataptr, str(len(id_map)))
storage_key = id_map.setdefault(id(storage), str(len(id_map)))
location = 'cpu'
if storage_key not in serialized_storages:
serialized_storages[storage_key] = storage
serialized_storages[storage_key] = storage

return ('storage',
storage_type,
storage_key,
location,
storage_numel)

res = ('storage',
storage_type,
storage_key,
location,
storage_numel,
storage_shape)
return res
return None

data_buf = io.BytesIO()
@@ -549,9 +530,16 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
for key in sorted(serialized_storages.keys()):
name = f'archive/data/{key}'
storage = serialized_storages[key]
storage_data = storage.get_bytes()
storage_data = storage.inner_data
zip_file.write_record(name, storage_data)

class StorageType():
def __init__(self, name):
self.dtype = _get_dtype_from_pickle_storage_type(name)

def __str__(self):
return f'StorageType(dtype={self.dtype})'


def load(f: FILE_LIKE,
map_location=None,
@@ -571,38 +559,129 @@ def load(f: FILE_LIKE,
with _open_zipfile_reader(opened_file, ) as opened_zipfile:
if _is_torchscript_zip(opened_zipfile):
raise ValueError('do not support torchscript now')
return _load(opened_zipfile,
torch_disable()
result = _load(opened_zipfile,
pickle_module,
overall_storage=overall_storage,
**pickle_load_args)

return _legacy_load(opened_file, pickle_module, **pickle_load_args)

torch_pop()
return result
torch_disable()
result = _legacy_load(opened_file, pickle_module, **pickle_load_args)
zoulq commented 3 weeks ago
Review
现在还有要依赖pytorch的场景吗?
hanjr commented 3 weeks ago
Review
没有依赖pytorch的场景,在load torch的权重时,会读取到保存的torch函数指针,这个地方只是为了保证这个函数指针一定指向mindtorch实现的同名函数位置。
torch_pop()
return result

def _legacy_load(f, pickle_module, **pickle_load_args):
deserialized_objects: Dict[int, Any] = {}
class UnpicklerWrapper(pickle_module.Unpickler):
def find_class(self, mod_name, name):
if name == '_rebuild_tensor_v2':
name = '_rebuild_tensor_legacy'
if mod_name == 'torch._utils':
return eval(name)
if mod_name == 'torch':
return str(name)
if isinstance(name, str) and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
return super().find_class(mod_name, name)

def legacy_load(f):
deserialized_objects: Dict[int, Any] = {}
def _check_container_source(container_type, source_file, original_source):
try:
current_source = ''.join(get_source_lines_and_file(container_type)[0])
except Exception:
warning("Couldn't retrieve source code for container of "
"type " + container_type.__name__ + ". It won't be checked "
"for correctness upon loading.")
return
if original_source != current_source:
if container_type.dump_patches:
file_name = container_type.__name__ + '.patch'
diff = difflib.unified_diff(current_source.split('\n'),
original_source.split('\n'),
source_file,
source_file, lineterm="")
lines = '\n'.join(diff)
try:
with open(file_name, 'a+') as f:
file_size = f.seek(0, 2)
f.seek(0)
if file_size == 0:
f.write(lines)
elif file_size != len(lines) or f.read() != lines:
raise OSError
msg = ("Saved a reverse patch to " + file_name + ". "
"Run `patch -p0 < " + file_name + "` to revert your changes.")
except OSError:
msg = ("Tried to save a patch, but couldn't create a "
"writable file " + file_name + ". Make sure it "
"doesn't exist and your working directory is "
"writable.")
else:
msg = ("you can retrieve the original source code by "
"accessing the object's source attribute or set "
"`torch.nn.Module.dump_patches = True` and use the "
"patch tool to revert the changes.")
msg = f"source code of class '{typename(container_type)}' has changed. {msg}"
warning(msg, SourceChangeWarning)


def legacy_load(file):
deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
if all(saved_id[1:]):
_check_container_source(*saved_id) #TODO
return saved_id[0]
return deserialized_objects[int(saved_id)]

with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
with closing(tarfile.open(fileobj=file, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
raise ValueError('do not support legacy load for Pytorch.')

tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as _file:
num_storages = pickle_module.load(_file, **pickle_load_args)
for i in range(num_storages):
args = pickle_module.load(_file, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type.dtype
element_size = _utils._element_size(dtype)
nbytes = np.frombuffer(_file.read(8), np.int64).item() * element_size
data = np.fromfile(_file, dtype=np.uint8, count=nbytes, offset=0)
obj = _UntypedStorage.from_buffer(data)

deserialized_objects[key] = _TypedStorage(
wrap_storage=obj,
dtype=dtype,
_internal=True)

storage_views = pickle_module.load(_file, **pickle_load_args)
for target_cdata, root_cdata, offset, numel in storage_views:
root = deserialized_objects[root_cdata]
element_size = _utils._element_size(root.dtype)
offset_bytes = offset * element_size
deserialized_objects[target_cdata] = _TypedStorage(
wrap_storage=root._untyped()[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype,
_internal=True)

tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as _file:
num_tensors = pickle_module.load(_file, **pickle_load_args)
for _ in range(num_tensors):
args = pickle_module.load(_file, **pickle_load_args)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
ndim, = struct.unpack('<i', _file.read(4))
_file.read(4)
numel = struct.unpack(f'<{ndim}q', _file.read(8 * ndim))
stride = struct.unpack(f'<{ndim}q', _file.read(8 * ndim))
storage_offset, = struct.unpack('<q', _file.read(8))
tmp_tensor = tensor([], dtype=storage.dtype).set_(
storage._untyped(), storage_offset, numel, stride)
deserialized_objects[key] = tmp_tensor

pickle_file = tar.extractfile('pickle')
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result

deserialized_objects = {}

@@ -617,25 +696,29 @@ def _legacy_load(f, pickle_module, **pickle_load_args):
"Do not support load module now. Please use 'torch.load' to load model parameters."
"Model parameters should be saved in 'PyTorch' by 'torch.save(net.state_dict(), filename)'.")
if typename == 'storage':
if len(data) == 6:
storage_type, root_key, location, numel, view_metadata, mindtorch_info = data
else:
storage_type, root_key, location, numel, view_metadata = data
storage_type, root_key, location, numel, view_metadata = data
location = _maybe_decode_ascii(location)
dtype = storage_type.dtype

nbytes = numel * _utils._element_size(dtype)

if root_key not in deserialized_objects:
typed_storage = np.empty(numel, dtype_map[storage_type])
deserialized_objects[root_key] = typed_storage
else:
typed_storage = deserialized_objects[root_key]
obj = _UntypedStorage(nbytes)
deserialized_objects[root_key] = _TypedStorage(
wrap_storage=obj, dtype=dtype)

typed_storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata
offset_bytes = offset * _utils._element_size(dtype)
view_size_bytes = view_size * _utils._element_size(dtype)
if view_key not in deserialized_objects:
deserialized_objects[view_key] = typed_storage[offset: offset + view_size]
deserialized_objects[view_key] = _TypedStorage(
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes],
dtype=dtype)
res = deserialized_objects[view_key]
else:
res = typed_storage
if mindtorch_info is not None:
res = _rebuild_tensor_legacy(res, 0, mindtorch_info, None, False, None)
return res
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")

@@ -668,55 +751,19 @@ def _legacy_load(f, pickle_module, **pickle_load_args):
unpickler = UnpicklerWrapper(f, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()

deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)

offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
typed_storage = deserialized_objects[key]
f.read(8)
array = np.frombuffer(f.read(typed_storage.nbytes), typed_storage.dtype)
typed_storage[:] = array
if typed_storage.dtype == bfloat16:
assert np.allclose(typed_storage.astype(np.float32), array.astype(np.float32))
else:
assert np.allclose(typed_storage, array)
typed_storage._storage._set_from_file(
f, 0, f_should_read_directly,
_utils._element_size(typed_storage.dtype))
if offset is not None:
offset = f.tell()

def result_convert(result):
elem_type = type(result)
if isinstance(result, FakeParameter):
if result.size == ():
num_elemets = 1
else:
num_elemets = reduce(operator.mul, result.size)
array = result.storage[result.storage_offset: result.storage_offset + num_elemets]
array = array.reshape(result.size)
if array.dtype == bfloat16:
array = array.astype(np.float32)
return tensor(array, dtype=_dtype.bfloat16)
return tensor(array)
elif isinstance(result, Mapping):
try:
return elem_type({key: result_convert(result[key]) for key in result})
except TypeError:
return {key: result_convert(result[key]) for key in result}
elif isinstance(result, tuple) and hasattr(result, '_fields'):
return elem_type(*(result_convert(d) for d in result))
elif isinstance(result, (tuple, list)):
return [result_convert(d) for d in result]
elif isinstance(result, Sequence) and not isinstance(result, string_classes):
try:
return elem_type([result_convert(d) for d in result])
except TypeError:
return [result_convert(d) for d in result]
else:
return result

new_result = result_convert(result)

return new_result
return result


def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl', **pickle_load_args):
@@ -740,13 +787,20 @@ def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl',
if not zip_file.has_record(byteordername) and \
get_default_load_endianness() is None and \
sys.byteorder == 'big':
warnings.warn("The default load endianness for checkpoints without a byteorder mark "
warning("The default load endianness for checkpoints without a byteorder mark "
"on big endian machines was changed from 'native' to 'little' endian, "
"to avoid this behavior please use "
"torch.serialization.set_default_load_endianness to set "
"the desired default load endianness",
UserWarning)

def load_tensor(dtype, numel, key, location):
name = f'data/{key}'

tmp_storage = zip_file.get_storage_from_record(name, numel, _UntypedStorage).storage()._untyped()
loaded_storages[key] = _TypedStorage(
wrap_storage=tmp_storage,dtype=dtype)

def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
@@ -754,29 +808,17 @@ def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl',

assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
shape = None
if len(data) == 5:
storage_type, key, location, numel, shape = data
storage_type, key, location, numel = data
if storage_type is _UntypedStorage:
dtype = _dtype.uint8
else:
storage_type, key, location, numel = data
name = f'data/{key}'
if name in loaded_storages:
return loaded_storages[name]
dtype = storage_type.dtype

if overall_storage is not None:
array = np.memmap(overall_storage, dtype=dtype_map[storage_type],
offset=zip_file.open_record(name)._fileobj.tell(), shape=(numel,))
else:
array = np.frombuffer(zip_file.read_record(name), dtype_map[storage_type])
if shape is not None:
array = np.reshape(array, shape)
if dtype_map[storage_type] == bfloat16:
array = array.astype(np.float32)
array = tensor(array, dtype=_dtype.bfloat16)
else:
array = tensor(array)
loaded_storages[name] = array
return array
if key not in loaded_storages:
nbytes = numel * _utils._element_size(dtype)
load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))

return loaded_storages[key]

load_module_mapping: Dict[str, str] = {
'torch.tensor': 'torch._tensor'
@@ -787,10 +829,11 @@ def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl',
raise NotImplementedError(
"Do not support load module now. Please use 'torch.load' to load model parameters."
"Model parameters should be saved in 'PyTorch' by 'torch.save(net.state_dict(), filename)'.")
if mod_name == 'torch._utils':
return eval(name)
if mod_name == 'torch':
return str(name)
if isinstance(name, str) and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
mod_name = load_module_mapping.get(mod_name, mod_name)
return super().find_class(mod_name, name)



+ 24
- 6
mindtorch/torch/storage.py View File

@@ -4,6 +4,7 @@ import collections
from functools import lru_cache
from ast import literal_eval
from typing import Any
from ml_dtypes import bfloat16 as np_bfloat16
import mindspore as ms
import mindtorch.torch.common.dtype as _dtype
from mindtorch.torch.common.dtype import _TypeDict
@@ -71,13 +72,20 @@ class _StorageBase():
self._update_referenced_tensor()
return self

def _update_referenced_tensor(self, strict=True):
def _update_referenced_tensor(self, strict=True, size=None):
if self.referenced_tensor is not None:
np_data = np.frombuffer(self.inner_data,
_TypeDict.get(self.referenced_tensor.dtype))
if size is not None:
np_data = np_data.reshape(size)
if strict:
np_data = np_data.reshape(self.referenced_tensor.shape)
value = ms.Tensor.from_numpy(np_data)
if np_data.dtype == np_bfloat16:
np_data = np_data.astype(np.float32)
value = ms.Tensor.from_numpy(np_data)
value = value.astype(_dtype.bfloat16)
else:
value = ms.Tensor.from_numpy(np_data)
self.referenced_tensor.assign_value(value)

def nbytes(self):
@@ -154,7 +162,7 @@ class _StorageBase():

def resize_(self, size):
if size <= self.size():
self.inner_data = np.frombuffer(self.inner_data, dtype=np.uint8, count=size)
self.inner_data = self.inner_data[:size]
else:
append_data = np.random.randint(0, 255, size=size - self.size(), dtype=np.uint8)
self.inner_data = np.concatenate((self.inner_data, append_data), axis=0)
@@ -173,7 +181,7 @@ class _StorageBase():
raise RuntimeError("Currently, in `storage._set_from_file` only is_real_file==True supported.")
nbytes = np.frombuffer(f.read(8), np.int64).item() * element_size
array = np.fromfile(f, dtype=np.uint8, count=nbytes, offset=offset)
self.inner_data = array
self.inner_data[:] = array
self._update_referenced_tensor()
return self

@@ -370,7 +378,8 @@ class _TypedStorage:
self[0:len(self)] = value
return self

def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=True):
unsupported_attr(_internal)
if cls == _LegacyStorage:
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")

@@ -436,7 +445,8 @@ class _TypedStorage:
wrap_storage=wrap_storage,
dtype=cls.dtype)

def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=True):
unsupported_attr(_internal)
arg_error_msg = (
'_TypedStorage.__init__ received an invalid combination '
'of arguments. Expected one of:\n'
@@ -908,3 +918,11 @@ _storage_classes_dict = {_dtype.double: DoubleStorage,
_dtype.cdouble: ComplexDoubleStorage,
_dtype.cfloat: ComplexFloatStorage,
}


def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
try:
return _storage_type_to_dtype_map()[pickle_storage_type]
except KeyError as e:
raise KeyError(
f'pickle storage type "{pickle_storage_type}" is not recognized') from e

+ 47
- 16
mindtorch/torch/tensor.py View File

@@ -4,13 +4,17 @@ import os
import abc
import numbers
import operator
from collections import OrderedDict
# from functools import reduce, lru_cache
from copy import deepcopy
from functools import reduce
import numpy as np
import mindspore as ms
from mindspore import Tensor as ms_Tensor
from mindspore.scipy.ops import SolveTriangular
try:
from mindspore.scipy.ops import SolveTriangular # not support on win cpu
except ImportError:
...
from mindspore.common import dtype as mstype
import mindspore.ops as P
from mindspore.ops.primitive import _primexpr
@@ -39,6 +43,8 @@ from mindtorch.torch.logging import warning, info
import mindtorch.torch._register_numpy_primitive as numpy_cell
from mindtorch.torch._default_dtype import _not_default_fp32_dtype, get_default_dtype
from mindtorch.torch._C.Size import Size
from mindtorch.torch._tensor import _rebuild_from_type_v2
from mindtorch.torch import _utils

_dtypeDict = {
'float16': mstype.float16,
@@ -255,6 +261,8 @@ class _TensorMeta(type(ms_Tensor), abc.ABCMeta):
"""

class Tensor(StubTensor, metaclass=_TensorMeta):

layout = property(lambda self: object(), lambda self, v: None, lambda self: None)
def __init__(self, *data, requires_grad=False, dtype=None, inner=False, cast_tensor=False):
if cast_tensor:
if len(data) != 1:
@@ -616,11 +624,38 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return out

def __getstate__(self):
pickled = {"input_data": self.asnumpy(), "dtype": self.dtype}
return pickled
state = {key: value for key, value in self.__dict__.items() if key not in Tensor().__dict__}
return state

def __reduce_ex__(self, protocol):
state = _utils._get_obj_state(self)
if isinstance(self, Tensor) and not state:
return self._reduce_ex_internal()
func, args = self._reduce_ex_internal()
return (_rebuild_from_type_v2, (func, type(self), args, state))

def _reduce_ex_internal(self):
backward_hooks = OrderedDict()
args = (
_TypedStorage(
wrap_storage=self.storage()._untyped(),
dtype=self.dtype),
0,
tuple(self.size()),
self.stride(),
self.requires_grad,
backward_hooks)
return (_utils._rebuild_tensor_v2, args)

def __setstate__(self, state):
Tensor.__init__(self, state["input_data"], dtype=state["dtype"], inner=True)
if isinstance(state, tuple):
if len(state) == 4:
self.set_(*state)
return
elif len(state) == 5:
data = state[0]
Tensor.__init__(self, data, dtype=data.dtype, inner=True, requires_grad=state[3])
return

@property
def grad_fn(self):
@@ -687,7 +722,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta):

def storage(self):
if graph_mode_condition():
raise NotImplementedError('Currently, `tensor.storage()` is not supported in graph mode. '
warning('Currently, `tensor.storage()` is not supported in graph mode. '
'Please replace `Storage` related interfaces with the equivalent interface.')

if self.dtype == mindtorch_dtype.bfloat16:
@@ -1591,22 +1626,18 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(input_ms.copy())

def set_(self, source=None, storage_offset=0, size=None, stride=None):
if storage_offset or size or stride:
raise ValueError("Currently, `Tensor.set_` specifying `storage_offset`, "
"`size` or `stride` are not supported.")

if source is None:
raise ValueError("Currently, `Tensor.set_` only supported specify the `source`, " \
"please ensure that it is not None.")

unsupported_attr(storage_offset)
unsupported_attr(stride)
if graph_mode_condition():
raise RuntimeError('`Tensor.set_` is an in-place operation and "x.set_()" is not supported to use '
warning('`Tensor.set_` is an in-place operation and "x.set_()" is not supported to use '
'in MindSpore static graph mode.')

if isinstance(source, Tensor):
if source.dtype != self.dtype:
raise RuntimeError("In `tensor.set_`, sourse.dtype must equal to self.dtype.")
source = cast_to_ms_tensor(source)
if size:
source = source.reshape(size)
self.assign_value(source)
return self

@@ -1615,12 +1646,12 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
if source.dtype != self.dtype:
raise RuntimeError("In `tensor.set_`, _TypedStorage.dtype must equal to self.dtype.")
source._storage.referenced_tensor = self
source._storage._update_referenced_tensor(strict=False)
source._storage._update_referenced_tensor(strict=False, size=size)
return self

# handle source is a _UntypedStorage
source.referenced_tensor = self
source._update_referenced_tensor(strict=False)
source._update_referenced_tensor(strict=False, size=size)
return self

def to(self, *args, **kwargs):


+ 7
- 2
mindtorch/torch/utils/data/dataloader.py View File

@@ -6,7 +6,7 @@ in `./_utils/worker.py`.
"""
from __future__ import absolute_import

import functools
import sys
import itertools
import logging
import os
@@ -218,7 +218,7 @@ class DataLoader(Generic[T_co]):
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
num_workers: int = 1, collate_fn: Optional[_collate_fn_t] = None,
num_workers: int = None, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
@@ -226,6 +226,11 @@ class DataLoader(Generic[T_co]):
persistent_workers: bool = False,
pin_memory_device: str = ""):
# torch._C._log_api_usage_once("python.data_loader")
if num_workers is None:
if sys.platform == "win32":
num_workers = 0
else:
num_workers = 1

if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '


+ 2
- 2
testing/ut/pytorch/torch/test_hub.py View File

@@ -9,9 +9,9 @@ import numpy as np
from mindspore import context
import pytest

from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, SKIP_ENV_CPU
set_mode_by_env_config()
@SKIP_ENV_CPU(reason="need stable network")
def test_get_dir():
ms_hub_dir = ms_torch.hub.get_dir()
torch_hub_dir = torch.hub.get_dir()


+ 1
- 0
testing/ut/pytorch/torch/test_import.py View File

@@ -3,6 +3,7 @@ from mindtorch.module_hooker import torch_enable, torch_pop
from ...utils import set_mode_by_env_config
set_mode_by_env_config()

@pytest.fixture(scope='function')
def test_import():
torch_enable()
import torch


+ 240
- 3
testing/ut/pytorch/torch/test_serialization.py View File

@@ -1,11 +1,13 @@
import os
import pytest
import numpy as np
import torch
import mindtorch.torch as pytorch
from ...utils import set_mode_by_env_config, param_compare
from ...utils import set_mode_by_env_config, param_compare, SKIP_ENV_CPU, SKIP_ENV_GPU, SKIP_ENV_ASCEND

set_mode_by_env_config()

@pytest.fixture(scope='function')
def test_save_load_1():
state_dict_torch ={}
state_dict_mindtorch = {}
@@ -31,6 +33,7 @@ def test_save_load_1():
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_2():
state_dict_torch = {}
state_dict_mindtorch = {}
@@ -56,6 +59,7 @@ def test_save_load_2():
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_3():
state_dict_torch = {}
state_dict_mindtorch = {}
@@ -80,7 +84,7 @@ def test_save_load_3():
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_4():
state_dict_torch = {}
state_dict_mindtorch = {}
@@ -105,6 +109,7 @@ def test_save_load_4():
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_bf16_1():
state_dict_mindtorch = {}
a = pytorch.tensor(800000, dtype=pytorch.bfloat16)
@@ -125,6 +130,7 @@ def test_save_load_bf16_1():
param_compare(b.to(pytorch.float32), state_dict_mindtorch["b"].to(pytorch.float32))
assert c == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_bf16_2():
state_dict_mindtorch = {}
a = pytorch.tensor(800000, dtype=pytorch.bfloat16)
@@ -145,7 +151,7 @@ def test_save_load_bf16_2():
param_compare(b.to(pytorch.float32), state_dict_mindtorch["b"].to(pytorch.float32))
assert c == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_bf16_3():
state_dict_mindtorch = {}
a = torch.tensor(800000, dtype=torch.bfloat16)
@@ -164,6 +170,7 @@ def test_save_load_bf16_3():
param_compare(b.to(torch.float32), state_dict_mindtorch["b"].to(pytorch.float32))
assert c == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_bf16_4():
state_dict_mindtorch = {}
a = torch.tensor(800000, dtype=torch.bfloat16)
@@ -182,6 +189,229 @@ def test_save_load_bf16_4():
param_compare(b.to(torch.float32), state_dict_mindtorch["b"].to(pytorch.float32))
assert c == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_parameter_1():
state_dict_torch ={}
state_dict_mindtorch = {}
a = np.random.rand(3, 3).astype(np.float32)
b = np.random.rand(1, 64,64, 3).astype(np.float32)
c = 1
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a))
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b))
state_dict_torch["c"] = c

state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a))
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b))
state_dict_mindtorch["c"] = c

torch.save(state_dict_torch, "test_save_load_parameter_1_torch.pth")
pytorch.save(state_dict_mindtorch, "test_save_load_parameter_1_mindtorch.pth")

state_dict_torch = torch.load("test_save_load_parameter_1_torch.pth")
state_dict_mindtorch = pytorch.load("test_save_load_parameter_1_mindtorch.pth")
os.remove("test_save_load_parameter_1_torch.pth")
os.remove("test_save_load_parameter_1_mindtorch.pth")
param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"])
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_parameter_2():
state_dict_torch ={}
state_dict_mindtorch = {}
a = np.random.rand(3, 3).astype(np.float32)
b = np.random.rand(1, 64,64, 3).astype(np.float32)
c = 1
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a))
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b))
state_dict_torch["c"] = c

state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a))
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b))
state_dict_mindtorch["c"] = c

torch.save(state_dict_torch, "test_save_load_parameter_2_torch.pth", _use_new_zipfile_serialization=False)
pytorch.save(state_dict_mindtorch, "test_save_load_parameter_2_mindtorch.pth", _use_new_zipfile_serialization=False)

state_dict_torch = torch.load("test_save_load_parameter_2_torch.pth")
state_dict_mindtorch = pytorch.load("test_save_load_parameter_2_mindtorch.pth")
os.remove("test_save_load_parameter_2_torch.pth")
os.remove("test_save_load_parameter_2_mindtorch.pth")
param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"])
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]


@pytest.fixture(scope='function')
def test_save_load_parameter_3():
state_dict_torch = {}
state_dict_mindtorch = {}
a = np.random.rand(3, 3).astype(np.float32)
b = np.random.rand(1, 64, 64, 3).astype(np.float32)
c = 1
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a))
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b))
state_dict_torch["c"] = c

state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a))
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b))
state_dict_mindtorch["c"] = c

torch.save(state_dict_torch, "test_save_load_parameter_3_torch.pth")

state_dict_torch = torch.load("test_save_load_parameter_3_torch.pth")
state_dict_mindtorch = pytorch.load("test_save_load_parameter_3_torch.pth")
os.remove("test_save_load_parameter_3_torch.pth")

param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"])
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_parameter_4():
state_dict_torch = {}
state_dict_mindtorch = {}
a = np.random.rand(3, 3).astype(np.float32)
b = np.random.rand(1, 64, 64, 3).astype(np.float32)
c = 1
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a))
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b))
state_dict_torch["c"] = c

state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a))
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b))
state_dict_mindtorch["c"] = c

torch.save(state_dict_torch, "test_save_load_parameter_4_torch.pth", _use_new_zipfile_serialization=False)

state_dict_torch = torch.load("test_save_load_parameter_4_torch.pth")
state_dict_mindtorch = pytorch.load("test_save_load_parameter_4_torch.pth")
os.remove("test_save_load_parameter_4_torch.pth")

param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"])
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"])
assert state_dict_torch["c"] == state_dict_mindtorch["c"]

@pytest.fixture(scope='function')
def test_save_load_net():
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self, num_classes: int = 10) -> None:
super(Net, self).__init__()

self.features = nn.Sequential(
nn.Conv2d(3, 64, (11, 11), (4, 4), (2, 2), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d((3, 3), (2, 2)),
)

self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256 * 6 * 6, 4096),
)

net = Net()
state_dict = {
'net': net.state_dict(),
}

torch.save(state_dict, 'torch_module.pt', _use_new_zipfile_serialization=True)

torch.save(state_dict, 'torch_module_oldfile.pt', _use_new_zipfile_serialization=False)

import mindtorch.torch as pytorch
import mindtorch.torch.nn as nn
class Net(nn.Module):
def __init__(self, num_classes: int = 10) -> None:
super(Net, self).__init__()

self.features = nn.Sequential(
nn.Conv2d(3, 64, (11, 11), (4, 4), (2, 2), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d((3, 3), (2, 2)),
)

self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256 * 6 * 6, 4096),
)

net = Net()
state = pytorch.load("torch_module.pt")
net.load_state_dict(state['net'])
os.remove("torch_module.pt")
state_dict = {
'net': net.state_dict(),
}
pytorch.save(state_dict, 'mindtorch_module.pt', _use_new_zipfile_serialization=True)
os.remove("mindtorch_module.pt")

state = pytorch.load("torch_module_oldfile.pt")
net.load_state_dict(state['net'])
os.remove("torch_module_oldfile.pt")
state_dict = {
'net': net.state_dict(),
}
pytorch.save(state_dict, 'mindtorch_module_oldfile.pt', _use_new_zipfile_serialization=False)
os.remove('mindtorch_module_oldfile.pt')


@pytest.fixture(scope='function')
@SKIP_ENV_ASCEND(reason="This function need torch version >= 2.1.0")
@SKIP_ENV_GPU(reason="This function need torch version >= 2.1.0")
@SKIP_ENV_CPU(reason="This function need torch version >= 2.1.0")
def test_save_load_5():
a = torch.tensor(2.)
a.kkk = 3
torch.save(a, 'a.pth')
tensor = pytorch.load('a.pth')
os.remove('a.pth')
assert tensor.kkk == a.kkk
param_compare(a, tensor)


@pytest.fixture(scope='function')
def test_save_load_6():
a = pytorch.tensor(2.)
a.kkk = 3
pytorch.save(a, 'a.pth')
tensor = pytorch.load('a.pth')
os.remove('a.pth')
assert tensor.kkk == a.kkk
param_compare(a, tensor)


@pytest.fixture(scope='function')
@SKIP_ENV_ASCEND(reason="This function need torch version >= 2.1.0")
@SKIP_ENV_GPU(reason="This function need torch version >= 2.1.0")
@SKIP_ENV_CPU(reason="This function need torch version >= 2.1.0")
def test_save_load_7():
a = torch.nn.Parameter(torch.tensor(2.))
a.kkk = 3
torch.save(a, 'a.pth')
tensor = pytorch.load('a.pth')
os.remove('a.pth')
assert tensor.kkk == a.kkk
param_compare(a.detach(), tensor)


@pytest.fixture(scope='function')
def test_save_load_8():
a = pytorch.nn.Parameter(pytorch.tensor(2.))
a.kkk = 3
pytorch.save(a, 'a.pth')
tensor = pytorch.load('a.pth')
os.remove('a.pth')
assert tensor.kkk == a.kkk
param_compare(a, tensor)

if __name__ == '__main__':
test_save_load_1()
test_save_load_2()
@@ -191,3 +421,10 @@ if __name__ == '__main__':
test_save_load_bf16_2()
test_save_load_bf16_3()
test_save_load_bf16_4()
test_save_load_parameter_1()
test_save_load_parameter_2()
test_save_load_parameter_3()
test_save_load_parameter_4()
test_save_load_net()
test_save_load_5()
test_save_load_6()

Loading…
Cancel
Save