|
- import os
- import torch
- import argparse
- from dataset_tools.dataset_newer_1012 import TR_label_SlicePairsDataset
- from torch.utils.data import DataLoader
-
- from yizx_models.model_MLP_TRSlice_Prediction import TRSlice_prediction_model
-
-
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
-
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dim', default=2560, type=int)
- parser.add_argument('--output_dim', default=2560, type=int)
-
- parser.add_argument('--lr', default=0.0015, type=float)
- parser.add_argument('--weight_decay', default=0.01, type=float)
-
- parser.add_argument('--load_path',
- default="/userhome/wave_training_old/"
- "channal_estimation_MLP_new/epoch0_iter14999.pth", type=str)
-
- # parser.add_argument('--load_path',
- # default="/userhome/wave_training_old/"
- # "channal_estimation_MLP/epoch31_iter4999.pth", type=str)
-
- args = parser.parse_args()
- return args
-
- def do_eval(model, test_dataloader):
- model.eval()
- test_acc = 0.0
-
- with torch.no_grad():
- for i, data in enumerate(test_dataloader):
- if i==10:
- break
- wav_data, wav_labels, wave_names = data
- wav_data = wav_data.cuda()
- wav_labels = wav_labels.cuda()
-
- output = model(wav_data)
- pred = (output.ge(0.5)).long()
-
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc += correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
-
- print('Eval on test_dataset, acc is: [{}%]'.format(test_acc * 10))
-
-
- def setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args):
- dataset = TR_label_SlicePairsDataset()
- train_dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # , num_workers=20, drop_last=True)
-
- test_dataset = None #AudioDataset(raw_test_data_folder)
- test_dataloader = None #DataLoader(test_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
-
- my_model = TRSlice_prediction_model(args.input_dim, args.output_dim)
- my_model.cuda()
-
- return my_model, train_dataloader, test_dataloader
-
- def train(raw_train_data_folder, raw_test_data_folder, args):
-
-
- my_model, train_dataloader, test_dataloader = \
- setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args)
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
- optimizer = optim.Adam(my_model.parameters(), lr=args.lr) #, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-5, verbose=False)
-
- # criterion = nn.CrossEntropyLoss()
- criterion = nn.MSELoss()
- saving_dir = './TRSlice_Prediction_MLP_addDropout'
- if not os.path.exists(saving_dir):
- os.mkdir(saving_dir)
-
- for epoch in range(200):
- epoch_loss = 0
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- this_TR_realwav_slice, this_send01_bpsk_slice = data
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- this_send01_bpsk_slice = this_send01_bpsk_slice.cuda()
-
- out_vector = my_model(this_TR_realwav_slice)
- loss = criterion(out_vector, this_send01_bpsk_slice.float())
-
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
-
- if (i+1) % 200 == 0:
- print('epoch[{}], iter[{}]/[{}],'
- ' loss is: {}, lr is: {}'.format(epoch,
- i,
- len(train_dataloader),
- epoch_loss / (i+1),
- optimizer.state_dict()['param_groups'][0]['lr']
- ))
- if (i+1) % 1000 == 0:
- torch.save(my_model.state_dict(), './{}/epoch{}_iter{}.pth'.format(saving_dir, epoch, i))
- # do_eval(my_model, test_dataloader)
-
- # output, hidden, _ = encoder(wav_data)
- # print(output.shape, hidden.shape) # [411921, 1, 1024] [2, 1, 1024]
-
- #
- # def do_eval(args, test_TR_files, test_bpsk_front_files, do_TR_norm=True, is_train_TR=True, K=50):
- # my_model = channel_estimation_model(args.input_dim)
- #
- # state_dict = torch.load(args.load_path)
- # my_model.load_state_dict(state_dict)
- # print("> loading pretrained_model from {} passed! >>>".format(args.load_path))
- # my_model.cuda()
- #
- # assert len(test_TR_files) == 50
- # assert len(test_bpsk_front_files) == 1
- #
- # for i in range(K):
- # if is_train_TR:
- # TR_wav_slice = np.load(test_TR_files[i])
- # else:
- # TR_wav_slice = load_TR_mat(test_TR_files[i])
- # if do_TR_norm:
- # TR_wav_slice = normalization(TR_wav_slice)
- #
- # front_bpsk_sig_file = test_bpsk_front_files[i] if i == 0 else None
- # front_bpsk_sig = torch.from_numpy(np.load(front_bpsk_sig_file)).float() if front_bpsk_sig_file is not None else None
- # TR_wav_slice = torch.from_numpy(TR_wav_slice).float().view(1, -1).cuda()
- # my_model.eval(TR_wav_slice, front_bpsk_sig)
- # break
- #
- #
- #
- #
-
-
-
-
-
- if __name__ == '__main__':
- train_wav_dataset_dir = './raw_data/train/data'
- train_wav_labels_dir = './raw_data/train/labels'
- raw_train_data_folder = './raw_data/train'
- raw_test_data_folder = './raw_data/test'
-
- # # do visualization
- # saving_plots_dir = './wave_plots'
- # if not os.path.exists(saving_plots_dir):
- # os.mkdir(saving_plots_dir)
- # for item in os.listdir(train_wav_dataset_dir):
- # this_f = os.path.join(train_wav_dataset_dir, item)
- # vis_wave(this_f, saving_plots_dir)
-
- args = get_args()
-
- train(raw_train_data_folder, raw_test_data_folder, args)
-
-
- # #########################################################
- # #file_prefix_list = ["comm0035_10_slice", "comm0001_20_slice"]
- # file_prefix_list = ["comm0929_20_slice"]
- # for file_prefix in file_prefix_list:
- # test_train_TR_files = ['/userhome/wave_training/raw_data/train' \
- # '/TR_wav_slices/{}{}.npy'.format(file_prefix, i) for i in range(50)]
- # test_bpsk_files = ['/userhome/wave_training/raw_data/train' \
- # '/bpsk_hint_slices/{}{}.npy'.format(file_prefix, 0)]
- # do_eval(args, test_train_TR_files, test_bpsk_files, is_train_TR=True, K=50)
-
-
-
|