|
- import os
- import torch
- import numpy as np
- import argparse
- from dataset_tools.dataset_ThisTR_frontBPSK_NextTR_NextBPSK_1012 import ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset
- from torch.utils.data import DataLoader
- from CSOMP_modeling_1024 import bpsk_demod_model
-
-
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
- import torch.utils.data as D
- from dataset_tools.dataset_newer_1012 import toeplitz
- from yizx_models.model_channel_estimation_1014 import CS_OMP_FindPostion, padding_shift_signal_to_origin
- from py_matlab_tools.bpsk_mod import symbols_to_bpsk_signal
-
- simple_pairs_training = True
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dim', default=2560, type=int)
- parser.add_argument('--output_dim', default=20, type=int)
-
- parser.add_argument('--lr', default=0.015, type=float) # 0.0015 for Adam, 0.015 for SGD
- parser.add_argument('--weight_decay', default=0.01, type=float)
-
- parser.add_argument('--load_path',
- default="/userhome/wave_training_old/"
- "bpsk_demod_MLP_SGD_Notsorted_sample400/epoch175_iter149.pth", type=str)
- # "bpsk_demod_MLP_SGD_Notsorted_sample400"
- # "bpsk_demod_MLP_SGD_sorted_sample200"
-
- parser.add_argument('--pred_prefix',
- default="comm0256_20", type=str)
-
- args = parser.parse_args()
- return args
-
- def do_valid(model, valid_dataloader, is_valid=True):
- model = model.eval()
- test_ber = model_test_func(model, valid_dataloader)
- if is_valid:
- print('> [Valid] BER is {}%.'.format(test_ber))
- else:
- print('> [Test] BER is {}%.'.format(test_ber))
-
-
- def do_test(args):
- model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args)
- # load from pretrained_model
- state_dict = torch.load(args.load_path)
- model.load_state_dict(state_dict)
- print("> loading pretrained_model from {} passed! >>>".format(args.load_path))
- model.cuda()
- model.eval()
-
- test_ber = model_test_func(model, test_dataloader)
- print('> [Test] BER is {}%.'.format(test_ber))
-
-
- def do_pipeline_prediction_BER(args):
- model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args)
- # load from pretrained_model
- state_dict = torch.load(args.load_path)
- model.load_state_dict(state_dict)
- print("> loading pretrained_model from {} passed! >>>".format(args.load_path))
- model.cuda()
- model.eval()
-
- prediction_result = give_commFileSlice0_output_symbolsBER(model, args)
- all_1klabel_file = "/userhome/wave_training/raw_data/train/labels/" + args.pred_prefix + '.txt'
- all_f = open(all_1klabel_file, 'r')
- all_labels = all_f.readlines()
- all_f.close()
-
- correct_num = 0
- all_nums = 0
- for i, each_line in enumerate(all_labels):
- this_label = int(each_line.strip().split('\n')[0])
- pred_label = prediction_result[i]
-
- if this_label == int(pred_label):
- correct_num += 1
- all_nums += 1
- assert all_nums == 1000
- acc = 100.0 * correct_num / float(all_nums)
- print('Test acc on [{}] is {}%'.format(args.pred_prefix, acc))
-
-
-
- def give_commFileSlice0_output_symbolsBER(model, args, K=49):
- # pipeline_pred_file = "/userhome/wave_training/raw_data/train/TR_wav_slices/" + args.pred_prefix + "_slice0.npy"
-
- TR_data_dir = "/userhome/wave_training/raw_data/train/TR_wav_slices/"
- hint_data_dir = "/userhome/wave_training/raw_data/train/bpsk_hint_slices/"
- label01_data_dir = "/userhome/wave_training/raw_data/train/label01_slices/"
- saving_dir = './demod_test1025'
- if not os.path.exists(saving_dir):
- os.mkdir(saving_dir)
- saving_pred_f = open(os.path.join(saving_dir, '{}_pred.txt'.format(args.pred_prefix)), 'w+')
-
- result = []
-
- prefix = args.pred_prefix
- cache_symbols = None
- for i in range(K):
- if i == 0:
- this_TR_realwav_f = TR_data_dir + prefix + '_slice0.npy'
- this_TR_realwav_slice = np.load(this_TR_realwav_f)
- this_bpsk_hintwav_slice = np.load(hint_data_dir + prefix + '_slice0.npy')
- next_TR_realwav_slice = np.load(TR_data_dir + prefix + '_slice1.npy')
-
- start_hint_label01 = np.load(label01_data_dir + prefix + '_slice0.npy')
- for item in start_hint_label01:
- result.append(item)
- saving_pred_f.write(str(item))
- saving_pred_f.write('\n')
-
- p = torch.zeros((1, this_TR_realwav_slice.shape[-1]))
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.reshape(-1, 1), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.reshape(-1, 1))
- new_next_TR_realwav_slice = torch.FloatTensor(next_TR_realwav_slice.reshape(1, -1)[:, matlab_position:])
- input_data = padding_shift_signal_to_origin(new_next_TR_realwav_slice, matlab_position).float().cuda()
- output = model(input_data)
- prediction = (output >= 0.5).long()
- cache_symbols = prediction.cpu().numpy().reshape(-1)
- for item in cache_symbols:
- result.append(item)
- saving_pred_f.write(str(item))
- saving_pred_f.write('\n')
- else:
- assert cache_symbols is not None
- this_TR_realwav_f = TR_data_dir + prefix + '_slice{}.npy'.format(i)
- this_TR_realwav_slice = np.load(this_TR_realwav_f)
- next_TR_realwav_slice = np.load(TR_data_dir + prefix + '_slice{}.npy'.format(i+1))
-
- this_bpsk_hintwav_slice = symbols_to_bpsk_signal(cache_symbols)
- this_bpsk_hintwav_slice = torch.FloatTensor(this_bpsk_hintwav_slice).cuda()
- p = torch.flip(this_bpsk_hintwav_slice, dims=[-1])
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.reshape(-1, 1))
- new_next_TR_realwav_slice = torch.FloatTensor(next_TR_realwav_slice.reshape(1, -1)[:, matlab_position:])
- input_data = padding_shift_signal_to_origin(new_next_TR_realwav_slice, matlab_position).float().cuda()
- output = model(input_data)
- prediction = (output >= 0.5).long()
- cache_symbols = prediction.cpu().numpy().reshape(-1)
- for item in cache_symbols:
- result.append(item)
- saving_pred_f.write(str(item))
- saving_pred_f.write('\n')
- saving_pred_f.close()
- return result
-
-
-
-
- def model_test_func(model, test_dataloader):
- test_ber = 0.0
- with torch.no_grad():
- batch_num = 0
- for i, data in enumerate(test_dataloader):
- if i == 100:
- break
- this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- next_TR_realwav_slice, next_label01_slice, is_front_slice, this_filename = data
-
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- this_bpsk_hintwav_slice = this_bpsk_hintwav_slice.cuda()
- next_label01_slice = next_label01_slice.cuda()
-
- if is_front_slice:
- p = torch.zeros((1, this_TR_realwav_slice.shape[-1]))
- else:
- p = torch.flip(this_bpsk_hintwav_slice, dims=[-1]) # .view(1, -1)
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.cpu().numpy().reshape(-1, 1))
- new_next_TR_realwav_slice = next_TR_realwav_slice[:, matlab_position:]
- input_data = padding_shift_signal_to_origin(new_next_TR_realwav_slice, matlab_position).float().cuda()
- output = model(input_data)
- prediction = (output >= 0.5).long()
- total_nums = next_label01_slice.size(0) * next_label01_slice.size(1)
- mask = (prediction==next_label01_slice)
- this_ber = torch.sum(mask) / float(total_nums) * 100.0
-
- test_ber += this_ber
- batch_num += 1
-
- test_ber = test_ber / batch_num
- return test_ber
-
-
- def setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args):
- dataset = ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset()
- train_scale, valid_scale, test_scale = 90, 7, 3
- train_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * train_scale)
- valid_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * valid_scale)
- test_nums = len(dataset) - train_nums - valid_nums
- train_dataset, valid_dataset, test_dataset = D.random_split(dataset, [train_nums, valid_nums, test_nums])
- train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
-
- valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
- test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
-
- my_model = bpsk_demod_model(args.input_dim, args.output_dim)
- my_model.cuda()
-
- return my_model, train_dataloader, valid_dataloader, test_dataloader
-
- def train(raw_train_data_folder, raw_test_data_folder, args):
- my_model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets(raw_train_data_folder, raw_test_data_folder, args)
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
- optimizer = optim.SGD(my_model.parameters(), lr=args.lr) #, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=2e-3, verbose=False) # eta_min 1e-4 for Adam; 2e-3 for SGD
-
- criterion = nn.BCELoss()
- saving_dir = './bpsk_demod_MLP_SGD_Notsorted_sample400'
- if not os.path.exists(saving_dir):
- os.mkdir(saving_dir)
-
- for epoch in range(200):
- epoch_loss = 0
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- next_TR_realwav_slice, next_label01_slice, is_front_slice, this_filename = data
-
- this_TR_realwav_slice = this_TR_realwav_slice.float().cuda()
- this_bpsk_hintwav_slice = this_bpsk_hintwav_slice.cuda()
- next_label01_slice = next_label01_slice.cuda()
-
- if is_front_slice:
- p = torch.zeros((1, this_TR_realwav_slice.shape[-1]))
- else:
- p = torch.flip(this_bpsk_hintwav_slice, dims=[-1]) # .view(1, -1)
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.cpu().numpy().reshape(-1, 1))
- new_next_TR_realwav_slice = next_TR_realwav_slice[:, matlab_position:]
- input_data = padding_shift_signal_to_origin(new_next_TR_realwav_slice, matlab_position).float().cuda()
-
- output = my_model(input_data)
- ## matlab bpsk_demod
- loss = criterion(output, next_label01_slice.float())
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
-
- if (i+1) % 100 == 0:
- print('epoch[{}], iter[{}]/[{}],'
- ' loss is: {}, lr is: {}'.format(epoch,
- i,
- len(train_dataloader),
- epoch_loss / (i+1),
- optimizer.state_dict()['param_groups'][0]['lr']
- ))
- if (i+1) % 150 == 0 and (epoch % 5 == 0):
- do_valid(my_model, valid_dataloader, is_valid=True)
- do_valid(my_model, test_dataloader, is_valid=False)
- torch.save(my_model.state_dict(), './{}/epoch{}_iter{}.pth'.format(saving_dir, epoch, i))
-
-
-
- if __name__ == '__main__':
- train_wav_dataset_dir = './raw_data/train/data'
- train_wav_labels_dir = './raw_data/train/labels'
- raw_train_data_folder = './raw_data/train'
- raw_test_data_folder = './raw_data/test'
-
- args = get_args()
-
- # train(raw_train_data_folder, raw_test_data_folder, args)
- # do_test(args)
-
- do_pipeline_prediction_BER(args)
-
-
-
|