|
- import os
-
- import librosa
- import torch
- import numpy as np
- import argparse
- from torch.utils.data import DataLoader
- from yizx_models.total_channel_estimation_model import channel_estimation_model, channel_estimation_L0_model
- from dataset_tools.dataset_channel_estimation import sendLFM_receivedLFM_pairs, normalization
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
- import torch.utils.data as D
- import mat4py
-
- from py_matlab_tools.LFMgen import LFMgen, get_startpoint_extract_recvLFM
- import torch.nn.functional as F
- import matplotlib.pyplot as plt
- import scipy.io as scio
- from py_matlab_tools.FIR_python_vs_matlab import py_fir3
-
- from py_matlab_tools.TR_yizx import TR_func
- from py_matlab_tools.bpsk_mod import symbols_to_bpsk_signal, BPSKdemod
- from dataset_tools.dataset_newer_1012 import toeplitz
- from yizx_models.model_channel_estimation_1014 import CS_OMP_FindPostion
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def signal_conv_torch(x, y, mode='full'):
- x_tr = x.unsqueeze(0)
- y_tr = torch.flip(y, dims=[-1]).unsqueeze(0)
- if mode == 'valid':
- pass
- # res = F.conv1d(x_tr, y_tr)
- elif mode == 'full':
- padding_nums = int(y_tr.shape[-1] - 1)
- res = F.conv1d(x_tr, y_tr, padding=padding_nums)
- elif mode == 'same':
- padding_nums = int(y_tr.shape[-1] // 2)
- res = F.conv1d(x_tr, y_tr, padding=padding_nums)
- return res
-
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--input_dim', default=10240 + 5120, type=int)
- parser.add_argument('--output_dim', default=5121, type=int)
- parser.add_argument('--lr', default=0.0015, type=float)
- parser.add_argument('--weight_decay', default=0.01, type=float)
- parser.add_argument('--use_sparsity_model', default=False, type=bool)
- parser.add_argument('--l0_scale', default=1e-3, type=float)
- parser.add_argument('--use_double_type', default=True, type=bool)
-
- # parser.add_argument('--load_path',
- # default="", type=str)
-
- parser.add_argument('--load_path',
- default="/userhome/wave_training_old/ckpts_cache/"
- "total_channal_estimation_MLP_normdata_largerLR_double/epoch499_iter999.pth", type=str)
- # "total_channal_estimation_MLP_normdata_largerLR/epoch499_iter999.pth", type=str)
- # "total_channal_estimation_MLP/epoch12_iter999.pth", type=str)
-
- # parser.add_argument('--test_dir',
- # default="/userhome/wave_training/raw_data/train/"
- # "lfm_esti_channel_norm/", type=str) # lfm_esti_channel
- parser.add_argument('--test_dir',
- default="/userhome/wave_training/juesai_yizx/"
- "juesai_esti_channel_norm/", type=str) # lfm_esti_channel
-
- parser.add_argument('--save_dir',
- default="/userhome/wave_training_old/"
- "juesai_all_symbols_prediction_MLP_Res/", type=str)
- #"juesai_channal_estimation_sparse_MLP_testRes/", type=str)
- # "channal_estimation_MLP_testRes/", type=str)
- parser.add_argument('--sample_nums', default=3, type=int)
-
-
- args = parser.parse_args()
- return args
-
- def setup_model(args):
- if args.use_sparsity_model:
- if args.load_path == "":
- my_model = channel_estimation_L0_model(input_dim=args.input_dim, output_dim=args.output_dim, is_train=True)
- else:
- my_model = channel_estimation_L0_model(input_dim=args.input_dim, output_dim=args.output_dim, is_train=False)
- else:
- my_model = channel_estimation_model(input_dim=args.input_dim, output_dim=args.output_dim, type_double=args.use_double_type)
-
- if not args.load_path == "":
- state_dict = torch.load(args.load_path)
- my_model.load_state_dict(state_dict)
- print("> loading pretrained_model from {} passed! >>>".format(args.load_path))
- my_model.cuda()
- return my_model
-
- def setup_model_with_datasets(args):
- dataset = sendLFM_receivedLFM_pairs(do_norm=True)
- train_scale, valid_scale, test_scale = 90, 7, 3
- train_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * train_scale)
- valid_nums = int(len(dataset) // (train_scale + valid_scale + test_scale) * valid_scale)
- test_nums = len(dataset) - train_nums - valid_nums
- train_dataset, valid_dataset, test_dataset = D.random_split(dataset, [train_nums, valid_nums, test_nums])
-
- train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
- valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
- test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True) # , num_workers=20, drop_last=True)
-
- my_model = setup_model(args)
-
- return my_model, train_dataloader, valid_dataloader, test_dataloader
-
- def valid_loss_forward(my_model, test_dataloader, is_valid=True):
- my_model.eval()
- epoch_loss = 0
- criterion = nn.MSELoss()
- for (i, data) in enumerate(test_dataloader):
- send_lfm, rec_lfm, send01, f_name = data
- send_lfm = send_lfm.cuda() #.float().cuda()
- rec_lfm = rec_lfm.cuda() #.float().cuda()
-
- if args.use_sparsity_model:
- out_vector, l0_reg = my_model(send_lfm, rec_lfm)
- else:
- out_vector = my_model(send_lfm, rec_lfm)
-
- # reconstruct_lfm = get_signal_conv(send_lfm.cpu().numpy().reshape(-1,),
- # out_vector.detach().cpu().numpy().reshape(-1,), mode='full').reshape(1, -1)
-
- reconstruct_lfm = signal_conv_torch(send_lfm, out_vector, mode='full').view(1, -1)
- loss = criterion(reconstruct_lfm, rec_lfm)
- epoch_loss += loss.item()
-
- output_string = '> [Valid] avg loss is: {}'.format(epoch_loss / i) if is_valid else '> [Test] avg loss is: {}'.format(epoch_loss / i)
- if args.use_sparsity_model:
- output_string += ", l0_reg_loss is {}".format(args.l0_scale * l0_reg)
- print(output_string)
-
- def train(args):
-
- my_model, train_dataloader, valid_dataloader, test_dataloader = \
- setup_model_with_datasets(args)
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
- optimizer = optim.SGD(my_model.parameters(), lr=args.lr) #, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=2e-4, verbose=False)
-
- criterion = nn.MSELoss()
- if not args.use_sparsity_model:
- saving_dir = './total_channal_estimation_MLP_normdata_largerLR'
- if args.use_double_type:
- saving_dir += '_double'
- else:
- saving_dir = './sparse_{}_total_channal_estimation_MLP_normdata_largerLR'.format(str(args.l0_scale))
-
- if not os.path.exists(saving_dir):
- os.mkdir(saving_dir)
-
- for epoch in range(500):
- epoch_loss = 0
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- send_lfm, rec_lfm, send01, f_name = data
- send_lfm = send_lfm.cuda() #.float().cuda()
- rec_lfm = rec_lfm.cuda() #.float().cuda()
-
- if not args.use_sparsity_model:
- out_vector = my_model(send_lfm, rec_lfm)
- else:
- out_vector, l0_reg = my_model(send_lfm, rec_lfm)
- # reconstruct_lfm = get_signal_conv(send_lfm.cpu().numpy().reshape(-1,),
- # out_vector.detach().cpu().numpy().reshape(-1,), mode='full').reshape(1, -1)
-
- reconstruct_lfm = signal_conv_torch(send_lfm, out_vector, mode='full').view(1, -1)
- loss = criterion(reconstruct_lfm, rec_lfm)
-
-
- if args.use_sparsity_model:
- loss += args.l0_scale * l0_reg
-
- loss.backward()
-
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
- if (i+1) % 300 == 0:
- print('epoch[{}], iter[{}]/[{}],'
- ' loss is: {}, lr is: {}'.format(epoch,
- i,
- len(train_dataloader),
- epoch_loss / (i+1),
- optimizer.state_dict()['param_groups'][0]['lr']
- ))
- if (i+1) % 1000 == 0:
- valid_loss_forward(my_model, valid_dataloader, is_valid=True)
- valid_loss_forward(my_model, test_dataloader, is_valid=False)
- torch.save(my_model.state_dict(), './{}/epoch{}_iter{}.pth'.format(saving_dir, epoch, i))
- # do_eval(my_model, test_dataloader)
-
-
- def do_test(args, do_vis=False):
- #rec_lfm_folder = '/userhome/wave_training/raw_data/train/raw_train_rec_lfm_norm/'
- rec_lfm_folder = '/userhome/wave_training/juesai_yizx/juesai_lfm_norm/'
- rec_wav_folder = '/userhome/wave_training/juesai_data/test/data/'
- pilot_symbols_folder = '/userhome/wave_training/juesai_data/test/pilot_20symbols/'
-
- my_model = setup_model(args)
- my_model.eval()
-
- if not os.path.exists(args.save_dir):
- os.mkdir(args.save_dir)
-
-
-
- input_globs = os.listdir(args.test_dir)
- input_globs.sort()
- input_files = input_globs[-args.sample_nums:]
- # input_files = random.sample(input_globs, args.sample_nums)
- for each_channel_f in input_files:
- if each_channel_f.endswith('.mat') and 'comm' in each_channel_f:
- common_file_name = 'comm' + each_channel_f.split('.mat')[0].split('comm')[-1]
-
- saving_f = os.path.join(args.save_dir, common_file_name+'.txt')
-
- this_recv_signal_f = os.path.join(rec_wav_folder, common_file_name + '.wav')
- this_pilot_20symbols_f = os.path.join(pilot_symbols_folder, common_file_name + '.txt')
-
-
- this_esti_channel_f = os.path.join(args.test_dir, each_channel_f)
- this_LFM_signal_f = os.path.join(rec_lfm_folder, common_file_name + '.mat')
-
- #############################################################################
- recv_wav_signal, _ = librosa.load(this_recv_signal_f, sr=128000)
- recv_wav_signal = normalization(recv_wav_signal)
- symbols_f = open(this_pilot_20symbols_f, 'r')
- symbols_all = symbols_f.readlines()
- symbols_f.close()
- this_20symbols = [int(each_line[:-1].strip()) for each_line in symbols_all]
- assert len(this_20symbols) == 20
- pilot_symbols = np.array(this_20symbols)
- ####################################################################
- # 从原始信号中操作
- # step1: 带通滤波,保留基带和载波信号
- ####### recv_wav_signal = get_baseWithCarrier_bandpass(recv_wav_signal)
- recv_wav_signal = py_fir3(recv_wav_signal)
- # step2: 和lfm同步头计算自相关,获取信号【接收端的LFM信号】
- recvLFM_py = get_startpoint_extract_recvLFM(recv_wav_signal)
- #############################################################################
-
- send_LFM = LFMgen(128000)
- esti_channel = np.array(mat4py.loadmat(this_esti_channel_f)['h_lfm']).reshape(-1)
- recv_LFM = np.array(mat4py.loadmat(this_LFM_signal_f)['r']).reshape(-1)
-
-
- assert recvLFM_py.shape == recv_LFM.shape
- # plt.subplot(211)
- # plt.plot(recv_LFM)
- # plt.subplot(212)
- # plt.plot(recvLFM_py)
- # plt.show()
-
- # #recv_LFM = normalization(recv_LFM) # do norm
- # #send_LFM = normalization(send_LFM) # do norm
- #
- if args.use_double_type:
- send_LFM = torch.DoubleTensor(send_LFM).view(1, -1).cuda()
- recv_LFM = torch.DoubleTensor(recvLFM_py).view(1, -1).cuda()
- # recv_LFM = torch.DoubleTensor(recv_LFM).view(1, -1).cuda()
- else:
- send_LFM = torch.FloatTensor(send_LFM).view(1, -1).cuda()
- recv_LFM = torch.FloatTensor(recvLFM_py).view(1, -1).cuda()
- # recv_LFM = torch.FloatTensor(recv_LFM).view(1, -1).cuda()
-
- with torch.no_grad():
- if args.use_sparsity_model:
- prediction_channel, _ = my_model(send_LFM, recv_LFM)
- else:
- prediction_channel = my_model(send_LFM, recv_LFM)
- reconstruct_lfm = signal_conv_torch(send_LFM, prediction_channel, mode='full').view(1, -1)
- prediction_channel_npy = prediction_channel.cpu().numpy().reshape(-1)
-
- if do_vis:
- scio.savemat('{}/{}_h_lfm.mat'.format(args.save_dir, common_file_name),
- mdict={'ai_h_lfm': prediction_channel_npy.reshape(-1, 1)})
- plt.subplot(211)
- plt.plot(esti_channel, label='estim_channel')
- plt.ylim([-0.1, 0.1])
- plt.title('{}, Channel esti'.format(common_file_name))
- plt.legend()
- plt.subplot(212)
- plt.plot(prediction_channel_npy, label='[AI] estim_channel')
- plt.legend()
- plt.show()
-
- plt.subplot(211)
- plt.plot(recv_LFM.cpu().numpy().reshape(-1), label='recv_lfm')
- plt.title('{}, reconstruct LFM'.format(common_file_name))
- plt.legend()
- plt.subplot(212)
- plt.plot(reconstruct_lfm.cpu().numpy().reshape(-1), label='[AI] reconstruct_lfm')
- plt.legend()
- plt.show()
-
- #######################################################################################
- sync_recv_wav_signal = TR_func(recv_wav_signal, pilot_symbols, prediction_channel_npy)
- print('>【已完成】TR时返消多径')
- iter_CSOMP_demod(sync_recv_wav_signal, pilot_symbols, saving_f)
-
-
-
- def iter_CSOMP_demod(afterTR_signal, pilot_symbols, saving_path, demod_symbols=20, K=49):
- points_per_bit = 128
- sliding_size = demod_symbols * points_per_bit
- cache_symbols = None
-
- saving_f = open(saving_path, 'w+')
- for i in range(K):
- if i==0:
- for itt in pilot_symbols:
- saving_f.write(str(int(itt)))
- saving_f.write('\n')
- cache_symbols_bpsk = symbols_to_bpsk_signal(pilot_symbols, fs=128000)
- p = np.zeros((1, sliding_size))
- else:
- cache_symbols_bpsk = symbols_to_bpsk_signal(cache_symbols, fs=128000)
- p = np.flip(cache_symbols_bpsk)
-
- this_TR_wav_slice = afterTR_signal[i*sliding_size: (i+1)*sliding_size]
- next_TR_wav_slice = afterTR_signal[(i+1)*sliding_size: (i+2)*sliding_size] #.reshape(1, -1)
-
- ChuanGAN_matrix = toeplitz(cache_symbols_bpsk.reshape(-1, 1), p.reshape(1, -1))
- matlab_position, transpose_flag = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_wav_slice.reshape(-1, 1))
- if transpose_flag:
- new_next_TR_realwav_slice = next_TR_wav_slice[matlab_position:] * -1.0
- else:
- new_next_TR_realwav_slice = next_TR_wav_slice[matlab_position:]
- # padding 0 at last
- # (1,2)表示在一维数组array前面填充1位,最后面填充2位
- # constant_values=(0,2) 表示前面填充0,后面填充2
- demod_input = np.pad(new_next_TR_realwav_slice, (0, matlab_position), 'constant', constant_values=(0, 0))
- symbols_demod = BPSKdemod(demod_input)
- cache_symbols = symbols_demod
- for itt in symbols_demod:
- saving_f.write(str(int(itt)))
- saving_f.write('\n')
-
- saving_f.close()
-
-
-
- if __name__ == '__main__':
-
- args = get_args()
- # train(args)
- ################################################
- do_test(args, do_vis=False)
-
-
-
|