|
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
- from functools import partial
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- import torch.fft
- from params import get_args
- from torch.utils.checkpoint import checkpoint_sequential
- import random
- from torch.utils.data import Dataset, DataLoader
- from tqdm import tqdm
- from sklearn.metrics import mean_squared_error
- from sklearn.preprocessing import MinMaxScaler
- import batchnorm
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
- import torch.nn as nn
- import torch.nn.functional as F
- import torch
-
-
-
-
-
-
- class unetConv2(nn.Module):
- def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
- super(unetConv2, self).__init__()
- self.n = n
- self.ks = ks
- self.stride = stride
- self.padding = padding
- s = stride
- p = padding
- if is_batchnorm:
- for i in range(1, n + 1):
- conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
- nn.BatchNorm2d(out_size),
- nn.ReLU(inplace=True), )
- setattr(self, 'conv%d' % i, conv)
- in_size = out_size
-
- else:
- for i in range(1, n + 1):
- conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
- nn.ReLU(inplace=True), )
- setattr(self, 'conv%d' % i, conv)
- in_size = out_size
-
-
- def forward(self, inputs):
- x = inputs
- for i in range(1, self.n + 1):
- conv = getattr(self, 'conv%d' % i)
- x = conv(x)
-
- return x
-
-
- class unetUp(nn.Module):
- def __init__(self, in_size, out_size, is_deconv, n_concat=2):
- super(unetUp, self).__init__()
- self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
- if is_deconv:
- self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0)
- else:
- self.up = nn.Sequential(
- nn.UpsamplingBilinear2d(scale_factor=2),
- nn.Conv2d(in_size, out_size, 1))
-
- def forward(self, high_feature, *low_feature):
- outputs0 = self.up(high_feature)
- for feature in low_feature:
- outputs0 = torch.cat([outputs0, feature], 1)
- return self.conv(outputs0)
-
-
- class UNet_learn(nn.Module):
-
- def __init__(self, in_channels=81, n_classes=10, feature_scale=2, is_deconv=True, is_batchnorm=True):
- super(UNet_learn, self).__init__()
- self.in_channels = in_channels
- self.feature_scale = feature_scale
- self.is_deconv = is_deconv
- self.is_batchnorm = is_batchnorm
-
- filters = [64, 128, 256, 512]
- filters = [int(x / self.feature_scale) for x in filters]
-
- # downsampling
- self.convin = nn.Conv2d(in_channels = 81, out_channels=81, kernel_size=2,padding=1)
- self.maxpool = nn.MaxPool2d(kernel_size=2)
- self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
- self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
- self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
- # self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
- self.center = unetConv2(filters[2], filters[3], self.is_batchnorm)
- # upsampling
- # self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
- self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
- self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
- self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
- # final conv (without any concat)
- self.final = nn.Conv2d(filters[0], n_classes, 1)
-
-
- def forward(self, inputs):
- inputs = self.convin(inputs)
- # print('input.shape:{}'.format(inputs.shape)) # 输入([2, 1, 64, 64]) | ([2, 10, 40, 200])
- conv1 = self.conv1(inputs) # 16*512*512
- # print('conv1.shape:{}'.format(conv1.shape)) # 输入([2, 1, 64, 64]) ---> ([2, 32, 64, 64]) | ([2, 10, 40, 200]) ---> ([2, 32, 40, 200])
- maxpool1 = self.maxpool(conv1) # 16*256*256
- # print('maxpool1.shape:{}'.format(maxpool1.shape)) # 输入([2, 32, 64, 64]) ---> ([2, 32, 32, 32]) | ([2, 32, 40, 200]) ---> ([2, 32, 20, 100])
- conv2 = self.conv2(maxpool1) # 32*256*256
- # print('conv2.shape:{}'.format(conv2.shape)) # 输入([2, 32, 32, 32]) ---> ([2, 64, 32, 32]) | ([2, 32, 20, 100]) ---> ([2, 64, 20, 100])
- maxpool2 = self.maxpool(conv2) # 32*128*128
- # print('maxpool2.shape:{}'.format(maxpool2.shape)) # 输入([2, 64, 32, 32]) ---> ([2, 64, 16, 16]) | ([2, 64, 20, 100]) ---> ([2, 64, 10, 50])
- conv3 = self.conv3(maxpool2) # 64*128*128
- # print('conv3.shape:{}'.format(conv3.shape)) # 输入([2, 64, 16, 16]) ---> ([2, 128, 16, 16]) | ([2, 64, 10, 50]) ---> ([2, 128, 10, 50])
- maxpool3 = self.maxpool(conv3) # 64*64*64
- # print('maxpool3.shape:{}'.format(maxpool3.shape)) # 输入([2, 128, 16, 16]) ---> ([2, 128, 8, 8]) | ([2, 128, 10, 50]) --->([2, 128, 5, 25])
- # conv4 = self.conv4(maxpool3) # 128*64*64
- # # print('conv4.shape:{}'.format(conv4.shape)) # 输入([2, 128, 8, 8])---> ([2, 256, 8, 8])
- # maxpool4 = self.maxpool(conv4) # 128*32*32
- # # print('maxpool4.shape:{}'.format(maxpool4.shape)) # 输入([2, 256, 8, 8])--> ([2, 256, 4, 4])
-
- center = self.center(maxpool3) # 256*32*32
- # print('center.shape:{}'.format(center.shape)) # 输入([2, 256, 4, 4])--> ([2, 512, 4, 4]) | ([2, 128, 5, 25]) ---> ([2, 128, 5, 25])
- # up4 = self.up_concat4(center, conv3) # 128*64*64
- # print('up4.shape:{}'.format(up4.shape)) # 输入([2, 512, 4, 4])--> ([2, 256, 8, 8])
- up3 = self.up_concat3(center, conv3) # 64*128*128
- # print('up3.shape:{}'.format(up3.shape)) # 输入([2, 256, 8, 8])--> ([2, 128, 16, 16])
- up2 = self.up_concat2(up3, conv2) # 32*256*256
- # print('up2.shape:{}'.format(up2.shape)) # 输入([2, 128, 16, 16])-> ([2, 64, 32, 32])
- up1 = self.up_concat1(up2, conv1) # 16*512*512
- # print('up1.shape:{}'.format(up1.shape)) # 输入([2, 64, 32, 32])-> ([2, 32, 64, 64])
-
- final = self.final(up1)
-
- return final
-
-
-
-
-
-
-
- try:
- from batchnorm import SynchronizedBatchNorm3d
- except:
- pass
-
-
- def normalization(planes, norm='bn'):
- if norm == 'bn':
- m = nn.BatchNorm3d(planes)
- elif norm == 'gn':
- m = nn.GroupNorm(4, planes)
- elif norm == 'in':
- m = nn.InstanceNorm3d(planes)
- elif norm == 'sync_bn':
- m = SynchronizedBatchNorm3d(planes)
- else:
- raise ValueError('normalization type {} is not supported'.format(norm))
- return m
-
-
- class Conv3d_Block(nn.Module):
- def __init__(self, num_in, num_out, kernel_size=1, stride=1, g=1, padding=None, norm=None):
- super(Conv3d_Block, self).__init__()
- if padding == None:
- padding = (kernel_size - 1) // 2
- self.bn = normalization(num_in, norm=norm)
- self.act_fn = nn.ReLU(inplace=True)
- self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding, stride=stride, groups=g,
- bias=False)
-
- def forward(self, x): # BN + Relu + Conv
- h = self.act_fn(self.bn(x))
- h = self.conv(h)
- return h
-
-
- class DilatedConv3DBlock(nn.Module):
- def __init__(self, num_in, num_out, kernel_size=(1, 1, 1), stride=1, g=1, d=(1, 1, 1), norm=None):
- super(DilatedConv3DBlock, self).__init__()
- assert isinstance(kernel_size, tuple) and isinstance(d, tuple)
-
- padding = tuple(
- [(ks - 1) // 2 * dd for ks, dd in zip(kernel_size, d)]
- )
-
- self.bn = normalization(num_in, norm=norm)
- self.act_fn = nn.ReLU(inplace=True)
- self.conv = nn.Conv3d(num_in, num_out, kernel_size=kernel_size, padding=padding, stride=stride, groups=g,
- dilation=d, bias=False)
-
- def forward(self, x):
- h = self.act_fn(self.bn(x))
- h = self.conv(h)
- return h
-
-
- class MFunit(nn.Module):
- def __init__(self, num_in, num_out, g=1, stride=1, d=(1, 1), norm=None):
- """ The second 3x3x1 group conv is replaced by 3x3x3.
- :param num_in: number of input channels
- :param num_out: number of output channels
- :param g: groups of group conv.
- :param stride: 1 or 2
- :param d: tuple, d[0] for the first 3x3x3 conv while d[1] for the 3x3x1 conv
- :param norm: Batch Normalization
- """
- super(MFunit, self).__init__()
- num_mid = num_in if num_in <= num_out else num_out
- self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm)
- self.conv1x1x1_in2 = Conv3d_Block(num_in // 4, num_mid, kernel_size=1, stride=1, norm=norm)
- self.conv3x3x3_m1 = DilatedConv3DBlock(num_mid, num_out, kernel_size=(3, 3, 3), stride=stride, g=g,
- d=(d[0], d[0], d[0]), norm=norm) # dilated
- self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=1, g=g,
- d=(d[1], d[1], 1), norm=norm)
- # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out,num_out,kernel_size=(1,3,3),stride=1,g=g,d=(1,d[1],d[1]),norm=norm)
-
- # skip connection
- if num_in != num_out or stride != 1:
- if stride == 1:
- self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm)
- if stride == 2:
- # if MF block with stride=2, 2x2x2
- self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0,
- norm=norm) # params
-
- def forward(self, x):
- x1 = self.conv1x1x1_in1(x)
- x2 = self.conv1x1x1_in2(x1)
- x3 = self.conv3x3x3_m1(x2)
- x4 = self.conv3x3x3_m2(x3)
-
- shortcut = x
-
- if hasattr(self, 'conv1x1x1_shortcut'):
- shortcut = self.conv1x1x1_shortcut(shortcut)
- if hasattr(self, 'conv2x2x2_shortcut'):
- shortcut = self.conv2x2x2_shortcut(shortcut)
-
- return x4 + shortcut
-
-
- class DMFUnit(nn.Module):
- # weighred add
- def __init__(self, num_in, num_out, g=1, stride=1, norm=None, dilation=None):
- super(DMFUnit, self).__init__()
- self.weight1 = nn.Parameter(torch.ones(1))
- self.weight2 = nn.Parameter(torch.ones(1))
- self.weight3 = nn.Parameter(torch.ones(1))
-
- num_mid = num_in if num_in <= num_out else num_out
-
- self.conv1x1x1_in1 = Conv3d_Block(num_in, num_in // 4, kernel_size=1, stride=1, norm=norm)
- self.conv1x1x1_in2 = Conv3d_Block(num_in // 4, num_mid, kernel_size=1, stride=1, norm=norm)
-
- self.conv3x3x3_m1 = nn.ModuleList()
- if dilation == None:
- dilation = [1, 2, 3]
- for i in range(3):
- self.conv3x3x3_m1.append(
- DilatedConv3DBlock(num_mid, num_out, kernel_size=(3, 3, 3), stride=stride, g=g,
- d=(dilation[i], dilation[i], dilation[i]), norm=norm)
- )
-
- # It has not Dilated operation
- self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(3, 3, 1), stride=(1, 1, 1), g=g,
- d=(1, 1, 1), norm=norm)
- # self.conv3x3x3_m2 = DilatedConv3DBlock(num_out, num_out, kernel_size=(1, 3, 3), stride=(1,1,1), g=g,d=(1,1,1), norm=norm)
-
- # skip connection
- if num_in != num_out or stride != 1:
- if stride == 1:
- self.conv1x1x1_shortcut = Conv3d_Block(num_in, num_out, kernel_size=1, stride=1, padding=0, norm=norm)
- if stride == 2:
- self.conv2x2x2_shortcut = Conv3d_Block(num_in, num_out, kernel_size=2, stride=2, padding=0, norm=norm)
-
- def forward(self, x):
- # print('x1.shape:{}'.format(x.shape))
- x1 = self.conv1x1x1_in1(x)
- x2 = self.conv1x1x1_in2(x1)
- x3 = self.weight1 * self.conv3x3x3_m1[0](x2) + self.weight2 * self.conv3x3x3_m1[1](x2) + self.weight3 * \
- self.conv3x3x3_m1[2](x2)
- x4 = self.conv3x3x3_m2(x3)
- shortcut = x
- if hasattr(self, 'conv1x1x1_shortcut'):
- shortcut = self.conv1x1x1_shortcut(shortcut)
- if hasattr(self, 'conv2x2x2_shortcut'):
- shortcut = self.conv2x2x2_shortcut(shortcut)
- return x4 + shortcut
-
-
- class MFNet(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=13, n=32, channels=128, groups=16, norm='bn', num_classes=4):
- super(MFNet, self).__init__()
-
- # 时间维度变大
- # self.time_dimison = nn.Conv3d( c, n, kernel_size=(1,1,1), padding=(0,0,3), stride=(1,1,1), bias=False)
-
- # Entry flow
- self.encoder_block1 = nn.Conv3d(13, n, kernel_size=3, padding=1, stride=2, bias=False) # H//2
- self.encoder_block2 = nn.Sequential(
- MFunit(n, channels, g=groups, stride=2, norm=norm), # H//4 down
- MFunit(channels, channels, g=groups, stride=1, norm=norm),
- MFunit(channels, channels, g=groups, stride=1, norm=norm)
- )
- #
- self.encoder_block3 = nn.Sequential(
- MFunit(channels, channels * 2, g=groups, stride=2, norm=norm), # H//8
- MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm),
- MFunit(channels * 2, channels * 2, g=groups, stride=1, norm=norm)
- )
-
- self.encoder_block4 = nn.Sequential( # H//8,channels*4
- MFunit(channels * 2, channels * 3, g=groups, stride=2, norm=norm), # H//16
- MFunit(channels * 3, channels * 3, g=groups, stride=1, norm=norm),
- MFunit(channels * 3, channels * 2, g=groups, stride=1, norm=norm),
- )
-
- self.upsample1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//8
- self.decoder_block1 = MFunit(channels * 2 + channels * 2, channels * 2, g=groups, stride=1, norm=norm)
-
- self.upsample2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//4
- self.decoder_block2 = MFunit(channels * 2 + channels, channels, g=groups, stride=1, norm=norm)
-
- self.upsample3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H//2
- self.decoder_block3 = MFunit(channels + n, n, g=groups, stride=1, norm=norm)
- self.upsample4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) # H
- self.seg = nn.Conv3d(n, num_classes, kernel_size=1, padding=0, stride=1, bias=False)
-
- self.softmax = nn.Softmax(dim=1)
-
- # Initialization
- for m in self.modules():
- if isinstance(m, nn.Conv3d):
- torch.nn.init.torch.nn.init.kaiming_normal_(m.weight) #
- elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, SynchronizedBatchNorm3d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- def forward(self, x):
- # Encoder
- x = self.time_exten_dimison(x)
-
- # print('x0.shape:{}'.format(x.shape)) # (1,4,128,128,128)) ---> (1,4,128,128,128)
- x1 = self.encoder_block1(x) # H//2 down # (1,4,128,128,128) ---> torch.Size([1, 32, 64, 64, 64])
- # print('x8.shape:{}'.format(x1.shape))
- x2 = self.encoder_block2(x1) # H//4 down # ([1, 32, 64, 64, 64]) ---> torch.Size([1, 128, 32, 32, 32])
- # print(x2.shape)
- x3 = self.encoder_block3(
- x2) # H//8 down # torch.Size([1, 128, 32, 32, 32]) ---> torch.Size([1, 256, 16, 16, 16])
- # print('x3.shape:{}'.format(x.shape)) # x4 = self.encoder_block4(x3) # H//16 # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 256, 8, 8, 8])
- # print(x4.shape)
- # Decoder
- # y1 = self.upsample1(x4)# H//8 # torch.Size([1, 256, 8, 8, 8]) ---> torch.Size([1, 256, 16, 16, 16])
- # print(y1.shape)
- # y1 = torch.cat([x3,y1],dim=1) # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 512, 16, 16, 16])
- # print(y1.shape)
- # y1 = self.decoder_block1(y1) # torch.Size([1, 512, 16, 16, 16]) ---> torch.Size([1, 256, 16, 16, 16])
- # print(y1.shape)
- y2 = self.upsample2(x3) # H//4 # torch.Size([1, 256, 16, 16, 16]) ---> torch.Size([1, 256, 32, 32, 32])
- # print(y2.shape)
- y2 = torch.cat([x2, y2], dim=1) # torch.Size([1, 256, 32, 32, 32]) ---> torch.Size([1, 384, 32, 32, 32])
- # print(y2.shape)
- y2 = self.decoder_block2(y2) # torch.Size([1, 384, 32, 32, 32]) ---> torch.Size([1, 128, 32, 32, 32])
- # print(y2.shape)
- y3 = self.upsample3(y2) # H//2 # torch.Size([1, 128, 32, 32, 32]) ---> torch.Size([1, 128, 64, 64, 64])
- # print(y3.shape)
- y3 = torch.cat([x1, y3], dim=1) # torch.Size([1, 128, 64, 64, 64]) ---> torch.Size([1, 160, 64, 64, 64])
- # print(y3.shape)
- y3 = self.decoder_block3(y3) # torch.Size([1, 160, 64, 64, 64]) --> torch.Size([1, 32, 64, 64, 64])
- # print(y3.shape)
- y4 = self.upsample4(y3) # torch.Size([1, 32, 64, 64, 64])---> torch.Size([1, 32, 128, 128, 128])
- # print(y4.shape)
- y4 = self.seg(y4) # torch.Size([1, 32, 128, 128, 128]) ---> torch.Size([1, 4, 128, 128, 128])
- # print('y4.shape:{}'.format(y4.shape))
- if hasattr(self, 'softmax'):
- y4 = self.softmax(y4)
- # print('y4.shape:{}'.format(y4.shape))
- y4 = self.time_descent_dimison(y4)
- return y4
-
-
- class time_exten_dimison(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=13, n=13, channels=128, groups=16, norm='bn', num_classes=4):
- super(time_exten_dimison, self).__init__()
-
- # 时间维度变大
- self.time_exten_dimison = nn.Conv3d(c, n, kernel_size=(1, 1, 3), padding=(0, 0, 0), stride=(1, 1, 1),
- bias=False)
-
- def forward(self, x):
- # print('x00.shape:{}'.format(x.shape))
- x = self.time_exten_dimison(x)
-
- return x
-
-
- class time_descent_dimison(nn.Module): #
- # [96] Flops: 13.361G & Params: 1.81M
- # [112] Flops: 16.759G & Params: 2.46M
- # [128] Flops: 20.611G & Params: 3.19M
- def __init__(self, c=4, n=2, channels=128, groups=16, norm='bn', num_classes=4):
- super(time_descent_dimison, self).__init__()
-
- # 时间维度变大
- self.time_descent_dimison = nn.Conv3d(c, n, kernel_size=(1, 1, 5), padding=(0, 0, 0), stride=(1, 1, 4),
- bias=False)
-
- def forward(self, x):
- # print('x00.shape:{}'.format(x.shape))
- x = self.time_descent_dimison(x)
-
- return x
-
-
- class DMFNet(MFNet): # softmax
- # [128] Flops: 27.045G & Params: 3.88M
- def __init__(self, c=13, n=32, channels=128, groups=16, norm='bn', num_classes=4):
- super(DMFNet, self).__init__(c, n, channels, groups, norm, num_classes)
-
- self.time_exten_dimison = time_exten_dimison()
- # print('xxxx.shape:{}'.format(x.shape))
- self.encoder_block2 = nn.Sequential(
- DMFUnit(n, channels, g=groups, stride=2, norm=norm, dilation=[1, 2, 3]), # H//4 down
- DMFUnit(channels, channels, g=groups, stride=1, norm=norm, dilation=[1, 2, 3]), # Dilated Conv 3
- DMFUnit(channels, channels, g=groups, stride=1, norm=norm, dilation=[1, 2, 3])
- )
-
- self.encoder_block3 = nn.Sequential(
- DMFUnit(channels, channels * 2, g=groups, stride=2, norm=norm, dilation=[1, 2, 3]), # H//8
- DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm, dilation=[1, 2, 3]), # Dilated Conv 3
- DMFUnit(channels * 2, channels * 2, g=groups, stride=1, norm=norm, dilation=[1, 2, 3])
- )
-
- self.time_descent_dimison = time_descent_dimison()
-
-
- def setup_seed(seed):
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
- # 设置随机数种子
- setup_seed(1)
-
- # 需要 mld u v sss temp 降水 蒸发 混合层下的盐度
-
- data = np.load(r'/dataset/10day_for_14day_all_variables_surface_pacific_10_19_Qfenjie_we.npz')
- print(data.files) #['surface_latent_heat_flux', 'surface_sensible_heat_flux', 'surface_net_radiation', 'evaporation', 'total_precipitation', 'mld', 'sst_surface', 'sss_surface', 'u_surface', 'v_surface',
- #'T_d', 'S_d', 'u_d', 'v_d', 'xx', 'sst_surface_label', 'sss_surface_label', 'dS_dt', 'dT_dt', 'dS_dx', 'dT_dx', 'dS_dy', 'dT_dy']
-
- # 2000-01-01 ---> 2019-12-31
- surface_latent_heat_flux = data['surface_latent_heat_flux'][:] # (7305, 10, 40, 200)
- surface_sensible_heat_flux = data['surface_sensible_heat_flux'][:] # (9851, 10, 40, 200)
- surface_net_radiation = data['surface_net_radiation'][:] # (9851, 10, 40, 200)
- evaporation = data['evaporation'][:]
- total_precipitation = data['total_precipitation'][:]
- mld = data['mld'][:]
- sst_surface = data['sst_surface'][:]
- sss_surface = data['sss_surface'][:]
- u_surface = data['u_surface'][:]
- v_surface = data['v_surface'][:]
- T_d = data['T_d'][:]
- S_d = data['S_d'][:]
- u_d = data['u_d'][:]
- v_d = data['v_d'][:]
-
- sst_surface_label = data['sst_surface_label'][:]
- sss_surface_label = data['sss_surface_label'][:]
-
- # print(surface_latent_heat_flux.shape)
- # print(sst_surface.shape)
- # (2203, 10, 40, 200)
- surface_latent_heat_flux = surface_latent_heat_flux.transpose(0, 2, 3, 1)
- surface_sensible_heat_flux = surface_sensible_heat_flux.transpose(0, 2, 3, 1)
- surface_net_radiation = surface_net_radiation.transpose(0, 2, 3, 1)
- evaporation = evaporation.transpose(0, 2, 3, 1)
- total_precipitation = total_precipitation.transpose(0, 2, 3, 1)
- mld = mld.transpose(0, 2, 3, 1)
- sst_surface = sst_surface.transpose(0, 2, 3, 1)
- sss_surface = sss_surface.transpose(0, 2, 3, 1)
- u_surface = u_surface.transpose(0, 2, 3, 1)
- v_surface = v_surface.transpose(0, 2, 3, 1)
- T_d = T_d.transpose(0, 2, 3, 1)
- S_d = S_d.transpose(0, 2, 3, 1)
- u_d = u_d.transpose(0, 2, 3, 1)
- v_d = v_d.transpose(0, 2, 3, 1)
-
- # print(sst_surface.shape) # (2203, 40, 200, 10)
- # print(sss_surface.shape) # (2203, 40, 200, 10) B LAT LON T
-
-
- dS_dt = data['dS_dt'][:] # dS_dt.shape:(2203, 40, 200, 10)
- dT_dt = data['dT_dt'][:]
- dS_dx = data['dS_dx'][:]
- dS_dy = data['dS_dy'][:]
- dT_dx = data['dT_dx'][:]
- dT_dy = data['dT_dy'][:]
- print('dS_dt.shape:{}'.format(dS_dt.shape))
-
- # 归一化
-
- scaler1 = MinMaxScaler()
- scaler2 = MinMaxScaler()
- surface_latent_heat_flux1 = surface_latent_heat_flux.reshape(-1, 1)
- surface_sensible_heat_flux1 = surface_sensible_heat_flux.reshape(-1, 1)
- surface_net_radiation1 = surface_net_radiation.reshape(-1, 1)
- evaporation1 = evaporation.reshape(-1, 1)
- total_precipitation1 = total_precipitation.reshape(-1, 1)
- mld1 = mld.reshape(-1, 1)
- sst_surface1 = sst_surface.reshape(-1, 1)
- sss_surface1 = sss_surface.reshape(-1, 1)
- u_surface1 = u_surface.reshape(-1, 1)
- v_surface1 = v_surface.reshape(-1, 1)
- T_d1 = T_d.reshape(-1, 1)
- S_d1 = S_d.reshape(-1, 1)
- u_d1 = u_d.reshape(-1, 1)
- v_d1 = v_d.reshape(-1, 1)
- sst_surface_label1 = sst_surface_label.reshape(-1, 1)
- sss_surface_label1 = sss_surface_label.reshape(-1, 1)
-
-
-
- surface_latent_heat_flux = scaler1.fit_transform(surface_latent_heat_flux1).reshape(surface_latent_heat_flux.shape)
- surface_sensible_heat_flux = scaler1.fit_transform(surface_sensible_heat_flux1).reshape(surface_sensible_heat_flux.shape)
- surface_net_radiation1 = scaler1.fit_transform(surface_net_radiation1).reshape(surface_net_radiation.shape)
- evaporation = scaler1.fit_transform(evaporation1).reshape(evaporation.shape)
- total_precipitation = scaler1.fit_transform(total_precipitation1).reshape(total_precipitation.shape)
- mld = scaler1.fit_transform(mld1).reshape(mld.shape)
- sst_surface = scaler1.fit_transform(sst_surface1).reshape(sst_surface.shape)
- sss_surface = scaler1.fit_transform(sss_surface1).reshape(sss_surface.shape)
- u_surface = scaler1.fit_transform(u_surface1).reshape(u_surface.shape)
- v_surface = scaler1.fit_transform(v_surface1).reshape(v_surface.shape)
- T_d = scaler1.fit_transform(T_d1).reshape(T_d.shape)
- S_d = scaler1.fit_transform(S_d1).reshape(S_d.shape)
- u_d = scaler1.fit_transform(u_d1).reshape(u_d.shape)
- v_d = scaler1.fit_transform(v_d1).reshape(v_d.shape)
- sst_surface_label = scaler1.fit_transform(sst_surface_label1).reshape(sst_surface_label.shape)
- sss_surface_label = scaler1.fit_transform(sss_surface_label1).reshape(sss_surface_label.shape)
-
- Q_net = surface_latent_heat_flux + surface_sensible_heat_flux + surface_net_radiation1
-
-
-
-
-
- dS_dt1 = dS_dt.reshape(-1, 1)
- dT_dt1 = dT_dt.reshape(-1, 1)
- dS_dx1 = dS_dx.reshape(-1, 1)
- dS_dy1 = dS_dy.reshape(-1, 1)
- dT_dx1 = dT_dx.reshape(-1, 1)
- dT_dy1 = dT_dy.reshape(-1, 1)
-
- dS_dt = scaler2.fit_transform(dS_dt1).reshape(dS_dt.shape)
- dT_dt = scaler2.fit_transform(dT_dt1).reshape(dT_dt.shape)
- dS_dx = scaler2.fit_transform(dS_dx1).reshape(dS_dx.shape)
- dS_dy = scaler2.fit_transform(dS_dy1).reshape(dS_dy.shape)
- dT_dx = scaler2.fit_transform(dT_dx1).reshape(dT_dx.shape)
- dT_dy = scaler2.fit_transform(dT_dy1).reshape(dT_dy.shape)
- u_surface_we = scaler2.fit_transform(u_surface1).reshape(u_surface.shape)
- v_surface_we = scaler2.fit_transform(v_surface1).reshape(v_surface.shape)
- mld_we = scaler2.fit_transform(mld1).reshape(mld.shape)
-
- train_size = 1312
- valid_size = 1760
-
-
-
- # 前20% 作为验证 剩下的20%的作为测试
- # surface_latent_heat_flux = surface_latent_heat_flux + surface_sensible_heat_flux + surface_net_radiation
-
- u_surface_we = u_surface_we.reshape(-1, 1, 40, 200, 10)
- u_surface_we = torch.Tensor(u_surface_we)
- u_surface_we_train = u_surface_we[0:train_size, :, :, :, :]
- u_surface_we_valid = u_surface_we[train_size:valid_size, :, :, :, :]
- # Q_net_test = Q_net[valid_size:,:,:,:,:]
-
- v_surface_we = v_surface_we.reshape(-1, 1, 40, 200, 10)
- v_surface_we = torch.Tensor(v_surface_we)
- v_surface_we_train = v_surface_we[0:train_size, :, :, :, :]
- v_surface_we_valid = v_surface_we[train_size:valid_size, :, :, :, :]
-
- mld_we = mld_we.reshape(-1, 1, 40, 200, 10)
- mld_we = torch.Tensor(mld_we)
- mld_we_train = mld_we[0:train_size, :, :, :, :]
- mld_we_valid = mld_we[train_size:valid_size, :, :, :, :]
-
-
- Q_net = Q_net.reshape(-1, 1, 40, 200, 10)
- Q_net = torch.Tensor(Q_net)
- Q_net_train = Q_net[0:train_size, :, :, :, :]
- Q_net_valid = Q_net[train_size:valid_size, :, :, :, :]
- # Q_net_test = Q_net[valid_size:,:,:,:,:]
-
- surface_latent_heat_flux = surface_latent_heat_flux.reshape(-1, 1, 40, 200, 10)
- surface_latent_heat_flux = torch.Tensor(surface_latent_heat_flux)
- surface_latent_heat_flux_train = surface_latent_heat_flux[0:train_size, :, :, :, :]
- surface_latent_heat_flux_valid = surface_latent_heat_flux[train_size:valid_size, :, :, :, :]
- # surface_latent_heat_flux_test = surface_latent_heat_flux[valid_size:,:,:,:,:]
-
-
- surface_sensible_heat_flux = surface_sensible_heat_flux.reshape(-1, 1, 40, 200, 10)
- surface_sensible_heat_flux = torch.Tensor(surface_sensible_heat_flux)
- surface_sensible_heat_flux_train = surface_sensible_heat_flux[0:train_size, :, :, :, :]
- surface_sensible_heat_flux_valid = surface_sensible_heat_flux[train_size:valid_size, :, :, :, :]
- # # # surface_sensible_heat_flux_test = surface_sensible_heat_flux[valid_size:,:,:,:,:]
-
-
- surface_net_radiation = surface_net_radiation.reshape(-1, 1, 40, 200, 10)
- surface_net_radiation = torch.Tensor(surface_net_radiation)
- surface_net_radiation_train = surface_net_radiation[0:train_size, :, :, :, :]
- surface_net_radiation_valid = surface_net_radiation[train_size:valid_size, :, :, :, :]
- # # surface_net_radiation_test = surface_net_radiation[valid_size:,:,:,:,:]
-
-
- evaporation = evaporation.reshape(-1, 1, 40, 200, 10)
- evaporation = torch.Tensor(evaporation)
- evaporation_train = evaporation[0:train_size, :, :, :, :]
- evaporation_valid = evaporation[train_size:valid_size, :, :, :, :]
- # evaporation_test = evaporation[valid_size:,:,:,:,:]
-
-
- total_precipitation = total_precipitation.reshape(-1, 1, 40, 200, 10)
- total_precipitation = torch.Tensor(total_precipitation)
- total_precipitation_train = total_precipitation[0:train_size, :, :, :, :]
- total_precipitation_valid = total_precipitation[train_size:valid_size, :, :, :, :]
- # total_precipitation_test = total_precipitation[valid_size:,:,:,:,:]
-
-
- mld = mld.reshape(-1, 1, 40, 200, 10)
- mld = torch.Tensor(mld)
- mld_train = mld[0:train_size, :, :, :, :]
- mld_valid = mld[train_size:valid_size, :, :, :, :]
- # mld_test = mld[valid_size:,:,:,:,:]
-
-
- sst_surface = sst_surface.reshape(-1, 1, 40, 200, 10)
- sst_surface = torch.Tensor(sst_surface)
- sst_surface_train = sst_surface[0:train_size, :, :, :, :]
- sst_surface_valid = sst_surface[train_size:valid_size, :, :, :, :]
- # sst_surface_test = sst_surface[valid_size:,:,:,:,:]
-
-
- sss_surface = sss_surface.reshape(-1, 1, 40, 200, 10)
- sss_surface = torch.Tensor(sss_surface)
- sss_surface_train = sss_surface[0:train_size, :, :, :, :]
- sss_surface_valid = sss_surface[train_size:valid_size, :, :, :, :]
- # sss_surface_test = sss_surface[valid_size:,:,:,:,:]
-
-
- u_surface = u_surface.reshape(-1, 1, 40, 200, 10)
- u_surface = torch.Tensor(u_surface)
- u_surface_train = u_surface[0:train_size, :, :, :, :]
- u_surface_valid = u_surface[train_size:valid_size, :, :, :, :]
- # u_surface_test = u_surface[valid_size:,:,:,:,:]
-
-
- v_surface = v_surface.reshape(-1, 1, 40, 200, 10)
- v_surface = torch.Tensor(v_surface)
- v_surface_train = v_surface[0:train_size, :, :, :, :]
- v_surface_valid = v_surface[train_size:valid_size, :, :, :, :]
- # # v_surface_test = v_surface[valid_size:,:,:,:,:]
-
-
- T_d = T_d.reshape(-1, 1, 40, 200, 10)
- T_d = torch.Tensor(T_d)
- T_d_train = T_d[0:train_size, :, :, :, :]
- T_d_valid = T_d[train_size:valid_size, :, :, :, :]
- # T_d_test = T_d[valid_size:,:,:,:,:]
-
-
- S_d = S_d.reshape(-1, 1, 40, 200, 10)
- S_d = torch.Tensor(S_d)
- S_d_train = S_d[0:train_size, :, :, :, :]
- S_d_valid = S_d[train_size:valid_size, :, :, :, :]
- # S_d_test = S_d[valid_size:,:,:,:,:]
-
-
- u_d = u_d.reshape(-1, 1, 40, 200, 10)
- u_d = torch.Tensor(u_d)
- u_d_train = u_d[0:train_size, :, :, :, :]
- u_d_valid = u_d[train_size:valid_size, :, :, :, :]
- # u_d_test = u_d[valid_size:,:,:,:,:]
-
-
- v_d = v_d.reshape(-1, 1, 40, 200, 10)
- v_d = torch.Tensor(v_d)
- v_d_train = v_d[0:train_size, :, :, :, :]
- v_d_valid = v_d[train_size:valid_size, :, :, :, :]
- # # v_d_test = v_d[valid_size:,:,:,:,:]
-
- # xx = xx.reshape(-1, 1, 40, 200, 10)
- # xx = torch.Tensor(xx)
- # xx_train = xx[0:train_size,:,:,:,:]
- # xx_valid = xx[train_size:valid_size,:,:,:,:]
- # # xx_test = xx[valid_size:,:,:,:,:]
-
-
- # yy = yy.reshape(-1, 1, 40, 200, 10)
- # yy = torch.Tensor(yy)
- # yy_train = yy[0:train_size,:,:,:,:]
- # yy_valid = yy[train_size:valid_size,:,:,:,:]
- # yy_test = yy[valid_size:,:,:,:,:]
-
-
-
- dS_dt = dS_dt.reshape(-1, 1, 40, 200, 10)
- dS_dt = torch.Tensor(dS_dt)
- dS_dt_train = dS_dt[0:train_size, :, :, :]
- dS_dt_valid = dS_dt[train_size:valid_size, :, :, :]
-
- dT_dt = dT_dt.reshape(-1, 1, 40, 200, 10)
- dT_dt = torch.Tensor(dT_dt)
- dT_dt_train = dT_dt[0:train_size, :, :, :]
- dT_dt_valid = dT_dt[train_size:valid_size, :, :, :]
-
- dS_dx = dS_dx.reshape(-1, 1, 40, 200, 10)
- dS_dx = torch.Tensor(dS_dx)
- dS_dx_train = dS_dx[0:train_size, :, :, :]
- dS_dx_valid = dS_dx[train_size:valid_size, :, :, :]
-
- dS_dy = dS_dy.reshape(-1, 1, 40, 200, 10)
- dS_dy = torch.Tensor(dS_dy)
- dS_dy_train = dS_dy[0:train_size, :, :, :]
- dS_dy_valid = dS_dy[train_size:valid_size, :, :, :]
-
- dT_dx = dT_dx.reshape(-1, 1, 40, 200, 10)
- dT_dx = torch.Tensor(dT_dx)
- dT_dx_train = dT_dx[0:train_size, :, :, :]
- dT_dx_valid = dT_dx[train_size:valid_size, :, :, :]
-
- dT_dy = dT_dy.reshape(-1, 1, 40, 200, 10)
- dT_dy = torch.Tensor(dT_dy)
- dT_dy_train = dT_dy[0:train_size, :, :, :]
- dT_dy_valid = dT_dy[train_size:valid_size, :, :, :]
-
-
- # train_data = torch.cat((surface_latent_heat_flux_train, surface_sensible_heat_flux_train, surface_net_radiation_train,
- # evaporation_train, total_precipitation_train, mld_train, sst_surface_train, sss_surface_train,
- # u_surface_train, v_surface_train, T_d_train, S_d_train, u_d_train, v_d_train,
- # xx_train, yy_train), dim=2) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
-
- # valid_data = torch.cat((surface_latent_heat_flux_valid, surface_sensible_heat_flux_valid, surface_net_radiation_valid,
- # evaporation_valid, total_precipitation_valid, mld_valid, sst_surface_valid, sss_surface_valid,
- # u_surface_valid, v_surface_valid, T_d_valid, S_d_valid, u_d_valid, v_d_valid,
- # xx_valid, yy_valid), dim=2)
-
- # test_data = torch.cat((surface_latent_heat_flux_test, surface_sensible_heat_flux_test, surface_net_radiation_test,
- # evaporation_test, total_precipitation_test, mld_test, sst_surface_test, sss_surface_test,
- # u_surface_test, v_surface_test, T_d_test, S_d_test, u_d_test, v_d_test,
- # xx_test, yy_test), dim=2)
-
-
- train_data = torch.cat((Q_net_train,
- evaporation_train, total_precipitation_train, mld_train, sst_surface_train, sss_surface_train,
- u_surface_train, v_surface_train, T_d_train, S_d_train, u_d_train, v_d_train, dS_dt_train, dT_dt_train, dS_dx_train, dS_dy_train, dT_dx_train, dT_dy_train
- ,u_surface_we_train, v_surface_we_train,mld_we_train), dim=1) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
-
- valid_data = torch.cat((Q_net_valid,
- evaporation_valid, total_precipitation_valid, mld_valid, sst_surface_valid, sss_surface_valid,
- u_surface_valid, v_surface_valid, T_d_valid, S_d_valid, u_d_valid, v_d_valid, dS_dt_valid, dT_dt_valid, dS_dx_valid, dS_dy_valid, dT_dx_valid, dT_dy_valid
- ,u_surface_we_valid, v_surface_we_valid, mld_we_valid), dim=1)
-
- # train_data = torch.cat((sst_surface_train, sss_surface_train,), dim=1) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
-
- # valid_data = torch.cat((sst_surface_valid, sss_surface_valid,), dim=1)
-
- print(train_data.shape)
- print(valid_data.shape)
-
- sst_train_label = sst_surface_label[14:train_size + 14, :, :]
- sst_valid_label = sst_surface_label[train_size + 14: valid_size + 14, :, :]
- # sst_test_label = sst_surface_label[valid_size + 14:,:,:]
-
-
- sss_train_label = sss_surface_label[14:train_size + 14, :, :]
- sss_valid_label = sss_surface_label[train_size + 14: valid_size + 14, :, :]
- # sss_test_label = sss_surface_label[valid_size + 14:,:,:]
-
- sst_train_label = sst_train_label.reshape(-1, 1, 40, 200)
- sst_valid_label = sst_valid_label.reshape(-1, 1, 40, 200)
-
- sss_train_label = sss_train_label.reshape(-1, 1, 40, 200)
- sss_valid_label = sss_valid_label.reshape(-1, 1, 40, 200)
-
- train_label = np.concatenate((sst_train_label, sss_train_label), axis=1)
- valid_label = np.concatenate((sst_valid_label, sss_valid_label), axis=1)
-
-
- # train_label = sst_train_label
- # valid_label = sst_valid_label
-
- # print('train_label.shape:{}'.format(train_label.shape))
- # print('valid_label.shape:{}'.format(valid_label.shape))
- # print('train_data.shape:{}'.format(train_data.shape))
- # print('valid_data.shape:{}'.format(valid_data.shape))
-
- # print('train_data.shape:{}'.format(train_data.shape)) # train_data.shape:torch.Size([5920, 10, 16, 40, 200])
- # print('valid_data.shape:{}'.format(valid_data.shape)) # valid_data.shape:torch.Size([1952, 10, 16, 40, 200])
- # print('test_data.shape:{}'.format(test_data.shape)) # test_data.shape:torch.Size([1979, 10, 16, 40, 200])
- # print('sst_train_label.shape:{}'.format(sst_train_label.shape)) # sst_train_label.shape:(5920, 14, 40, 200)
- # print('sst_valid_label.shape:{}'.format(sst_valid_label.shape)) # sst_valid_label.shape:(1952, 14, 40, 200)
- # print('sst_test_label.shape:{}'.format(sst_test_label.shape)) # sst_test_label.shape:(1961, 14, 40, 200)
- # print('sss_train_label.shape:{}'.format(sss_train_label.shape)) # sss_train_label.shape:(5920, 14, 40, 200)
- # print('sss_valid_label.shape:{}'.format(sss_valid_label.shape)) # sss_valid_label.shape:(1952, 14, 40, 200)
- # print('sss_test_label.shape:{}'.format(sss_test_label.shape)) # sss_test_label.shape:(1961, 14, 40, 200)
-
-
- # 构建数据管道
- class MyDataset(Dataset):
- def __init__(self, data, label):
- self.data = torch.Tensor(data)
- self.label = torch.Tensor(label)
-
- def __len__(self):
- return len(self.label)
-
- def __getitem__(self, idx):
- return self.data[idx], self.label[idx]
-
-
- batch_size1 = 32
- batch_size2 = 32
- batch_size3 = 3000
-
- trainset = MyDataset(train_data, train_label)
- trainloader = DataLoader(trainset, batch_size=batch_size1, shuffle=True, drop_last=False, pin_memory=True,
- num_workers=4)
-
- validset = MyDataset(valid_data, valid_label)
- validloader = DataLoader(validset, batch_size=batch_size2, shuffle=True, drop_last=False, pin_memory=True,
- num_workers=4)
-
- # testset = MyDataset(test_data, sss_test_label)
- # testloader = DataLoader(testset, batch_size=batch_size3, shuffle=False, drop_last=False,pin_memory=True, num_workers=0)
-
- # print('Qnet_train.shape:{}'.format(sst1_train.shape))
-
-
- model_weights1 = '/model/epo300_lay3_lr0.001_e5_forecastnet_14day_model_weights.pth'
- torch.backends.cudnn.enabled = False
-
- model = DMFNet(c=12, groups=16, norm='sync_bn', num_classes=4).cuda()
-
- model1 = UNet_learn().cuda()
-
- criterion = nn.MSELoss()
- # 定义优化器
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
- epochs = 200
- train_losses, valid_losses = [], []
- # best_loss = 2
- best_score = float('inf')
- best_score1 = float('inf')
-
- pred_val = np.zeros((448, 2, 40, 200))
-
- sores = []
-
-
- def rmse(y_true, y_preds):
- return np.sqrt(mean_squared_error(y_pred=y_preds, y_true=y_true))
-
-
- for epoch in range(epochs):
- print('Epoch: {}/{}'.format(epoch + 1, epochs))
- # print(var_y)
- # 模型训练
- model.train()
- losses = 0
- loss1 = 0
- for i, data in tqdm(enumerate(trainloader)):
- data, label = data
- data = data.cuda()
- label = label.cuda()
- print('data.shape:{}'.format(data.shape)) #data.shape:torch.Size([32, 18, 40, 200, 10])
-
-
- # dS_dt dS_dx dS_dy dT_dt dT_dx dT_dy
-
- data2 = data[:,12:,:,:,:]
-
- data2 = data2[:,:,:39,:199,:9]
- print('data2.shape:{}'.format(data2.shape))
- B,C,H,W,T = data2.size()
- data3 = data2.reshape(B,C*T,H,W)
- print('data3.shape:{}'.format(data3.shape))
- out_we = model1(data3)
- # print(out_we.shape) #torch.Size([32, 10, 40, 200])
-
- out_we = out_we.permute(0,2,3,1)
- print('out_we.shape:{}'.format(out_we.shape)) #torch.Size([32, 10, 40, 200])
- out_we = out_we.reshape(-1,1,40,200,10)
- # out_we1 = out_we.reshape(-1,1).detach().cpu().numpy()
- # out_we = scaler.fit_transform(out_we1).reshape(out_we.shape)
- # out_we = torch.Tensor(out_we).cuda()
-
- print(out_we)
- data1 = data[:,:12,:,:,:]
-
- data_train = torch.cat((data1, out_we), dim = 1)
- print('data_train.shape:{}'.format(data_train.shape)) #torch.Size([32, 10, 40, 200])
-
- # data_train1 = data_train.reshape(-1,1).detach().cpu().numpy()
- # data_train = scaler.fit_transform(data_train1).reshape(data_train.shape)
- # data_train = torch.Tensor(data_train).cuda()
-
- out = model(data_train)
-
-
- # print('out.shape:{}'.format(out.shape))
- # print('label.shape:{}'.format(label.shape))
- # print(out)
- # 偏S/偏t - (E - P) * (S / h) - [u * 偏S/偏x + v * 偏S/偏y ] + H * (w_h + dh/dt * ((S - S_h) / h)) = 0 loss1
- # 偏T/偏t - Q / (p * C_p * h_m) - u * (偏T/偏x) - v * (偏T/偏y) + w_e * ((T - T_d) / h) = 0
-
- # sst_label = label[:,0,:,:]
- # sss_label = label[:,1,:,:]
-
- # sst_out = out[:,0,:,:]
- # sss_out = out[:,1,:,:]
- out = out.reshape(-1, 2, 40, 200)
- loss = criterion(out, label)
- # loss2 = criterion(sss_out, sss_label)
-
- losses += loss
-
- loss.backward()
- optimizer.step()
- train_loss = losses / len(trainloader)
- train_losses.append(train_loss)
-
- print('Training Loss: {:.10f}'.format((train_loss)))
-
- model.eval()
- losses = 0
-
- for i, data in tqdm(enumerate(validloader)):
- data, label = data
- data = data.cuda()
- label = label.cuda()
- optimizer.zero_grad()
-
- # h_m = data[:,3,:,:,:]
- # u_v_surface_train = data[:,6:8,:,:,:]
- data2 = data[:,12:,:,:,:]
-
- # h_m.shape:torch.Size([32, 40, 200, 10])
- # u_v_surface_train.shape:torch.Size([32, 2, 40, 200, 10])
- # data2.shape:torch.Size([32, 6, 40, 200, 10])
- # h_m = h_m.reshape(-1,1,40,200,10)
- # data2 = torch.cat((h_m, u_v_surface_train, data2), dim = 1)
- data2 = data2[:,:,:39,:199,:9]
- print('data2.shape:{}'.format(data2.shape))
- B,C,H,W,T = data2.size()
- data3 = data2.reshape(B,C*T,H,W)
- print('data3.shape:{}'.format(data3.shape))
- out_we = model1(data3)
- # print(out_we.shape) #torch.Size([32, 10, 40, 200])
-
- out_we = out_we.permute(0,2,3,1)
- print('out_we.shape:{}'.format(out_we.shape)) #torch.Size([32, 10, 40, 200])
- out_we = out_we.reshape(-1,1,40,200,10)
-
- # out_we1 = out_we.reshape(-1,1).detach().cpu().numpy()
- # out_we = scaler.fit_transform(out_we1).reshape(out_we.shape)
- # out_we = torch.Tensor(out_we).cuda()
-
- data1 = data[:,:12,:,:,:]
-
- data_train = torch.cat((data1, out_we), dim = 1)
- print('data_train.shape:{}'.format(data_train.shape)) #torch.Size([32, 10, 40, 200])
-
- # data_train1 = data_train.reshape(-1,1).detach().cpu().numpy()
- # data_train = scaler.fit_transform(data_train1).reshape(data_train.shape)
- # data_train = torch.Tensor(data_train).cuda()
-
- out = model(data_train)
- out = out.reshape(-1, 2, 40, 200)
-
- loss = criterion(out, label)
-
- losses += float(loss)
-
- out1 = out.detach().cpu().numpy()
- pred_val[i * batch_size2:(i + 1) * batch_size2] = np.array(out1)
-
- valid_loss = losses / len(validloader)
- valid_losses.append(valid_loss)
-
- valid_label1 = valid_label.reshape(-1, 1)
- preds1 = pred_val.reshape(-1, 1)
-
- s = rmse(valid_label1, preds1)
- sores.append(s)
- print('Score: {:.3f}'.format(s))
-
- if valid_loss < best_score1: # 求s的最小值 ---》最大值反过来 inf符号也要反过来
- best_score1 = valid_loss
- checkpoint = {'best_score': valid_loss,
- 'state_dict': model.state_dict()}
- torch.save(checkpoint, model_weights1) # if valid_loss < best_loss:
- best_loss = valid_loss
- torch.save(model.state_dict(),
- '/model/pinn_likeICDM_lr0.005_model_300_layer3_2day_e5.pt')
-
- print(sores)
- print(best_score)
- print(s)
|