|
- import os
- import numpy as np
- import matplotlib.pyplot as plt
- def get_epochAVG_acc_loss(file_path):
- f = open(file_path, 'r')
- all_lines = f.readlines()
- f.close()
-
- loss_list = []
- avg_val_acc_list = []
- avg_test_acc_list = []
-
-
- stop_flag = False
- sample_lines = all_lines[:30]
- start_idx = None
- end_idx = None
- epoch_start_num = None
-
- for index, line in enumerate(sample_lines):
- if 'epoch[' in line and 'loss is:' in line and not stop_flag:
- epoch_start_num = int(line.split('epoch[')[-1].split(']/')[0])
- start_idx = index
- stop_flag = True
- elif 'epoch[' in line and 'loss is:' in line:
- this_epoch_num = int(line.split('epoch[')[-1].split(']/')[0])
- if not this_epoch_num == epoch_start_num:
- end_idx = index
- break
- new_tmp_lines = sample_lines[start_idx: end_idx]
-
- avg_num = 0
- for a_line in new_tmp_lines:
- if 'val_dataset' in a_line:
- avg_num += 1
-
- tmp_loss_list = []
- tmp_val_list = []
- tmp_test_list = []
- for idx, each_line in enumerate(all_lines):
- if 'epoch[' in each_line and 'loss is:' in each_line:
- loss_item = float(each_line.split('loss is:')[-1].split(',')[0])
- tmp_loss_list.append(loss_item)
- if len(tmp_loss_list) == avg_num:
- loss_list.append(np.mean(tmp_loss_list))
- tmp_loss_list = []
-
- elif 'val_dataset' in each_line:
- this_val_acc = float(each_line.split('acc is: [')[-1].split('%')[0])
- tmp_val_list.append(this_val_acc)
- if len(tmp_val_list) == avg_num:
- avg_val_acc_list.append(np.mean(tmp_val_list))
- tmp_val_list = []
-
- elif 'test_dataset' in each_line:
- this_test_acc = float(each_line.split('acc is: [')[-1].split('%')[0])
- tmp_test_list.append(this_test_acc)
- if len(tmp_test_list) == avg_num:
- avg_test_acc_list.append(np.mean(tmp_test_list))
- tmp_test_list = []
-
- assert len(loss_list) == len(avg_test_acc_list) == len(avg_val_acc_list)
- return loss_list, avg_val_acc_list, avg_test_acc_list
-
- def get_matplot_maxValuePos(input_list):
- maxPos = np.argmax(input_list)
- maxValue = input_list[maxPos]
- return maxPos, maxValue
-
- if __name__ == '__main__':
- # all_train_log_dir = './channel_plots/all_channel_tSNE_UWA2022/train_loss_log'
- # fig_save_dir = './channel_plots/all_channel_tSNE_UWA2022/'
-
- plot_num = 0
- colorList = ['r', 'b', 'y', 'green', 'purple', 'orange', 'cyan']
-
- all_train_log_dir = './channel_plots/UWA2023_haoxinhu/train_loss_log'
- fig_save_dir = './channel_plots/UWA2023_haoxinhu/'
-
- if not os.path.exists(fig_save_dir):
- os.makedirs(fig_save_dir, exist_ok=True)
-
- ################################ haoxinhu 2023 #########################################
- for each_file in os.listdir(all_train_log_dir):
- this_f_path = os.path.join(all_train_log_dir, each_file)
- loss_list, avg_val_acc_list, avg_test_acc_list = get_epochAVG_acc_loss(this_f_path)
- date_used_list = each_file.replace('.log', '').split('train_')[-1].split('_')
- label = 'date_'
- for date in date_used_list:
- label += '_{}'.format(date)
- all_label = label
-
- plt.plot(loss_list, label=all_label)
- # plt.tight_layout()
- plt.title('Train-Loss for haoxinhu_{}'.format(all_label))
- plt.legend()
- plt.savefig(fig_save_dir + '{}_loss.png'.format(all_label), dpi=300)
- plt.close()
-
-
- # plt.plot(avg_val_acc_list, label='Val ACC')
- # maxTestPos, maxTestValue = get_matplot_maxValuePos(avg_val_acc_list)
- # plt.text(maxTestPos, maxTestValue + 0.5, '{:.2f}'.format(maxTestValue))
- plt.plot(avg_test_acc_list, label='Test ACC')
- maxTestPos, maxTestValue = get_matplot_maxValuePos(avg_test_acc_list)
- plt.text(maxTestPos, maxTestValue + 0.5, '{:.2f}'.format(maxTestValue))
- plt.ylim([48, 105])
- plt.grid()
- plt.xlabel("epochs")
- plt.ylabel("Accuracy")
- # plt.tight_layout()
- plt.title('Accuracy for haoxinhu_{}'.format(all_label))
- plt.legend(loc="lower right")
- plt.savefig(fig_save_dir + '{}_acc.png'.format(all_label), dpi=300)
- plt.close()
- ################################ haoxinhu 2023 #########################################
-
- # for each_file in os.listdir(all_train_log_dir):
- # this_f_path = os.path.join(all_train_log_dir, each_file)
- # loss_list, avg_val_acc_list, avg_test_acc_list = get_epochAVG_acc_loss(this_f_path)
- # chunk_size = each_file.split('Mul')[-1].split('_')[0]
- # if 'GroupAll' in each_file:
- # label = 'dongjiang3k'
- # elif 'RawData' in each_file:
- # label = 'NoTR_RawData'
- # elif 'Group1' in each_file:
- # label = 'dongjiang-Group1'
- # all_label = label + '-' +str(chunk_size)
- #
- # plt.plot(loss_list, label=all_label)
- # plt.tight_layout()
- # plt.title('Train-Loss')
- # plt.legend()
- # plt.savefig(fig_save_dir + 'loss.png', dpi=300)
- # plt.close()
- # plot_num = 0
-
- # for each_file in os.listdir(all_train_log_dir):
- # this_f_path = os.path.join(all_train_log_dir, each_file)
- # loss_list, avg_val_acc_list, avg_test_acc_list = get_epochAVG_acc_loss(this_f_path)
- # print(each_file, len(avg_val_acc_list))
- # chunk_size = each_file.split('Mul')[-1].split('_')[0]
- # if 'GroupAll' in each_file:
- # label = 'dongjiang3k'
- # elif 'RawData' in each_file:
- # label = 'NoTR_RawData'
- # elif 'Group1' in each_file:
- # label = 'dongjiang-Group1'
- # all_label = label + '-' +str(chunk_size)
- #
- # plt.plot(avg_val_acc_list, label=all_label)
- # plt.tight_layout()
- # plt.title('Val-Acc')
- # plt.ylim([48, 98])
- # plt.legend()
- # plt.savefig(fig_save_dir + 'val_acc.png', dpi=300)
- # plt.close()
- # plot_num = 0
-
- # for each_file in os.listdir(all_train_log_dir):
- # this_f_path = os.path.join(all_train_log_dir, each_file)
- # loss_list, avg_val_acc_list, avg_test_acc_list = get_epochAVG_acc_loss(this_f_path)
- # chunk_size = each_file.split('Mul')[-1].split('_')[0]
- # # if 'GroupAll' in each_file:
- # # label = 'dongjiang3k'
- # # elif 'RawData' in each_file:
- # # label = 'NoTR_RawData'
- # # elif 'Group1' in each_file:
- # # label = 'dongjiang-Group1'
- # # elif 'TrainA_TestOnjuesai_[TR1-TR1]' in each_file:
- # # label = 'TrainA_TestOnjuesai_[TR1-TR1]'
- # # all_label = label + '-' +str(chunk_size)
- #
- # #########################################################
- # if 'TrainA_TestOnjuesai' in each_file:
- # trainTR_method = int(each_file.split('[TR')[-1].split('-')[0])
- # testTR_method = int(each_file.split('-TR')[-1].split(']')[0])
- # if trainTR_method == 4:
- # train_label = '[train]gOMP'
- # else:
- # train_label = '[train]CTR'
- # if testTR_method == 4:
- # test_label = '-[test]gOMP'
- # else:
- # test_label = '-[test]CTR'
- #
- # this_color = colorList[plot_num]
- # else:
- # # continue
- # if 'RawData' in each_file:
- # train_label = 'RawData'
- # test_label = ''
- # this_color = 'gray'
- # else:
- # if 'TrainAWithDongjiangTestOnjuesai' in each_file:
- # train_label = '[train]AWithDongjiang3K-gOMP'
- # test_label = '-[test]gOMP'
- # this_color = 'green'
- # else:
- # continue
- # all_label = '{}{}-windowSize{}'.format(train_label, test_label, chunk_size)
- # plt.plot(avg_test_acc_list, label=all_label, c=this_color)
- #
- # maxTestPos, maxTestValue = get_matplot_maxValuePos(avg_test_acc_list)
- # plt.text(maxTestPos, maxTestValue + 0.5, '{:.2f}'.format(maxTestValue))
- # plt.plot([maxTestPos, maxTestPos], [-1, maxTestValue], '.', c=this_color)
- # plt.plot([-1, maxTestPos], [maxTestValue, maxTestValue], '.', c=this_color)
- #
- # # #########################################################
- # # if 'TRSlice' in each_file and 'Group1' in each_file:
- # # all_label = 'windowSize{}'.format(chunk_size)
- # # this_color = colorList[plot_num]
- # # else:
- # # continue
- # #
- # # plt.plot(avg_test_acc_list, label=all_label, c=this_color)
- # # maxTestPos, maxTestValue = get_matplot_maxValuePos(avg_test_acc_list)
- # # plt.text(maxTestPos, maxTestValue + 0.5, '{:.2f}'.format(maxTestValue))
- # # plt.plot([maxTestPos, maxTestPos], [-1, maxTestValue], '.', c=this_color)
- # # plt.plot([-1, maxTestPos], [maxTestValue, maxTestValue], '.', c=this_color)
- # # ######################################################
- #
- # if not this_color == 'gray':
- # plot_num += 1
- #
- # # plt.tight_layout()
- #
- # plt.title('Test-Acc')
- # plt.ylim([48, 100])
- # plt.xlim([0, 50])
- # plt.xlabel("Train-Epoch numbers")
- # plt.ylabel("Decoding Accuracy")
- # plt.grid()
- # plt.legend()
- # plt.savefig(fig_save_dir + 'test_acc_diffWindow.png', dpi=300)
- # plt.close()
- # plot_num = 0
- #
|