|
- import math
- from functools import partial
- from collections import OrderedDict
- from copy import Error, deepcopy
- from re import S
- from numpy.lib.arraypad import pad
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- #from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.models.layers import DropPath, trunc_normal_
- import torch.fft
- from torch.nn.modules.container import Sequential
- from torch.utils.checkpoint import checkpoint_sequential
- from einops import rearrange, repeat
- from einops.layers.torch import Rearrange
- # from utils.img_utils import PeriodicPad2d
- import gc
- import time
- # load dataset
- import numpy as np
- import os
- os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'
- import pytorch_lightning as pl
- from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
- import torch.optim as optim
-
- batch_size = 7
- print("start loading Dataset")
- import h5py
- N_TRAIN= -1
- import datetime
- #sevir generator
- import warnings
- warnings.filterwarnings("ignore")
- # Make sure you add SEVIR module to your path
- import sys
- sys.path.append('/home/movis/ldw/sevir') # enter path to sevir module if not installed.
- os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
- TYPES = ['vis', 'ir069', 'ir107', 'vil', 'lght']
- DEFAULT_CATALOG = '/mnt/data/sevir_data/CATALOG.csv'
- DEFAULT_DATA_HOME = '/mnt/data/sevir_data/'
- FRAME_TIMES = np.arange(-120.0, 120.0, 5) * 60
-
- # A keras.Sequece class for SEVIR
- import numpy as np
- from generator import SEVIRGenerator
-
-
- def prep_clf(sim, obs, threshold=0.00001):
- obs = np.asarray(obs.cpu().detach().numpy())
- sim = np.asarray(sim.cpu().detach().numpy())
- obs = np.where(obs >= threshold, 1, 0)
- sim = np.where(sim >= threshold, 1, 0)
-
- # True positive (TP)
- hits = np.sum((obs == 1) & (sim == 1))
-
- # False negative (FN)
- misses = np.sum((obs == 1) & (sim == 0))
-
- # False positive (FP)
- falsealarms = np.sum((obs == 0) & (sim == 1))
-
- # True negative (TN)
- correctnegatives = np.sum((obs == 0) & (sim == 0))
-
- return hits, misses, falsealarms, correctnegatives
-
-
- def CSI(sim, obs, threshold=0.00001):
-
- hits, misses, falsealarms, correctnegatives = prep_clf(obs=obs, sim=sim,
- threshold=threshold)
-
- results = (hits / (hits + misses + falsealarms)).mean()
-
- return results
-
-
-
-
-
-
-
-
-
- class PeriodicPad2d(nn.Module):
- """
- pad longitudinal (left-right) circular
- and pad latitude (top-bottom) with zeros
- """
- def __init__(self, pad_width):
- super(PeriodicPad2d, self).__init__()
- self.pad_width = pad_width
-
- def forward(self, x):
- # pad left and right circular
- out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular")
- # pad top and bottom zeros
- out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0)
- return out
-
-
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
- class AFNO2D(nn.Module):
- def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
- super().__init__()
- assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"
-
- self.hidden_size = hidden_size
- self.sparsity_threshold = sparsity_threshold
- self.num_blocks = num_blocks
- self.block_size = self.hidden_size // self.num_blocks
- self.hard_thresholding_fraction = hard_thresholding_fraction
- self.hidden_size_factor = hidden_size_factor
- self.scale = 0.02
-
- self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
- self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
- self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
- self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
-
- def forward(self, x):
- bias = x
-
- dtype = x.dtype
- x = x.float()
- B, H, W, C = x.shape
-
- x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
- x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)
-
- o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
- o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
- o2_real = torch.zeros(x.shape, device=x.device)
- o2_imag = torch.zeros(x.shape, device=x.device)
-
-
- total_modes = H // 2 + 1
- kept_modes = int(total_modes * self.hard_thresholding_fraction)
-
- o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
- torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
- torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
- self.b1[0]
- )
-
- o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
- torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
- torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
- self.b1[1]
- )
-
- o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
- torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
- torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
- self.b2[0]
- )
-
- o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = (
- torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
- torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
- self.b2[1]
- )
-
- x = torch.stack([o2_real, o2_imag], dim=-1)
- x = F.softshrink(x, lambd=self.sparsity_threshold)
- x = torch.view_as_complex(x)
- x = x.reshape(B, H, W // 2 + 1, C)
- x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
- x = x.type(dtype)
-
- return x + bias
-
-
- class Block(nn.Module):
- def __init__(
- self,
- dim,
- mlp_ratio=4.,
- drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- double_skip=True,
- num_blocks=8,
- sparsity_threshold=0.01,
- hard_thresholding_fraction=1.0
- ):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- #self.drop_path = nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
- self.double_skip = double_skip
-
- def forward(self, x):
- residual = x
- x = self.norm1(x)
- x = self.filter(x)
-
- if self.double_skip:
- x = x + residual
- residual = x
-
- x = self.norm2(x)
- x = self.mlp(x)
- x = self.drop_path(x)
- x = x + residual
- return x
-
- class PrecipNet(nn.Module):
- def __init__(self, backbone):
- super().__init__()
- # self.params = params
- # self.patch_size = (params.patch_size, params.patch_size)
- self.in_chans = 5
- self.out_chans = 20
- self.backbone = backbone
- self.ppad = PeriodicPad2d(1)
- self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True)
- self.act = nn.ReLU()
-
- def forward(self, x):
- x = self.backbone(x)
- x = self.ppad(x)
- x = self.conv(x)
- x = self.act(x)
- return x
-
- class AFNONet(nn.Module):# params,
- def __init__(
- self,
-
- img_size=(720, 1440),
- patch_size=(16, 16),
- in_chans=2,
- out_chans=2,
- embed_dim=768,
- depth=12,
- mlp_ratio=4.,
- drop_rate=0.,
- drop_path_rate=0.,
- num_blocks=16,
- sparsity_threshold=0.01,
- hard_thresholding_fraction=1.0,
- ):
- super().__init__()
- # self.params = params
- self.img_size = img_size
- self.patch_size = (patch_size, patch_size)
- self.in_chans = in_chans
- self.out_chans = out_chans
- self.num_features = self.embed_dim = embed_dim
- self.num_blocks = num_blocks
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
-
- self.patch_embed = PatchEmbed(img_size=img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim)
- num_patches = self.patch_embed.num_patches
-
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
-
- self.h = 384 // 4
- self.w = 384 // 4
-
- self.blocks = nn.ModuleList([
- Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
- num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction)
- for i in range(depth)])
-
- self.norm = norm_layer(embed_dim)
-
- self.head = nn.Linear(embed_dim, self.out_chans*4*4, bias=False)
-
- trunc_normal_(self.pos_embed, std=.02)
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed', 'cls_token'}
-
- def forward_features(self, x):
- B = x.shape[0]
-
- x = self.patch_embed(x)
-
- x = x + self.pos_embed
-
- x = self.pos_drop(x)
-
- x = x.reshape(B, 96, 96, self.embed_dim)
- for blk in self.blocks:
- x = blk(x)
-
- return x
-
- def forward(self, x):
- x = self.forward_features(x)
- x = self.head(x)
- x = rearrange(
- x,
- "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)",
- p1=4,
- p2=4,
- h=96,
- w=96,
- )
- return x
-
-
- class PatchEmbed(nn.Module):
- def __init__(self, img_size=(384, 384), patch_size=(4, 4), in_chans=5, embed_dim=384): #224, 224
- super().__init__()
- num_patches = (384 // 4) * (384 // 4)
- self.img_size = img_size
- self.patch_size = patch_size
- self.num_patches = num_patches
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=4, stride=4)
-
- def forward(self, x):
- B, C, H, W = x.shape
- assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2)
- return x
-
- class AF_PL(pl.LightningModule):
- def __init__(self, lr):
- super().__init__()
- self.save_hyperparameters()
- self.model = PrecipNet(AFNONet(img_size=(384, 384), patch_size=(4,4), in_chans=5, out_chans=20))
-
-
- def forward(self, x):
- return self.model(x)
-
- def configure_optimizers(self):
- optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
- lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
- return [optimizer], [lr_scheduler]
-
- def _calculate_loss(self, batch, mode="train"):
- imgs, labels = batch
- preds = self.model(imgs)
- loss = F.mse_loss(preds, labels)
- csi = CSI(preds,labels,threshold=74)
-
- self.log("%s_loss" % mode, loss)
- self.log("%s_csi" % mode, csi)
- return loss
-
- def training_step(self, batch, batch_idx):
- loss = self._calculate_loss(batch, mode="train")
- return loss
-
- def validation_step(self, batch, batch_idx):
- self._calculate_loss(batch, mode="val")
-
- def test_step(self, batch, batch_idx):
- self._calculate_loss(batch, mode="test")
-
-
-
- from pytorch_lightning.callbacks.early_stopping import EarlyStopping
- def train_model(**kwargs):
- trainer = pl.Trainer(
- default_root_dir=os.path.join(CHECKPOINT_PATH, "nowcast_test"),
- gpus=3, # if str(device) == "cuda:0" else 0,
- strategy='ddp',
- max_epochs=50,
- callbacks=[
- ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_csi"),
- LearningRateMonitor("epoch"),
- EarlyStopping(monitor="val_csi", mode="max",patience=3),
- ],
- progress_bar_refresh_rate=1,
- )
- trainer.logger._log_graph = True # If True, we plot the computation graph in tensorboard
- trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
-
- # Check whether pretrained model exists. If yes, load it and skip training
- pretrained_filename = '/home/movis/ldw/AFNOnet/saved_models/xxxxAFNO1.ckpt'
- if os.path.isfile(pretrained_filename):
- print("Found pretrained model at %s, loading..." % pretrained_filename)
- # Automatically loads the model with the saved hyperparameters
- model = AF_PL.load_from_checkpoint(pretrained_filename)
- trainer.fit(model, train_loader, val_loader)
- # Load best checkpoint after training
- model = AF_PL.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
-
-
- else:
- pl.seed_everything(42) # To be reproducable
- model = AF_PL(lr=3e-4) #DDP()
- trainer.fit(model, train_loader, val_loader)
- # Load best checkpoint after training
- model = AF_PL.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
-
- # Test best model on validation and test set
- val_result = trainer.test(model, dataloaders=val_loader, verbose=False)
- test_result = trainer.test(model, dataloaders=test_loader, verbose=False)
- result = {"test": test_result[0]["test_csi"], "val": val_result[0]["test_csi"]}
- trainer.save_checkpoint('/home/movis/ldw/AFNOnet/saved_models/1119AFNO_precip_valcsi1.ckpt')
- return model, result
-
-
-
-
- if __name__ == "__main__":
-
- vil_gen = SEVIRGenerator(x_img_types=['vil'],batch_size=1,unwrap_time=False,
- start_date=datetime.datetime(2017,1,1),
- end_date=datetime.datetime(2019,7,1))
-
- vil_gen1 = SEVIRGenerator(x_img_types=['vil'],batch_size=1,unwrap_time=False,
- start_date=datetime.datetime(2019,7,1),
- end_date=datetime.datetime(2019,12,31))
-
-
- train_loader=torch.utils.data.DataLoader(vil_gen, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
- val_loader=torch.utils.data.DataLoader(vil_gen1, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)
- test_loader=torch.utils.data.DataLoader(vil_gen1, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)
- CHECKPOINT_PATH = "/home/movis/ldw/AFNOnet/saved_models"
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- print("Device:", device)
-
- t1=time.time()
-
- model, result = train_model()
-
-
- t2=time.time()
-
- t=t2-t1
-
- print('The process is running for {:.0f} hours {:.0f} minutes {:.0f} seconds'.format(t//3600,t%3600//60,t%60))
-
- print("AFNO results", result)
-
-
- print('Congratulations,training is done!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
-
-
-
- ######################################################
- print("Starting predicting")
-
- def torch_to_nppixel(results):
- results=np.uint8(np.clip(0,255,results))
- return results
-
-
- #metrics
- def prep_clf(sim, obs,threshold=0.1):
- obs=np.asarray(obs)
- sim=np.asarray(sim)
- obs = np.where(obs >= threshold, 1, 0)
- sim = np.where(sim >= threshold, 1, 0)
- # True positive (TP)
- hits = np.sum((obs == 1) & (sim == 1))
- # False negative (FN)
- misses = np.sum((obs == 1) & (sim == 0))
- # False positive (FP)
- falsealarms = np.sum((obs == 0) & (sim == 1))
- # True negative (TN)
- correctnegatives = np.sum((obs == 0) & (sim == 0))
- return hits, misses, falsealarms, correctnegatives
-
-
-
- def CSI(sim, obs, threshold=0.1):
- hits, misses, falsealarms, correctnegatives = prep_clf(obs=obs, sim=sim,
- threshold=threshold)
- results=(hits / (hits + misses + falsealarms)).mean()
- return results
-
- def RMSE(obs, sim):
-
- obs = obs.flatten()
- sim = sim.flatten()
-
- return np.sqrt(np.mean((obs - sim) ** 2))
-
- def FAR(obs, sim, threshold=0.1):
-
- hits, misses, falsealarms, correctnegatives = prep_clf(obs=obs, sim=sim,
- threshold=threshold)
-
- return falsealarms / (hits + falsealarms)
-
- def HSS(obs, sim, threshold=0.1):
-
- hits, misses, falsealarms, correctnegatives = prep_clf(obs=obs, sim=sim,
- threshold=threshold)
-
- HSS_num = 2 * (hits * correctnegatives - misses * falsealarms)
- HSS_den = (misses**2 + falsealarms**2 + 2*hits*correctnegatives +
- (misses + falsealarms)*(hits + correctnegatives))
-
- return HSS_num / HSS_den
-
- def evaluate_models(model_sim,obs,metrics,thres):
- scores=[]
- if 'rmse' in metrics:
- for i in range(20):
- single_frame_score=RMSE(obs[:,i,],model_sim[:,i,])
- scores.append(single_frame_score)
- elif 'csi' in metrics:
- for i in range(20):
- single_frame_score=CSI(obs[:,i,],model_sim[:,i,],threshold=thres)
- scores.append(single_frame_score)
- elif 'far' in metrics:
- for i in range(20):
- single_frame_score=FAR(obs[:,i,],model_sim[:,i,],threshold=thres)
- scores.append(single_frame_score)
- elif 'hss' in metrics:
- for i in range(20):
- single_frame_score=HSS(obs[:,i,],model_sim[:,i,],threshold=thres)
- scores.append(single_frame_score)
-
- return np.asarray(scores)
-
-
-
- vil_gen = SEVIRGenerator(x_img_types=['vil'],batch_size=10,unwrap_time=False,
- start_date=datetime.datetime(2019,1,1),
- end_date=datetime.datetime(2019,12,31))
-
- final = vil_gen.load_batches(n_batches=300,offset=500,progress_bar=True) #offset=500, 6 7 already: 500 600 700
- print(final[0].shape)
- final=np.expand_dims(final[0],axis=-1)
- final=np.transpose(final,(0,3,4,1,2))
- final=np.asarray(final[:,:25,])
- input=torch.Tensor(final[:,:5,])
- ground_truth=final[:,5:,0]
- print('transposed_final_data shape: ',final.shape)
- del(final)
-
- res_input=torch.utils.data.TensorDataset(input[:,:,0])
-
- res_loader=torch.utils.data.DataLoader(res_input, batch_size=40, shuffle=False)
- os.environ['CUDA_VISIBLE_DEVICES']='0'
- # pretrained_filename='/home/movis/ldw/restormer/saved_models/1110restomer1.ckpt'
- # model = Res.load_from_checkpoint(pretrained_filename)
- model.to(device)
- model.eval()
- dataiter=iter(res_loader)
- test_len=res_loader.__len__()
- results=[]
- with torch.no_grad():
- for i in range(test_len):
- in1=next(dataiter)
- out1=model(in1[0].cuda()) #.cuda()
- results.append(out1.cpu().detach().numpy())#.cpu().detach()
- del(in1)
- del(out1)
- gc.collect()
- results=np.array(results).reshape(-1,20,384,384)
- results.shape
-
- results=torch_to_nppixel(results)
- root_path='/home/movis/ldw/final_prediction_results/3000samples'
- model_name='afno/'
- path=os.path.join(root_path,model_name)
- if not os.path.exists(path):
- os.makedirs(path)
- np.save(path+'afno_precip_1119data.npy',results)
-
- import numpy as np
- metrics=['rmse','far','hss']
- for i in metrics:
- res_mtr=evaluate_models(results,ground_truth,i,thres=0.1)
- np.save(path+'/metrics_valcsi/afno_{}.npy'.format(i),res_mtr)
- print(i,':','%.4f' % np.mean(res_mtr))
-
-
- import cupy as np
- os.environ['CUDA_VISIBLE_DEVICES']='0'
- thres=[0.1,16,74] # [133,181]
- for i in thres:
-
- res_csi=evaluate_models(results,ground_truth,'csi',thres=i)
- np.save(path+'/metrics_valcsi/{}afno_csi.npy'.format(int(i)),res_csi)
- print('CSI-{}:'.format(i),'%.4f' % np.mean(res_csi))
-
- thres=[133,181] # [133,181]
- for i in thres:
-
- res_csi=evaluate_models(results,ground_truth,'csi',thres=i)
- np.save(path+'/metrics_valcsi/{}afno_csi.npy'.format(int(i)),res_csi)
- print('CSI-{}:'.format(i),'%.4f' % np.mean(res_csi))
-
-
-
-
-
|