|
- import dask
- import numpy as np
- import datetime
- from dateutil.relativedelta import relativedelta
- import xarray as xr
- import pickle
- from pathlib import Path
- import calendar
- from ffrecord import FileWriter
-
- np.random.seed(2022)
-
- DATADIR = '/public/home/studentresearch/WeatherBench/5.625deg'
- DATANAMES = ['geopotential', 'relative_humidity', 'temperature',
- 'u_component_of_wind', 'v_component_of_wind', '10m_u_component_of_wind',
- '10m_v_component_of_wind', '2m_temperature',]
-
- DATAMAP = {
- 'geopotential': 'z',
- 'relative_humidity': 'r',
- 'temperature': 't',
- 'u_component_of_wind': 'u',
- 'v_component_of_wind': 'v'
- } # 变量名
-
- VARIABLE_NAMES = ['z', 'r', 't', 'u', 'v', 'u10', 'v10', 't2m']
-
- LEVELS = [50, 500, 850, 1000]
-
- # 计算该变量的均值和方差
- def dataset_to_sample(raw_data, mean, std):
- tmpdata = (raw_data - mean) / std
-
- x_input = tmpdata[0: -1]
- y_output = tmpdata[1: ]
-
- return x_input, y_output
-
- # 将输入写入文件
- def write_dataset(x_input, y_output, out_file):
- n_sample = x_input.shape[0]
- # 初始化ffrecord
- writer = FileWriter(out_file, n_sample)
- for item in zip(x_input, y_output):
- bytes_ = pickle.dumps(item)
- writer.write_one(bytes_)
- writer.close()
-
- # 加载数据
- def load_ndf(time_scale):
- datas = []
- for file in DATANAMES:
- tmp = xr.open_mfdataset(f'{DATADIR}/{file}/*.nc', combine='by_coords').sel(time=time_scale)
- if file.split('_')[0] not in ['10m', '2m']:
- tmp = tmp.rename_vars({DATAMAP[file]: f'{DATAMAP[file]}'})
- datas.append(tmp)
- with dask.config.set(**{'array.slicing.split_large_chunks': False}):
- valid_data = xr.merge(datas, compat="identical", join="inner")
-
- return valid_data
-
- def fetch_dataset(Mean, Std, cursor_time, step, out_dir, slice_time):
- n_hours_this_month = 24 * calendar.monthrange(cursor_time.year, cursor_time.month)[-1]
- # print(n_hours_this_month)
- valid_data = load_ndf(slice_time)
-
- X_input, Y_output, Pt1, Global_Shape = [], [], [], []
-
- for i, name in enumerate(VARIABLE_NAMES):
- raw = valid_data[name]
- # lon, level, lat, time
- if name in ['u10', 'v10', 't2m']:
- data = raw
- else:
- data = raw.loc[:,LEVELS,:,:]
-
- if len(data.shape) == 3:
- data = data.expand_dims({'hight': 1}, axis=1).assign_coords(coords={'hight': [10.0,]})
-
- time = data.time
-
- # split sample data
- x_input, y_output = dataset_to_sample(data, Mean[name], Std[name])
- print(f"{name} | x_input.shape: {x_input.shape}, y_output.shape: {y_output.shape}, mean: {Mean[name]}, std: {Std[name]}\n")
-
- X_input.append(x_input)
- Y_output.append(y_output)
-
- # tpdata = np.nan_to_num(valid_data['tp'].values[:, :, :, 0])
- # tpdata = (tpdata - Mean[-1]) / Std[-1]
- # Pt1 = tpdata[1: -1]
-
- X_input = np.concatenate(X_input, axis=1)
- Y_output = np.concatenate(Y_output, axis=1)
-
- # print(f"Xt.shape: {Xt.shape}, Xt1.shape: {Xt1.shape}, Xt2.shape: {Xt2.shape}, Pt1.shape: {Pt1.shape}\n")
- print(f"X_input.shape: {X_input.shape}, Y_output.shape: {Y_output.shape}\n")
- write_dataset(X_input, Y_output, out_dir / f"{step:03d}.ffr")
- return
-
- # 存储数据主函数
- def dump_weatherbench(out_dir, out_dir_scaler):
- tmp_dir = out_dir
- out_dir.mkdir(exist_ok=True, parents=True)
-
- start_time = datetime.date(2008, 1, 1)
- end_time = datetime.date(2018, 12, 31)
-
- cursor_time = start_time
-
- global_shape = {'train':[], 'val':[], 'test':[]}
-
- with open(out_dir_scaler / f"scaler.pkl", "rb") as f:
- pkl = pickle.load(f)
- global_means = pkl['global_means']
- global_stds = pkl['global_stds']
- f.close()
-
- while True:
- if cursor_time >= end_time:
- break
-
- step = (cursor_time.year - 1979) * 12 + (cursor_time.month - 1) + 1
- start = cursor_time.strftime('%Y-%m-%d %H:%M:%S')
- end = (cursor_time + relativedelta(months=1, hours=0)).strftime('%Y-%m-%d %H:%M:%S')
-
- if step <= 444: # 444:
- mode = 'train'
- elif step <= 468:
- mode = 'val'
- else:
- mode = 'test'
- out_dir = Path(str(tmp_dir) +f'/{mode}.ffr')
- out_dir.mkdir(exist_ok=True, parents=True)
- print(f'Write files. Step {step} | from {start} to {end}')
- fetch_dataset(global_means, global_stds, cursor_time, step, out_dir, slice(start, end))
-
- cursor_time += relativedelta(months=1)
-
- if __name__ == "__main__":
- out_dir_scaler = Path("/public/home/wangwuxing01/research/weatherbench/data")
- out_dir = Path("/public/home/wangwuxing01/research/weatherbench/data/test_data.ffr")
- dump_weatherbench(out_dir, out_dir_scaler)
|