Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
chenzomi 24c9001cf8 | 1 year ago | |
---|---|---|
.. | ||
graph_to_mindrecord | 1 year ago | |
src | 1 year ago | |
README.md | 1 year ago | |
main.py | 1 year ago |
图神经网络(Graph Neural Network, GNN)把深度学习应用到图结构(Graph)中,其中的图卷积网络(Graph Convolutional Network,GCN)可以在Graph上进行卷积操作。但是GCN存在一些缺陷:依赖拉普拉斯矩阵,不能直接用于有向图;模型训练依赖于整个图结构,不能用于动态图;卷积的时候没办法为邻居节点分配不同的权重。
图注意力网络(Graph Attention Networks)由Petar Veličković等人于2018年提出。GAT采用了Attention机制,可以为不同节点分配不同权重,训练时依赖于成对的相邻节点,而不依赖具体的网络结构,可以用于inductive任务。
[1] https://baijiahao.baidu.com/s?id=1671028964544884749
本实验主要介绍在Cora和Citeseer数据集上使用MindSpore进行图注意力网络的训练和验证。
Cora和CiteSeer是图神经网络常用的数据集,数据集官网LINQS Datasets。
Cora数据集包含2708个科学出版物,分为七个类别。 引用网络由5429个链接组成。 数据集中的每个出版物都用一个0/1值的词向量描述,0/1指示词向量中是否出现字典中相应的词。 该词典包含1433个独特的单词。 数据集中的README文件提供了更多详细信息。
CiteSeer数据集包含3312种科学出版物,分为六类。 引用网络由4732个链接组成。 数据集中的每个出版物都用一个0/1值的词向量描述,0/1指示词向量中是否出现字典中相应的词。 该词典包含3703个独特的单词。 数据集中的README文件提供了更多详细信息。
本实验使用Github上kimiyoung/planetoid预处理和划分好的数据集。
将数据集放置到所需的路径下,该文件夹应包含以下文件:
data
├── ind.cora.allx
├── ind.cora.ally
├── ...
├── ind.cora.test.index
├── trans.citeseer.tx
├── trans.citeseer.ty
├── ...
└── trans.pubmed.y
inductive模型的输入包含:
x
,已标记的训练实例的特征向量,y
,已标记的训练实例的one-hot标签,allx
,标记的和未标记的训练实例(x
的超集)的特征向量,graph
,一个dict
,格式为{index: [index_of_neighbor_nodes]}.
令n为标记和未标记训练实例的数量。在graph
中这n个实例的索引应从0到n-1,其顺序与allx
中的顺序相同。
除了x
,y
,allx
,和graph
如上所述,预处理的数据集还包括:
tx
,测试实例的特征向量,ty
,测试实例的one-hot标签,test.index
,graph
中测试实例的索引,ally
,是allx
中实例的标签。从课程gitee仓库上下载本实验相关脚本。将脚本和数据集组织为如下形式:
gat
├── data
├── graph_to_mindrecord
│ ├── citeseer
│ ├── cora
│ ├── graph_map_schema.py
│ └── writer.py
├── src
│ ├── utils.py
│ ├── gat.py
│ ├── dataset.py
│ └── config.py
│── main.py
└── README.md
本实验需要使用华为云OBS存储实验脚本和数据集,可以参考快速通过OBS控制台上传下载文件了解使用OBS创建桶、上传文件、下载文件的使用方法(下文给出了操作步骤)。
提示: 华为云新用户使用OBS时通常需要创建和配置“访问密钥”,可以在使用OBS时根据提示完成创建和配置。也可以参考获取访问密钥并完成ModelArts全局配置获取并配置访问密钥。
打开OBS控制台,点击右上角的“创建桶”按钮进入桶配置页面,创建OBS桶的参考配置如下:
点击新建的OBS桶名,再打开“对象”标签页,通过“上传对象”、“新建文件夹”等功能,将脚本和数据集上传到OBS桶中。上传文件后,查看页面底部的“任务管理”状态栏(正在运行、已完成、失败),确保文件均上传完成。若失败请:
ModelArts提供了训练作业服务,训练作业资源池大,且具有作业排队等功能,适合大规模并发使用。使用训练作业时,如果有修改代码和调试的需求,有如下三个方案:
创建训练作业时,运行参数会通过脚本传参的方式输入给脚本代码,脚本必须解析传参才能在代码中使用相应参数。如data_url和train_url,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。脚本对传参进行解析后赋值到args
变量里,在后续代码里可以使用。
import argparse
parser = argparse.ArgumentParser(description='GAT')
parser.add_argument('--data_url', required=True, help='Location of data.')
parser.add_argument('--train_url', required=True, help='Location of training outputs.')
args_opt = parser.parse_args()
MindSpore暂时没有提供直接访问OBS数据的接口,需要通过ModelArts自带的moxing框架与OBS交互。拷贝自己账户下或他人共享的OBS桶内的数据集至执行容器。
import moxing as mox
# src_url形如's3://OBS/PATH',为OBS桶中数据集的路径,dst_url为执行容器中的路径
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url='./data')
可以参考使用常用框架训练模型来创建并启动训练作业(下文给出了操作步骤)。
打开ModelArts控制台-训练管理-训练作业,点击“创建”按钮进入训练作业配置页面,创建训练作业的参考配置:
main.py
启动并查看训练过程:
推荐使用ModelArts训练作业进行实验,适合大规模并发使用。若使用ModelArts Notebook,请参考LeNet5及Checkpoint实验案例,了解Notebook的使用方法和注意事项。
导入MindSpore模块和辅助模块,设置MindSpore上下文,如执行模式、设备等。
import os
import argparse
import numpy as np
from easydict import EasyDict as edict
from mindspore import context
from src.gat import GAT
from src.config import GatConfig
from src.dataset import load_and_process
from src.utils import LossAccuracyWrapper, TrainGAT
from graph_to_mindrecord.writer import run
from mindspore.train.serialization import load_checkpoint, save_checkpoint
context.set_context(mode=context.GRAPH_MODE,device_target="Ascend", save_graphs=False)
训练参数可以在config.py中设置。
"learning_rate": 0.005, # Learning rate
"num_epochs": 200, # Epoch sizes for training
"hid_units": [8], # Hidden units for attention head at each layer
"n_heads": [8, 1], # Num heads for each layer
"early_stopping": 100, # Early stop patience
"l2_coeff": 0.0005 # l2 coefficient
"attn_dropout": 0.6 # Attention dropout ratio
"feature_dropout":0.6 # Feature dropout ratio
def train(args_opt):
"""Train GAT model."""
if not os.path.exists("ckpts"):
os.mkdir("ckpts")
# train parameters
hid_units = GatConfig.hid_units
n_heads = GatConfig.n_heads
early_stopping = GatConfig.early_stopping
lr = GatConfig.lr
l2_coeff = GatConfig.l2_coeff
num_epochs = GatConfig.num_epochs
feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(args_opt.data_dir,
args_opt.train_nodes_num,
args_opt.eval_nodes_num,
args_opt.test_nodes_num)
feature_size = feature.shape[2]
num_nodes = feature.shape[1]
num_class = y_train.shape[2]
gat_net = GAT(feature,
biases,
feature_size,
num_class,
num_nodes,
hid_units,
n_heads,
attn_drop=GatConfig.attn_dropout,
ftr_drop=GatConfig.feature_dropout)
gat_net.add_flags_recursive(fp16=True)
eval_net = LossAccuracyWrapper(gat_net,
num_class,
y_val,
eval_mask,
l2_coeff)
train_net = TrainGAT(gat_net,
num_class,
y_train,
train_mask,
lr,
l2_coeff)
train_net.set_train(True)
val_acc_max = 0.0
val_loss_min = np.inf
for _epoch in range(num_epochs):
train_result = train_net()
train_loss = train_result[0].asnumpy()
train_acc = train_result[1].asnumpy()
eval_result = eval_net()
eval_loss = eval_result[0].asnumpy()
eval_acc = eval_result[1].asnumpy()
print("Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}".format(
_epoch, train_loss, train_acc, eval_loss, eval_acc))
if eval_acc >= val_acc_max or eval_loss < val_loss_min:
if eval_acc >= val_acc_max and eval_loss < val_loss_min:
val_acc_model = eval_acc
val_loss_model = eval_loss
if os.path.exists('ckpts/gat.ckpt'):
os.remove('ckpts/gat.ckpt')
save_checkpoint(train_net.network, "ckpts/gat.ckpt")
val_acc_max = np.max((val_acc_max, eval_acc))
val_loss_min = np.min((val_loss_min, eval_loss))
curr_step = 0
else:
curr_step += 1
if curr_step == early_stopping:
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
break
gat_net_test = GAT(feature,
biases,
feature_size,
num_class,
num_nodes,
hid_units,
n_heads,
attn_drop=0.0,
ftr_drop=0.0)
load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
gat_net_test.add_flags_recursive(fp16=True)
test_net = LossAccuracyWrapper(gat_net_test,
num_class,
y_test,
test_mask,
l2_coeff)
test_result = test_net()
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
使用不同的数据集训练操作起来非常方便,只需要将参数dataname
修改为需要训练的数据集名称即可。
#------------------------定义变量------------------------------
dataname = 'cora'
datadir_save = './data_mr'
datadir = os.path.join(datadir_save, dataname)
cfg = edict({
'SRC_PATH': './data',
'MINDRECORD_PATH': datadir_save,
'DATASET_NAME': dataname, # citeseer,cora
'mindrecord_partitions':1,
'mindrecord_header_size_by_bit' : 18,
'mindrecord_page_size_by_bit' : 20,
'data_dir': datadir,
'seed' : 123,
'train_nodes_num':140,
'eval_nodes_num':500,
'test_nodes_num':1000
})
# 转换数据格式
print("============== Graph To Mindrecord ==============")
run(cfg)
#训练
print("============== Starting Training ==============")
train(cfg)
训练结果将打印如下结果:
============== Starting Training ==============
Epoch:0, train loss=1.98498 train acc=0.17143 | val loss=1.97946 val acc=0.27200
Epoch:1, train loss=1.98345 train acc=0.15000 | val loss=1.97233 val acc=0.32600
Epoch:2, train loss=1.96968 train acc=0.21429 | val loss=1.96747 val acc=0.37400
Epoch:3, train loss=1.97061 train acc=0.20714 | val loss=1.96410 val acc=0.47600
Epoch:4, train loss=1.96864 train acc=0.13571 | val loss=1.96066 val acc=0.59600
...
Epoch:195, train loss=1.45111 train_acc=0.56429 | val_loss=1.44325 val_acc=0.81200
Epoch:196, train loss=1.52476 train_acc=0.52143 | val_loss=1.43871 val_acc=0.81200
Epoch:197, train loss=1.35807 train_acc=0.62857 | val_loss=1.43364 val_acc=0.81400
Epoch:198, train loss=1.47566 train_acc=0.51429 | val_loss=1.42948 val_acc=0.81000
Epoch:199, train loss=1.56411 train_acc=0.55000 | val_loss=1.42632 val_acc=0.80600
Test loss=1.5366285, test acc=0.84199995
下表显示了Cora数据集上的结果:
MindSpore + Ascend910 | Tensorflow + V100 | |
---|---|---|
精度 | 0.830933271 | 0.828649968 |
训练耗时(200 epochs) | 27.62298311 s | 36.711862 s |
端到端训练耗时(200 epochs) | 39.074 s | 50.894 s |
MindSpore实验,仅用于教学或培训目的。配合MindSpore官网使用。 MindSpore experiments, for teaching or training purposes only. Use it together with the MindSpore official website.
CSV Jupyter Notebook Text Python Markdown other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》