|
- import os
- import torch
- import numpy as np
- import argparse
- from dataset_tools.dataset_FrontTR_ThisTR_2023exp1 import FrontWithCurrent_TRSlicePairs_end2end_decodingCurrent_Dataset, \
- chunk_bits_size
-
- from torch.utils.data import DataLoader
- import torch.optim as optim
- import torch.nn as nn
- from torch.optim import lr_scheduler
- from yizx_models.SBRNN_OOK import SBRnn, SBRnn2, SBRnn_old, Varying_SBRnn, BiLSTM_hlfm, BiLSTM_hlfm2
- from tqdm import tqdm
-
-
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- def get_args():
- """Text generation arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument('--chunk', default=128, type=int)
- parser.add_argument('--hid_dim', default=128, type=int)
- parser.add_argument('--num_layers', default=4, type=int)
- parser.add_argument('--output_dim', default=2, type=int)
- parser.add_argument('--lr', default=2e-3, type=float)
- parser.add_argument('--min_lr', default=2e-3, type=float)
- parser.add_argument('--sliding_window', default=10, type=int)
-
- parser.add_argument('--load_path', default=None, type=str,
- help='using pretrained_model')
- #################################################################
- parser.add_argument('--tr_method', default=1, type=int)
- parser.add_argument('--pairs_concat_method', default=1, type=int)
-
- args = parser.parse_args()
- return args
-
- def do_eval(model, test_dataloader, args, is_test=True):
- model.eval()
- sliding_window = args.sliding_window
-
- test_acc_list = []
- with torch.no_grad():
- for i, data in enumerate(test_dataloader):
- if isinstance(model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- front_wav_data, wav_data, wav_labels, _ = data
- front_wav_data = front_wav_data.float().cuda()
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
- bs, all_seq_lenth = wav_data.shape
- front_wav_data = front_wav_data.view(bs, -1, args.chunk)
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if args.pairs_concat_method == 1:
- wav_data = torch.cat((front_wav_data, wav_data), dim=-1)
-
- if isinstance(model, Varying_SBRnn) or isinstance(model, SBRnn_old):
- output = model(wav_data, sliding_window)
- elif isinstance(model, SBRnn) or isinstance(model, SBRnn2):
- output = model(wav_data, None)
- elif isinstance(model, BiLSTM_hlfm):
- output = model(wav_data, None, wav_h_lfm)
- elif isinstance(model, BiLSTM_hlfm2):
- output, output2 = model(wav_data, None, wav_h_lfm)
-
- pred = torch.argmax(output, axis=-1)
- pred_mask = (pred == wav_labels)
- correct_num = pred_mask.sum().item()
- test_acc = correct_num / (pred_mask.shape[0] * pred_mask.shape[1])
- test_acc_list.append(test_acc)
-
- # if isinstance(model, BiLSTM_hlfm2):
- # pred2 = torch.argmax(output2, axis=-1)
- # pred_mask2 = (pred2 == channel_labels)
- # correct_num2 = pred_mask2.sum().item()
- # test_acc2 = correct_num2 / (pred_mask2.shape[0] * pred_mask2.shape[1])
- # test_channel_acc_list.append(test_acc2)
- if is_test:
- print('Eval on test_dataset, [bits] acc is: [{}%]\n'.format(np.mean(test_acc_list) * 100))
- else:
- print('Eval on val_dataset, [bits] acc is: [{}%]'.format(np.mean(test_acc_list) * 100))
-
- def setup_model_with_datasets(args):
- print('-' * 10 + ' [TR-{}]TRSlice_Pairs_dataset '.format(args.tr_method) + '-' * 10)
- all_dataset = FrontWithCurrent_TRSlicePairs_end2end_decodingCurrent_Dataset(method=args.tr_method)
- test_size = int(len(all_dataset) // 3)
- train_size = int(0.9 * (len(all_dataset) - test_size))
- val_size = len(all_dataset) - train_size - test_size
-
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(all_dataset,
- [train_size, val_size, test_size])
-
- train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=20)
- val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=20)
- test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=20)
- from yizx_models.SBRNN_OOK import SBRnn2 as SBRNN
-
- if args.pairs_concat_method == 1: # [bs, -1, 128] concat [bs, -1, 128], with dim=-1
- my_model = SBRNN(2 * args.chunk, args.hid_dim, args.num_layers, args.output_dim)
- elif args.pairs_concat_method == 2:
- my_model = SBRNN(args.chunk, args.hid_dim, args.num_layers, args.output_dim)
-
- my_model.cuda()
-
- return my_model, train_dataloader, val_dataloader, test_dataloader
-
- def train(args):
-
- my_model, train_dataloader, val_dataloader, test_dataloader = \
- setup_model_with_datasets(args)
-
- sliding_window = args.sliding_window
- M = 2
-
- print(f'The model has {count_parameters(my_model):,} trainable parameters')
-
- # optimizer = optim.AdamW(lr=4e-3, params=my_model.parameters())
- optimizer = optim.Adam(my_model.parameters(), lr=args.lr)#, weight_decay=args.weight_decay, betas=(0.9, 0.98), eps=1e-8)
-
- #scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=0.7, eta_min=args.min_lr, verbose=False)
-
- criterion = nn.CrossEntropyLoss()
-
- if args.load_path is not None:
- my_model.load_state_dict(torch.load(args.load_path))
-
- epoch_loss = 0
- tmp_idx = 0
- MAX_EPOCH_NUM = 50
-
- save_ckpt_path = './ckpts_2023/PairsConcatMethod{}_UWA_2022_TR{}_sbrnn/'.format(args.pairs_concat_method, args.tr_method)
- if not os.path.exists(save_ckpt_path):
- os.makedirs(save_ckpt_path, exist_ok=True)
-
- for epoch in tqdm(range(MAX_EPOCH_NUM)):
- for (i, data) in enumerate(train_dataloader):
- my_model.train()
- optimizer.zero_grad()
-
- if isinstance(my_model, BiLSTM_hlfm):
-
- wav_data, wav_h_lfm, wav_labels = data
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- elif isinstance(my_model, BiLSTM_hlfm2):
- wav_data, wav_h_lfm, wav_labels, channel_labels = data
- channel_labels = channel_labels.cuda()
- wav_data = wav_data.float().cuda()
- wav_h_lfm = wav_h_lfm.float().cuda()
- wav_labels = wav_labels.cuda()
-
- else:
- front_wav_data, wav_data, wav_labels, _ = data
- front_wav_data = front_wav_data.float().cuda()
- wav_data = wav_data.float().cuda()
- wav_labels = wav_labels.cuda()
-
-
-
- bs, all_seq_lenth = wav_data.shape
- front_wav_data = front_wav_data.view(bs, -1, args.chunk)
- wav_data = wav_data.view(bs, -1, args.chunk)
-
- if args.pairs_concat_method == 1:
- wav_data = torch.cat((front_wav_data, wav_data), dim=-1)
- elif args.pairs_concat_method == 2:
- wav_data = wav_data + front_wav_data
-
- if isinstance(my_model, Varying_SBRnn) or isinstance(my_model, SBRnn_old):
- output = my_model(wav_data, sliding_window)
- elif isinstance(my_model, SBRnn) or isinstance(my_model, SBRnn2):
- output = my_model(wav_data, None)
- elif isinstance(my_model, BiLSTM_hlfm):
- output = my_model(wav_data, None, wav_h_lfm)
- if isinstance(my_model, BiLSTM_hlfm2):
- output, output2 = my_model(wav_data, None, wav_h_lfm)
-
- prediction = output.reshape(-1, M)
- wav_labels = wav_labels.reshape(-1)
- loss = criterion(prediction, wav_labels)
-
- if isinstance(my_model, BiLSTM_hlfm2):
- channel_pred_logits = output2.reshape(-1, M)
- channel_labels = channel_labels.reshape(-1)
-
- loss2 = criterion(channel_pred_logits, wav_labels)
- loss += loss2
-
- loss.backward()
- optimizer.step()
- scheduler.step()
- epoch_loss += loss.item()
-
- tmp_idx += 1
-
- log_step = int(1000 * (20.0 / float(chunk_bits_size)))
-
-
- if (i+1) % log_step == 0:# and (epoch + 1) % 2 == 0:
- print('epoch[{}]/[{}], [{}]/[{}], loss is: {}, lr {}'.format(epoch, MAX_EPOCH_NUM, i, len(train_dataloader),
- epoch_loss / tmp_idx,
- optimizer.state_dict()['param_groups'][0]['lr']
- )
- )
- do_eval(my_model, val_dataloader, args, is_test=False)
- do_eval(my_model, test_dataloader, args, is_test=True)
- if (epoch + 1) % 2 == 0 and i == 39:
- torch.save(my_model.state_dict(),
- '{}/new_slideWS{}_epoch_{}.pth'.format(save_ckpt_path, args.sliding_window, epoch))
-
- if __name__ == '__main__':
-
-
- args = get_args()
- train(args)
-
-
|