|
- From ed543c49ec315e1e8cb9c6f8cb0fbc9a08deaad1 Mon Sep 17 00:00:00 2001
- From: lvhaoyu <lvhaoyu@huawei.com>
- Date: Tue, 22 Nov 2022 15:03:16 +0800
- Subject: [PATCH] adapt to msadapter.pytorch
-
- ---
- einops/_backends.py | 4 ++--
- einops/_torch_specific.py | 2 +-
- einops/layers/torch.py | 6 +++---
- 3 files changed, 6 insertions(+), 6 deletions(-)
-
- diff --git a/einops/_backends.py b/einops/_backends.py
- index ba6fca8..80cce17 100644
- --- a/einops/_backends.py
- +++ b/einops/_backends.py
- @@ -316,7 +316,7 @@ class TorchBackend(AbstractBackend):
- framework_name = 'torch'
-
- def __init__(self):
- - import torch
- + import msadapter.pytorch as torch
- self.torch = torch
-
- def is_appropriate_type(self, tensor):
- @@ -677,4 +677,4 @@ class OneFlowBackend(AbstractBackend):
- return oneflow
-
- def einsum(self, pattern, *x):
- - return self.flow.einsum(pattern, *x)
- \ No newline at end of file
- + return self.flow.einsum(pattern, *x)
- diff --git a/einops/_torch_specific.py b/einops/_torch_specific.py
- index 204d935..5ea9253 100644
- --- a/einops/_torch_specific.py
- +++ b/einops/_torch_specific.py
- @@ -10,7 +10,7 @@ Importantly, whole lib is designed so that you can't use it
-
- from typing import Dict, List
-
- -import torch
- +import msadapter.pytorch as torch
- from einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
-
-
- diff --git a/einops/layers/torch.py b/einops/layers/torch.py
- index 3199241..cb9af63 100644
- --- a/einops/layers/torch.py
- +++ b/einops/layers/torch.py
- @@ -1,6 +1,6 @@
- from typing import Optional, Dict, cast
-
- -import torch
- +import msadapter.pytorch as torch
-
- from . import RearrangeMixin, ReduceMixin
- from ._einmix import _EinmixMixin
- @@ -29,10 +29,10 @@ class Reduce(ReduceMixin, torch.nn.Module):
-
- class EinMix(_EinmixMixin, torch.nn.Module):
- def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
- - self.weight = torch.nn.Parameter(torch.zeros(weight_shape).uniform_(-weight_bound, weight_bound),
- + self.weight = torch.nn.Parameter(torch.zeros(weight_shape).uniform(-weight_bound, weight_bound),
- requires_grad=True)
- if bias_shape is not None:
- - self.bias = torch.nn.Parameter(torch.zeros(bias_shape).uniform_(-bias_bound, bias_bound),
- + self.bias = torch.nn.Parameter(torch.zeros(bias_shape).uniform(-bias_bound, bias_bound),
- requires_grad=True)
- else:
- self.bias = None
- --
- 2.25.1
|