|
- import os
- import numpy as np
- import torch
- from torch.autograd import Variable
- from pdb import set_trace as st
- from IPython import embed
-
- class BaseModel():
- def __init__(self):
- pass;
-
- def name(self):
- return 'BaseModel'
-
- def initialize(self, use_gpu=True, gpu_ids=[0]):
- self.use_gpu = use_gpu
- self.gpu_ids = gpu_ids
-
- def forward(self):
- pass
-
- def get_image_paths(self):
- pass
-
- def optimize_parameters(self):
- pass
-
- def get_current_visuals(self):
- return self.input
-
- def get_current_errors(self):
- return {}
-
- def save(self, label):
- pass
-
- # helper saving function that can be used by subclasses
- def save_network(self, network, path, network_label, epoch_label):
- save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
- save_path = os.path.join(path, save_filename)
- torch.save(network.state_dict(), save_path)
-
- # helper loading function that can be used by subclasses
- def load_network(self, network, network_label, epoch_label):
- save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
- save_path = os.path.join(self.save_dir, save_filename)
- print('Loading network from %s'%save_path)
- network.load_state_dict(torch.load(save_path))
-
- def update_learning_rate():
- pass
-
- def get_image_paths(self):
- return self.image_paths
-
- def save_done(self, flag=False):
- np.save(os.path.join(self.save_dir, 'done_flag'),flag)
- np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|