|
- import mindspore as ms
- import mindspore.nn as nn
- import config as cfg
-
-
- def proj_feat(x, hidden_size, feat_size):
- x = x.reshape([x.shape[0], feat_size[0], feat_size[1], feat_size[2], hidden_size])
- x = x.transpose([0, 4, 1, 2, 3])
- return x
-
-
- class vit(nn.Cell):
- def __init__(self, in_channel, img_size, patch_size, hidden_size, mlp_dim, num_layer, num_head, pos_embed,
- classification, dropout_rate):
- super(vit, self).__init__()
- self.in_channel = in_channel
- self.img_size = img_size
- self.patch_size = patch_size
- self.hid_size = hidden_size
- self.mlp_dim = mlp_dim
- self.num_layer = num_layer
- self.num_head = num_head
- self.pos_embed = pos_embed
- self.classification = classification
- self.dro_rate = dropout_rate
- self.num_class = 2
- self.spatial_dim = 3
- self.acti = 'Tanh'
- self.qkv_bias = False
-
- self.patch_embedding = PatchEmbeddingBlock(in_channels=self.in_channel,
- img_size=self.img_size,
- patch_size=self.patch_size,
- hidden_size=self.hid_size,
- num_heads=self.num_head,
- pos_embed=self.pos_embed,
- dropout_rate=self.dro_rate,
- spatial_dims=self.spatial_dim)
- self.block = nn.SequentialCell(
- [TransformerBlock(self.hid_size, self.mlp_dim, self.num_head, self.dro_rate, self.qkv_bias) for i in
- range(self.num_layer)])
- self.norm = nn.LayerNorm(self.hid_size)
- if self.classification:
- self.cls_token = nn.Parameter(ms.zeros(1, 1, self.hid_size))
- if self.acti == "Tanh":
- self.classification_head = nn.SequentialCell(nn.Linear(self.hid_size, self.num_class), nn.Tanh())
- else:
- self.classification_head = nn.Linear(self.hid_size, self.num_class)
-
-
- def construct(self, x):
- x = self.patch_embedding(x)
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
- x = ms.ops.concat((cls_token, x), axis=1)
- hidden_states_out = []
- for blk in self.block:
- x = blk(x)
- hidden_states_out.append(x)
- x = self.norm(x)
- x = self.classification_head(x[:, 0])
- return x, hidden_states_out
-
-
- class UNETR(nn.Cell):
- def __init__(self):
- super(UNETR, self).__init__()
- self.in_channel = cfg.in_channel
- self.out_channel = cfg.out_channel
- self.img_size = cfg.img_size
- self.ft_size = cfg.ft_size
- self.hid_size = cfg.hid_size
- self.mlp_dim = cfg.mlp_dim
- self.num_head = cfg.num_head
- self.pos_embed = cfg.pos_embed
- self.norm = cfg.norm
- self.conv_block = cfg.conv_block
- self.res_block = cfg.res_block
- self.dro_rate = cfg.dro_rate
- self.num_layer = cfg.num_layer
- self.patch_size = cfg.patch_size
- self.feat_size = (self.img_size[0] // self.patch_size[0],
- self.img_size[1] // self.patch_size[1],
- self.img_size[2] // self.patch_size[2])
- self.classification = False
-
- self.vit = vit(in_channel=self.in_channel,
- img_size=self.img_size,
- patch_size=self.patch_size,
- hidden_size=self.hid_size,
- mlp_dim=self.mlp_dim,
- num_layer=self.num_layer,
- num_head=self.num_head,
- pos_embed=self.pos_embed,
- classification=self.classification,
- dropout_rate=self.dro_rate)
- self.encode1 = UnetrBasicBlock(spatial_dims=3,
- in_channels=self.in_channel,
- out_channels=self.ft_size,
- kernel_size=3,
- stride=1,
- norm_name=self.norm,
- res_block=self.res_block)
- self.encode2 = UnetrPrUpBlock(spatial_dims=3,
- in_channels=self.hid_size,
- out_channels=self.ft_size * 2,
- num_layer=2,
- kernel_size=3,
- stride=1,
- upsample_kernel_size=2,
- norm_name=self.norm,
- conv_block=self.conv_block,
- res_block=self.res_block)
- self.encode3 = UnetrPrUpBlock(spatial_dims=3,
- in_channels=self.hid_size,
- out_channels=self.ft_size * 4,
- num_layer=1,
- kernel_size=3,
- stride=1,
- upsample_kernel_size=2,
- norm_name=self.norm,
- conv_block=self.conv_block,
- res_block=self.res_block)
- self.encode4 = UnetrPrUpBlock(spatial_dims=3,
- in_channels=self.hid_size,
- out_channels=self.ft_size * 8,
- num_layer=0,
- kernel_size=3,
- stride=1,
- upsample_kernel_size=2,
- norm_name=self.norm,
- conv_block=self.conv_block,
- res_block=self.res_block)
- self.decode4 = UnetrUpBlock(spatial_dims=3,
- in_channels=self.hid_size,
- out_channels=self.ft_size * 8,
- kernel_size=3,
- upsample_kernel_size=2,
- norm_name=self.norm,
- res_block=self.res_block)
- self.decode3 = UnetrUpBlock(spatial_dims=3,
- in_channels=self.ft_size * 8,
- out_channels=self.ft_size * 4,
- kernel_size=3,
- upsample_kernel_size=2,
- norm_name=self.norm,
- res_block=self.res_block)
- self.decode2 = UnetrUpBlock(spatial_dims=3,
- in_channels=self.ft_size * 4,
- out_channels=self.ft_size * 2,
- kernel_size=3,
- upsample_kernel_size=2,
- norm_name=self.norm,
- res_block=self.res_block)
- self.decode1 = UnetrUpBlock(spatial_dims=3,
- in_channels=self.ft_size * 2,
- out_channels=self.ft_size,
- kernel_size=3,
- upsample_kernel_size=2,
- norm_name=self.norm,
- res_block=self.res_block)
- self.out = UnetOutBlock(spatial_dims=3, in_channels=self.ft_size, out_channels=self.out_channel)
-
- def construct(self, x):
- x, h = self.vit(x)
- e1 = self.encode1(x)
- x2 = h[3]
- e2 = self.encode2(proj_feat(x2, self.hid_size, self.feat_size))
- x3 = h[6]
- e3 = self.encode3(proj_feat(x3, self.hid_size, self.feat_size))
- x4 = h[9]
- e4 = self.encode4(proj_feat(x4, self.hid_size, self.feat_size))
- d4 = proj_feat(x, self.hid_size, self.feat_size)
- d3 = self.decode4(d4, e4)
- d2 = self.decode3(d3, e3)
- d1 = self.decode2(d2, e2)
- d = self.decode1(d1, e1)
- final = self.out(d)
- return final
|