|
- # -*- coding: utf-8 -*-
- import time
- import numpy as np
- import torch
- from torch.utils.data import Dataset
- import os
- import h5py
- import mat4py
- import librosa
- import matplotlib.pyplot as plt
- from py_matlab_tools.LFMgen import bandpass_yizx, LFMgen, get_signal_conv
- from py_matlab_tools.bpsk_mod import symbols_to_bpsk_signal
- import random
-
- from tqdm import tqdm
-
- def load_TR_mat(path):
- data = np.array(mat4py.loadmat(path)['r1']).reshape(-1)
- return data
-
- def normalization(x):
- x = x - np.mean(x)
- x = x / np.max(np.abs(x))
- return x
-
- def toeplitz(c, r): # 生成托普利茨矩阵
- c = np.array(c)
- r = np.array(r)
-
- # 将c与r的全部元素构成列向量
- m, n = c.shape
- y, z = r.shape
- temp1 = []
- temp2 = []
- for i in range(n):
- for temp in c:
- temp1.append(temp[i])
- for i in range(z):
- for temp in r:
- temp2.append(temp[i])
-
- c = temp1
- r = temp2
-
- p = len(r)
- m = len(c)
-
- x = list(r[p - 1:0:-1])
- for i in c:
- x.append(i)
-
- temp3 = np.arange(0, m)
- temp4 = np.arange(p - 1, -1, -1)
-
- temp3.shape = (m, 1)
- temp4.shape = (1, p)
-
- ij = temp3 + temp4
- t = np.array(x)[ij]
-
- return t
-
- class ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset(Dataset):
- def __init__(self, fs=128000, do_norm=True, is_train=True, method=1):
- """类的初始化"""
- # self.hint_num = hint_num
- self.do_norm = do_norm
- self.method = method
- self.fs = fs
- self.label01_folder = '/userhome/wave_training/raw_data/train/label01_slices/'
- self.TRwav_folder = '/userhome/wave_training/raw_data/train/TR_wav_slices/'
- self.bpsk_folder = '/userhome/wave_training/raw_data/train/bpsk_hint_slices/'
-
- self.TR_realwav_slices_list = []
- self.bpsk_hintwav_slices_list = []
- self.Next_TR_realwav_slices_list = []
- self.NextLabel01_slices_list = []
-
- self.is_front_list = []
- self.file_name_list = []
-
- file_globs = os.listdir(self.TRwav_folder)
- # file_globs.sort()
- # print(file_globs[:200])
-
- sample_nums = 400 #10000
- file_globs = random.sample(file_globs, sample_nums)
- print(file_globs[:300])
-
- for item in tqdm(file_globs):
- if item.endswith('.npy') and 'comm' in item and 'slice49' not in item:
- common_file_name = 'comm' + item.split('.npy')[0].split('comm')[-1]
-
- TR_wav_path = os.path.join(self.TRwav_folder, common_file_name + '.npy')
- front_bpsk_sig_path = os.path.join(self.bpsk_folder, common_file_name + '.npy')
-
- #######################################################################################
- this_slice_num = int(common_file_name.split('slice')[-1])
- next_slice_num = this_slice_num + 1
- next_TR_wav_path = os.path.join(self.TRwav_folder, common_file_name.split('slice')[0]
- + 'slice{}.npy'.format(next_slice_num))
- Next_TR_wav_slice = np.load(next_TR_wav_path)
- self.Next_TR_realwav_slices_list.append(Next_TR_wav_slice)
- label01_f_path = os.path.join(self.label01_folder, common_file_name.split('slice')[0]
- + 'slice{}.npy'.format(next_slice_num))
- NextLabel01_slice = np.load(label01_f_path)
- self.NextLabel01_slices_list.append(NextLabel01_slice)
- #######################################################################################
-
- TR_wav_slice = np.load(TR_wav_path)
- front_bpsk_sig_slice = np.load(front_bpsk_sig_path)
- self.TR_realwav_slices_list.append(TR_wav_slice)
- self.bpsk_hintwav_slices_list.append(front_bpsk_sig_slice)
-
- if 'slice0.npy' in item:
- self.is_front_list.append(True)
- else:
- self.is_front_list.append(False)
- self.file_name_list.append(common_file_name)
-
- # if len(self.TR_realwav_slices_list) == 200:
- # break
-
- assert len(self.TR_realwav_slices_list) == len(self.bpsk_hintwav_slices_list)
-
- def normalization(self, x):
- x = x - np.mean(x)
- x = x / np.max(np.abs(x))
- return x
-
- def __getitem__(self, item):
- """每次怎么读数据,返回数据和标签"""
- this_TR_realwav_slice = self.TR_realwav_slices_list[item]
- this_bpsk_hintwav_slice = self.bpsk_hintwav_slices_list[item]
- next_TR_realwav_slice = self.Next_TR_realwav_slices_list[item]
- next_send01_slice = self.NextLabel01_slices_list[item]
-
- this_is_frontal = self.is_front_list[item]
- this_filename = self.file_name_list[item]
-
- if self.do_norm:
- this_TR_realwav_slice = self.normalization(this_TR_realwav_slice)
- next_TR_realwav_slice = self.normalization(next_TR_realwav_slice)
-
- return this_TR_realwav_slice, this_bpsk_hintwav_slice, next_TR_realwav_slice,\
- next_send01_slice, this_is_frontal, this_filename
-
- def __len__(self):
- """返回整个数据集的长度"""
- return len(self.TR_realwav_slices_list)
-
-
- def load_debug_csomp_mat(path, key):
- res = np.array(mat4py.loadmat(path)[key])#.tolist()
- return res
-
-
- def L2distance(x, At, y):
- distance = np.sum((x*At - y)**2) # 7.9952e-4 *
- print("x is: {}, L2distance is: {}".format(x, distance))
-
-
- if __name__ == '__main__':
-
- from torch.utils.data import DataLoader
-
- ######################################################################
- train_data_dir = '/userhome/wave_training/raw_data/data/train'
- dataset = ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset()
- train_dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
- print(len(dataset))
-
- for i, data in enumerate(train_dataloader):
- this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- next_TR_realwav_slice, next_label01_slice, is_front_slice, this_filename = data
- print(this_TR_realwav_slice.shape, this_bpsk_hintwav_slice.shape, next_TR_realwav_slice.shape)
-
- break
- #######################################################################
-
-
- # dataset = TRwav_BPSK_label_SlicePairsDataset()
- # train_dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
- #
- # for i, data in enumerate(train_dataloader):
- # this_TR_realwav_slice, this_bpsk_hintwav_slice, this_send01_slice = data
- # print(this_TR_realwav_slice.shape, this_bpsk_hintwav_slice.shape, this_send01_slice)
- # break
|