|
- from typing import Dict
- import numpy as np
- import torch.nn as nn
- import torch.nn.functional as F
- import random
- import torch
- from tqdm import tqdm
- from torch.autograd import Variable
- from torch.utils.data import Dataset, DataLoader
- from sklearn.metrics import mean_squared_error
- from torch.nn import Flatten
-
- class DoubleConvDS(nn.Module):
- """(convolution => [BN] => ReLU) * 2"""
- #
- def __init__(self, in_channels, out_channels, mid_channels=None, kernels_per_layer=1):
- super().__init__()
- if not mid_channels:
- mid_channels = out_channels
- self.double_conv = nn.Sequential(
- DepthwiseSeparableConv(in_channels, mid_channels, kernel_size=3, kernels_per_layer=kernels_per_layer, padding=1),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True),
- DepthwiseSeparableConv(mid_channels, out_channels, kernel_size=3, kernels_per_layer=kernels_per_layer, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True)
- )
-
- def forward(self, x):
- return self.double_conv(x)
-
-
- class DoubleConv(nn.Sequential):
- def __init__(self, in_channels, out_channels, mid_channels=None):
- if mid_channels is None:
- mid_channels = out_channels
- super(DoubleConv, self).__init__(
- nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(mid_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True)
- )
-
- class Down(nn.Sequential):
- def __init__(self, in_channels, out_channels):
- super(Down, self).__init__(
- nn.MaxPool2d(2, stride=2),
- DoubleConv(in_channels, out_channels)
- )
-
- class Up(nn.Module):
- def __init__(self, in_channels, out_channels, bilinear=True):
- super(Up, self).__init__()
- if bilinear:
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
- self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
- else:
- self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
- self.conv = DoubleConv(in_channels, out_channels)
-
- def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
- x1 = self.up(x1)
- # [N, C, H, W]
- diff_y = x2.size()[2] - x1.size()[2]
- diff_x = x2.size()[3] - x1.size()[3]
-
- # padding_left, padding_right, padding_top, padding_bottom
- x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
- diff_y // 2, diff_y - diff_y // 2])
-
- x = torch.cat([x2, x1], dim=1)
- x = self.conv(x)
- return x
-
- class OutConv(nn.Sequential):
- def __init__(self, in_channels, num_classes):
- super(OutConv, self).__init__(
- nn.Conv2d(in_channels, num_classes, kernel_size=1)
- )
-
- class ChannelAttention(nn.Module):
- def __init__(self, input_channels, reduction_ratio=16):
- super(ChannelAttention, self).__init__()
- self.input_channels = input_channels
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.max_pool = nn.AdaptiveMaxPool2d(1)
- # https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py
- # uses Convolutions instead of Linear
- self.MLP = nn.Sequential(
- Flatten(),
- nn.Linear(input_channels, input_channels // reduction_ratio),
- nn.ReLU(),
- nn.Linear(input_channels // reduction_ratio, input_channels)
- )
-
- def forward(self, x):
- # Take the input and apply average and max pooling
- avg_values = self.avg_pool(x)
- max_values = self.max_pool(x)
- out = self.MLP(avg_values) + self.MLP(max_values)
- scale = x * torch.sigmoid(out).unsqueeze(2).unsqueeze(3).expand_as(x)
- return scale
-
-
- class SpatialAttention(nn.Module):
- def __init__(self, kernel_size=7):
- super(SpatialAttention, self).__init__()
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
- padding = 3 if kernel_size == 7 else 1
- self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
- self.bn = nn.BatchNorm2d(1)
-
- def forward(self, x):
- avg_out = torch.mean(x, dim=1, keepdim=True)
- max_out, _ = torch.max(x, dim=1, keepdim=True)
- out = torch.cat([avg_out, max_out], dim=1)
- out = self.conv(out)
- out = self.bn(out)
- scale = x * torch.sigmoid(out)
- return scale
-
-
- class CBAM(nn.Module):
- def __init__(self, input_channels, reduction_ratio=16, kernel_size=7):
- super(CBAM, self).__init__()
- self.channel_att = ChannelAttention(input_channels, reduction_ratio=reduction_ratio)
- self.spatial_att = SpatialAttention(kernel_size=kernel_size)
-
- def forward(self, x):
- out = self.channel_att(x)
- out = self.spatial_att(out)
- return out
-
- #深度可分离卷积
- class DepthwiseSeparableConv(nn.Module):
- def __init__(self, in_channels, output_channels, kernel_size, padding=0, kernels_per_layer=1):
- super(DepthwiseSeparableConv, self).__init__()
- # In Tensorflow DepthwiseConv2D has depth_multiplier instead of kernels_per_layer
- self.depthwise = nn.Conv2d(in_channels, in_channels * kernels_per_layer, kernel_size=kernel_size,
- padding=padding,
- groups=in_channels)
- self.pointwise = nn.Conv2d(in_channels * kernels_per_layer, output_channels, kernel_size=1)
-
- def forward(self, x):
- x = self.depthwise(x)
- x = self.pointwise(x)
- return x
-
-
- class SmaAt_UNet(nn.Module):
- def __init__(self, n_channels = 175, n_classes = 70, kernels_per_layer=2, bilinear=True, reduction_ratio=16):
- super(SmaAt_UNet, self).__init__()
- self.n_channels = n_channels
- self.n_classes = n_classes
- kernels_per_layer = kernels_per_layer
- self.bilinear = bilinear
- reduction_ratio = reduction_ratio
-
- self.inc = DoubleConv(self.n_channels, 64)
- self.cbam1 = CBAM(64, reduction_ratio=reduction_ratio)
- self.down1 = Down(64, 128)
- self.cbam2 = CBAM(128, reduction_ratio=reduction_ratio)
- self.down2 = Down(128, 256)
- self.cbam3 = CBAM(256, reduction_ratio=reduction_ratio)
- self.down3 = Down(256, 512)
- self.cbam4 = CBAM(512, reduction_ratio=reduction_ratio)
- factor = 2 if self.bilinear else 1
- self.down4 = Down(512, 1024 // factor)
- self.cbam5 = CBAM(1024 // factor, reduction_ratio=reduction_ratio)
- self.up1 = Up(1024, 512 // factor, self.bilinear)
- self.up2 = Up(512, 256 // factor, self.bilinear)
- self.up3 = Up(256, 128 // factor, self.bilinear)
- self.up4 = Up(128, 64, self.bilinear)
-
- self.outc = OutConv(64, self.n_classes)
-
- def forward(self, x):
- b, c, t1, h, w = x.size()
- x = x.reshape(b, t1*c, h, w)
- x1 = self.inc(x)
- x1Att = self.cbam1(x1)
- x2 = self.down1(x1)
- x2Att = self.cbam2(x2)
- x3 = self.down2(x2)
- x3Att = self.cbam3(x3)
- x4 = self.down3(x3)
- x4Att = self.cbam4(x4)
- x5 = self.down4(x4)
- x5Att = self.cbam5(x5)
- x = self.up1(x5Att, x4Att)
- x = self.up2(x, x3Att)
- x = self.up3(x, x2Att)
- x = self.up4(x, x1Att)
- logits = self.outc(x)
- logits = logits.reshape(b,7,10,h,w)
-
- return logits
-
- for iii in range(1,6):
- for epoch111 in range(100,600,100):
-
- def setup_seed(seed):
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
- # 设置随机数种子
- setup_seed(iii)
-
- data = np.load(r'/dataset/7day_for_Nday_data_openi_09_17_atlantic_area_last.npz')
- print(data.files) # ['hycom_temp', 'slfh', 'sshf', 'ssr', 'str', 'mld', 'analysis_temp', 'u', 'v', 'T_d', 'u_d', 'v_d', 'xx', 'yy']
-
- hycom_temp = data['hycom_temp'][:] # (3281, 7, 7, 41, 201)
- slfh = data['slfh'][:] # (3281, 7, 41, 201)
- sshf = data['sshf'][:] # (3281, 7, 41, 201)
- ssr = data['ssr'][:] # (3281, 7, 41, 201)
- str = data['str'][:] # (3281, 7, 41, 201)
- mld = data['mld'][:] # (3281, 7, 41, 201)
- analysis_temp = data['analysis_temp'][:] # (3281, 7, 7, 41, 201)
- u = data['u'][:] # (3281, 7, 7, 41, 201)
- v = data['v'][:] # (3281, 7, 7, 41, 201)
- T_d = data['T_d'][:] # (3281, 7, 41, 201)
- u_d = data['u_d'][:] # (3281, 7, 41, 201)
- v_d = data['v_d'][:] # (3281, 7, 41, 201)
- xx = data['xx'][:] # (3281, 7, 41, 201)
- yy = data['yy'][:] # (3281, 7, 41, 201)
-
-
- analysis_temp = analysis_temp.transpose(0, 2, 1, 3, 4)
-
-
- train_size = 1952 # 前60%
- valid_size = 2624 # 中间20% 作为验证 剩下的20%的作为测试
-
- Q_net = (slfh + sshf + ssr + str)/86400
-
- # print(Q_net)
- hycom_temp = hycom_temp.transpose(0, 2, 1, 3, 4)
- hycom_temp = hycom_temp.reshape(-1,7,7,41,81)
- hycom_temp = torch.Tensor(hycom_temp)
- hycom_temp_train = hycom_temp[0:train_size,:,:,:,:]
- hycom_temp_valid = hycom_temp[train_size:valid_size,:,:,:,:]
- hycom_temp_test = hycom_temp[valid_size:,:,:,:,:]
-
-
- # slfh = slfh.reshape(-1,1,1,7,41,201)
- # slfh = torch.Tensor(slfh)
- # slfh_train = slfh[0:train_size,:,:,:,:]
- # slfh_valid = slfh[train_size:valid_size,:,:,:,:]
- # slfh_test = slfh[valid_size:,:,:,:,:]
- #
- # sshf = sshf.reshape(-1,1,1,7,41,201)
- # sshf = torch.Tensor(sshf)
- # sshf_train = sshf[0:train_size,:,:,:,:]
- # sshf_valid = sshf[train_size:valid_size,:,:,:,:]
- # sshf_test = sshf[valid_size:,:,:,:,:]
- #
- #
- # ssr = ssr.reshape(-1,1,1,7,41,201)
- # ssr = torch.Tensor(ssr)
- # ssr_train = ssr[0:train_size,:,:,:,:]
- # ssr_valid = ssr[train_size:valid_size,:,:,:,:]
- # ssr_test = ssr[valid_size:,:,:,:,:]
- #
- #
- # str = str.reshape(-1,1,1,7,41,201)
- # str = torch.Tensor(str)
- # str_train = str[0:train_size,:,:,:,:]
- # str_valid = str[train_size:valid_size,:,:,:,:]
- # str_test = str[valid_size:,:,:,:,:]
-
- Q_net = Q_net.reshape(-1,1,7,41,81)
- Q_net = torch.Tensor(Q_net)
- Q_net_train = Q_net[0:train_size,:,:,:,:]
- Q_net_valid = Q_net[train_size:valid_size,:,:,:,:]
- Q_net_test = Q_net[valid_size:,:,:,:,:]
-
-
- mld = mld.reshape(-1,1,7,41,81)
- mld = torch.Tensor(mld)
- mld_train = mld[0:train_size,:,:,:,:]
- mld_valid = mld[train_size:valid_size,:,:,:,:]
- mld_test = mld[valid_size:,:,:,:,:]
-
-
- analysis_temp = analysis_temp.reshape(-1,7,10,41,81)
- analysis_temp = torch.Tensor(analysis_temp)
- # analysis_temp_train = analysis_temp[0:train_size,:,:,:,:]
- # analysis_temp_valid = analysis_temp[train_size:valid_size,:,:,:,:]
- # analysis_temp_test = analysis_temp[valid_size:,:,:,:,:]
-
- u = u.transpose(0, 2, 1, 3, 4)
- u = u.reshape(-1,7,7,41,81)
- u = torch.Tensor(u)
- u_train = u[0:train_size,:,:,:,:]
- u_valid = u[train_size:valid_size,:,:,:,:]
- u_test = u[valid_size:,:,:,:,:]
-
- v = v.transpose(0, 2, 1, 3, 4)
- v = v.reshape(-1,7,7,41,81)
- v = torch.Tensor(v)
- v_train = v[0:train_size,:,:,:,:]
- v_valid = v[train_size:valid_size,:,:,:,:]
- v_test = v[valid_size:,:,:,:,:]
-
-
- T_d = T_d.reshape(-1,1,7,41,81)
- T_d = torch.Tensor(T_d)
- T_d_train = T_d[0:train_size,:,:,:,:]
- T_d_valid = T_d[train_size:valid_size,:,:,:,:]
- T_d_test = T_d[valid_size:,:,:,:,:]
-
-
- u_d = u_d.reshape(-1,1,7,41,81)
- u_d = torch.Tensor(u_d)
- u_d_train = u_d[0:train_size,:,:,:,:]
- u_d_valid = u_d[train_size:valid_size,:,:,:,:]
- u_d_test = u_d[valid_size:,:,:,:,:]
-
-
- v_d = v_d.reshape(-1,1,7,41,81)
- v_d = torch.Tensor(v_d)
- v_d_train = v_d[0:train_size,:,:,:,:]
- v_d_valid = v_d[train_size:valid_size,:,:,:,:]
- v_d_test = v_d[valid_size:,:,:,:,:]
-
-
- xx = xx.reshape(-1,1,7,41,81)
- xx = torch.Tensor(xx)
- xx_train = xx[0:train_size,:,:,:,:]
- xx_valid = xx[train_size:valid_size,:,:,:,:]
- xx_test = xx[valid_size:,:,:,:,:]
-
-
- yy = yy.reshape(-1,1,7,41,81)
- yy = torch.Tensor(yy)
- yy_train = yy[0:train_size,:,:,:,:]
- yy_valid = yy[train_size:valid_size,:,:,:,:]
- yy_test = yy[valid_size:,:,:,:,:]
-
-
- train_data = torch.cat((hycom_temp_train, Q_net_train, mld_train, u_train, v_train, T_d_train, v_d_train), dim=1) # train_data.shape:torch.Size([1952, 25, 7, 41, 81]) # 数据 第一维时间 第二维深度和变量 第三维7dat seq 4 5lat lon
-
-
-
- valid_data = torch.cat((hycom_temp_valid, Q_net_valid, mld_valid, u_valid, v_valid, T_d_valid, v_d_valid), dim=1) # valid_data.shape:torch.Size([672, 7, 32, 41, 81])
-
-
- test_data = torch.cat((hycom_temp_test, Q_net_test, mld_test, u_test, v_test, T_d_test, v_d_test), dim=1) #
-
-
- train_label = analysis_temp[7:train_size + 7,:,:,:] # train_label.shape:torch.Size([1952, 7, 7, 41, 81])
-
- valid_label = analysis_temp[train_size + 7:valid_size + 7,:,:,:] # valid_label.shape:torch.Size([672, 7, 7, 41, 81])
-
-
- test_label = analysis_temp[valid_size + 7: 3646,:,:,:] # test_label.shape:torch.Size([650, 7, 7, 41, 81])
-
-
- train_label = torch.Tensor(train_label)
- valid_label = torch.Tensor(valid_label)
- test_label = torch.Tensor(test_label)
-
- test_label11 = test_label
- valid_label11 = valid_label
-
- #构建数据管道
- class MyDataset(Dataset):
- def __init__(self, data, label):
- self.data = torch.Tensor(data)
- self.label = torch.Tensor(label)
-
- def __len__(self):
- return len(self.label)
-
- def __getitem__(self, idx):
- return self.data[idx], self.label[idx]
-
-
- batch_size1 = 32
- batch_size2 = 32
- batch_size3 = 3000
- # trainset = MyDataset(input_data, sst)
- # trainset = MyDataset(train_data,train_label)
- # trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
- # batch_size = 1826
- trainset = MyDataset(train_data, train_label)
- trainloader = DataLoader(trainset, batch_size=batch_size1, shuffle=True, drop_last=False,pin_memory=True, num_workers=4)
-
- validset = MyDataset(valid_data, valid_label)
- validloader = DataLoader(validset, batch_size=batch_size2, shuffle=True, drop_last=False,pin_memory=True, num_workers=4)
-
- testset = MyDataset(test_data, test_label)
- testloader = DataLoader(testset, batch_size=batch_size3, shuffle=False, drop_last=False,pin_memory=True, num_workers=4)
-
- # print('Qnet_train.shape:{}'.format(sst1_train.shape))
-
-
- model_weights1 = '/model/Unet_CAMB_black_WM__epo{}_lay1_e{}_black_ahead_10_day_model_weights.pth'.format(epoch111,iii)
-
- print('-----------------------train_black_att_3dcnn_convlstm--------------------------')
-
-
- model = SmaAt_UNet().cuda()
- # model = model.cuda()
- # out = model(a) #torch.Size([1826, 12, 15])
- # print(out.shape)
- criterion = nn.MSELoss()
- # 定义优化器
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
- epochs = epoch111
- train_losses, valid_losses = [], []
- # best_loss = 2
- best_score = float('inf')
- best_score1 = float('inf')
- preds = np.zeros((672,7,10,41,81))
- # preds = np.expand_dims(preds, axis=1)#preds.reshape(100,1,12,15)
- # preds = np.zeros((18000,1))
- print(preds.shape)
- sores = []
- def rmse(y_true, y_preds):
- return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))
-
- for epoch in range(epochs):
- print('Epoch: {}/{}'.format(epoch + 1, epochs))
- # print(var_y)
- #模型训练
- model.train()
- losses = 0
- for data, label in tqdm(trainloader):
- data = data.cuda()
- label = label.cuda()
-
- optimizer.zero_grad()
- # print('data.shape:{}'.format(data.shape)) #torch.Size([160, 5, 1, 12, 15])
- out = model(data)
-
- loss = criterion(out, label)
- # losses +=loss.cpu().detach().numpy
- losses += loss
- # 反向传播
- loss.backward()
- optimizer.step()
- train_loss = losses / len(trainloader)
- train_losses.append(train_loss)
- print('Training Loss: {:.3f}'.format(train_loss))
-
- # # 模型验证
- # model.eval()
- losses = 0
-
- with torch.no_grad():
- for i, data in tqdm(enumerate(validloader)):
- # print('i:{}'.format(i))
- data, labels = data
- data = data.cuda()
- label = label.cuda()
- optimizer.zero_grad()
-
- pred = model(data)
-
- # print('pred_shape:{}'.format(pred.shape)) #torch.Size([2, 1, 12, 15])
- loss = criterion(pred, label)
- losses += loss
-
- preds[i * batch_size2:(i + 1) * batch_size2] = pred.cpu()
- # print(preds.shape)
- # preds[i * batch_size2:(i + 1) * batch_size2] = np.array(tmp)
- valid_loss = losses / len(validloader)
- valid_losses.append(valid_loss)
- print('Validation Loss: {:.3f}'.format(valid_loss))
-
- valid_label1 = valid_label.reshape(-1,1)
- preds1 = preds.reshape(-1,1)
- # print('valid_label1.shape:{}'.format(valid_label1.shape))#(18000,1)
- # print('preds.shape:{}0'.format(preds.shape)) #(360,1)
- s = rmse(valid_label1,preds1)
- sores.append(s)
- print('Score: {:.3f}'.format(s))
- # 保存最佳模型权重
- # if s < best_score: #求s的最小值 ---》最大值反过来 inf符号也要反过来
- # best_score = s
- # checkpoint = {'best_score': s,
- # 'state_dict': model.state_dict()}
- # torch.save(checkpoint, model_weights) # if valid_loss < best_loss:
- # best_loss = valid_loss
- # torch.save(model.state_dict(), './model_epo300_5e-4__hidden64_lay3e2_14day.pt')
- if valid_loss < best_score1: # 求s的最小值 ---》最大值反过来 inf符号也要反过来
- best_score1 = valid_loss
- checkpoint = {'best_score': valid_loss,
- 'state_dict': model.state_dict()}
- torch.save(checkpoint, model_weights1) # if valid_loss < best_loss:
- best_loss = valid_loss
- torch.save(model.state_dict(),
- '/model/Unet_CAMB_black_WM_e{}_layer1_ahead_1_day_e{}.pt'.format(epoch111,iii))
-
- print(sores)
- print('best_score1:{}'.format(best_score1))
- print('s:{}'.format(s))
- print(valid_losses)
-
- print('valid_losses.:{}'.format(valid_losses))
-
-
- # 表面数据 ----第一天 到7天
- # list1 = [] # 存dT_1 ... dT_6 与计算的loss
- # for j in range(7): # 7个深度
- # for i in range(7):
- # T_m_dj_i = T_m[:, j, i, :, :].reshape(-1, 1, 40, 80)
- # u_dj_i = u[:, j, i, :, :].reshape(-1, 1, 40, 80)
- # v_dj_i = v[:, j, i, :, :].reshape(-1, 1, 40, 80)
- # # dT_dt_dj_i = dT_dt[:, j, i, :, :].reshape(-1, 1, 40, 80) # 过去七天也不可知dT_dt7
- # dT_dx_dj_i = dT_dx[:, j, i, :, :].reshape(-1, 1, 40, 80)
- # dT_dy_dj_i = dT_dy[:, j, i, :, :].reshape(-1, 1, 40, 80)
- # Q_i = Q[:, :, i, :, :].reshape(-1, 1, 40, 80)
- # h_m_i = h_m[:, :, i, :, :].reshape(-1, 1, 40, 80)
- # T_d_i = T_d[:, :, i, :, :].reshape(-1, 1, 40, 80)
- # w_e_i = w_e[:, :, i, :, :].reshape(-1, 1, 40, 80)
-
- # data_j_i = u_dj_i * (dT_dx_dj_i)
- # # print('data{}.shape:{}'.format(i, datai.shape)) # data0.shape:torch.Size([32, 1, 40, 80])
- # out_j_i = model(data_j_i)
- # # print('out{}.shape:{}'.format(i, outi.shape)) # data0.shape:torch.Size([32, 1, 40, 80])
-
- # dT_dti = Q_i / (1025 * 4000 * h_m_i) + v_dj_i * (dT_dy_dj_i) + out_j_i + w_e_i * ((T_m_dj_i - T_d_i) / h_m_i)
- # # print(dT_dti.shape)
- # list1.append(dT_dti)
-
- #深度0,天数1
- T_m_d0_1 = T_m[:, 0, 0, :, :].reshape(-1, 1, 40, 80)
- u_d0_1 = u[:, 0, 0, :, :].reshape(-1, 1, 40, 80)
- v_d0_1 = v[:, 0, 0, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_1 = dT_dx[:, 0, 0, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_1 = dT_dy[:, 0, 0, :, :].reshape(-1, 1, 40, 80)
- Q_1 = Q[:, :, 0, :, :].reshape(-1, 1, 40, 80)
- h_m_1 = h_m[:, :, 0, :, :].reshape(-1, 1, 40, 80)
- T_d_1 = T_d[:, :, 0, :, :].reshape(-1, 1, 40, 80)
- w_e_1 = w_e[:, :, 0, :, :].reshape(-1, 1, 40, 80)
-
- data_0_1 = u_d0_1 * (dT_dx_d0_1)
- out_0_1 = model(data_0_1)
-
- dT_dt_0_1 = Q_1 / (1025 * 4000 * h_m_1) + v_d0_1 * (dT_dy_d0_1) + out_0_1 + w_e_1 * ((T_m_d0_1 - T_d_1) / h_m_1)
-
-
- #深度0,天数2
- T_m_d0_2 = T_m[:, 0, 1, :, :].reshape(-1, 1, 40, 80)
- u_d0_2 = u[:, 0, 1, :, :].reshape(-1, 1, 40, 80)
- v_d0_2 = v[:, 0, 1, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_2 = dT_dx[:, 0, 1, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_2 = dT_dy[:, 0, 1, :, :].reshape(-1, 1, 40, 80)
- Q_2 = Q[:, :, 1, :, :].reshape(-1, 1, 40, 80)
- h_m_2 = h_m[:, :, 1, :, :].reshape(-1, 1, 40, 80)
- T_d_2 = T_d[:, :, 1, :, :].reshape(-1, 1, 40, 80)
- w_e_2 = w_e[:, :, 1, :, :].reshape(-1, 1, 40, 80)
-
- data_0_2 = u_d0_2 * (dT_dx_d0_2)
- out_0_2 = model(data_0_2)
-
- dT_dt_0_2 = Q_2 / (1025 * 4000 * h_m_2) + v_d0_2 * (dT_dy_d0_2) + out_0_2 + w_e_2 * ((T_m_d0_2 - T_d_2) / h_m_2)
-
-
- #深度0,天数3
- T_m_d0_3 = T_m[:, 0, 2, :, :].reshape(-1, 1, 40, 80)
- u_d0_3 = u[:, 0, 2, :, :].reshape(-1, 1, 40, 80)
- v_d0_3 = v[:, 0, 2, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_3 = dT_dx[:, 0, 2, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_3 = dT_dy[:, 0, 2, :, :].reshape(-1, 1, 40, 80)
- Q_3 = Q[:, :, 2, :, :].reshape(-1, 1, 40, 80)
- h_m_3 = h_m[:, :, 2, :, :].reshape(-1, 1, 40, 80)
- T_d_3 = T_d[:, :, 2, :, :].reshape(-1, 1, 40, 80)
- w_e_3 = w_e[:, :, 2, :, :].reshape(-1, 1, 40, 80)
-
- data_0_3 = u_d0_3 * (dT_dx_d0_3)
- out_0_3 = model(data_0_3)
-
- dT_dt_0_3 = Q_3 / (1025 * 4000 * h_m_3) + v_d0_3 * (dT_dy_d0_3) + out_0_3 + w_e_3 * ((T_m_d0_3 - T_d_3) / h_m_3)
-
-
- #深度0,天数4
- T_m_d0_4 = T_m[:, 0, 3, :, :].reshape(-1, 1, 40, 80)
- u_d0_4 = u[:, 0, 3, :, :].reshape(-1, 1, 40, 80)
- v_d0_4 = v[:, 0, 3, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_4 = dT_dx[:, 0, 3, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_4 = dT_dy[:, 0, 3, :, :].reshape(-1, 1, 40, 80)
- Q_4 = Q[:, :, 3, :, :].reshape(-1, 1, 40, 80)
- h_m_4 = h_m[:, :, 3, :, :].reshape(-1, 1, 40, 80)
- T_d_4 = T_d[:, :, 3, :, :].reshape(-1, 1, 40, 80)
- w_e_4 = w_e[:, :, 3, :, :].reshape(-1, 1, 40, 80)
-
- data_0_4 = u_d0_4 * (dT_dx_d0_4)
- out_0_4 = model(data_0_4)
-
- dT_dt_0_4 = Q_4 / (1025 * 4000 * h_m_4) + v_d0_4 * (dT_dy_d0_4) + out_0_4 + w_e_4 * ((T_m_d0_4 - T_d_4) / h_m_4)
-
-
- #深度0,天数5
- T_m_d0_5 = T_m[:, 0, 4, :, :].reshape(-1, 1, 40, 80)
- u_d0_5 = u[:, 0, 4, :, :].reshape(-1, 1, 40, 80)
- v_d0_5 = v[:, 0, 4, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_5 = dT_dx[:, 0, 4, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_5 = dT_dy[:, 0, 4, :, :].reshape(-1, 1, 40, 80)
- Q_5 = Q[:, :, 4, :, :].reshape(-1, 1, 40, 80)
- h_m_5 = h_m[:, :, 4, :, :].reshape(-1, 1, 40, 80)
- T_d_5 = T_d[:, :, 4, :, :].reshape(-1, 1, 40, 80)
- w_e_5 = w_e[:, :, 4, :, :].reshape(-1, 1, 40, 80)
-
- data_0_5 = u_d0_5 * (dT_dx_d0_5)
- out_0_5 = model(data_0_5)
-
- dT_dt_0_5 = Q_5 / (1025 * 4000 * h_m_5) + v_d0_5 * (dT_dy_d0_5) + out_0_5 + w_e_5 * ((T_m_d0_5 - T_d_5) / h_m_5)
-
-
- #深度0,天数6
- T_m_d0_6 = T_m[:, 0, 5, :, :].reshape(-1, 1, 40, 80)
- u_d0_6 = u[:, 0, 5, :, :].reshape(-1, 1, 40, 80)
- v_d0_6 = v[:, 0, 5, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_6 = dT_dx[:, 0, 5, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_6 = dT_dy[:, 0, 5, :, :].reshape(-1, 1, 40, 80)
- Q_6 = Q[:, :, 5, :, :].reshape(-1, 1, 40, 80)
- h_m_6 = h_m[:, :, 5, :, :].reshape(-1, 1, 40, 80)
- T_d_6 = T_d[:, :, 5, :, :].reshape(-1, 1, 40, 80)
- w_e_6 = w_e[:, :, 5, :, :].reshape(-1, 1, 40, 80)
-
- data_0_6 = u_d0_6 * (dT_dx_d0_6)
- out_0_6 = model(data_0_6)
-
- dT_dt_0_6 = Q_6 / (1025 * 4000 * h_m_6) + v_d0_6 * (dT_dy_d0_6) + out_0_6 + w_e_6 * ((T_m_d0_6 - T_d_6) / h_m_6)
-
-
- #深度0,天数7
- T_m_d0_7 = T_m[:, 0, 6, :, :].reshape(-1, 1, 40, 80)
- u_d0_7 = u[:, 0, 6, :, :].reshape(-1, 1, 40, 80)
- v_d0_7 = v[:, 0, 6, :, :].reshape(-1, 1, 40, 80);
- dT_dx_d0_7 = dT_dx[:, 0, 6, :, :].reshape(-1, 1, 40, 80)
- dT_dy_d0_7 = dT_dy[:, 0, 6, :, :].reshape(-1, 1, 40, 80)
- Q_7 = Q[:, :, 6, :, :].reshape(-1, 1, 40, 80)
- h_m_7 = h_m[:, :, 6, :, :].reshape(-1, 1, 40, 80)
- T_d_7 = T_d[:, :, 6, :, :].reshape(-1, 1, 40, 80)
- w_e_7 = w_e[:, :, 6, :, :].reshape(-1, 1, 40, 80)
-
- data_0_7 = u_d0_7 * (dT_dx_d0_7)
- out_0_7 = model(data_0_7)
-
- dT_dt_0_7 = Q_7 / (1025 * 4000 * h_m_7) + v_d0_7 * (dT_dy_d0_7) + out_0_7 + w_e_7 * ((T_m_d0_7 - T_d_7) / h_m_7)
-
-
- # #深度1,天数1
- # T_m_d1_1 = T_m[:, 1, 0, :, :].reshape(-1, 1, 40, 80)
- # u_d1_1 = u[:, 1, 0, :, :].reshape(-1, 1, 40, 80)
- # v_d1_1 = v[:, 1, 0, :, :].reshape(-1, 1, 40, 80);
- # dT_dx_d1_1 = dT_dx[:, 1, 0, :, :].reshape(-1, 1, 40, 80)
- # dT_dy_d1_1 = dT_dy[:, 1, 0, :, :].reshape(-1, 1, 40, 80)
- # Q_1 = Q[:, :, 0, :, :].reshape(-1, 1, 40, 80)
- # h_m_1 = h_m[:, :, 0, :, :].reshape(-1, 1, 40, 80)
- # T_d_1 = T_d[:, :, 0, :, :].reshape(-1, 1, 40, 80)
- # w_e_1 = w_e[:, :, 0, :, :].reshape(-1, 1, 40, 80)
-
- # data_1_1 = u_d1_1 * (dT_dx_d1_1)
- # out_1_1 = model(data_1_1)
-
- # dT_dt_1_1 = Q_1 / (1025 * 4000 * h_m_1) + v_d1_1 * (dT_dy_d1_1) + out_1_1 + w_e_1 * ((T_m_d1_1 - T_d_1) / h_m_1)
-
-
- # #深度1,天数2
- # T_m_d1_2 = T_m[:, 1, 1, :, :].reshape(-1,
|