|
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
- from functools import partial
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_, to_3tuple
- import torch.fft
- from params import get_args
- from torch.utils.checkpoint import checkpoint_sequential
- import random
- from torch.utils.data import Dataset, DataLoader
- from tqdm import tqdm
- from sklearn.metrics import mean_squared_error
- from sklearn.preprocessing import MinMaxScaler
- import batchnorm
-
-
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- try:
- from batchnorm import SynchronizedBatchNorm3d
- except:
- pass
-
- def normalization(planes, norm='bn'):
- if norm == 'bn':
- m = nn.BatchNorm3d(planes)
- elif norm == 'gn':
- m = nn.GroupNorm(4, planes)
- elif norm == 'in':
- m = nn.InstanceNorm3d(planes)
- elif norm == 'sync_bn':
- m = SynchronizedBatchNorm3d(planes)
- else:
- raise ValueError('normalization type {} is not supported'.format(norm))
- return m
-
- class Conv3d_Block(nn.Module):
- def __init__(self,num_in,num_out,kernel_size=1,stride=1,g=1,padding=None,norm=None):
- super(Conv3d_Block, self).__init__()
- if padding == None:
- padding = (kernel_size - 1) // 2
- self.bn = normalization(num_in,norm=norm)
- self.act_fn = nn.ReLU(inplace=True)
- self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding,stride=stride, groups=g, bias=False)
-
- def forward(self, x): # BN + Relu + Conv
- h = self.act_fn(self.bn(x))
- h = self.conv(h)
- return h
-
-
- class DilatedConv3DBlock(nn.Module):
- def __init__(self, num_in, num_out, kernel_size=(1,1,1), stride=1, g=1, d=(1,1,1), norm=None):
- super(DilatedConv3DBlock, self).__init__()
- assert isinstance(kernel_size,tuple) and isinstance(d,tuple)
-
- padding = tuple(
- [(ks-1)//2 *dd for ks, dd in zip(kernel_size, d)]
- )
-
- self.bn = normalization(num_in, norm=norm)
- self.act_fn = nn.ReLU(inplace=True)
- self.conv = nn.Conv3d(num_in,num_out,kernel_size=kernel_size,padding=padding,stride=stride,groups=g,dilation=d,bias=False)
-
- def forward(self, x):
- h = self.act_fn(self.bn(x))
- h = self.conv(h)
- return h
-
-
- class MFunit(nn.Module):
- def __init__(self, num_in, num_out, g=1, stride=1, d=(1,1),norm=None):
- """ The second 3x3x1 group conv is replaced by 3x3x3.
- :param num_in: number of input channels
- :param num_out: number of output channels
- :param g: groups of group conv.
- :param stride: 1 or 2
- :param d: tuple, d[0] for the first 3x3x3 conv while d[1] for the 3x3x1 conv
- :param norm: Batch Normalization
- """
- super(MFunit, self).__init__()
- num_mid = num_in if num_in <= num_out else num_out
- self.conv1x1x1_in1 = Conv3d_Block(num_in,num_in//4,kernel_size=1,stride=1,norm=norm)
- self.conv1x1x1_in2 = Conv3d_Block(num_in//4,num_mid,kernel_size=1,stride=1,norm=norm)
- self.conv3x3x3_m1 = DilatedConv3DBlock(num_mid,num_out,kernel_size=(3,3,3),stride=stride,g=g,d=(d[0],d[0],d[0]),norm=norm) # dilated
- self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(3,3,1),stride=1,g=g,d=(d[1],d[1],1),norm=norm)
- # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(1,3,3),stride=1,g=g,d=(1,d[1],d[1]),norm=norm)
-
- # skip connection
- if num_in != num_out or stride != 1:
- if stride == 1:
- self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0,norm=norm)
- if stride == 2:
- # if MF block with stride=2, 2x2x2
- self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2,padding=0, norm=norm) # params
-
- def forward(self, x):
- x1 = self.conv1x1x1_in1(x)
- x2 = self.conv1x1x1_in2(x1)
- x3 = self.conv3x3x3_m1(x2)
- x4 = self.conv3x3x3_m2(x3)
-
- shortcut = x
-
- if hasattr(self,'conv1x1x1_shortcut'):
- shortcut = self.conv1x1x1_shortcut(shortcut)
- if hasattr(self,'conv2x2x2_shortcut'):
- shortcut = self.conv2x2x2_shortcut(shortcut)
-
- return x4 + shortcut
-
- class DMFUnit(nn.Module):
- # weighred add
- def __init__(self, num_in, num_out, g=1, stride=1,norm=None,dilation=None):
- super(DMFUnit, self).__init__()
- self.weight1 = nn.Parameter(torch.ones(1))
- self.weight2 = nn.Parameter(torch.ones(1))
- self.weight3 = nn.Parameter(torch.ones(1))
-
- num_mid = num_in if num_in <= num_out else num_out
-
- self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm)
- self.conv1x1x1_in2 = Conv3d_Block(num_in // 4,num_mid,kernel_size=1, stride=1, norm=norm)
-
- self.conv3x3x3_m1 = nn.ModuleList()
- if dilation == None:
- dilation = [1,2,3]
- for i in range(3):
- self.conv3x3x3_m1.append(
- DilatedConv3DBlock(num_mid,num_out, kernel_size=(3, 3, 3), stride=stride, g=g, d=(dilation[i],dilation[i], dilation[i]),norm=norm)
- )
-
- # It has not Dilated operation
- self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
- # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(1, 3, 3), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
-
- # skip connection
- if num_in != num_out or stride != 1:
- if stride == 1:
- self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm)
- if stride == 2:
- self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0, norm=norm)
-
-
- def forward(self, x):
- # print('x1.shape:{}'.format(x.shape))
- x1 = self.conv1x1x1_in1(x)
- x2 = self.conv1x1x1_in2(x1)
- x3 = self.weight1*self.conv3x3x3_m1[0](x2) + self.weight2*self.conv3x3x3_m1[1](x2) + self.weight3*self.conv3x3x3_m1[2](x2)
- x4 = self.conv3x3x3_m2(x3)
- shortcut = x
- if hasattr(self, 'conv1x1x1_shortcut'):
- shortcut = self.conv1x1x1_shortcut(shortcut)
- if hasattr(self, 'conv2x2x2_shortcut'):
- shortcut = self.conv2x2x2_shortcut(shortcut)
- return x4 + shortcut
-
-
- class MFNet(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=12,n=32,channels=128,groups = 16,norm='bn', num_classes=4):
- super(MFNet, self).__init__()
-
- #时间维度变大
- # self.time_dimison = nn.Conv3d( c, n, kernel_size=(1,1,1), padding=(0,0,3), stride=(1,1,1), bias=False)
-
- # Entry flow
- self.encoder_block1 = nn.Conv3d(14, n, kernel_size=3, padding=1, stride=2, bias=False)# H//2
- self.encoder_block2 = nn.Sequential(
- MFunit(n, channels, g=groups, stride=2, norm=norm),# H//4 down
- MFunit(channels, channels, g=groups, stride=1, norm=norm),
- MFunit(channels, channels, g=groups, stride=1, norm=norm)
- )
- #
- self.encoder_block3 = nn.Sequential(
- MFunit(channels, channels*2, g=groups, stride=2, norm=norm), # H//8
- MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm),
- MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm)
- )
-
- self.encoder_block4 = nn.Sequential(# H//8,channels*4
- MFunit(channels*2, channels*3, g=groups, stride=2, norm=norm), # H//16
- MFunit(channels*3, channels*3, g=groups, stride=1, norm=norm),
- MFunit(channels*3, channels*2, g=groups, stride=1, norm=norm),
- )
-
- self.upsample1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//8
- self.decoder_block1 = MFunit(channels*2+channels*2, channels*2, g=groups, stride=1, norm=norm)
-
- self.upsample2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//4
- self.decoder_block2 = MFunit(channels*2 + channels, channels, g=groups, stride=1, norm=norm)
-
- self.upsample3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//2
- self.decoder_block3 = MFunit(channels + n, n, g=groups, stride=1, norm=norm)
- self.upsample4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H
- self.seg = nn.Conv3d(n, num_classes, kernel_size=1, padding=0,stride=1,bias=False)
-
- self.softmax = nn.Softmax(dim=1)
-
- # Initialization
- for m in self.modules():
- if isinstance(m, nn.Conv3d):
- torch.nn.init.torch.nn.init.kaiming_normal_(m.weight) #
- elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, SynchronizedBatchNorm3d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, x):
- # Encoder
- x = self.time_exten_dimison(x)
-
- # print('x0.shape:{}'.format(x.shape)) # (1,4,128,128,128)) ---> (1,4,128,128,128)
- x1 = self.encoder_block1(x)# H//2 down # (1,4,128,128,128) ---> torch.Size([1, 32, 64, 64, 64])
- # print('x8.shape:{}'.format(x1.shape))
- x2 = self.encoder_block2(x1)# H//4 down # ([1, 32, 64, 64, 64]) ---> torch.Size([1, 128, 32, 32, 32])
- # print(x2.shape)
- x3 = self.encoder_block3(x2)# H//8 down # torch.Size([1, 128, 32, 32, 32]) ---> torch.Size([1, 256, 16, 16, 16])
- # print('x3.shape:{}'.format(x.shape)) # x4 = self.encoder_block4(x3) # H//16 # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 256, 8, 8, 8])
- # print(x4.shape)
- # Decoder
- # y1 = self.upsample1(x4)# H//8 # torch.Size([1, 256, 8, 8, 8]) ---> torch.Size([1, 256, 16, 16, 16])
- # print(y1.shape)
- # y1 = torch.cat([x3,y1],dim=1) # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 512, 16, 16, 16])
- # print(y1.shape)
- # y1 = self.decoder_block1(y1) # torch.Size([1, 512, 16, 16, 16]) ---> torch.Size([1, 256, 16, 16, 16])
- # print(y1.shape)
- y2 = self.upsample2(x3)# H//4 # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 256, 32, 32, 32])
- # print(y2.shape)
- y2 = torch.cat([x2,y2],dim=1) # torch.Size([1, 256, 32, 32, 32]) ---> torch.Size([1, 384, 32, 32, 32])
-
- y2 = self.decoder_block2(y2) # torch.Size([1, 384, 32, 32, 32]) ---> torch.Size([1, 128, 32, 32, 32])
- # print(y2.shape)
- y3 = self.upsample3(y2)# H//2 # torch.Size([1, 128, 32, 32, 32]) ---> torch.Size([1, 128, 64, 64, 64])
-
- y3 = torch.cat([x1,y3],dim=1) # torch.Size([1, 128, 64, 64, 64]) ---> torch.Size([1, 160, 64, 64, 64])
- # print('y3.shape:{}'.format(y3.shape))
- y3 = self.decoder_block3(y3) # torch.Size([1, 160, 64, 64, 64]) --> torch.Size([1, 32, 64, 64, 64])
-
- y4 = self.upsample4(y3) # torch.Size([1, 32, 64, 64, 64])---> torch.Size([1, 32, 128, 128, 128])
- # print(y4.shape)
- y4 = self.seg(y4) # torch.Size([1, 32, 128, 128, 128]) ---> torch.Size([1, 4, 128, 128, 128])
- # print('y4.shape:{}'.format(y4.shape))
- if hasattr(self,'softmax'):
- y4 = self.softmax(y4)
- # print('y4.shape:{}'.format(y4.shape))
- y4 = self.time_descent_dimison(y4)
- return y4
-
- class time_exten_dimison(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=7,n=14,channels=128,groups = 16,norm='bn', num_classes=4):
- super(time_exten_dimison, self).__init__()
-
- #时间维度变大
- self.time_exten_dimison = nn.Conv3d( c, n, kernel_size=(1,1,3), padding=(0,0,1), stride=(1,1,1), bias=False)
-
- def forward(self, x):
- # print('x00.shape:{}'.format(x.shape))
- x = self.time_exten_dimison(x)
-
- return x
-
-
- class time_descent_dimison(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=4,n=1,channels=128,groups = 16,norm='bn', num_classes=4):
- super(time_descent_dimison, self).__init__()
-
- #时间维度变大
- self.time_descent_dimison = nn.Conv3d( c, n, kernel_size=(1,1,5), padding=(0,0,0), stride=(1,1,4), bias=False)
-
- def forward(self, x):
- # print('x00.shape:{}'.format(x.shape))
- x = self.time_descent_dimison(x)
-
- return x
-
-
- class DMFNet(MFNet): # softmax
- # [128] Flops: 27.045G & Params: 3.88M
- def __init__(self, c=7,n=32,channels=128, groups=16,norm='bn', num_classes=4):
- super(DMFNet, self).__init__(c,n,channels,groups, norm, num_classes)
-
- self.time_exten_dimison = time_exten_dimison()
- # print('xxxx.shape:{}'.format(x.shape))
- self.encoder_block2 = nn.Sequential(
- DMFUnit(n, channels, g=groups, stride=2, norm=norm,dilation=[1,2,3]),# H//4 down
- DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3]), # Dilated Conv 3
- DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3])
- )
-
- self.encoder_block3 = nn.Sequential(
- DMFUnit(channels, channels*2, g=groups, stride=2, norm=norm,dilation=[1,2,3]), # H//8
- DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3]),# Dilated Conv 3
- DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3])
- )
-
- self.time_descent_dimison = time_descent_dimison()
-
-
- for seed in range(1,2):
- for i in range(1):
- def setup_seed(seed):
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
- # 设置随机数种子
- setup_seed(seed)
-
- #需要 mld u v sss temp 降水 蒸发 混合层下的盐度
-
- data = np.load(r'/dataset/7day_for_Nday_data_openi_09_17_atlantic_area_last.npz')
-
- print(data.files) # ['hycom_temp', 'slfh', 'sshf', 'ssr', 'str', 'mld', 'analysis_temp', 'u', 'v', 'T_d', 'u_d', 'v_d', 'xx', 'yy']
-
- hycom_temp = data['hycom_temp'][:,:,0,:,:] # (3281, 7, 7, 41, 201)
- slfh = data['slfh'][:] # (3281, 7, 41, 201)
- sshf = data['sshf'][:] # (3281, 7, 41, 201)
- ssr = data['ssr'][:] # (3281, 7, 41, 201)
- str = data['str'][:] # (3281, 7, 41, 201)
- mld = data['mld'][:] # (3281, 7, 41, 201)
- analysis_temp = data['analysis_temp'][:,:,0,:,:] # (3275, 10, 7, 41, 81)
- u = data['u'][:,:,0,:,:] # (3281, 7, 7, 41, 201)
- v = data['v'][:,:,0,:,:] # (3281, 7, 7, 41, 201)
- T_d = data['T_d'][:] # (3281, 7, 41, 201)
- u_d = data['u_d'][:] # (3281, 7, 41, 201)
- v_d = data['v_d'][:] # (3281, 7, 41, 201)
- xx = data['xx'][:] # (3281, 7, 41, 201)
- yy = data['yy'][:] # (3281, 7, 41, 201)
-
- hycom_temp = hycom_temp.transpose(0,2,3,1)
- slfh = slfh.transpose(0,2,3,1)
- sshf = sshf.transpose(0,2,3,1)
- ssr = ssr.transpose(0,2,3,1)
- str = str.transpose(0,2,3,1)
- mld = mld.transpose(0,2,3,1)
- # analysis_temp = analysis_temp.transpose(0,2,3,1)
- u = u.transpose(0,2,3,1)
- v = v.transpose(0,2,3,1)
- T_d = T_d.transpose(0,2,3,1)
- u_d = u_d.transpose(0,2,3,1)
- v_d = v_d.transpose(0,2,3,1)
- xx = xx.transpose(0,2,3,1)
- yy = yy.transpose(0,2,3,1)
-
- print(analysis_temp.shape)
-
- # analysis_temp = analysis_temp.transpose(0,2,1,3,4)
-
-
- # data1 = np.load(r'/dataset/7day_for_10day_label_09_17_atlantic_area_last.npz')
- # analysis_temp = data1['analysis_temp'][:] # (3287, 7, 41, 81)
- # analysis_temp = analysis_temp.transpose(0,2,1,3,4)
-
- # print(analysis_temp.shape)
-
-
- train_size = 1952 # 前60%
- valid_size = 2624 # 中间20% 作为验证 剩下的20%的作为测试
-
- Q_net = (slfh + sshf + ssr + str)/86400
- # print(Q_net)
- # hycom_temp = hycom_temp.transpose(0, 2, 1, 3, 4)
- hycom_temp = hycom_temp.reshape(-1,1,41,81,7)
- hycom_temp = torch.Tensor(hycom_temp)
- hycom_temp_train = hycom_temp[0:train_size,:,:,:,:]
- hycom_temp_valid = hycom_temp[train_size:valid_size,:,:,:,:]
- hycom_temp_test = hycom_temp[valid_size:,:,:,:,:]
-
-
- # slfh = slfh.reshape(-1,1,1,7,41,201)
- # slfh = torch.Tensor(slfh)
- # slfh_train = slfh[0:train_size,:,:,:,:]
- # slfh_valid = slfh[train_size:valid_size,:,:,:,:]
- # slfh_test = slfh[valid_size:,:,:,:,:]
- #
- # sshf = sshf.reshape(-1,1,1,7,41,201)
- # sshf = torch.Tensor(sshf)
- # sshf_train = sshf[0:train_size,:,:,:,:]
- # sshf_valid = sshf[train_size:valid_size,:,:,:,:]
- # sshf_test = sshf[valid_size:,:,:,:,:]
- #
- #
- # ssr = ssr.reshape(-1,1,1,7,41,201)
- # ssr = torch.Tensor(ssr)
- # ssr_train = ssr[0:train_size,:,:,:,:]
- # ssr_valid = ssr[train_size:valid_size,:,:,:,:]
- # ssr_test = ssr[valid_size:,:,:,:,:]
- #
- #
- # str = str.reshape(-1,1,1,7,41,201)
- # str = torch.Tensor(str)
- # str_train = str[0:train_size,:,:,:,:]
- # str_valid = str[train_size:valid_size,:,:,:,:]
- # str_test = str[valid_size:,:,:,:,:]
-
- Q_net = Q_net.reshape(-1,1,41,81,7)
- Q_net = torch.Tensor(Q_net)
- Q_net_train = Q_net[0:train_size,:,:,:,:]
- Q_net_valid = Q_net[train_size:valid_size,:,:,:,:]
- Q_net_test = Q_net[valid_size:,:,:,:,:]
-
-
- mld = mld.reshape(-1,1,41,81,7)
- mld = torch.Tensor(mld)
- mld_train = mld[0:train_size,:,:,:,:]
- mld_valid = mld[train_size:valid_size,:,:,:,:]
- mld_test = mld[valid_size:,:,:,:,:]
-
-
- analysis_temp = analysis_temp.reshape(-1,10,41,81)
- analysis_temp = torch.Tensor(analysis_temp)
- # analysis_temp_train = analysis_temp[0:train_size,:,:,:,:]
- # analysis_temp_valid = analysis_temp[train_size:valid_size,:,:,:,:]
- # analysis_temp_test = analysis_temp[valid_size:,:,:,:,:]
-
- # u= u.transpose(0, 2, 1, 3, 4)
- u = u.reshape(-1,1,41,81,7)
- u = torch.Tensor(u)
- u_train = u[0:train_size,:,:,:,:]
- u_valid = u[train_size:valid_size,:,:,:,:]
- u_test = u[valid_size:,:,:,:,:]
-
- # v= v.transpose(0, 2, 1, 3, 4)
- v = v.reshape(-1,1,41,81,7)
- v = torch.Tensor(v)
- v_train = v[0:train_size,:,:,:,:]
- v_valid = v[train_size:valid_size,:,:,:,:]
- v_test = v[valid_size:,:,:,:,:]
-
-
- T_d = T_d.reshape(-1,1,41,81,7)
- T_d = torch.Tensor(T_d)
- T_d_train = T_d[0:train_size,:,:,:,:]
- T_d_valid = T_d[train_size:valid_size,:,:,:,:]
- T_d_test = T_d[valid_size:,:,:,:,:]
-
-
- u_d = u_d.reshape(-1,1,41,81,7)
- u_d = torch.Tensor(u_d)
- u_d_train = u_d[0:train_size,:,:,:,:]
- u_d_valid = u_d[train_size:valid_size,:,:,:,:]
- u_d_test = u_d[valid_size:,:,:,:,:]
-
-
- v_d = v_d.reshape(-1,1,41,81,7)
- v_d = torch.Tensor(v_d)
- v_d_train = v_d[0:train_size,:,:,:,:]
- v_d_valid = v_d[train_size:valid_size,:,:,:,:]
- v_d_test = v_d[valid_size:,:,:,:,:]
-
-
- xx = xx.reshape(-1,1,41,81,7)
- xx = torch.Tensor(xx)
- xx_train = xx[0:train_size,:,:,:,:]
- xx_valid = xx[train_size:valid_size,:,:,:,:]
- xx_test = xx[valid_size:,:,:,:,:]
-
-
- yy = yy.reshape(-1,1,41,81,7)
- yy = torch.Tensor(yy)
- yy_train = yy[0:train_size,:,:,:,:]
- yy_valid = yy[train_size:valid_size,:,:,:,:]
- yy_test = yy[valid_size:,:,:,:,:]
-
- print('hycom_temp_train.shape:{}'.format(hycom_temp_train.shape))
- print('Q_net_train.shape:{}'.format(Q_net_train.shape))
- print('mld_train.shape:{}'.format(mld_train.shape))
- print('u_train.shape:{}'.format(u_train.shape))
- print('v_train.shape:{}'.format(v_train.shape))
- print('T_d_train.shape:{}'.format(T_d_train.shape))
- print('v_d_train.shape:{}'.format(v_d_train.shape))
-
-
- train_data = torch.cat((hycom_temp_train, Q_net_train, mld_train, u_train, v_train, T_d_train, v_d_train), dim=1) # train_data.shape:torch.Size([1952, 7, 32, 41, 81]) # 数据 第一维时间 第二维深度和变量 第三维7dat seq 4 5lat lon
-
-
- print('train_data.shape:{}'.format(train_data.shape))
-
- valid_data = torch.cat((hycom_temp_valid, Q_net_valid, mld_valid, u_valid, v_valid, T_d_valid, v_d_valid), dim=1) # valid_data.shape:torch.Size([672, 7, 32, 41, 81])
- print('valid_data.shape:{}'.format(valid_data.shape)) #([326, 3, 12, 15])
-
- test_data = torch.cat((hycom_temp_test, Q_net_test, mld_test, u_test, v_test, T_d_test, v_d_test), dim=1) #
- print('test_data.shape:{}'.format(test_data.shape)) #([326, 3, 12, 15])
-
- train_label = analysis_temp[7:train_size + 7,0,:,:] # train_label.shape:torch.Size([1952, 7, 7, 41, 81])
-
- valid_label = analysis_temp[train_size + 7:valid_size + 7,0,:40,:80] # valid_label.shape:torch.Size([672, 7, 7, 41, 81])
-
-
- test_label = analysis_temp[valid_size + 7: 3646,0,:,:] # test_label.shape:torch.Size([650, 7, 7, 41, 81])
- print('train_label.shape:{}'.format(train_label.shape)) # test_label.shape:torch.Size([644, 1, 10, 41, 81])
- print('valid_label.shape:{}'.format(valid_label.shape)) # test_label.shape:torch.Size([644, 1, 10, 41, 81])
-
- print('test_label.shape:{}'.format(test_label.shape)) # test_label.shape:torch.Size([644, 1, 10, 41, 81])
-
- train_label = torch.Tensor(train_label)
- valid_label = torch.Tensor(valid_label)
- test_label = torch.Tensor(test_label)
- print('train_label.shape:{}'.format(train_label.shape)) #train_label.shape:torch.Size([224, 1, 6, 27])
- test_label11 = test_label
- valid_label11 = valid_label
-
- #构建数据管道
- class MyDataset(Dataset):
- def __init__(self, data, label):
- self.data = torch.Tensor(data)
- self.label = torch.Tensor(label)
-
- def __len__(self):
- return len(self.label)
-
- def __getitem__(self, idx):
- return self.data[idx], self.label[idx]
-
-
- batch_size1 = 32
- batch_size2 = 32
- batch_size3 = 3000
- # trainset = MyDataset(input_data, sst)
- # trainset = MyDataset(train_data,train_label)
- # trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
- # batch_size = 1826
- trainset = MyDataset(train_data, train_label)
- trainloader = DataLoader(trainset, batch_size=batch_size1, shuffle=True, drop_last=False,pin_memory=True, num_workers=8)
-
- validset = MyDataset(valid_data, valid_label)
- validloader = DataLoader(validset, batch_size=batch_size2, shuffle=True, drop_last=False,pin_memory=True, num_workers=8)
-
- testset = MyDataset(test_data, test_label)
- testloader = DataLoader(testset, batch_size=batch_size3, shuffle=False, drop_last=False,pin_memory=True, num_workers=8)
-
-
- model_weights1 = '/model/epo200_lay3_lr0.001_e{}_Unet_3D_1day_model_weights.pth'.format(seed)
- torch.backends.cudnn.enabled = False
-
- model = DMFNet(c=7, groups=16, norm='bn', num_classes=4).cuda()
-
- criterion = nn.MSELoss()
- # 定义优化器
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
- epochs = 300
- train_losses, valid_losses = [], []
- # best_loss = 2
- best_score = float('inf')
- best_score1 = float('inf')
-
- pred_val= np.zeros((672,1,40,80))
-
- sores = []
- def rmse(y_true, y_preds):
- return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))
-
-
- for epoch in range(epochs):
- print('Epoch: {}/{}'.format(epoch + 1, epochs))
- # print(var_y)
- #模型训练
- model.train()
- losses = 0
- loss1 = 0
- for i, data in tqdm(enumerate(trainloader)):
- data, label = data
- data = data.cuda()
- label = label.cuda()
-
- # print('data1.shape:{}'.format(data1.shape)) # data1.shape:torch.Size([32, 10, 6, 40, 200]) surface_latent_heat_flux_train surface_sensible_heat_flux_train surface_net_radiation_train evaporation_train total_precipitation_train mld_train
- # print('data2.shape:{}'.format(data2.shape)) # data2.shape:torch.Size([32, 10, 3, 40, 200]) sss_surface_train u_surface_train v_surface_train
- # print('data3.shape:{}'.format(data3.shape)) # data3.shape:torch.Size([32, 10, 40, 200]) S_d_train
- # print('data4.shape:{}'.format(data4.shape)) # data4.shape:torch.Size([32, 10, 2, 40, 200]) u_d_train v_d_train
- # print('label.shape:{}'.format(label.shape)) #label.shape:torch.Size([32, 14, 40, 200])
- # print('data.shape:{}'.format(data.shape)) # label.shape:torch.Size([32, 14, 40, 200])
- data = data[:,:,:40,:80,:]
- label = label[:,:40,:80]
-
- out = model(data)
- # print('out.shape:{}'.format(out.shape))
- # print('label.shape:{}'.format(label.shape))
- # print(out)
- # 偏S/偏t - (E - P) * (S / h) - [u * 偏S/偏x + v * 偏S/偏y ] + H * (w_h + dh/dt * ((S - S_h) / h)) = 0 loss1
- # 偏T/偏t - Q / (p * C_p * h_m) - u * (偏T/偏x) - v * (偏T/偏y) + w_e * ((T - T_d) / h) = 0
-
- # sst_label = label[:,0,:,:]
- # sss_label = label[:,1,:,:]
-
- # sst_out = out[:,0,:,:]
- # sss_out = out[:,1,:,:]
- out = out.reshape(-1,1,40,80)
- label = label.reshape(-1,1,40,80)
- loss = criterion(out, label)
- # loss2 = criterion(sss_out, sss_label)
-
-
- losses += loss
-
- loss.backward()
- optimizer.step()
- train_loss = losses / len(trainloader)
- train_losses.append(train_loss)
-
- print('Training Loss: {:.10f}'.format((train_loss)))
-
- model.eval()
- losses = 0
-
- for i, data in tqdm(enumerate(validloader)):
- data, label = data
- data = data.cuda()
- label = label.cuda()
- optimizer.zero_grad()
- data = data[:,:,:40,:80,:]
- label = label[:,:40,:80]
-
- out = model(data)
- out = out.reshape(-1,1,40,80)
- label = label.reshape(-1,1,40,80)
- loss = criterion(out, label)
-
- losses += float(loss)
-
- out1 = out.detach().cpu().numpy()
- pred_val[i * batch_size2:(i + 1) * batch_size2] = np.array(out1)
-
- valid_loss = losses / len(validloader)
- valid_losses.append(valid_loss)
-
- valid_label1 = valid_label.reshape(-1,1)
- preds1 = pred_val.reshape(-1,1)
-
- s = rmse(valid_label1,preds1)
- sores.append(s)
- print('Score: {:.3f}'.format(s))
-
- if valid_loss < best_score1: # 求s的最小值 ---》最大值反过来 inf符号也要反过来
- best_score1 = valid_loss
- checkpoint = {'best_score': valid_loss,
- 'state_dict': model.state_dict()}
- torch.save(checkpoint, model_weights1) # if valid_loss < best_loss:
- best_loss = valid_loss
- torch.save(model.state_dict(),
- '/model/Unet_3D_lr0.001_model_200_layer3_1day_e{}.pt'.format(seed))
-
- print(sores)
- print(best_score)
- print(s)
|