|
- import os
- import torch
- import numpy as np
- import argparse
- from dataset_tools.dataset_UWA_2022_grouped_TRSlice import Audio_UWA2022_Grouped_TRSliceDataset, \
- Audio_UWA2022_2_Group_150Samples_TRSliceDataset, \
- Audio_UWA2022_GroupAll_3kSamples_TRSliceDataset, \
- Audio_UWA2022_GroupAll_3kSamples_TRSliceWithChannelDataset, \
- Audio_UWA2022_GroupAll_3kSamples_TRSliceWithChannelabelsDataset, \
- Audio_UWA2022_GroupAll_3kSamples_TRDataset, Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset,\
- Audio_UWA2022_TrainA1k5SamplesAddDongjiang3K_TestOn_Juesai_TRSliceDataset, chunk_bits_size
-
- from torch.utils.data import DataLoader
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
- from yizx_models.SBRNN_OOK import SBRnn, SBRnn2, SBRnn_old, Varying_SBRnn, BiLSTM_hlfm, BiLSTM_hlfm2
- from tqdm import tqdm
-
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--chunk', default=128, type=int)
- parser.add_argument('--hid_dim', default=128, type=int)
- parser.add_argument('--num_layers', default=4, type=int)
- parser.add_argument('--output_dim', default=2, type=int)
- parser.add_argument('--lr', default=2e-3, type=float)
- parser.add_argument('--min_lr', default=2e-3, type=float)
- parser.add_argument('--sliding_window', default=10, type=int)
- parser.add_argument('--tsne_new_data', default=3, type=int)
- # [
- # 0 for raw_grouped_data,
- # 1 for 1+2+3+4+5=150samples-group1-dataset,
- # 2 for 1+2+3+4+5=150samples-group1-dataset,
- # 3 for 3K-samples-group1-dataset,
- # 4 for 3k-samples-group2-dataset,
- # 5 for 3K-samples-finegrained-group1-dataset,
- # 6 for 3k-samples-finegrained-group2-dataset,
- # 7 for 3k-samples-allgroup-dataset,
- # 10 for 3k-samples-allgroup-dataset, SBRNN, not using channel,
- ###################################
- # 11 for [TR-1]1k5-TrainA-samples-dataset/Test on juesai-dataset[TR-1], SBRNN, not using channel,
- # 12 for [TR-4]1k5-TrainA-samples-dataset/Test on juesai-dataset[TR-4], SBRNN, not using channel,
- # 13 for [TR-1]1k5-TrainA-samples-dataset/Test on juesai-dataset[TR-4], SBRNN, not using channel,
- # 14 for [TR-4]1k5-TrainA-samples-dataset/Test on juesai-dataset[TR-1], SBRNN, not using channel,
-
- # 15 for [TR-4]1k5-TrainA-samples-dataset+[TR-4]3k-dongjiang-samples/Test on juesai-dataset[TR-4], SBRNN, not using channel,
-
- # ]
- parser.add_argument('--load_path', default=None, type=str,
- help='using pretrained_model')
- args = parser.parse_args()
- return args
-
- def do_eval(model, test_dataloader, args, is_test=True):
- model.eval()
- test_acc = 0.0
- sliding_window = args.sliding_window
- M = 2
-
- test_acc_list = []
- test_channel_acc_list = []
- with torch.no_grad():
- for i, data in enumerate(test_dataloader):
- if isinstance(model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- if args.tsne_new_data == 10:
- wav_data, _, wav_labels = data
- else:
- wav_data, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if isinstance(model, Varying_SBRnn) or isinstance(model, SBRnn_old):
- output = model(wav_data, sliding_window)
- elif isinstance(model, SBRnn) or isinstance(model, SBRnn2):
- output = model(wav_data, None)
- elif isinstance(model, BiLSTM_hlfm):
- output = model(wav_data, None, wav_h_lfm)
- elif isinstance(model, BiLSTM_hlfm2):
- output, output2 = model(wav_data, None, wav_h_lfm)
-
- pred = torch.argmax(output, axis=-1)
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc = correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
- test_acc_list.append(test_acc)
-
- # if isinstance(model, BiLSTM_hlfm2):
- # pred2 = torch.argmax(output2, axis=-1)
- # pred_mask2 = (pred2 == channel_labels)
- # correct_num2 = pred_mask2.sum().item()
- # test_acc2 = correct_num2 / (pred_mask2.shape[0] * pred_mask2.shape[1])
- # test_channel_acc_list.append(test_acc2)
- if is_test:
- print('Eval on test_dataset, [bits] acc is: [{}%]\n'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on test_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
- else:
- print('Eval on val_dataset, [bits] acc is: [{}%]'.format(np.mean(test_acc_list) * 100))
- # if isinstance(model, BiLSTM_hlfm2):
- # print('Eval on val_dataset, [channel] acc is: [{}%]\n'.format(np.mean(test_channel_acc_list) * 100))
-
- def setup_model_with_datasets(train_data_dir, train_labels_dir, test_data_dir, test_labels_dir, args):
- if args.tsne_new_data == 0:
- all_dataset = Audio_UWA2022_Grouped_TRSliceDataset(train_data_dir, train_labels_dir)
- train_size = int(0.9 * len(all_dataset))
- val_size = len(all_dataset) - train_size
- train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
-
- test_dataset = Audio_UWA2022_Grouped_TRSliceDataset(test_data_dir, test_labels_dir)
-
- elif args.tsne_new_data == 1 or args.tsne_new_data == 2:
- print('-'*10 + ' grouped_tSNE_new_150samples_dataset ' + '-'*10)
- all_dataset = Audio_UWA2022_2_Group_150Samples_TRSliceDataset(group_idx=args.tsne_new_data)
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9*(len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
-
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size, test_size])
-
- elif args.tsne_new_data == 3 or args.tsne_new_data == 4:
- print('-' * 10 + ' grouped_tSNE_new_3ksamples_dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_GroupAll_3kSamples_TRSliceDataset(group_idx=args.tsne_new_data - 2)
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
-
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- elif args.tsne_new_data == 5 or args.tsne_new_data == 6:
- print('-' * 10 + ' [FineGrained] grouped_tSNE_new_3ksamples_dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_GroupAll_3kSamples_TRSliceDataset(group_idx=args.tsne_new_data - 4, finegrained=True)
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- elif args.tsne_new_data == 7 or args.tsne_new_data == 10:
- print('-' * 10 + ' [all3k_with_channel] grouped_tSNE_new_3ksamples_dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_GroupAll_3kSamples_TRSliceWithChannelDataset()
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- elif args.tsne_new_data == 8:
- print('-' * 10 + ' [all3k_with_channel_labels] grouped_tSNE_new_3ksamples_dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_GroupAll_3kSamples_TRSliceWithChannelabelsDataset()
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- elif args.tsne_new_data == 9:
- print('-' * 10 + ' [new-group1-TRNoSlice] grouped_tSNE_new_3ksamples_dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_GroupAll_3kSamples_TRDataset(group_idx=args.tsne_new_data - 8, sample_chunk_len=args.chunk)
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- elif args.tsne_new_data == 11:
- print('-' * 10 + ' [TrainA_[TR1]_TRSlice_TestON_juesai[TR1]] dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=1, is_train=True)
- test_size = 0 #int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
- test_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=1, is_train=False)
-
- elif args.tsne_new_data == 12:
- print('-' * 10 + ' [TrainA_[TR4]_TRSlice_TestON_juesai[TR4]] dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=4, is_train=True)
- test_size = 0 #int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
- test_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=4, is_train=False)
-
- elif args.tsne_new_data == 13:
- print('-' * 10 + ' [TrainA_[TR1]_TRSlice_TestON_juesai[TR4]] dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=1, is_train=True)
- test_size = 0 #int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
- test_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=4, is_train=False)
-
- elif args.tsne_new_data == 14:
- print('-' * 10 + ' [TrainA_[TR4]_TRSlice_TestON_juesai[TR1]] dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=4, is_train=True)
- test_size = 0 #int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
- test_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=1, is_train=False)
-
- elif args.tsne_new_data == 15:
- print('-' * 10 + ' [TrainA+Dongjiang3k_[TR4]_TRSlice_TestON_juesai[TR4]] dataset ' + '-' * 10)
- all_dataset = Audio_UWA2022_TrainA1k5SamplesAddDongjiang3K_TestOn_Juesai_TRSliceDataset(TR_method=4, is_train=True)
- test_size = 0 #int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
- train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, [train_size, val_size])
- test_dataset = Audio_UWA2022_TrainA1k5Samples_TestOn_Juesai_TRSliceDataset(TR_method=4, is_train=False)
-
- train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=20)
- val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=20)
- test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=20)
-
- if args.tsne_new_data < 7 or args.tsne_new_data == 9 or args.tsne_new_data > 10: # or args.tsne_new_data == 12:
- from yizx_models.SBRNN_OOK import SBRnn2 as SBRNN
- elif args.tsne_new_data == 7:
- from yizx_models.SBRNN_OOK import BiLSTM_hlfm as SBRNN
- elif args.tsne_new_data == 8:
- from yizx_models.SBRNN_OOK import BiLSTM_hlfm2 as SBRNN
- elif args.tsne_new_data == 10:
- from yizx_models.SBRNN_OOK import SBRnn2 as SBRNN
-
- my_model = SBRNN(args.chunk, args.hid_dim, args.num_layers, args.output_dim)
- my_model.cuda()
-
- return my_model, train_dataloader, val_dataloader, test_dataloader
-
- def train(train_data_dir, train_labels_dir, test_data_dir, test_labels_dir, args, idx):
-
- my_model, train_dataloader, val_dataloader, test_dataloader = \
- setup_model_with_datasets(train_data_dir, train_labels_dir, test_data_dir, test_labels_dir, args)
-
- sliding_window = args.sliding_window
- M = 2
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
-
- # optimizer = optim.AdamW(lr=4e-3, params=my_model.parameters())
- optimizer = optim.Adam(my_model.parameters(), lr=args.lr)#, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
-
- #scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=0.7, eta_min=args.min_lr, verbose=False)
-
- criterion = nn.CrossEntropyLoss()
- #nn.BCEWithLogitsLoss() #CrossEntropyLoss()
-
- if args.load_path is not None:
- my_model.load_state_dict(torch.load(args.load_path))
-
- epoch_loss = 0
- tmp_idx = 0
- MAX_EPOCH_NUM = 50
-
- if args.tsne_new_data < 3:
- save_ckpt_path = 'UWA_2022_Group_{}_sbrnn/'.format(idx)
- else:
- save_ckpt_path = 'ckpts_2023/UWA_2022_DataType{}_sbrnn_chunkBitSize{}/'.format(args.tsne_new_data, chunk_bits_size)
- if not os.path.exists(save_ckpt_path):
- os.makedirs(save_ckpt_path, exist_ok=True)
-
- for epoch in tqdm(range(MAX_EPOCH_NUM)):
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- if isinstance(my_model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(my_model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- if args.tsne_new_data == 10:
- wav_data, _, wav_labels = data
- else:
- wav_data, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if isinstance(my_model, Varying_SBRnn) or isinstance(my_model, SBRnn_old):
- output = my_model(wav_data, sliding_window)
- elif isinstance(my_model, SBRnn) or isinstance(my_model, SBRnn2):
- output = my_model(wav_data, None)
- elif isinstance(my_model, BiLSTM_hlfm):
- output = my_model(wav_data, None, wav_h_lfm)
- if isinstance(my_model, BiLSTM_hlfm2):
- output, output2 = my_model(wav_data, None, wav_h_lfm)
-
- prediction = output.reshape(-1, M)
- wav_labels = wav_labels.reshape(-1)
- loss = criterion(prediction, wav_labels)
-
- if isinstance(my_model, BiLSTM_hlfm2):
- channel_pred_logits = output2.reshape(-1, M)
- channel_labels = channel_labels.reshape(-1)
-
- loss2 = criterion(channel_pred_logits, wav_labels)
- loss += loss2
-
- loss.backward()
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
- tmp_idx += 1
-
- if args.tsne_new_data < 3:
- log_step = 20
- elif args.tsne_new_data == 3 or args.tsne_new_data == 4 or args.tsne_new_data > 9:
- # log_step = 2 #500
- log_step = int(1000 * (20.0 / float(chunk_bits_size)))
- elif args.tsne_new_data == 5 or args.tsne_new_data == 6:
- log_step = 300
- elif args.tsne_new_data == 7 or args.tsne_new_data == 8:
- log_step = 500
- elif args.tsne_new_data == 9:
- log_step = 20
-
-
-
-
- if (i+1) % log_step == 0:# and (epoch + 1) % 2 == 0:
- print('epoch[{}]/[{}], [{}]/[{}], loss is: {}, lr {}'.format(epoch, MAX_EPOCH_NUM, i, len(train_dataloader),
- epoch_loss / tmp_idx,
- optimizer.state_dict()['param_groups'][0]['lr']
- )
- )
- do_eval(my_model, val_dataloader, args, is_test=False)
- do_eval(my_model, test_dataloader, args, is_test=True)
- if (epoch + 1) % 2 == 0 and i == 39:
- if args.tsne_new_data < 3:
- torch.save(my_model.state_dict(), '{}/slideWS{}_epoch_{}.pth'.format(save_ckpt_path, args.sliding_window, epoch))
- else:
- torch.save(my_model.state_dict(),
- '{}/new_slideWS{}_epoch_{}.pth'.format(save_ckpt_path, args.sliding_window, epoch))
-
- if __name__ == '__main__':
-
-
- args = get_args()
- idx_used = [5]
- for idx in idx_used:
- train_data_dir = '/userhome/wave_training_old/UWA_2022_dataset/{}/train{}_TR'.format(idx, idx)
- train_labels_dir = '/userhome/wave_training_old/UWA_2022_dataset/{}/labels_train{}'.format(idx, idx)
- test_data_dir = '/userhome/wave_training_old/UWA_2022_dataset/{}/test{}_TR'.format(idx, idx)
- test_labels_dir = '/userhome/wave_training_old/UWA_2022_dataset/{}/labels_test{}'.format(idx, idx)
- train(train_data_dir, train_labels_dir, test_data_dir, test_labels_dir, args, idx)
-
-
|