|
- import pickle
- import torch
- from ffrecord.torch import Dataset as FFDataset
- from pathlib import Path
- from ffrecord import FileReader
- import time
-
- """
- 数据组织结构:
- [data_dir]
- scaler.pkl 统计信息
- train.ffr/
- 001.ffr 第1个月的数据
- 002.ffr 第2个月的数据
- ...
- val.ffr/
- ...
- test.ffr/
- ...
- """
-
- class StandardScaler:
- def __init__(self):
- self.mean = 0.0
- self.std = 1.0
-
- def load(self, scaler_dir):
- with open(scaler_dir, 'rb') as f:
- pkl = pickle.load(f)
- self.mean = pkl['global_means']
- self.std = pkl['global_stds']
- return self.mean, self.std
-
- def inverse_transform(self, data):
- mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
- std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
- return (data * std) + mean
-
- class WeatherBench(FFDataset):
- """
- 这是一个全球大气再分析数据及,数据及由Pengo地球科学大数据社区构建并开元,从1979年到2018年没1小时的全球再分析,
- 分辨率分别为5.625度、2.8125度和1.40625度
- 数据包含多个要素指标:
- "10m_u_comonent_of_wind", "10m_v_comonent_of_wind", "2m_temperature", "geopotential", "geopotential_500",
- "potential_vorticity", "relative_humidity", "specific_humidity", "temperature", "temperature_850",
- "toa_incident_solar_radiation", "total_cloud_cover", "total_precipitation", "u_comonent_of_wind", "v_comonent_of_wind"
-
- Args:
- data_dir (str): 数据存的路径
- mode (str): 数据用于训练(train)或验证(valid)或测试(test)
- minist (bool): 是否使用mini数据集(默认为'False')
- dt (int): 预报步长,默认 1 (hour)
-
- Returns:
- xt, xt1 (np.ndarray, np.ndarray, np.ndarray): 返回的每个样本是一个二元组,包括
- t时刻和t+1时刻的气象指标数据
-
- Example:
- ```
- code-vlock: python
-
- from dataset import WeatherBench
-
- data_set = Weatherbench(data_dir='../../data/', mode='train', dt=1)
- ```
- """
- def __init__(self, fname: str, data_dir: str, mode: str, miniset: bool = False, check_data: bool = True) -> None:
- # super(WeatherBench, self).__init__()
-
- data_dir = data_dir
- if miniset:
- data_dir = data_dir + "mini/"
- self.data_dir = data_dir
-
- assert mode in ['train', 'val', 'test']
- self.mode = mode
- self.fname = str(Path(self.data_dir) / f"data.ffr/{mode}.ffr")
- self.reader = FileReader(self.fname, check_data)
- self.scaler = StandardScaler()
- self.scaler.load(str(Path(self.data_dir) / "scaler.pkl"))
-
- def __len__(self):
- return self.reader.n
-
- def __getitem__(self, indices):
- seqs_bytes = self.reader.read(indices)
- samples = []
- for i, bytes_ in enumerate(seqs_bytes):
- x_input, y_output = pickle.loads(bytes_)
- samples.append((x_input, y_output))
-
- return samples
-
- def get_scaler(self):
- """
- 获取WeatherBench数据的统计特征信息
-
- Returns:
- 数据分布统计对象,包含均值(mean)和方差(std)
- """
- return self.scaler
-
- if __name__ == '__main__':
- dataset = WeatherBench(data_dir='/public/home/wangwuxing01/research/weatherbench/data/', mode='train')
- dataset[5832]
|