|
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
- from functools import partial
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_, to_3tuple
- import torch.fft
- from params import get_args
- from torch.utils.checkpoint import checkpoint_sequential
- import random
- from torch.utils.data import Dataset, DataLoader
- from tqdm import tqdm
- from sklearn.metrics import mean_squared_error
- from sklearn.preprocessing import MinMaxScaler
- import batchnorm
-
-
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- try:
- from batchnorm import SynchronizedBatchNorm3d
- except:
- pass
-
- def normalization(planes, norm='bn'):
- if norm == 'bn':
- m = nn.BatchNorm3d(planes)
- elif norm == 'gn':
- m = nn.GroupNorm(4, planes)
- elif norm == 'in':
- m = nn.InstanceNorm3d(planes)
- elif norm == 'sync_bn':
- m = SynchronizedBatchNorm3d(planes)
- else:
- raise ValueError('normalization type {} is not supported'.format(norm))
- return m
-
- class Conv3d_Block(nn.Module):
- def __init__(self,num_in,num_out,kernel_size=1,stride=1,g=1,padding=None,norm=None):
- super(Conv3d_Block, self).__init__()
- if padding == None:
- padding = (kernel_size - 1) // 2
- self.bn = normalization(num_in,norm=norm)
- self.act_fn = nn.ReLU(inplace=True)
- self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding,stride=stride, groups=g, bias=False)
-
- def forward(self, x): # BN + Relu + Conv
- h = self.act_fn(self.bn(x))
- h = self.conv(h)
- return h
-
-
- class DilatedConv3DBlock(nn.Module):
- def __init__(self, num_in, num_out, kernel_size=(1,1,1), stride=1, g=1, d=(1,1,1), norm=None):
- super(DilatedConv3DBlock, self).__init__()
- assert isinstance(kernel_size,tuple) and isinstance(d,tuple)
-
- padding = tuple(
- [(ks-1)//2 *dd for ks, dd in zip(kernel_size, d)]
- )
-
- self.bn = normalization(num_in, norm=norm)
- self.act_fn = nn.ReLU(inplace=True)
- self.conv = nn.Conv3d(num_in,num_out,kernel_size=kernel_size,padding=padding,stride=stride,groups=g,dilation=d,bias=False)
-
- def forward(self, x):
- h = self.act_fn(self.bn(x))
- h = self.conv(h)
- return h
-
-
- class MFunit(nn.Module):
- def __init__(self, num_in, num_out, g=1, stride=1, d=(1,1),norm=None):
- """ The second 3x3x1 group conv is replaced by 3x3x3.
- :param num_in: number of input channels
- :param num_out: number of output channels
- :param g: groups of group conv.
- :param stride: 1 or 2
- :param d: tuple, d[0] for the first 3x3x3 conv while d[1] for the 3x3x1 conv
- :param norm: Batch Normalization
- """
- super(MFunit, self).__init__()
- num_mid = num_in if num_in <= num_out else num_out
- self.conv1x1x1_in1 = Conv3d_Block(num_in,num_in//4,kernel_size=1,stride=1,norm=norm)
- self.conv1x1x1_in2 = Conv3d_Block(num_in//4,num_mid,kernel_size=1,stride=1,norm=norm)
- self.conv3x3x3_m1 = DilatedConv3DBlock(num_mid,num_out,kernel_size=(3,3,3),stride=stride,g=g,d=(d[0],d[0],d[0]),norm=norm) # dilated
- self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(3,3,1),stride=1,g=g,d=(d[1],d[1],1),norm=norm)
- # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(1,3,3),stride=1,g=g,d=(1,d[1],d[1]),norm=norm)
-
- # skip connection
- if num_in != num_out or stride != 1:
- if stride == 1:
- self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0,norm=norm)
- if stride == 2:
- # if MF block with stride=2, 2x2x2
- self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2,padding=0, norm=norm) # params
-
- def forward(self, x):
- x1 = self.conv1x1x1_in1(x)
- x2 = self.conv1x1x1_in2(x1)
- x3 = self.conv3x3x3_m1(x2)
- x4 = self.conv3x3x3_m2(x3)
-
- shortcut = x
-
- if hasattr(self,'conv1x1x1_shortcut'):
- shortcut = self.conv1x1x1_shortcut(shortcut)
- if hasattr(self,'conv2x2x2_shortcut'):
- shortcut = self.conv2x2x2_shortcut(shortcut)
-
- return x4 + shortcut
-
- class DMFUnit(nn.Module):
- # weighred add
- def __init__(self, num_in, num_out, g=1, stride=1,norm=None,dilation=None):
- super(DMFUnit, self).__init__()
- self.weight1 = nn.Parameter(torch.ones(1))
- self.weight2 = nn.Parameter(torch.ones(1))
- self.weight3 = nn.Parameter(torch.ones(1))
-
- num_mid = num_in if num_in <= num_out else num_out
-
- self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm)
- self.conv1x1x1_in2 = Conv3d_Block(num_in // 4,num_mid,kernel_size=1, stride=1, norm=norm)
-
- self.conv3x3x3_m1 = nn.ModuleList()
- if dilation == None:
- dilation = [1,2,3]
- for i in range(3):
- self.conv3x3x3_m1.append(
- DilatedConv3DBlock(num_mid,num_out, kernel_size=(3, 3, 3), stride=stride, g=g, d=(dilation[i],dilation[i], dilation[i]),norm=norm)
- )
-
- # It has not Dilated operation
- self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
- # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(1, 3, 3), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
-
- # skip connection
- if num_in != num_out or stride != 1:
- if stride == 1:
- self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm)
- if stride == 2:
- self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0, norm=norm)
-
-
- def forward(self, x):
- # print('x1.shape:{}'.format(x.shape))
- x1 = self.conv1x1x1_in1(x)
- x2 = self.conv1x1x1_in2(x1)
- x3 = self.weight1*self.conv3x3x3_m1[0](x2) + self.weight2*self.conv3x3x3_m1[1](x2) + self.weight3*self.conv3x3x3_m1[2](x2)
- x4 = self.conv3x3x3_m2(x3)
- shortcut = x
- if hasattr(self, 'conv1x1x1_shortcut'):
- shortcut = self.conv1x1x1_shortcut(shortcut)
- if hasattr(self, 'conv2x2x2_shortcut'):
- shortcut = self.conv2x2x2_shortcut(shortcut)
- return x4 + shortcut
-
-
- class MFNet(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=13,n=32,channels=128,groups = 16,norm='bn', num_classes=4):
- super(MFNet, self).__init__()
-
- #时间维度变大
- # self.time_dimison = nn.Conv3d( c, n, kernel_size=(1,1,1), padding=(0,0,3), stride=(1,1,1), bias=False)
-
- # Entry flow
- self.encoder_block1 = nn.Conv3d(13, n, kernel_size=3, padding=1, stride=2, bias=False)# H//2
- self.encoder_block2 = nn.Sequential(
- MFunit(n, channels, g=groups, stride=2, norm=norm),# H//4 down
- MFunit(channels, channels, g=groups, stride=1, norm=norm),
- MFunit(channels, channels, g=groups, stride=1, norm=norm)
- )
- #
- self.encoder_block3 = nn.Sequential(
- MFunit(channels, channels*2, g=groups, stride=2, norm=norm), # H//8
- MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm),
- MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm)
- )
-
- self.encoder_block4 = nn.Sequential(# H//8,channels*4
- MFunit(channels*2, channels*3, g=groups, stride=2, norm=norm), # H//16
- MFunit(channels*3, channels*3, g=groups, stride=1, norm=norm),
- MFunit(channels*3, channels*2, g=groups, stride=1, norm=norm),
- )
-
- self.upsample1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//8
- self.decoder_block1 = MFunit(channels*2+channels*2, channels*2, g=groups, stride=1, norm=norm)
-
- self.upsample2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//4
- self.decoder_block2 = MFunit(channels*2 + channels, channels, g=groups, stride=1, norm=norm)
-
- self.upsample3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//2
- self.decoder_block3 = MFunit(channels + n, n, g=groups, stride=1, norm=norm)
- self.upsample4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H
- self.seg = nn.Conv3d(n, num_classes, kernel_size=1, padding=0,stride=1,bias=False)
-
- self.softmax = nn.Softmax(dim=1)
-
- # Initialization
- for m in self.modules():
- if isinstance(m, nn.Conv3d):
- torch.nn.init.torch.nn.init.kaiming_normal_(m.weight) #
- elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, SynchronizedBatchNorm3d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, x):
- # Encoder
- x = self.time_exten_dimison(x)
-
- # print('x0.shape:{}'.format(x.shape)) # (1,4,128,128,128)) ---> (1,4,128,128,128)
- x1 = self.encoder_block1(x)# H//2 down # (1,4,128,128,128) ---> torch.Size([1, 32, 64, 64, 64])
- # print('x8.shape:{}'.format(x1.shape))
- x2 = self.encoder_block2(x1)# H//4 down # ([1, 32, 64, 64, 64]) ---> torch.Size([1, 128, 32, 32, 32])
- # print(x2.shape)
- x3 = self.encoder_block3(x2)# H//8 down # torch.Size([1, 128, 32, 32, 32]) ---> torch.Size([1, 256, 16, 16, 16])
- # print('x3.shape:{}'.format(x.shape)) # x4 = self.encoder_block4(x3) # H//16 # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 256, 8, 8, 8])
- # print(x4.shape)
- # Decoder
- # y1 = self.upsample1(x4)# H//8 # torch.Size([1, 256, 8, 8, 8]) ---> torch.Size([1, 256, 16, 16, 16])
- # print(y1.shape)
- # y1 = torch.cat([x3,y1],dim=1) # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 512, 16, 16, 16])
- # print(y1.shape)
- # y1 = self.decoder_block1(y1) # torch.Size([1, 512, 16, 16, 16]) ---> torch.Size([1, 256, 16, 16, 16])
- # print(y1.shape)
- y2 = self.upsample2(x3)# H//4 # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 256, 32, 32, 32])
- # print(y2.shape)
- y2 = torch.cat([x2,y2],dim=1) # torch.Size([1, 256, 32, 32, 32]) ---> torch.Size([1, 384, 32, 32, 32])
- # print(y2.shape)
- y2 = self.decoder_block2(y2) # torch.Size([1, 384, 32, 32, 32]) ---> torch.Size([1, 128, 32, 32, 32])
- # print(y2.shape)
- y3 = self.upsample3(y2)# H//2 # torch.Size([1, 128, 32, 32, 32]) ---> torch.Size([1, 128, 64, 64, 64])
- # print(y3.shape)
- y3 = torch.cat([x1,y3],dim=1) # torch.Size([1, 128, 64, 64, 64]) ---> torch.Size([1, 160, 64, 64, 64])
- # print(y3.shape)
- y3 = self.decoder_block3(y3) # torch.Size([1, 160, 64, 64, 64]) --> torch.Size([1, 32, 64, 64, 64])
- # print(y3.shape)
- y4 = self.upsample4(y3) # torch.Size([1, 32, 64, 64, 64])---> torch.Size([1, 32, 128, 128, 128])
- # print(y4.shape)
- y4 = self.seg(y4) # torch.Size([1, 32, 128, 128, 128]) ---> torch.Size([1, 4, 128, 128, 128])
- # print('y4.shape:{}'.format(y4.shape))
- if hasattr(self,'softmax'):
- y4 = self.softmax(y4)
- # print('y4.shape:{}'.format(y4.shape))
- y4 = self.time_descent_dimison(y4)
- return y4
-
- class time_exten_dimison(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=6,n=13,channels=128,groups = 16,norm='bn', num_classes=4):
- super(time_exten_dimison, self).__init__()
-
- #时间维度变大
- self.time_exten_dimison = nn.Conv3d( c, n, kernel_size=(1,1,3), padding=(0,0,0), stride=(1,1,1), bias=False)
-
- def forward(self, x):
- # print('x00.shape:{}'.format(x.shape))
- x = self.time_exten_dimison(x)
-
- return x
-
-
- class time_descent_dimison(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=4,n=1,channels=128,groups = 16,norm='bn', num_classes=4):
- super(time_descent_dimison, self).__init__()
-
- #时间维度变大
- self.time_descent_dimison = nn.Conv3d( c, n, kernel_size=(1,1,5), padding=(0,0,0), stride=(1,1,4), bias=False)
-
- def forward(self, x):
- # print('x00.shape:{}'.format(x.shape))
- x = self.time_descent_dimison(x)
-
- return x
-
-
- class DMFNet(MFNet): # softmax
- # [128] Flops: 27.045G & Params: 3.88M
- def __init__(self, c=13,n=32,channels=128, groups=16,norm='bn', num_classes=4):
- super(DMFNet, self).__init__(c,n,channels,groups, norm, num_classes)
-
- self.time_exten_dimison = time_exten_dimison()
- # print('xxxx.shape:{}'.format(x.shape))
- self.encoder_block2 = nn.Sequential(
- DMFUnit(n, channels, g=groups, stride=2, norm=norm,dilation=[1,2,3]),# H//4 down
- DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3]), # Dilated Conv 3
- DMFUnit(channels, channels, g=groups, stride=1, norm=norm,dilation=[1,2,3])
- )
-
- self.encoder_block3 = nn.Sequential(
- DMFUnit(channels, channels*2, g=groups, stride=2, norm=norm,dilation=[1,2,3]), # H//8
- DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3]),# Dilated Conv 3
- DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm,dilation=[1,2,3])
- )
-
- self.time_descent_dimison = time_descent_dimison()
- for seed in range(2023,2024):
- for date_append in range(1,15):
- def setup_seed(seed):
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
- # 设置随机数种子
- setup_seed(seed)
-
- #需要 mld u v sss temp 降水 蒸发 混合层下的盐度
-
- data = np.load(r'/dataset/10day_for_14day_all_variables_deep1_pacific_10_19_SSS_learn.npz')
- print(data.files)
- # 2000-01-01 ---> 2019-12-31
- # surface_latent_heat_flux = data['surface_latent_heat_flux'][:] # (7305, 10, 40, 200)
- # surface_sensible_heat_flux = data['surface_sensible_heat_flux'][:] # (9851, 10, 40, 200)
- # surface_net_radiation = data['surface_net_radiation'][:] # (9851, 10, 40, 200)
- evaporation = data['evaporation'][:]
- total_precipitation = data['total_precipitation'][:]
- mld = data['mld'][:]
- # sst_surface = data['sst_surface'][:]
- sss_surface = data['sss_deep'][:]
- u_surface = data['u_deep'][:]
- v_surface = data['v_deep'][:]
- sss_surface_label = data['sss_deep_label'][:]
- # print(sst_surface)
- # print(np.max(sst_surface))
- # print(np.min(sst_surface))
-
- # print(sst_surface.shape)
- print(sss_surface.shape) #(3642, 10, 40, 200)
-
- print(sss_surface_label.shape) #(3638, 40, 200)
- evaporation = evaporation.transpose(0,2,3,1)
- total_precipitation = total_precipitation.transpose(0,2,3,1)
- mld = mld.transpose(0,2,3,1)
- sss_surface = sss_surface.transpose(0,2,3,1)
- u_surface = u_surface.transpose(0,2,3,1)
- v_surface = v_surface.transpose(0,2,3,1)
- # scaler = MinMaxScaler()
- # surface_latent_heat_flux1 = surface_latent_heat_flux.reshape(-1,1)
- # surface_sensible_heat_flux1 = surface_sensible_heat_flux.reshape(-1,1)
- # surface_net_radiation1 = surface_net_radiation.reshape(-1,1)
- # evaporation1 = evaporation.reshape(-1,1)
- # total_precipitation1 = total_precipitation.reshape(-1,1)
- # mld1 = mld.reshape(-1,1)
- # sst_surface1 = sst_surface.reshape(-1,1)
- # sss_surface1 = sss_surface.reshape(-1,1)
- # u_surface1 = u_surface.reshape(-1,1)
- # v_surface1 = v_surface.reshape(-1,1)
- # T_d1 = T_d.reshape(-1,1)
- # S_d1 = S_d.reshape(-1,1)
- # u_d1 = u_d.reshape(-1,1)
- # v_d1 = v_d.reshape(-1,1)
- # sst_surface_label1 = sst_surface_label.reshape(-1,1)
- # sss_surface_label1 = sss_surface_label.reshape(-1,1)
-
-
- # surface_latent_heat_flux = scaler.fit_transform(surface_latent_heat_flux1).reshape(surface_latent_heat_flux.shape)
- # surface_sensible_heat_flux = scaler.fit_transform(surface_sensible_heat_flux1).reshape(surface_sensible_heat_flux.shape)
- # surface_net_radiation = scaler.fit_transform(surface_net_radiation1).reshape(surface_net_radiation.shape)
- # evaporation = scaler.fit_transform(evaporation1).reshape(evaporation.shape)
- # total_precipitation = scaler.fit_transform(total_precipitation1).reshape(total_precipitation.shape)
- # mld = scaler.fit_transform(mld1).reshape(mld.shape)
- # sst_surface = scaler.fit_transform(sst_surface1).reshape(sst_surface.shape)
- # sss_surface = scaler.fit_transform(sss_surface1).reshape(sss_surface.shape)
- # u_surface = scaler.fit_transform(u_surface1).reshape(u_surface.shape)
- # v_surface = scaler.fit_transform(v_surface1).reshape(v_surface.shape)
- # T_d = scaler.fit_transform(T_d1).reshape(T_d.shape)
- # S_d = scaler.fit_transform(S_d1).reshape(S_d.shape)
- # u_d = scaler.fit_transform(u_d1).reshape(u_d.shape)
- # v_d = scaler.fit_transform(v_d1).reshape(v_d.shape)
- # sst_surface_label = scaler.fit_transform(sst_surface_label1).reshape(sst_surface_label.shape)
- # sss_surface_label = scaler.fit_transform(sss_surface_label1).reshape(sss_surface_label.shape)
-
- # Q_net = surface_latent_heat_flux + surface_sensible_heat_flux + surface_net_radiation
-
- # Q_net = Q_net/86400
-
-
- train_size = 2208
- valid_size = 2912 # 前20% 作为验证 剩下的20%的作为测试
- # surface_latent_heat_flux = surface_latent_heat_flux + surface_sensible_heat_flux + surface_net_radiation
-
- # Q_net = Q_net.reshape(-1, 1, 10, 40, 200)
- # Q_net = torch.Tensor(Q_net)
- # Q_net_train = Q_net[0:train_size, :, :, :, :]
- # Q_net_valid = Q_net[train_size:valid_size, :, :, :, :]
- # # Q_net_test = Q_net[valid_size:,:,:,:,:]
-
- # surface_latent_heat_flux = surface_latent_heat_flux.reshape(-1, 1, 10, 40, 200)
- # surface_latent_heat_flux = torch.Tensor(surface_latent_heat_flux)
- # surface_latent_heat_flux_train = surface_latent_heat_flux[0:train_size, :, :, :, :]
- # surface_latent_heat_flux_valid = surface_latent_heat_flux[train_size:valid_size, :, :, :, :]
- # # surface_latent_heat_flux_test = surface_latent_heat_flux[valid_size:,:,:,:,:]
-
- # surface_sensible_heat_flux = surface_sensible_heat_flux.reshape(-1, 1, 10, 40, 200)
- # surface_sensible_heat_flux = torch.Tensor(surface_sensible_heat_flux)
- # surface_sensible_heat_flux_train = surface_sensible_heat_flux[0:train_size, :, :, :, :]
- # surface_sensible_heat_flux_valid = surface_sensible_heat_flux[train_size:valid_size, :, :, :, :]
- # # # # surface_sensible_heat_flux_test = surface_sensible_heat_flux[valid_size:,:,:,:,:]
-
- # surface_net_radiation = surface_net_radiation.reshape(-1, 1, 10, 40, 200)
- # surface_net_radiation = torch.Tensor(surface_net_radiation)
- # surface_net_radiation_train = surface_net_radiation[0:train_size, :, :, :, :]
- # surface_net_radiation_valid = surface_net_radiation[train_size:valid_size, :, :, :, :]
- # # # surface_net_radiation_test = surface_net_radiation[valid_size:,:,:,:,:]
-
- evaporation = evaporation.reshape(-1, 1, 40, 200, 10)
- evaporation = torch.Tensor(evaporation)
- evaporation_train = evaporation[0:train_size, :, :, :, :]
- evaporation_valid = evaporation[train_size:valid_size, :, :, :, :]
- # evaporation_test = evaporation[valid_size:,:,:,:,:]
-
- total_precipitation = total_precipitation.reshape(-1, 1, 40, 200, 10)
- total_precipitation = torch.Tensor(total_precipitation)
- total_precipitation_train = total_precipitation[0:train_size, :, :, :, :]
- total_precipitation_valid = total_precipitation[train_size:valid_size, :, :, :, :]
- # total_precipitation_test = total_precipitation[valid_size:,:,:,:,:]
-
- mld = mld.reshape(-1, 1, 40, 200, 10)
- mld = torch.Tensor(mld)
- mld_train = mld[0:train_size, :, :, :, :]
- mld_valid = mld[train_size:valid_size, :, :, :, :]
- # mld_test = mld[valid_size:,:,:,:,:]
-
- # sst_surface = sst_surface.reshape(-1, 1, 10, 40, 200)
- # sst_surface = torch.Tensor(sst_surface)
- # sst_surface_train = sst_surface[0:train_size, :, :, :, :]
- # sst_surface_valid = sst_surface[train_size:valid_size, :, :, :, :]
- # # sst_surface_test = sst_surface[valid_size:,:,:,:,:]
-
- sss_surface = sss_surface.reshape(-1, 1, 40, 200, 10)
- sss_surface = torch.Tensor(sss_surface)
- sss_surface_train = sss_surface[0:train_size, :, :, :, :]
- sss_surface_valid = sss_surface[train_size:valid_size, :, :, :, :]
- # sss_surface_test = sss_surface[valid_size:,:,:,:,:]
-
- u_surface = u_surface.reshape(-1, 1, 40, 200, 10)
- u_surface = torch.Tensor(u_surface)
- u_surface_train = u_surface[0:train_size, :, :, :, :]
- u_surface_valid = u_surface[train_size:valid_size, :, :, :, :]
- # u_surface_test = u_surface[valid_size:,:,:,:,:]
-
- v_surface = v_surface.reshape(-1, 1, 40, 200, 10)
- v_surface = torch.Tensor(v_surface)
- v_surface_train = v_surface[0:train_size, :, :, :, :]
- v_surface_valid = v_surface[train_size:valid_size, :, :, :, :]
- # # v_surface_test = v_surface[valid_size:,:,:,:,:]
-
- # T_d = T_d.reshape(-1, 1, 10, 40, 200)
- # T_d = torch.Tensor(T_d)
- # T_d_train = T_d[0:train_size, :, :, :, :]
- # T_d_valid = T_d[train_size:valid_size, :, :, :, :]
- # # T_d_test = T_d[valid_size:,:,:,:,:]
-
- # S_d = S_d.reshape(-1, 1, 10, 40, 200)
- # S_d = torch.Tensor(S_d)
- # S_d_train = S_d[0:train_size, :, :, :, :]
- # S_d_valid = S_d[train_size:valid_size, :, :, :, :]
- # # S_d_test = S_d[valid_size:,:,:,:,:]
-
- # u_d = u_d.reshape(-1, 1, 10, 40, 200)
- # u_d = torch.Tensor(u_d)
- # u_d_train = u_d[0:train_size, :, :, :, :]
- # u_d_valid = u_d[train_size:valid_size, :, :, :, :]
- # # u_d_test = u_d[valid_size:,:,:,:,:]
-
- # v_d = v_d.reshape(-1, 1, 10, 40, 200)
- # v_d = torch.Tensor(v_d)
- # v_d_train = v_d[0:train_size, :, :, :, :]
- # v_d_valid = v_d[train_size:valid_size, :, :, :, :]
- # # # v_d_test = v_d[valid_size:,:,:,:,:]
-
- # # xx = xx.reshape(-1, 1, 10, 40, 200)
- # # xx = torch.Tensor(xx)
- # # xx_train = xx[0:train_size,:,:,:,:]
- # # xx_valid = xx[train_size:valid_size,:,:,:,:]
- # # # xx_test = xx[valid_size:,:,:,:,:]
-
- # # yy = yy.reshape(-1, 1, 10, 40, 200)
- # # yy = torch.Tensor(yy)
- # # yy_train = yy[0:train_size,:,:,:,:]
- # # yy_valid = yy[train_size:valid_size,:,:,:,:]
- # # yy_test = yy[valid_size:,:,:,:,:]
-
- # train_data = torch.cat((surface_latent_heat_flux_train, surface_sensible_heat_flux_train, surface_net_radiation_train,
- # evaporation_train, total_precipitation_train, mld_train, sst_surface_train, sss_surface_train,
- # u_surface_train, v_surface_train, T_d_train, S_d_train, u_d_train, v_d_train,
- # xx_train, yy_train), dim=2) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
-
- # valid_data = torch.cat((surface_latent_heat_flux_valid, surface_sensible_heat_flux_valid, surface_net_radiation_valid,
- # evaporation_valid, total_precipitation_valid, mld_valid, sst_surface_valid, sss_surface_valid,
- # u_surface_valid, v_surface_valid, T_d_valid, S_d_valid, u_d_valid, v_d_valid,
- # xx_valid, yy_valid), dim=2)
-
- # test_data = torch.cat((surface_latent_heat_flux_test, surface_sensible_heat_flux_test, surface_net_radiation_test,
- # evaporation_test, total_precipitation_test, mld_test, sst_surface_test, sss_surface_test,
- # u_surface_test, v_surface_test, T_d_test, S_d_test, u_d_test, v_d_test,
- # xx_test, yy_test), dim=2)
-
- train_data = torch.cat((evaporation_train, total_precipitation_train, sss_surface_train, mld_train, u_surface_train, v_surface_train
- ), dim=1) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
-
- valid_data = torch.cat((evaporation_valid, total_precipitation_valid, sss_surface_valid, mld_valid, u_surface_valid, v_surface_valid
- ), dim=1)
-
-
-
- # train_data = torch.cat((sst_surface_train, sss_surface_train,), dim=1) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
-
- # valid_data = torch.cat((sst_surface_valid, sss_surface_valid,), dim=1)
-
- print(train_data.shape)
- print(valid_data.shape)
-
- sss_train_label = sss_surface_label[10 + date_append:train_size + 10 + date_append,0, :, :]
- sss_valid_label = sss_surface_label[train_size + 10 + date_append: valid_size + 10 + date_append,0, :, :]
- # sss_test_label = sss_surface_label[valid_size + 10 + date_append:,:,:]
-
- # sst_train_label = sst_train_label.reshape(-1, 1, 40, 200)
- # sst_valid_label = sst_valid_label.reshape(-1, 1, 40, 200)
-
- sss_train_label = sss_train_label.reshape(-1, 1, 40, 200)
- sss_valid_label = sss_valid_label.reshape(-1, 1, 40, 200)
-
- # train_label = np.concatenate((sst_train_label, sss_train_label), axis=1)
- # valid_label = np.concatenate((sst_valid_label, sss_valid_label), axis=1)
- train_label = sss_train_label
- valid_label = sss_valid_label
- print(train_label.shape)
- # train_label = sst_train_label
- # valid_label = sst_valid_label
-
- # print('train_label.shape:{}'.format(train_label.shape))
- # print('valid_label.shape:{}'.format(valid_label.shape))
- # print('train_data.shape:{}'.format(train_data.shape))
- # print('valid_data.shape:{}'.format(valid_data.shape))
-
- # print('train_data.shape:{}'.format(train_data.shape)) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
- # print('valid_data.shape:{}'.format(valid_data.shape)) # valid_data.shape:torch.Size([1952, 10, 16, 40, 200])
- # print('test_data.shape:{}'.format(test_data.shape)) # test_data.shape:torch.Size([1979, 10, 16, 40, 200])
- # print('sst_train_label.shape:{}'.format(sst_train_label.shape)) # sst_train_label.shape:(5920, 14, 40, 200)
- # print('sst_valid_label.shape:{}'.format(sst_valid_label.shape)) # sst_valid_label.shape:(1952, 14, 40, 200)
- # print('sst_test_label.shape:{}'.format(sst_test_label.shape)) # sst_test_label.shape:(1961, 14, 40, 200)
- # print('sss_train_label.shape:{}'.format(sss_train_label.shape)) # sss_train_label.shape:(5920, 14, 40, 200)
- # print('sss_valid_label.shape:{}'.format(sss_valid_label.shape)) # sss_valid_label.shape:(1952, 14, 40, 200)
- # print('sss_test_label.shape:{}'.format(sss_test_label.shape)) # sss_test_label.shape:(1961, 14, 40, 200)
-
- # 构建数据管道
- class MyDataset(Dataset):
- def __init__(self, data, label):
- self.data = torch.Tensor(data)
- self.label = torch.Tensor(label)
-
- def __len__(self):
- return len(self.label)
-
- def __getitem__(self, idx):
- return self.data[idx], self.label[idx]
-
-
- batch_size1 = 32
- batch_size2 = 32
- batch_size3 = 3000
-
-
- trainset = MyDataset(train_data, train_label)
- trainloader = DataLoader(trainset, batch_size=batch_size1, shuffle=True, drop_last=False,pin_memory=True, num_workers=4)
-
- validset = MyDataset(valid_data, valid_label)
- validloader = DataLoader(validset, batch_size=batch_size2, shuffle=True, drop_last=False,pin_memory=True, num_workers=4)
-
- # testset = MyDataset(test_data, sss_test_label)
- # testloader = DataLoader(testset, batch_size=batch_size3, shuffle=False, drop_last=False,pin_memory=True, num_workers=0)
-
- # print('Qnet_train.shape:{}'.format(sst1_train.shape))
-
-
- model_weights1 = '/model/epo200_lay3_lr0.001_e{}_Unet_3D_{}day_model_weights_deep1.pth'.format(seed,date_append)
- torch.backends.cudnn.enabled = False
-
- model = DMFNet(c=6, groups=16, norm='sync_bn', num_classes=4).cuda()
-
- criterion = nn.MSELoss()
- # 定义优化器
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
- epochs = 200
- train_losses, valid_losses = [], []
- # best_loss = 2
- best_score = float('inf')
- best_score1 = float('inf')
-
- pred_val= np.zeros((704,1,40,200))
-
- sores = []
- def rmse(y_true, y_preds):
- return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))
-
-
- for epoch in range(epochs):
- print('Epoch: {}/{}'.format(epoch + 1, epochs))
- # print(var_y)
- #模型训练
- model.train()
- losses = 0
- loss1 = 0
- for data, label in tqdm(trainloader):
-
- data = data.cuda()
- label = label.cuda()
- optimizer.zero_grad()
- # print('data1.shape:{}'.format(data1.shape)) # data1.shape:torch.Size([32, 10, 6, 40, 200]) surface_latent_heat_flux_train surface_sensible_heat_flux_train surface_net_radiation_train evaporation_train total_precipitation_train mld_train
- # print('data2.shape:{}'.format(data2.shape)) # data2.shape:torch.Size([32, 10, 3, 40, 200]) sss_surface_train u_surface_train v_surface_train
- # print('data3.shape:{}'.format(data3.shape)) # data3.shape:torch.Size([32, 10, 40, 200]) S_d_train
- # print('data4.shape:{}'.format(data4.shape)) # data4.shape:torch.Size([32, 10, 2, 40, 200]) u_d_train v_d_train
- # print('label.shape:{}'.format(label.shape)) #label.shape:torch.Size([32, 14, 40, 200])
- # print('data.shape:{}'.format(data.shape)) # label.shape:torch.Size([32, 14, 40, 200])
- out = model(data)
- # print('out.shape:{}'.format(out.shape))
- # print('label.shape:{}'.format(label.shape))
- # print(out)
- # 偏S/偏t - (E - P) * (S / h) - [u * 偏S/偏x + v * 偏S/偏y ] + H * (w_h + dh/dt * ((S - S_h) / h)) = 0 loss1
- # 偏T/偏t - Q / (p * C_p * h_m) - u * (偏T/偏x) - v * (偏T/偏y) + w_e * ((T - T_d) / h) = 0
-
- # sst_label = label[:,0,:,:]
- # sss_label = label[:,1,:,:]
-
- # sst_out = out[:,0,:,:]
- # sss_out = out[:,1,:,:]
- out = out.reshape(-1,1,40,200)
- loss = criterion(out, label)
- # loss2 = criterion(sss_out, sss_label)
-
-
- losses += loss
-
- loss.backward()
- optimizer.step()
- train_loss = losses / len(trainloader)
- train_losses.append(train_loss)
-
- print('Training Loss: {:.10f}'.format((train_loss)))
-
- # model.eval()
- losses = 0
- with torch.no_grad():
- for i, data in tqdm(enumerate(validloader)):
- data, label = data
- data = data.cuda()
- label = label.cuda()
- optimizer.zero_grad()
-
-
- out = model(data)
- out = out.reshape(-1,1,40,200)
- loss = criterion(out, label)
-
- losses += float(loss)
-
- out1 = out.detach().cpu().numpy()
- pred_val[i * batch_size2:(i + 1) * batch_size2] = np.array(out1)
-
- valid_loss = losses / len(validloader)
- valid_losses.append(valid_loss)
-
- valid_label1 = valid_label.reshape(-1,1)
- preds1 = pred_val.reshape(-1,1)
-
- s = rmse(valid_label1,preds1)
- sores.append(s)
- print('Score: {:.3f}'.format(s))
-
- if valid_loss < best_score1: # 求s的最小值 ---》最大值反过来 inf符号也要反过来
- best_score1 = valid_loss
- checkpoint = {'best_score': valid_loss,
- 'state_dict': model.state_dict()}
- torch.save(checkpoint, model_weights1) # if valid_loss < best_loss:
- best_loss = valid_loss
- torch.save(model.state_dict(),
- '/model/Unet_3D_lr0.001_model_200_layer3_{}day_e{}_deep1.pt'.format(date_append,seed))
-
- print(sores)
- print(best_score)
- print(s)
|