|
- import os
- import torch
- import argparse
- from dataset_tools.dataset_ThisTR_frontBPSK_NextTR_NextBPSK_1012 import ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset
- from torch.utils.data import DataLoader
- from CSOMP_modeling_1024 import position_prediction_model
-
- import torch.utils.data as D
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
-
- from dataset_tools.dataset_newer_1012 import toeplitz
- from yizx_models.model_channel_estimation_1014 import CS_OMP_FindPostion
-
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dim', default=2560, type=int)
- parser.add_argument('--output_dim', default=2560, type=int)
-
- parser.add_argument('--lr', default=0.015, type=float)
- parser.add_argument('--weight_decay', default=0.01, type=float)
-
- args = parser.parse_args()
- return args
-
- def do_eval(model, test_dataloader):
- model.eval()
- test_acc = 0.0
-
- with torch.no_grad():
- for i, data in enumerate(test_dataloader):
- if i==10:
- break
- wav_data, wav_labels, wave_names = data
- wav_data = wav_data.cuda()
- wav_labels = wav_labels.cuda()
-
- output = model(wav_data)
- pred = (output.ge(0.5)).long()
-
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc += correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
-
- print('Eval on test_dataset, acc is: [{}%]'.format(test_acc * 10))
-
- def do_valid(model, valid_dataloader, is_valid=True):
- model = model.eval()
- test_acc = model_test_func(model, valid_dataloader)
- if is_valid:
- print('> [Valid] ACC is {}%.'.format(test_acc))
- else:
- print('> [Test] ACC is {}%.'.format(test_acc))
-
- def model_test_func(model, test_dataloader):
- test_acc = 0.0
- with torch.no_grad():
- batch_num = 0
- for i, data in enumerate(test_dataloader):
- this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- next_TR_realwav_slice, next_label01_slice, is_front_slice, this_filename = data
-
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- this_bpsk_hintwav_slice = this_bpsk_hintwav_slice.cuda()
-
- output = model(this_TR_realwav_slice, this_bpsk_hintwav_slice, is_front_slice)
- if is_front_slice:
- p = torch.zeros((1, this_TR_realwav_slice.shape[-1]))
- else:
- p = torch.flip(this_bpsk_hintwav_slice, dims=[-1]) # .view(1, -1)
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.cpu().numpy().reshape(-1, 1))
- prediction = torch.argmax(output.view(-1))
-
- if prediction == matlab_position:
- test_acc += 100.0
- batch_num += 1
-
- test_acc = test_acc / batch_num
- return test_acc
-
-
- def setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args):
- dataset = ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset()
- train_scale, valid_scale, test_scale = 90, 7, 3
- train_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * train_scale)
- valid_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * valid_scale)
- test_nums = len(dataset) - train_nums - valid_nums
- train_dataset, valid_dataset, test_dataset = D.random_split(dataset, [train_nums, valid_nums, test_nums])
- train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
-
- valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
- test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
-
- my_model = position_prediction_model(args.input_dim, args.output_dim)
- my_model.cuda()
-
- return my_model, train_dataloader, valid_dataloader, test_dataloader
-
- def train(raw_train_data_folder, raw_test_data_folder, args):
- my_model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args)
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
- optimizer = optim.SGD(my_model.parameters(), lr=args.lr) #, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=5e-4, verbose=False)
-
- criterion = nn.CrossEntropyLoss()
- # criterion = nn.MSELoss()
- saving_dir = './position_Prediction_MLP_NotSorted_sample400'
- if not os.path.exists(saving_dir):
- os.mkdir(saving_dir)
-
- for epoch in range(200):
- epoch_loss = 0
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- next_TR_realwav_slice, next_label01_slice, is_front_slice, this_filename = data
-
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- this_bpsk_hintwav_slice = this_bpsk_hintwav_slice.cuda()
-
- output = my_model(this_TR_realwav_slice, this_bpsk_hintwav_slice, is_front_slice)
- if is_front_slice:
- p = torch.zeros((1, this_TR_realwav_slice.shape[-1]))
- else:
- p = torch.flip(this_bpsk_hintwav_slice, dims=[-1]) # .view(1, -1)
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.cpu().numpy().reshape(-1, 1))
- position_label = torch.LongTensor([matlab_position]).cuda()
-
- loss = criterion(output, position_label)
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
-
- if (i+1) % 100 == 0:
- print('epoch[{}], iter[{}]/[{}],'
- ' loss is: {}, lr is: {}'.format(epoch,
- i,
- len(train_dataloader),
- epoch_loss / (i+1),
- optimizer.state_dict()['param_groups'][0]['lr']
- ))
- if (i+1) % 150 == 0 and (epoch % 5 == 0):
- do_valid(my_model, valid_dataloader, is_valid=True)
- do_valid(my_model, test_dataloader, is_valid=False)
- torch.save(my_model.state_dict(), './{}/epoch{}_iter{}.pth'.format(saving_dir, epoch, i))
-
-
- #
- # def do_eval(args, test_TR_files, test_bpsk_front_files, do_TR_norm=True, is_train_TR=True, K=50):
- # my_model = channel_estimation_model(args.input_dim)
- #
- # state_dict = torch.load(args.load_path)
- # my_model.load_state_dict(state_dict)
- # print("> loading pretrained_model from {} passed! >>>".format(args.load_path))
- # my_model.cuda()
- #
- # assert len(test_TR_files) == 50
- # assert len(test_bpsk_front_files) == 1
- #
- # for i in range(K):
- # if is_train_TR:
- # TR_wav_slice = np.load(test_TR_files[i])
- # else:
- # TR_wav_slice = load_TR_mat(test_TR_files[i])
- # if do_TR_norm:
- # TR_wav_slice = normalization(TR_wav_slice)
- #
- # front_bpsk_sig_file = test_bpsk_front_files[i] if i == 0 else None
- # front_bpsk_sig = torch.from_numpy(np.load(front_bpsk_sig_file)).float() if front_bpsk_sig_file is not None else None
- # TR_wav_slice = torch.from_numpy(TR_wav_slice).float().view(1, -1).cuda()
- # my_model.eval(TR_wav_slice, front_bpsk_sig)
- # break
- #
- #
- #
- #
-
-
-
-
-
- if __name__ == '__main__':
- train_wav_dataset_dir = './raw_data/train/data'
- train_wav_labels_dir = './raw_data/train/labels'
- raw_train_data_folder = './raw_data/train'
- raw_test_data_folder = './raw_data/test'
-
- # # do visualization
- # saving_plots_dir = './wave_plots'
- # if not os.path.exists(saving_plots_dir):
- # os.mkdir(saving_plots_dir)
- # for item in os.listdir(train_wav_dataset_dir):
- # this_f = os.path.join(train_wav_dataset_dir, item)
- # vis_wave(this_f, saving_plots_dir)
-
- args = get_args()
-
- train(raw_train_data_folder, raw_test_data_folder, args)
-
-
- # #########################################################
- # #file_prefix_list = ["comm0035_10_slice", "comm0001_20_slice"]
- # file_prefix_list = ["comm0929_20_slice"]
- # for file_prefix in file_prefix_list:
- # test_train_TR_files = ['/userhome/wave_training/raw_data/train' \
- # '/TR_wav_slices/{}{}.npy'.format(file_prefix, i) for i in range(50)]
- # test_bpsk_files = ['/userhome/wave_training/raw_data/train' \
- # '/bpsk_hint_slices/{}{}.npy'.format(file_prefix, 0)]
- # do_eval(args, test_train_TR_files, test_bpsk_files, is_train_TR=True, K=50)
-
-
-
|