|
-
- from ringmo_framework.models.backbone.swin_transformer_v2 import SwinTransformerV2, FinetuneSwinV2
- from ringmo_framework.models.layers import Linear
- from trans_ckpt_tools.print_pth_ckpt import print_ms_net_name, print_torch_pth
- from mindspore import nn, Tensor, Parameter
- from mindspore import ops as P
- from mindspore import dtype as mstype
- from mindspore.common.initializer import initializer, One, Zero
-
-
- def test_case_pth2ckpt():
-
- def swin_v2_base_p4_w8_img256():
- return FinetuneSwinV2(image_size=256, embed_dim=128, depths=[2,2,18,2], num_heads=[4, 8, 16, 32], window_size=8)
-
- print("===============Mindspore Net Begin=================")
- print_ms_net_name(swin_v2_base_p4_w8_img256())
- print("===============Mindspore Net End===================")
-
- filename = '/mnt/c/Users/admin/Downloads/swinv2_base_patch4_window8_256'
- print("===============Pytorch weight Begin===================")
- # print_torch_pth(filename)
- print("===============Pytorch weight End===================")
-
-
- class WindowCosineAttention(nn.Cell):
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
-
- super(WindowCosineAttention, self).__init__()
- if isinstance(dim, tuple) and len(dim) == 1:
- dim = dim[0]
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- self.log = P.Log()
- self.ones = P.Ones()
- logit_scale = Tensor((self.log(10 * self.ones((num_heads, 1, 1), mstype.float16))))
- self.logit_scale = Parameter(logit_scale, requires_grad=True, name="logit_scale")
- # get pair-wise relative position index for each token inside the window
- self.relative_position_bias = LogSpacedCPB(self.window_size, num_heads, pretrained_window_size)
-
- self.q = Linear(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
- self.k = Linear(in_channels=dim, out_channels=dim, has_bias=False)
- self.v = Linear(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
-
- self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
- self.proj = Linear(in_channels=dim, out_channels=dim, has_bias=True)
- self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
- self.softmax = nn.Softmax(axis=-1)
- self.cast = P.Cast()
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
- self.reshape_P = P.Reshape()
- self.matmul = P.BatchMatMul()
- self.exp = P.Exp()
- max = Tensor((100), mstype.float16)
- self.value_max = self.log(max)
- self.value_min = Tensor((-10000), mstype.int32)
- self.Normalize1 = P.L2Normalize(axis=3, epsilon=1e-4)
- self.Normalize2 = P.L2Normalize(axis=2, epsilon=1e-4)
-
- def construct(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- x = self.cast(x, mstype.float16)
- q = self.reshape(self.q(x), (B_, N, self.num_heads, C // self.num_heads)) #* self.scale
- q = self.transpose(q, (0, 2, 1, 3))
- q = self.Normalize1(q)
- k = self.reshape(self.k(x), (B_, N, self.num_heads, C // self.num_heads))
- k = self.transpose(k, (0, 2, 3, 1))
- k = self.Normalize2(k)
- v = self.reshape(self.v(x), (B_, N, self.num_heads, C // self.num_heads))
- v = self.transpose(v, (0, 2, 1, 3))
-
- attn = self.matmul(q, k)
-
- logit_scale = P.clip_by_value(self.logit_scale, clip_value_min=self.value_min, clip_value_max=self.value_max)
- logit_scale = self.exp(logit_scale)
-
- attn = attn * logit_scale
- attn = self.cast(attn, mstype.float32)
- attn = attn + self.relative_position_bias()
-
- if mask is not None:
- nW, ws2, _ = mask.shape
- mask = self.reshape_P(mask, (1, -1, 1, ws2, ws2))
- attn = self.reshape_P(attn, (B_ // nW, nW, self.num_heads, N, N,)) + mask
- attn = self.reshape_P(attn, (-1, self.num_heads, N, N,))
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
- attn = self.attn_drop(attn)
- attn = self.cast(attn, mstype.float16)
- x = self.reshape(self.transpose(self.matmul(attn, v), (0, 2, 1, 3)), (B_, N, C))
- x = self.cast(x, mstype.float16)
- x = self.proj(x)
- x = self.cast(x, mstype.float32)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
-
- class TestPart(nn.Cell):
- def __init__(self, dim, num_heads, window_size=(7, 7), qkv_bias=True, pretrained_window_size=[0, 0], attn_drop=0., proj_drop=0.):
- super().__init__()
- self.matmul = P.BatchMatMul()
- self.exp = P.Exp()
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
- self.cast = P.Cast()
- self.log = P.Log()
- self.ones = P.Ones()
-
- self.num_heads = num_heads
- self.window_size = window_size
-
- logit_scale = Tensor((self.log(10 * self.ones((num_heads, 1, 1), mstype.float32))))
- self.logit_scale = Parameter(logit_scale, requires_grad=True, name="logit_scale")
-
- max = Tensor((100), mstype.float32)
- self.value_max = self.log(max)
-
- self.value_min = Tensor((-1000), mstype.float32)
-
- self.q = Linear(in_channels=dim, out_channels=dim, has_bias=qkv_bias).to_float(mstype.float16)
- self.k = Linear(in_channels=dim, out_channels=dim, has_bias=False).to_float(mstype.float16)
- self.v = Linear(in_channels=dim, out_channels=dim, has_bias=qkv_bias).to_float(mstype.float16)
-
- self.Normalize = P.L2Normalize(axis=-1, epsilon=1e-12)
-
- # get pair-wise relative position index for each token inside the window
- self.relative_position_bias = LogSpacedCPB(self.window_size, num_heads, pretrained_window_size)
-
- self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
- self.proj = Linear(in_channels=dim, out_channels=dim, has_bias=True).to_float(mstype.float16)
- self.softmax = nn.Softmax(axis=-1)
- self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
-
- self.init_weight()
-
- def construct(self, x, mask=None):
- B_, N, C = x.shape
-
- q = self.reshape(self.q(x), (B_, N, self.num_heads, C // self.num_heads)) # float16
- # Normalize need to be caculated on float32
- q = self.cast(q, mstype.float32)
- q = self.Normalize(q) # float32
- q = self.transpose(q, (0, 2, 1, 3))
-
- k = self.reshape(self.k(x), (B_, N, self.num_heads, C // self.num_heads))
- k = self.cast(k, mstype.float32)
- k = self.Normalize(k)
- k = self.transpose(k, (0, 2, 3, 1))
-
- v = self.reshape(self.v(x), (B_, N, self.num_heads, C // self.num_heads))
- v = self.transpose(v, (0, 2, 1, 3))
-
- q = self.cast(q, mstype.float16)
- k = self.cast(k, mstype.float16)
- attn = self.matmul(q, k) # float16
-
- logit_scale = P.clip_by_value(self.logit_scale, clip_value_min=self.value_min, clip_value_max=self.value_max)
- logit_scale = self.exp(logit_scale) # float32
-
- attn = self.cast(attn, mstype.float32)
- attn = attn * logit_scale # float32
-
- attn = attn + self.relative_position_bias()
- # return attn
-
- if mask is not None:
- nW, ws2, _ = mask.shape
- mask = self.reshape(mask, (1, -1, 1, ws2, ws2))
- attn = self.reshape(attn, (B_ // nW, nW, self.num_heads, N, N,)) + mask
- attn = self.reshape(attn, (-1, self.num_heads, N, N,))
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- attn = self.cast(attn, mstype.float16)
- x = self.reshape(self.transpose(self.matmul(attn, v), (0, 2, 1, 3)), (B_, N, C))
- x = self.proj(x)
- x = self.cast(x, mstype.float32)
- x = self.proj_drop(x)
-
- return x
-
- def init_weight(self):
- for _, cell in self.cells_and_names():
- if isinstance(cell, Linear):
- cell.weight.set_data(initializer(One(), shape=cell.weight.shape, dtype=cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(initializer(Zero(), shape=cell.bias.shape, dtype=cell.bias.dtype))
-
-
- import torch
- import torch.nn.functional as F
- import numpy as np
- class TorchAttn(torch.nn.Module):
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = torch.nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = torch.nn.Sequential(torch.nn.Linear(2, 512, bias=True),
- torch.nn.ReLU(inplace=True),
- torch.nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = torch.nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = torch.nn.Parameter(torch.zeros(dim))
- self.v_bias = torch.nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = torch.nn.Dropout(attn_drop)
- self.proj = torch.nn.Linear(dim, dim)
- self.proj_drop = torch.nn.Dropout(proj_drop)
- self.softmax = torch.nn.Softmax(dim=-1)
-
- self.init_weight()
-
- def forward(self, x, mask=None):
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def init_weight(self):
- for m in self.modules():
- if isinstance(m, torch.nn.Linear):
- m.weight.data.fill_(1)
- if m.bias is not None:
- m.bias.data.zero_()
-
-
- class LogSpacedCPB(nn.Cell):
- def __init__(self, window_size, num_heads, pretrained_window_size):
- super(LogSpacedCPB, self).__init__()
- self.window_size = window_size # Wh, Ww
- # mlp to generate continuous relative position bias
- self.num_heads = num_heads
- self.cpb_mlp0 = Linear(2, 512, has_bias=True).to_float(mstype.float16)
- self.cpb_act1 = nn.ReLU()
- self.cpb_mlp2 = Linear(512, num_heads, has_bias=False).to_float(mstype.float16)
- self.cast = P.Cast()
-
- # get relative_coords_table
- relative_coords_h = Tensor(np.arange(-(self.window_size[0] - 1), self.window_size[0]),mstype.float32)
- relative_coords_w = Tensor(np.arange(-(self.window_size[1] - 1), self.window_size[1]),mstype.float32)
- relative_coords_table = P.Stack(axis=0)(
- P.Meshgrid(indexing='ij')((relative_coords_h,
- relative_coords_w)))#.Transpose(1, 2, 0).ExpandDims(0) # 1, 2*Wh-1, 2*Ww-1, 2
- relative_coords_table = P.Transpose()(relative_coords_table,(1, 2, 0))
- relative_coords_table = P.ExpandDims()(relative_coords_table,0)
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- # abs = P.Abs()
- sign = P.Sign()
- relative_coords_table = sign(relative_coords_table) * \
- Tensor(np.log2(np.abs(relative_coords_table.asnumpy()) + 1)) / np.log2(8)
-
- self.relative_coords_table = Parameter(
- Tensor(relative_coords_table, mstype.float32),
- requires_grad=False, name="relative_coords_table")
-
- # get pair-wise relative position index for each token inside the window
- coords_h = Tensor(np.arange(window_size[0]), mstype.int32)
- coords_w = Tensor(np.arange(window_size[1]), mstype.int32)
- coords = P.Stack(axis=0)(P.Meshgrid(indexing='ij')((coords_h, coords_w))) # 2, Wh, Ww
- coords_flatten = P.Flatten()(coords) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = P.Transpose()(relative_coords, (1, 2, 0)).asnumpy() # Wh*Ww, Wh*Ww, 2
-
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
-
- relative_position_index = np.sum(relative_coords, axis=-1) # Wh*Ww, Wh*Ww
-
- self.relative_position_index = Parameter(
- Tensor(relative_position_index, mstype.int32),
- requires_grad=False, name="relative_position_index")
-
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
- self.expand_dim = P.ExpandDims()
- self.sigmoid = P.Sigmoid()
-
- def construct(self):
- x = self.cpb_mlp0(self.relative_coords_table)
- x = self.cpb_act1(x)
- x = self.cpb_mlp2(x)
- x = self.reshape(x, (-1, self.num_heads))
- relative_position_bias = x[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- #relative_position_bias = relative_position_bias#.Transpose(2, 0, 1) # nH, Wh*Ww, Wh*Ww
- relative_position_bias = self.transpose(relative_position_bias, (2, 0, 1))
- relative_position_bias = self.cast(relative_position_bias, mstype.float32)
- relative_position_bias = 16 * self.sigmoid(relative_position_bias)
-
- relative_position_bias = self.expand_dim(relative_position_bias, 0)
- return relative_position_bias
-
-
- def _count_unequal_element(data_expected, data_me, rtol, atol):
- assert data_expected.shape == data_me.shape
- total_count = len(data_expected.flatten())
- error = np.abs(data_expected - data_me)
- greater = np.greater(error, atol + np.abs(data_me) * rtol)
- loss_count = np.count_nonzero(greater)
- if (loss_count / total_count) < rtol:
- print('AllClose.')
- else:
- print('\ndata_expected_std: {0}\ndata_me_error: {1}\nloss: {2}\nrtol: {3}\n'
- 'ratio: {4}\nmax_distance: {5}\n'.format(
- data_expected[greater], data_me[greater], error[greater],
- (error[greater] - atol) / np.abs(data_me[greater]), len(error[greater]) / total_count, max(error[greater])))
-
-
- def allclose_nparray(data_expected, data_me, rtol=1.0e-3, atol=1.0e-3, equal_nan=True):
- """AllClose."""
- if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)):
- if np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
- print('AllClose')
- else:
- print('No AllClose 0.')
- elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
- _count_unequal_element(data_expected, data_me, rtol, atol)
- else:
- if np.array(data_expected).shape == np.array(data_me).shape:
- print('AllClose.')
- else:
- print('No AllClose 1.')
-
-
- def test_case_attention():
- ms_attn = TestPart(128, 4)
- torch_attn = TorchAttn(128, (7, 7), 4)
-
- x1 = np.random.uniform(size=(4, 49, 128))
- # x2 = np.random.uniform(size=(4, 8, 256))
- ms_x = Tensor(x1, dtype=mstype.float32)
- th_x = torch.tensor(x1, dtype=torch.float32)
-
- allclose_nparray(th_x.detach().numpy(), ms_x.asnumpy())
- print("diff", np.max(np.abs(th_x.detach().numpy() - ms_x.asnumpy())))
- print('input x: ', x1)
-
- ms_out = ms_attn(ms_x)
- th_out = torch_attn(th_x)
-
- print('ms_out shape: ', ms_out.shape)
- print('ms_out type: ', ms_out.dtype)
- print('ms_out: ', ms_out)
- print('th_out shape: ', th_out.shape)
- print('th_out: ', th_out)
- allclose_nparray(th_out.detach().numpy(), ms_out.asnumpy())
- print("diff", np.max(np.abs(th_out.detach().numpy() - ms_out.asnumpy())))
-
-
- def test_corpos():
- ms_cpb = LogSpacedCPB((7, 7), 8, (0, 0))
- torch_attn = TorchAttn(128, (7, 7), 8)
-
- x = np.random.uniform(size=(4, 49, 128))
- ms_x = Tensor(x, dtype=mstype.float32)
- th_x = torch.tensor(x, dtype=torch.float32)
-
- ms_out = ms_cpb()
- # th_out = torch_attn(th_x).unsqueeze(0)
- th_out = torch_attn(th_x)
-
- print('ms_out shape: ', ms_out.shape)
- print('ms_out: ', ms_out)
- print('th_out shape: ', th_out.shape)
- print('th_out: ', th_out)
- allclose_nparray(th_out.detach().numpy(), ms_out.asnumpy())
- print("diff", np.max(np.abs(th_out.detach().numpy() - ms_out.asnumpy())))
-
- if __name__=='__main__':
- test_case_attention()
- # test_corpos()
|