|
- import dask
- import numpy as np
- import datetime
- from dateutil.relativedelta import relativedelta
- import xarray as xr
- import pickle
- from pathlib import Path
- import calendar
- from ffrecord import FileWriter
-
- np.random.seed(2022)
-
- DATADIR = '/public/home/studentresearch/WeatherBench/5.625deg'
- DATANAMES = ['geopotential', 'relative_humidity', 'temperature',
- 'u_component_of_wind', 'v_component_of_wind', '10m_u_component_of_wind',
- '10m_v_component_of_wind', '2m_temperature',]
-
- DATAMAP = {
- 'geopotential': 'z',
- 'relative_humidity': 'r',
- 'temperature': 't',
- 'u_component_of_wind': 'u',
- 'v_component_of_wind': 'v'
- } # 变量名
-
- VARIABLE_NAMES = ['z', 'r', 't', 'u', 'v', 'u10', 'v10', 't2m']
-
- LEVELS = [50, 500, 850, 1000]
-
- # 加载数据
- def load_ndf(time_scale):
- datas = []
- for file in DATANAMES:
- tmp = xr.open_mfdataset(f'{DATADIR}/{file}/*.nc', combine='by_coords').sel(time=time_scale)
- if file.split('_')[0] not in ['10m', '2m']:
- tmp = tmp.rename_vars({DATAMAP[file]: f'{DATAMAP[file]}'})
- datas.append(tmp)
- with dask.config.set(**{'array.slicing.split_large_chunks': False}):
- valid_data = xr.merge(datas, compat="identical", join="inner")
-
- return valid_data
-
- # 数据预处理
- def fetch_scaler(cursor_time, step, slice_time):
- # load weather data
- Mean, Std = {}, {}
-
- n_hours_this_month = 24 * calendar.monthrange(cursor_time.year, cursor_time.month)[-1]
- # print(n_hours_this_month)
- valid_data = load_ndf(slice_time)
-
- for i, name in enumerate(VARIABLE_NAMES):
- raw = valid_data[name]
- # lon, level, lat, time
-
- if name in ['u10', 'v10', 't2m']:
- data = raw
- else:
- data = raw.loc[:,LEVELS,:,:]
- if len(data.shape) == 3:
- data = data.expand_dims({'hight': 1}, axis=1).assign_coords(coords={'hight': [10.0,]})
-
- time = data.time
- mean, std = np.mean(data.values, keepdims=True, axis=(0, 2, 3)), np.std(data.values, keepdims=True, axis=(0, 2, 3))
- Mean[name] = mean * data.values.shape[0]
- Std[name] = (std ** 2) * data.values.shape[0]
-
- # tpdata = np.nan_to_num(valid_data['tp'].values[:, :, :, 0])
- # tpdata = (tpdata - Mean[-1]) / Std[-1]
- # Pt1 = tpdata[1: -1]
- # with open(out_dir / f"{step:03d}.pkl", "rb") as f:
- # pkl = pickle.load(f)
- # print(pkl['x_input'].shape)
- return Mean, Std, data.values.shape[0]
-
- # 存储数据主函数
- def dump_scaler():
- start_time = datetime.date(2008, 1, 1)
- end_time = datetime.date(2018, 12, 31)
-
- cursor_time = start_time
-
- global_means = {} # np.zeros((1, 13, 1, 1))
- global_stds = {} # np.zeros((1, 13, 1, 1))
- global_shape = {'train':[], 'val':[], 'test':[]}
-
- for i, name in enumerate(VARIABLE_NAMES):
- if name in ['u10', 'v10', 't2m']:
- global_means[name] = np.zeros((1, 1, 1, 1))
- global_stds[name] = np.zeros((1, 1, 1, 1))
- else:
- global_means[name] = np.zeros((1, 4, 1, 1))
- global_stds[name] = np.zeros((1, 4, 1, 1))
-
- total_hours = 0
-
- while True:
- if cursor_time >= end_time:
- break
-
- step = (cursor_time.year - 1979) * 12 + (cursor_time.month - 1) + 1
- start = cursor_time.strftime('%Y-%m-%d %H:%M:%S')
- end = (cursor_time + relativedelta(months=1, hours=0)).strftime('%Y-%m-%d %H:%M:%S')
- # 444即到了第2015年
- if step <= 444:
- mode = 'train'
- elif step <= 468:
- mode = 'val'
- else:
- mode = 'test'
- print(f'Calculate slacers. Step {step} | from {start} to {end}')
- mean, std, hours_this_month = fetch_scaler(cursor_time, step, slice(start, end))
- if step <= 444:
- for i, name in enumerate(VARIABLE_NAMES):
- global_means[name] += mean[name]
- global_stds[name] += std[name]
- total_months += 1
- total_hours += hours_this_month
-
- cursor_time += relativedelta(months=1)
- for i, name in enumerate(VARIABLE_NAMES):
- global_means[name] = global_means[name]/total_hours
- global_stds[name] = np.sqrt(global_stds[name]/total_hours)
-
- cursor_time = start_time
- return global_means, global_stds
-
- if __name__ == "__main__":
- out_dir_scaler = Path("/public/home/wangwuxing01/research/weatherbench/data")
- global_means, global_stds = dump_scaler()
- with open(out_dir_scaler / f"scaler.pkl", "wb") as f:
- item = {'global_means': global_means, 'global_stds': global_stds}
- pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
|