|
- import os
- import torch
- import numpy as np
- import argparse
- # from dataset_tools.dataset_haoxinhu_2023project import Audio_Maoming_haoxinhu_Dataset, chunk_bits_size, Audio_ZJQ_SimulationDataset
- # Audio_ZJQ_SimulationDataset()
- # exit()
-
- from dataset_tools.npy_dataset_haoxinhu_pytorch import NPY_haoxinhu_Dataset, chunk_bits_size
- from dataset_tools.npy_dataset_haoxinhu_withSimulationData_pytorch import NPY_haoxinhu_WithSim_Dataset, chunk_bits_size
- from torch.utils.data import DataLoader
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
- from yizx_models.SBRNN_OOK import SBRnn, SBRnn2, SBRnn_old, Varying_SBRnn, BiLSTM_hlfm, BiLSTM_hlfm2
- from tqdm import tqdm
- import time
-
- npy_dir = "/userhome/wave_training_old/dataset_cache_npy"
- used_DateList = [217, 220, 222, 226, 227, 228, 312, 313]
- # used_DateList = [228]
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--chunk', default=49, type=int)
- parser.add_argument('--hid_dim', default=128, type=int)
- parser.add_argument('--num_layers', default=4, type=int)
- parser.add_argument('--output_dim', default=2, type=int)
- parser.add_argument('--lr', default=2e-3, type=float)
- parser.add_argument('--min_lr', default=2e-4, type=float)
- parser.add_argument('--do_chunk', default=False, type=bool)
- parser.add_argument('--samplerate_khz', default=44.1, type=float)
- parser.add_argument('--sliding_window', default=10, type=int)
-
- parser.add_argument('--data_start_bits_idx', default=0, type=int)
- parser.add_argument('--data_end_bits_idx', default=0, type=int)
-
- parser.add_argument('--load_path', default=None, type=str,
- help='using pretrained_model')
- parser.add_argument('--test_wave_file', default=None, type=str,
- help='Testing model on One wave_file')
- parser.add_argument('--test_label_file', default=None, type=str,
- help='Testing model on One wave_file, setting label_file')
-
- parser.add_argument('--add_sim_data', default=False, type=bool)
-
- args = parser.parse_args()
- return args
-
- def do_eval(model, test_dataloader, args, is_test=True):
- model.eval()
- sliding_window = args.sliding_window
-
- test_acc_list = []
- with torch.no_grad():
- for i, data in enumerate(test_dataloader):
- if isinstance(model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- wav_data, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if isinstance(model, Varying_SBRnn) or isinstance(model, SBRnn_old):
- output = model(wav_data, sliding_window)
- elif isinstance(model, SBRnn) or isinstance(model, SBRnn2):
- output = model(wav_data, None)
- elif isinstance(model, BiLSTM_hlfm):
- output = model(wav_data, None, wav_h_lfm)
- elif isinstance(model, BiLSTM_hlfm2):
- output, output2 = model(wav_data, None, wav_h_lfm)
-
- pred = torch.argmax(output, axis=-1)
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc = correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
- test_acc_list.append(test_acc)
-
- # if isinstance(model, BiLSTM_hlfm2):
- # pred2 = torch.argmax(output2, axis=-1)
- # pred_mask2 = (pred2 == channel_labels)
- # correct_num2 = pred_mask2.sum().item()
- # test_acc2 = correct_num2 / (pred_mask2.shape[0] * pred_mask2.shape[1])
- # test_channel_acc_list.append(test_acc2)
- if is_test:
- print('Eval on test_dataset, [bits] acc is: [{}%]\n'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on test_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
- else:
- print('Eval on val_dataset, [bits] acc is: [{}%]'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on val_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
-
-
- def do_inference(model, test_dataloader, args, is_test=True):
- model.eval()
- sliding_window = args.sliding_window
-
- start_time = time.time()
- total_numbers = 0
-
- test_acc_list = []
- with torch.no_grad():
- for i, data in enumerate(test_dataloader):
- if isinstance(model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- wav_data, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
-
- total_numbers += bs
-
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if isinstance(model, Varying_SBRnn) or isinstance(model, SBRnn_old):
- output = model(wav_data, sliding_window)
- elif isinstance(model, SBRnn) or isinstance(model, SBRnn2):
- output = model(wav_data, None)
- elif isinstance(model, BiLSTM_hlfm):
- output = model(wav_data, None, wav_h_lfm)
- elif isinstance(model, BiLSTM_hlfm2):
- output, output2 = model(wav_data, None, wav_h_lfm)
-
- pred = torch.argmax(output, axis=-1)
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc = correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
- test_acc_list.append(test_acc)
-
- end_time = time.time()
-
- if is_test:
- print('Eval on test_dataset, [bits] acc is: [{}%]\n'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on test_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
- else:
- print('Eval on val_dataset, [bits] acc is: [{}%]'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on val_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
-
- print("\n>>> Total {} samples, inference speed: {}s/sample >>>\n".format(total_numbers, (end_time - start_time) / total_numbers))
-
-
- def do_inference_onFile(model, args, is_test=True):
- import librosa
- from py_matlab_tools.FIR_python_vs_matlab import py_fir3
- from py_matlab_tools.LFMgen import get_signal_conv, LFMgen_haoxinhu
- sample_rate_KHz = 44.1
-
- model.eval()
- sliding_window = args.sliding_window
-
- total_numbers = 0
- test_acc_list = []
-
- LFM = LFMgen_haoxinhu(44100)
- data_bits = 49
-
- realwav_data, _ = np.array(librosa.load(args.test_wave_file, sr=44100))
- # do filter
- recv_wav_signal = py_fir3(realwav_data, fs=int(sample_rate_KHz * 1000))
- # do auto-correlation
- sync_res = get_signal_conv(recv_wav_signal, LFM)
- sync_res = np.abs(sync_res)
- start_pos = np.argmax(sync_res)
- useful_data_pos = start_pos - int(0.05 * len(LFM)) + 10000
- recv_data = recv_wav_signal[useful_data_pos: useful_data_pos + data_bits * 800]
-
- label_f_path = args.test_label_file
-
- try:
- label_f = open(label_f_path, 'r+')
- if not "random102" in label_f_path:
- this_800_label = list(label_f.readlines())
- new_800_labels = np.array([int(itx.replace('\n', '')) for itx in this_800_label])
- else:
- this_800_label = list(label_f.readlines()[0])
- new_800_labels = np.array([int(itx) for itx in this_800_label])
- except:
- from dataset_tools.dataset_haoxinhu_2023project import load_data_mat
- new_800_labels = load_data_mat(label_f_path)
-
- start_time = time.time()
- with torch.no_grad():
-
- wav_data = torch.from_numpy(recv_data).unsqueeze(0)
- wav_labels = torch.from_numpy(new_800_labels).unsqueeze(0)
-
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
- total_numbers += 1
-
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if isinstance(model, Varying_SBRnn) or isinstance(model, SBRnn_old):
- output = model(wav_data, sliding_window)
- elif isinstance(model, SBRnn) or isinstance(model, SBRnn2):
- output = model(wav_data, None)
-
- pred = torch.argmax(output, axis=-1)
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc = correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
- test_acc_list.append(test_acc)
-
- end_time = time.time()
-
- if is_test:
- print('Eval on test_dataset, [bits] acc is: [{}%]\n'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on test_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
- else:
- print('Eval on val_dataset, [bits] acc is: [{}%]'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on val_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
-
- print("\n>>> Total {} samples, inference speed: {}s/sample >>>\n".format(total_numbers, (end_time - start_time) / total_numbers))
-
-
- def setup_model_with_datasets(args):
- # all_dataset = Audio_Maoming_haoxinhu_Dataset(sample_chunk_len=args.chunk, do_chunk=args.do_chunk)
- if args.add_sim_data:
- all_dataset = NPY_haoxinhu_WithSim_Dataset(npy_dir, used_DateList, sample_chunk_len=args.chunk, do_chunk=args.do_chunk)
- else:
- all_dataset = NPY_haoxinhu_Dataset(npy_dir, used_DateList, sample_chunk_len=args.chunk, do_chunk=args.do_chunk)
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=20)
- val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=20)
- test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=20)
-
- from yizx_models.SBRNN_OOK import SBRnn2 as SBRNN
-
- my_model = SBRNN(args.chunk, args.hid_dim, args.num_layers, args.output_dim)
- my_model.cuda()
-
- return my_model, train_dataloader, val_dataloader, test_dataloader
-
- def train(args):
-
- my_model, train_dataloader, val_dataloader, test_dataloader = \
- setup_model_with_datasets(args)
-
- sliding_window = args.sliding_window
- M = 2
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
-
- if args.load_path is not None:
- my_model.load_state_dict(torch.load(args.load_path))
- if args.test_wave_file is None:
- do_inference(my_model, test_dataloader, args, is_test=True)
- else:
- do_inference_onFile(my_model, args, is_test=True)
- else:
- # optimizer = optim.AdamW(lr=4e-3, params=my_model.parameters())
- optimizer = optim.Adam(my_model.parameters(), lr=args.lr)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=0.7, eta_min=args.min_lr)#, verbose=False)
- criterion = nn.CrossEntropyLoss()
-
- #nn.BCEWithLogitsLoss() #CrossEntropyLoss()
- epoch_loss = 0
- tmp_idx = 0
- MAX_EPOCH_NUM = 50
-
- if args.do_chunk:
- save_ckpt_path = 'ckpts_2023/SampleRateKHz{}_haoxinhu_2023_AutoCorr_DoChunk{}'\
- .format(args.samplerate_khz, chunk_bits_size)
- else:
- save_ckpt_path = 'ckpts_2023/SampleRateKHz{}_haoxinhu_2023_AutoCorr'.format(args.samplerate_khz)
-
- # if args.data_start_bits_idx == 0 and args.data_end_bits_idx == 0:
- # data_bits_list = []
- # else:
- # assert args.data_start_bits_idx < args.data_end_bits_idx
- # data_bits_list = [args.data_start_bits_idx, args.data_end_bits_idx]
- #
- # if len(data_bits_list) == 0:
- # save_ckpt_path += '/'
- # elif len(data_bits_list) == 2:
- # save_ckpt_path += '_dataBits{}to{}/'.format(int(data_bits_list[0]), int(data_bits_list[1]))
-
- for date in used_DateList:
- save_ckpt_path += "_{}".format(date)
-
- if args.add_sim_data:
- save_ckpt_path = "AddSimData_" + save_ckpt_path
-
- if not os.path.exists(save_ckpt_path):
- os.makedirs(save_ckpt_path, exist_ok=True)
-
- for epoch in tqdm(range(MAX_EPOCH_NUM)):
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- if isinstance(my_model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(my_model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- wav_data, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if isinstance(my_model, Varying_SBRnn) or isinstance(my_model, SBRnn_old):
- output = my_model(wav_data, sliding_window)
- elif isinstance(my_model, SBRnn) or isinstance(my_model, SBRnn2):
- output = my_model(wav_data, None)
- elif isinstance(my_model, BiLSTM_hlfm):
- output = my_model(wav_data, None, wav_h_lfm)
- if isinstance(my_model, BiLSTM_hlfm2):
- output, output2 = my_model(wav_data, None, wav_h_lfm)
-
- prediction = output.reshape(-1, M)
- wav_labels = wav_labels.reshape(-1)
- loss = criterion(prediction, wav_labels)
-
- if isinstance(my_model, BiLSTM_hlfm2):
- channel_pred_logits = output2.reshape(-1, M)
- channel_labels = channel_labels.reshape(-1)
-
- loss2 = criterion(channel_pred_logits, wav_labels)
- loss += loss2
-
- loss.backward()
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
- tmp_idx += 1
-
- log_step = 100 # 1000
-
- if (i+1) % log_step == 0:# and (epoch + 1) % 2 == 0:
- print('epoch[{}]/[{}], [{}]/[{}], loss is: {}, lr {}'.format(epoch, MAX_EPOCH_NUM, i, len(train_dataloader),
- epoch_loss / tmp_idx,
- optimizer.state_dict()['param_groups'][0]['lr']
- )
- )
- do_eval(my_model, val_dataloader, args, is_test=False)
- do_eval(my_model, test_dataloader, args, is_test=True)
-
- if (epoch + 1) % 2 == 0 and i == 39:
- torch.save(my_model.state_dict(), '{}/SBRNN_slideWS{}_epoch_{}.pth'.format(save_ckpt_path, args.sliding_window, epoch))
-
-
-
-
- if __name__ == '__main__':
-
-
- args = get_args()
- train(args)
-
-
|