|
- # -*- coding: utf-8 -*-
- """
- Spyder Editor
-
- This is a temporary script file.
- """
- import os, logging
- import numpy as np
- # import tensorflow as tf
- # import tensorflow.compat.v1 as tf1
- import random
- from datetime import datetime
- import torch
- from dataset import ctx_dataset2
- from models import tfModel10 as mymodel
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # physical_devices = tf.config.experimental.list_physical_devices('GPU')
- # assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
- # config = tf.config.experimental.set_memory_growth(physical_devices[0], True)
-
- def getlogger(logdir): #不用改
- logger = logging.getLogger(__name__) #提供了应用程序可以直接使用的接口
- logger.setLevel(level = logging.INFO) #NOTSET < DEBUG < INFO < WARNING < ERROR < CRITICAL
- handler = logging.FileHandler(os.path.join(logdir, 'log.txt')) #将(logger创建的)日志记录发送到合适的目的输出
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%m/%d %H:%M:%S') #决定日志记录的最终输出格式
- handler.setFormatter(formatter)
- console = logging.StreamHandler() #用于输出到控制台
- console.setLevel(logging.INFO)
- console.setFormatter(formatter)
- logger.addHandler(handler)
- logger.addHandler(console)
- return logger
- #%%config
- num_epochs = 30000
- batch_size = 30000
- learning_rate = 0.001
- lambda_wr = 0
-
- minprob= 0#.0001#0.01
- ctx_type=100
- #%%
- val_data_dir = '/userhome/NNCTX-main/val_datasets/data_100/'
- train_data_dir = '/userhome/NNCTX-main/train_datasets/data_100/'
-
- train_ds = ctx_dataset2(train_data_dir,ctx_type)
- val_ds = ctx_dataset2(val_data_dir,ctx_type)
-
- #%%
- mdl= mymodel(ctx_type).to(device)
-
- #%% DISCARD THE MOST FREQUENT CONTEXT FROM TRAINING SET; SINCE COUNTS ARE SO HIGH
-
- tot_counts = np.sum(train_ds.counts,1)
- disc_ind = np.argmax(tot_counts)
- ###
- vtot_counts = np.sum(val_ds.counts,1)
- vdisc_ind = np.argmax(vtot_counts)
-
- #%%##REFINE TRAINING SET BY COUNT RATIOS##################################
- # count_ratio_th = 0.2
- # ratios= np.min(train_ds.counts,1)/np.max(train_ds.counts,1)
- # train_inds = np.where(ratios<count_ratio_th)[0]
-
- #%%
-
- curr_date = datetime.now().strftime("%Y%m%d-%H%M%S")
- logdir = 'train_logs/' + curr_date + '/'
- os.mkdir(logdir)
- logger = getlogger(logdir)
-
- # checkpoint_path = logdir +'checkpoint.npy'
-
- # curr_file = inspect.getfile(inspect.currentframe()) # script filename (usually with path)
- # copyfile(curr_file,logdir + curr_date + "__" + curr_file.split("/")[-1])
- # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
-
- # tfcounts = tf1.placeholder(dtype='float',shape = [None,2])
- def loss_fn(tfcounts,mdl,mdl_output): #tfcounts [None,2]
- minprob= 0#.0001#0.01
- lambda_wr = 0
-
- # cl_loss = -tf1.reduce_sum(tfcounts[:,0]*tf1.log(mdl.output[:,0]+minprob) + tfcounts[:,1]*tf1.log(mdl.output[:,1]+minprob))
- cl_loss = -torch.sum(tfcounts[:,0]*torch.log(mdl_output[:,0]+minprob) + tfcounts[:,1]*torch.log(mdl_output[:,1]+minprob))
-
- # wr_loss = lambda_wr*(tf1.reduce_mean(mdl.w1)+tf1.reduce_mean(mdl.b1)+tf1.reduce_mean(mdl.w2)+tf1.reduce_mean(mdl.b2))
- wr_loss = lambda_wr*(torch.mean(mdl.w1)+torch.mean(mdl.b1)+torch.mean(mdl.w2)+torch.mean(mdl.b2))
-
- loss = cl_loss + wr_loss
- return loss
-
- # opti = tf1.train.AdamOptimizer(learning_rate=learning_rate)
- optimizer = torch.optim.Adam(mdl.parameters(), lr = learning_rate)
- # train_op = opti.minimize(loss)
-
- # step = tf.Variable(0, dtype=tf.int64)
- # step_update = step.assign_add(1)
- # train_writer = tf.summary.create_file_writer(logdir+ 'train')
- # val_writer = tf.summary.create_file_writer(logdir+ 'val')
- # with train_writer.as_default():
- # tr_loss_summ = tf.summary.scalar("train_loss", loss, step=step)
- # with val_writer.as_default():
- # val_loss_summ = tf.summary.scalar("val_loss", loss, step=step)
- # all_summary_ops = tf1.summary.all_v2_summary_ops()
- # train_writer_flush = train_writer.flush()
- # val_writer_flush = val_writer.flush()
-
- # sess = tf1.Session()
- # sess.run(tf1.global_variables_initializer())
- # sess.run([train_writer.init(),val_writer.init(), step.initializer])
-
- train_inds = list(range(disc_ind))+list(range(disc_ind+1,train_ds.n_ctxs))#list(range(train_ds.n_ctxs))
- val_inds = list(range(vdisc_ind))+list(range(vdisc_ind+1,val_ds.n_ctxs))#range(val_ds.n_ctxs)
-
- best_val_loss = 100000000
- prev_tr_loss = 10000000
-
- num_batches = len(train_inds)//batch_size
- done=0
- for epoch in range(num_epochs):
- print('epoch:' + str(epoch))
- logger.info('epoch:' + str(epoch))
- # if(epoch%10==0):
- np.random.shuffle(train_inds)
- total_loss = 0.
- for ibatch in range(num_batches):
- optimizer.zero_grad()
- batch_inds = train_inds[ibatch*batch_size:(ibatch+1)*batch_size]
- trctxs = train_ds.ctxs[batch_inds,:] #(n,100)
- trcounts = train_ds.counts[batch_inds,:] #(n,2)
- input = torch.from_numpy(trctxs).float().to(device)
- tfcounts = torch.from_numpy(trcounts).float().to(device)
- output = mdl(input)
- loss = loss_fn(tfcounts,mdl,output)
- loss.backward()
- optimizer.step()
- # sess.run([train_op],feed_dict = {mdl.input:trctxs,tfcounts:trcounts})
- total_loss += loss.item()
- with torch.no_grad():
- total_loss /= num_batches
- logger.info('total_loss:' + str(total_loss))
- # tr_loss = sess.run([loss,tr_loss_summ],feed_dict = {mdl.input:trctxs,tfcounts:trcounts})
- # sess.run(step_update)
- # sess.run(train_writer_flush)
- # print('tr_loss:'+str(tr_loss))
-
- #%%# VALIDATION:
- batch_inds = random.sample(val_inds,batch_size)
- vctxs = val_ds.ctxs[batch_inds,:]
- vcounts = val_ds.counts[batch_inds,:]
- input = torch.from_numpy(vctxs).float().to(device)
- tfcounts = torch.from_numpy(vcounts).float().to(device)
- output = mdl(input)
- val_loss = loss_fn(tfcounts,mdl,output)
- # val_loss,_ = sess.run([loss,val_loss_summ],feed_dict = {mdl.input:vctxs,tfcounts:vcounts})
- # sess.run(val_writer_flush)
- # print('val_loss:'+str(val_loss))
- logger.info('val_loss:' + str(val_loss.item()))
- if val_loss<best_val_loss:
- best_val_loss = val_loss
- logger.info('saving checkpoint..')
- # np.save(checkpoint_path,sess.run(tf1.trainable_variables()))
- save_dir = os.path.join(logdir, 'epoch_' + str(epoch) + '.pth')
- torch.save({'model': mdl.state_dict(),'best_val_loss':best_val_loss}, save_dir)
|