|
- import os
- import torch
- from abc import ABC, abstractmethod
-
- from ..tasks import build_task
- from ..layers.HeteroLinear import HeteroFeature
- from ..utils import get_nodes_dict
-
-
- class BaseFlow(ABC):
- candidate_optimizer = {
- 'Adam': torch.optim.Adam,
- 'SGD': torch.optim.SGD,
- 'Adadelta': torch.optim.Adadelta
- }
-
- def __init__(self, args):
- """
-
- Parameters
- ----------
- args
-
- Attributes
- -------------
- evaluate_interval: int
- the interval of evaluation in validation
- """
- super(BaseFlow, self).__init__()
- self.evaluator = None
- self.evaluate_interval = 1
- if hasattr(args, '_checkpoint'):
- self._checkpoint = os.path.join(args._checkpoint, f"{args.model_name}_{args.dataset_name}.pt")
- else:
- if hasattr(args, 'load_from_pretrained'):
- self._checkpoint = os.path.join(args.output_dir,
- f"{args.model_name}_{args.dataset_name}_{args.task}.pt")
- else:
- self._checkpoint = None
-
- if not hasattr(args, 'HGB_results_path') and args.dataset_name[:3] == 'HGB':
- args.HGB_results_path = os.path.join(args.output_dir,
- "{}_{}_{}.txt".format(args.model_name, args.dataset_name[5:],
- args.seed))
-
- # stage flags: whether to run the corresponding stages
- # todo: only take effects in node classification trainer flow
-
- # args.training_flag = getattr(args, 'training_flag', True)
- # args.validation_flag = getattr(args, 'validation_flag', True)
- args.test_flag = getattr(args, 'test_flag', True)
- args.prediction_flag = getattr(args, 'prediction_flag', False)
- args.use_uva = getattr(args, 'use_uva', False)
-
- self.args = args
- self.logger = self.args.logger
- self.model_name = args.model_name
- self.model = args.model
- self.device = args.device
- self.task = build_task(args)
- if self.args.use_uva:
- self.hg = self.task.get_graph()
- else:
- self.hg = self.task.get_graph().to(self.device)
- self.args.meta_paths_dict = self.task.dataset.meta_paths_dict
- self.patience = args.patience
- self.max_epoch = args.max_epoch
- self.optimizer = None
- self.loss_fn = self.task.get_loss_fn()
-
- def preprocess(self):
- r"""
- Every trainerflow should run the preprocess_feature if you want to get a feature preprocessing.
- The Parameters in input_feature will be added into optimizer and input_feature will be added into the model.
-
- Attributes
- -----------
- input_feature : HeteroFeature
- It will return the processed feature if call it.
-
- """
- if hasattr(self.args, 'activation'):
- if hasattr(self.args.activation, 'weight'):
- import torch.nn as nn
- act = nn.PReLU()
- else:
- act = self.args.activation
- else:
- act = None
- # useful type selection
- if hasattr(self.args, 'feat'):
- pass
- else:
- # Default 0, nothing to do.
- self.args.feat = 0
- self.feature_preprocess(act)
- self.optimizer.add_param_group({'params': self.input_feature.parameters()})
- # for early stop, load the model with input_feature module.
- self.model.add_module('input_feature', self.input_feature)
- self.load_from_pretrained()
-
- def feature_preprocess(self, act):
- """
- Feat
- 0, 1 ,2
- Node feature
- 1 node type & more than 1 node types
- no feature
-
- Returns
- -------
-
- """
-
- if self.hg.ndata.get('h', {}) == {} or self.args.feat == 2:
- if self.hg.ndata.get('h', {}) == {}:
- self.logger.feature_info('Assign embedding as features, because hg.ndata is empty.')
- else:
- self.logger.feature_info('feat2, drop features!')
- self.hg.ndata.pop('h')
- self.input_feature = HeteroFeature({}, get_nodes_dict(self.hg), self.args.hidden_dim,
- act=act).to(self.device)
- elif self.args.feat == 0:
- self.input_feature = self.init_feature(act)
- elif self.args.feat == 1:
- if self.args.task != 'node_classification':
- self.logger.feature_info('\'feat 1\' is only for node classification task, set feat 0!')
- self.input_feature = self.init_feature(act)
- else:
- h_dict = self.hg.ndata.pop('h')
- self.logger.feature_info('feat1, preserve target nodes!')
- self.input_feature = HeteroFeature({self.category: h_dict[self.category]}, get_nodes_dict(self.hg), self.args.hidden_dim,
- act=act).to(self.device)
-
- def init_feature(self, act):
- self.logger.feature_info("Feat is 0, nothing to do!")
- if isinstance(self.hg.ndata['h'], dict):
- # The heterogeneous contains more than one node type.
- input_feature = HeteroFeature(self.hg.ndata['h'], get_nodes_dict(self.hg),
- self.args.hidden_dim, act=act).to(self.device)
- elif isinstance(self.hg.ndata['h'], torch.Tensor):
- # The heterogeneous only contains one node type.
- input_feature = HeteroFeature({self.hg.ntypes[0]: self.hg.ndata['h']}, get_nodes_dict(self.hg),
- self.args.hidden_dim, act=act).to(self.device)
- return input_feature
-
- @abstractmethod
- def train(self):
- pass
-
- def _full_train_step(self):
- r"""
- Train with a full_batch graph
- """
- raise NotImplementedError
-
- def _mini_train_step(self):
- r"""
- Train with a mini_batch seed nodes graph
- """
- raise NotImplementedError
-
- def _full_test_step(self):
- r"""
- Test with a full_batch graph
- """
- raise NotImplementedError
-
- def _mini_test_step(self):
- r"""
- Test with a mini_batch seed nodes graph
- """
- raise NotImplementedError
-
- def load_from_pretrained(self):
- if hasattr(self.args, 'load_from_pretrained') and self.args.load_from_pretrained:
- try:
- ck_pt = torch.load(self._checkpoint)
- self.model.load_state_dict(ck_pt)
- self.logger.info('[Load Model] Load model from pretrained model:' + self._checkpoint)
- except FileNotFoundError:
- self.logger.info('[Load Model] Do not load the model from pretrained, '
- '{} doesn\'t exists'.format(self._checkpoint))
- # return self.model
-
- def save_checkpoint(self):
- if self._checkpoint and hasattr(self.model, "_parameters()"):
- torch.save(self.model.state_dict(), self._checkpoint)
|