|
- from src_npu.network import GraphNet
- from npu_bridge.npu_init import *
-
- flags = tf.app.flags
- # Customize
- flags.DEFINE_string('data_type', 'adapt', 'lgcn, adapt or lpa')
- flags.DEFINE_integer('save_checkpoint_steps', 1000, '')
- flags.DEFINE_integer('early_stop_steps', 15, 'val stat stops')
- flags.DEFINE_integer('random_seed', 1000, 'random seed')
- flags.DEFINE_boolean('allow_mix_precision', True, 'allow_mix_precision')
-
- # Ori
- flags.DEFINE_integer('max_step', 10000, '# of step for training')
- flags.DEFINE_integer('summary_interval', 10, '# of step to save summary')
- flags.DEFINE_float('learning_rate', 0.1, 'learning rate')
- flags.DEFINE_boolean('is_train', False, 'is train')
- flags.DEFINE_integer('class_num', 7, 'output class number')
- # Debug
- flags.DEFINE_string('logdir', './logdir', 'Log dir')
- flags.DEFINE_string('modeldir', './modeldir', 'Model dir')
- flags.DEFINE_string('model_name', 'model', 'Model file name')
- flags.DEFINE_integer('reload_step', 0, 'Reload step to continue training')
- flags.DEFINE_integer('test_step', 0, 'Test or predict model at this step')
- # network architecture
- # 8, 1
- #
- flags.DEFINE_integer('ch_num', 1, 'channel number')
- flags.DEFINE_integer('layer_num', 1, 'block number')
- flags.DEFINE_float('adj_keep_r', 1., 'dropout keep rate')
-
- # 0.1
- flags.DEFINE_float('keep_r', 1., 'dropout keep rate')
- flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 loss on embedding matrix.')
- flags.DEFINE_integer('k', 8, 'top k')
- flags.DEFINE_string('first_conv', 'simple_conv', 'simple_conv, chan_conv')
- flags.DEFINE_string('second_conv', 'simple_conv', 'graph_conv, simple_conv')
- flags.DEFINE_boolean('use_batch', True, 'use batch training')
- flags.DEFINE_integer('batch_size', 2500, 'batch size number')
- flags.DEFINE_integer('center_num', 1500, 'start center number')
-
- FLAGS = tf.app.flags.FLAGS
-
-
- def main(argv=None):
- with tf.get_default_graph().as_default():
- config = tf.ConfigProto()
- custom_op = config.graph_options.rewrite_options.custom_optimizers.add()
- custom_op.name = "NpuOptimizer"
- custom_op.parameter_map["use_off_line"].b = True # 在昇腾AI处理器执行训练
- custom_op.parameter_map["mix_compile_mode"].b = True
- config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # 关闭remap开关
- config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF
-
- with tf.Session(config=config) as sess:
- net = GraphNet(sess=sess, conf=FLAGS)
- net.reload(None)
- feed_test_dict = net.pack_trans_dict('test')
- test_accuracy = sess.run(net.accuracy_op, feed_dict=feed_test_dict)
- print('Restore Model, test_accuracy: %.4f' % test_accuracy)
-
-
- if __name__ == '__main__':
- tf.app.run()
|