|
- import numpy as np
- import mindspore as ms
- import mindspore.nn as nn
- import msadapter.pytorch.common.dtype as msdapter_dtype
-
- _error_msg = "[numpy backward issue.] For '{}', it can not backward, please use other function instead."
-
- #TODO: NumpyLstsq constructs the same output that torch.lstsq generates
- #Later, torch.lstsq will be deprecated and used linalg.lstsq instead, the NumpyLstsq will be deprecated as well
- class NumpyLstsq(nn.Cell):
- def __init__(self, op_name=None):
- super().__init__()
- self.op_name = op_name
- def construct(self, input, A):
- type_np = A.dtype
- shape_np = A.shape
- input_np = input.asnumpy()
- A_np = A.asnumpy()
- output = ms.Tensor(np.linalg.lstsq(A_np, input_np)[0])
- #TODO: linalg.lstsq not support qr as return, thus the qr will be set to zeros
- qr = ms.ops.zeros(shape_np, type_np)
- return output, qr
- def bprop(self, input, A, out, dout):
- raise RuntimeError(_error_msg.format(self.op_name))
-
- #TODO: NumpyLstsq constructs the same output that torch.linalg.lstsq generates
- class NumpyFullLstsq(nn.Cell):
- def __init__(self, op_name=None, rcond=None):
- super().__init__()
- self.op_name = op_name
- self.rcond = rcond
- def construct(self, a, b):
- a = a.asnumpy()
- b = b.asnumpy()
- output = np.linalg.lstsq(a, b, rcond=self.rcond)
- x = ms.Tensor(output[0])
- residuals = ms.Tensor(output[1])
- rank = ms.Tensor(output[2])
- s = ms.Tensor(output[3])
- return x, residuals, rank, s
- def bprop(self, a, b, out, dout):
- raise RuntimeError(_error_msg.format(self.op_name))
-
- class NumpyEigvals(nn.Cell):
- def __init__(self):
- super().__init__()
- self.op_name = "torch.linalg.eigvals"
- def construct(self, A):
- A_np = A.asnumpy()
- output = np.linalg.eigvals(A_np)
- if A_np.dtype is np.float64 or A_np.dtype is np.complex128:
- output = output.astype(np.complex128)
- else:
- output = output.astype(np.complex64)
- return ms.Tensor(output)
- def bprop(self, A, out, dout):
- raise RuntimeError(_error_msg.format(self.op_name))
-
- def _svd_not_compute_uv(input, full_matrices=False):
- input_np = input.asnumpy()
- output = np.linalg.svd(input_np, full_matrices, compute_uv=False)
- return ms.Tensor(output)
-
- def _svd_compute_uv(input, full_matrices=False):
- input_np = input.asnumpy()
- output = np.linalg.svd(input_np, full_matrices, compute_uv=True)
- u = ms.Tensor(output[0])
- s = ms.Tensor(output[1])
- v_np = output[2]
- #TODO: Currently ms.ops.swapaxes has problem on GRAPH mode
- v_np = np.swapaxes(v_np, -1, -2)
- v = ms.Tensor(v_np)
- return s, u, v
-
- class NumpySvd(nn.Cell):
- def __init__(self, op_name=None):
- super().__init__()
- self.op_name = op_name
- def construct(self, input, full_matrices=False, compute_uv=True):
- if compute_uv:
- output = _svd_compute_uv(input, full_matrices)
- else:
- output = _svd_not_compute_uv(input, full_matrices)
- return output
- def bprop(self, input, out, dout):
- raise RuntimeError(_error_msg.format(self.op_name))
-
- class NumpySvdvals(nn.Cell):
- def __init__(self):
- super().__init__()
- self.op_name = 'torch.linalg.svdvals'
- def construct(self, input, full_matrices=False):
- output = _svd_not_compute_uv(input, full_matrices)
- return output
- def bprop(self, input, out, dout):
- raise RuntimeError(_error_msg.format(self.op_name))
-
- class NumpyI0(nn.Cell):
- def __init__(self, op_name=None):
- super().__init__()
- self.op_name = op_name
- def construct(self, A):
- A_np = A.asnumpy()
- output = ms.Tensor(np.i0(A_np))
- if A.dtype in msdapter_dtype.all_int_type:
- output = output.astype(ms.float32)
- return output
- def bprop(self, A, out, dout):
- raise RuntimeError(_error_msg.format(self.op_name))
|