|
- import torch.nn as nn
- import torch
- import numpy as np
-
-
- def CS_OMP_FindPostion(ChuanGAN_matrix, TR_wav):
- product = np.matmul(ChuanGAN_matrix.T, TR_wav).reshape(-1)
- postion = np.argmax(np.abs(product))
- value = product[postion]
- return postion
-
- def padding_shift_signal_to_origin(x, padding_length):
- paddding_mode = torch.nn.ConstantPad2d((0, padding_length), 0)
- return paddding_mode(x)
-
- def numpy_onehot(x, num_classes=2560):
- # one_hot编码
- one_hot_codes = np.eye(num_classes)
- one_hot_label = one_hot_codes[x]
- return np.array(one_hot_label)
-
- class position_prediction_model(nn.Module):
- def __init__(self, input_dim, output_dim, method=1):
- super().__init__()
-
- self.method = method
- self.bpsk_prediction = None
-
- self.linear1 = nn.Linear(3 * input_dim, 2 * input_dim)
- self.linear2 = nn.Linear(2 * input_dim, input_dim)
- self.linear3 = nn.Linear(input_dim, output_dim)
-
- # self.linear1 = nn.Sequential(
- # nn.Linear(3 * input_dim, 2 * input_dim),
- # nn.BatchNorm1d(2 * input_dim),
- # nn.ReLU()
- # )
- # self.linear2 = nn.Sequential(
- # nn.Linear(2 * input_dim, input_dim),
- # nn.BatchNorm1d(input_dim),
- # nn.ReLU()
- # )
- # self.linear3 = nn.Linear(input_dim, output_dim)
-
- self.relu = nn.ReLU()
-
- def forward(self, TR_wav, front_bpsk_wav, is_frontal):
- if is_frontal:
- p = torch.zeros((1, TR_wav.shape[-1]))
- else:
- p = torch.flip(front_bpsk_wav, dims=[-1]) # .view(1, -1)
- DATAa_2D = front_bpsk_wav # TR_wav#.reshape(-1, 1)
- p = p.to(TR_wav.device)
-
- # origin_shape = TR_wav.shape
-
- # ChuanGAN_matrix = toeplitz(DATAa_2D.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- # matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, TR_wav.cpu().numpy().reshape(-1, 1))
-
- # in_data = next_TR_wav[:, matlab_position:]
- #
- # # padding next_TR_wav to origin shape
- # input_data = padding_shift_signal_to_origin(in_data, matlab_position).float()
- #
- # test_data = input_data.cpu().numpy()
- # symbols_demod = BPSKdemod(test_data) # acc = 100%
-
- input_data = torch.concat((DATAa_2D, TR_wav, p), dim=-1).float()
- # do forward
- # out = self.linear1(input_data)
- # out = self.linear2(out)
- # out = self.linear3(out)
- input_data = torch.concat((DATAa_2D, TR_wav, p), dim=-1).float()
- out = self.relu(self.linear1(input_data))
- out = self.relu(self.linear2(out))
- out = self.linear3(out)
-
- return out
-
- class bpsk_demod_model(nn.Module):
- def __init__(self, input_dim, output_dim):
- super().__init__()
-
- self.linear1 = nn.Linear(input_dim, 512)
- self.linear2 = nn.Linear(512, 128)
- self.linear3 = nn.Linear(128, output_dim)
-
- self.relu = nn.ReLU()
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, reconstruct_TR_slice):
- out = self.relu(self.linear1(reconstruct_TR_slice))
- out = self.relu(self.linear2(out))
- out = self.sigmoid(self.linear3(out))
- return out
-
-
-
- if __name__ == '__main__':
- from dataset_tools.dataset_newer_1012 import toeplitz
- from dataset_tools.dataset_ThisTR_frontBPSK_NextTR_NextBPSK_1012 import ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset
- from torch.utils.data import DataLoader
-
- # batch_size = 1
- # input = torch.rand(batch_size, 8)
- # bpsk = torch.rand(batch_size, 8)
- # res = []
- # for i in range(batch_size):
- # this_in = input[i, :].view(-1, 1)
- # this_bpsk = bpsk[i, :].view(1, -1)
- # this_res = toeplitz(this_in, this_bpsk)
- # res.append(this_res)
- #
- # in_data = input.view(-1, batch_size)
- # in_bpsk = bpsk.view(batch_size, -1)
- # all_res = toeplitz(in_data, in_bpsk)
- # print(all_res.shape, all_res.sum())
- # for item in res:
- # print(item.sum())
-
- # model = channel_estimation_model(input_dim=2560)
- # dataset = TRwav_BPSK_label_SlicePairsDataset()
- # dataloader = DataLoader(dataset, shuffle=False)
- # loss_func = nn.MSELoss()
- #
- # for i, data in enumerate(dataloader):
- # this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- # this_send01_slice, this_is_frontal, this_filename = data
- # # plt.plot(this_TR_realwav_slice.numpy().reshape(-1))
- # # plt.show()
- # output = model(this_TR_realwav_slice, this_bpsk_hintwav_slice, this_is_frontal)
- # ChuanGAN_matrix, out_vector = output
- #
- # max_postion = torch.argmax(torch.abs(out_vector)).cpu().item()
- # max_value = out_vector[:, max_postion] #torch.abs(out_vector)[:, max_postion]
- #
- # ChuanGAN_matrix_Main = torch.FloatTensor(ChuanGAN_matrix[:, max_postion]).view(1, -1)
- # loss = loss_func(max_value * ChuanGAN_matrix_Main, this_TR_realwav_slice)
- # print(this_filename, loss.item())
-
- model = position_prediction_model(input_dim=2560, output_dim=2560)
- dataset = ThisTRwav_ThisBPSK_NextTRwav_NextLabel_SlicePairsDataset()
- dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
- loss_func = nn.CrossEntropyLoss()
-
- for i, data in enumerate(dataloader):
- this_TR_realwav_slice, this_bpsk_hintwav_slice, \
- next_TR_realwav_slice, next_label01_slice, is_front_slice, this_filename = data
- output = model(this_TR_realwav_slice,
- this_bpsk_hintwav_slice,
- is_front_slice)
-
- if is_front_slice:
- p = torch.zeros((1, this_TR_realwav_slice.shape[-1]))
- else:
- p = torch.flip(this_bpsk_hintwav_slice, dims=[-1]) # .view(1, -1)
- ChuanGAN_matrix = toeplitz(this_bpsk_hintwav_slice.view(-1, 1).cpu().numpy(), p.view(1, -1).cpu().numpy())
- matlab_position = CS_OMP_FindPostion(ChuanGAN_matrix, this_TR_realwav_slice.cpu().numpy().reshape(-1, 1))
- # position_label_onehot = numpy_onehot(matlab_position).reshape(-1, 2560)
- # position_label_onehot = torch.
- position_label = torch.LongTensor([matlab_position], device=output.device) #.cuda()
-
- loss = loss_func(output, position_label)
-
- print(output.shape, matlab_position)
- exit()
|