GTN[NeurIPS2019] & fastGTN
Paper: [GTN] Graph Transformer Networks
Extension Paper: [fastGTN] Graph Transformer Networks: Learning Meta-path Graphs to Improve GNNs
Code from author: https://github.com/seongjunyun/Graph_Transformer_Networks
How to run
Clone the OpenHGNN
# Run GTN
python main.py -m GTN -t node_classification -d acm4GTN -g 0 --use_best_config
# Run the fastGTN
python main.py -m fastGTN -t node_classification -d acm4GTN -g 0 --use_best_config
If you do not have gpu, set -gpu -1.
Candidate dataset:
acm4GTN/imdb4GTN/dblp4GTN
Performance
Node classification
Node classification(F1 score) |
acm4GTN |
imdb4GTN |
dblp4GTN |
paper[GTN] |
92.68 |
60.92 |
94.18 |
OpenHGNN[GTN] |
Macro: 92.03 Micro: 92.00 |
Macro: 56.97 Micro: 58.61 |
87.33(OOM on Tesla T4(16GB), cpu result) |
OpenHGNN[fastGTN] |
Macro: 92.92 Micro: 92.85 |
Macro: 60.62 Micro: 62.59 |
Macro: 90.39 Micro: 91.39 |
TrainerFlow: node_classification
The model is trained in semi-supervisied node classification.
model
- transform_relation_graph_list
- Extract a graph list where every graph just contains a relation.
- GTLayer
- Contain GTConv
- Contain the product of the adjacency matrices of two graphs getting from GTConv.
- GTConv
- Create a weighted graph whose adjacency matrix is the sum of the adjacency matrices of the given graph list.
Dataset
Supported dataset: acm4GTN, imdb4GTN, dblp4ACM
Note: Every node in dataset should have the same features dimension.
Dataset Description
We process the acm dataset given by HAN. It saved as dgl.heterograph and can be loaded by dgl.load_graphs
You can download the dataset by
wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/acm4GTN.zip
wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/imdb4GTN.zip
Or run the code mentioned above and it will download automaticlly.
Hyper-parameter specific to the model
num_channels = 2 # number of channel
num_layers = 3 # number of layer
adaptive_lr_flag = True # use different learning rate for weight in GTLayer.
Best config can be found in best_config
Related API in DGL
dgl.adj_product_graph which is equivalent SpSpMM.
dgl.adj_sum_graph
GraphConv, EdgeWeightNorm
More
Contirbutor
Tianyu Zhao[GAMMA LAB]
If you have any questions,
Submit an issue or email to tyzhao@bupt.edu.cn.