|
- # A simple torch style logger
- # (C) Wei YANG 2017
- from __future__ import absolute_import
- import matplotlib.pyplot as plt
- import os
- import sys
- import numpy as np
-
- __all__ = ['Logger', 'LoggerMonitor', 'savefig']
-
- def savefig(fname, dpi=None):
- dpi = 150 if dpi == None else dpi
- plt.savefig(fname, dpi=dpi)
-
- def plot_overlap(logger, names=None):
- names = logger.names if names == None else names
- numbers = logger.numbers
- for _, name in enumerate(names):
- x = np.arange(len(numbers[name]))
- plt.plot(x, np.asarray(numbers[name]))
- return [logger.title + '(' + name + ')' for name in names]
-
- class Logger(object):
- '''Save training process to log file with simple plot function.'''
- def __init__(self, fpath, title=None, resume=False):
- self.file = None
- self.resume = resume
- self.title = '' if title == None else title
- if fpath is not None:
- if resume:
- self.file = open(fpath, 'r')
- name = self.file.readline()
- self.names = name.rstrip().split('\t')
- self.numbers = {}
- for _, name in enumerate(self.names):
- self.numbers[name] = []
-
- for numbers in self.file:
- numbers = numbers.rstrip().split('\t')
- for i in range(0, len(numbers)):
- self.numbers[self.names[i]].append(numbers[i])
- self.file.close()
- self.file = open(fpath, 'a')
- else:
- self.file = open(fpath, 'w')
-
- def set_names(self, names):
- if self.resume:
- pass
- # initialize numbers as empty list
- self.numbers = {}
- self.names = names
- for _, name in enumerate(self.names):
- self.file.write(name)
- self.file.write('\t')
- self.numbers[name] = []
- self.file.write('\n')
- self.file.flush()
-
-
- def append(self, numbers):
- assert len(self.names) == len(numbers), 'Numbers do not match names'
- for index, num in enumerate(numbers):
- self.file.write("{0:.6f}".format(num))
- self.file.write('\t')
- self.numbers[self.names[index]].append(num)
- self.file.write('\n')
- self.file.flush()
-
- def plot(self, names=None):
- names = self.names if names == None else names
- numbers = self.numbers
- for _, name in enumerate(names):
- x = np.arange(len(numbers[name]))
- plt.plot(x, np.asarray(numbers[name]))
- plt.legend([self.title + '(' + name + ')' for name in names])
- plt.grid(True)
-
- def close(self):
- if self.file is not None:
- self.file.close()
-
- class LoggerMonitor(object):
- '''Load and visualize multiple logs.'''
- def __init__ (self, paths):
- '''paths is a distionary with {name:filepath} pair'''
- self.loggers = []
- for title, path in paths.items():
- logger = Logger(path, title=title, resume=True)
- self.loggers.append(logger)
-
- def plot(self, names=None):
- plt.figure()
- plt.subplot(121)
- legend_text = []
- for logger in self.loggers:
- legend_text += plot_overlap(logger, names)
- plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
- plt.grid(True)
-
- if __name__ == '__main__':
- # # Example
- # logger = Logger('test.txt')
- # logger.set_names(['Train loss', 'Valid loss','Test loss'])
-
- # length = 100
- # t = np.arange(length)
- # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
- # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
- # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
-
- # for i in range(0, length):
- # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
- # logger.plot()
-
- # Example: logger monitor
- paths = {
- 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
- 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
- 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
- }
-
- field = ['Valid Acc.']
-
- monitor = LoggerMonitor(paths)
- monitor.plot(names=field)
- savefig('test.eps')
|