Skip to content

快速开始

Siran Yang edited this page Apr 9, 2019 · 7 revisions

本节将详细介绍如何使用Euler和TensorFlow进行GraphSage模型训练。GraphSage是由Stanford提出的一种Inductive的图学习方法,具有GCN模型的良好性质,同时在实际使用中可以扩展到十亿顶点的大规模图。

快速开始

1. 数据准备

首先需要准备Euler引擎可以读取的图数据,这里我们以PPI(Protein-Protein Interactions)数据集作为例子,提供一个预处理脚本:

apt-get update && apt-get install -y curl
curl -k -O https://raw.githubusercontent.com/alibaba/euler/master/examples/ppi_data.py
pip install networkx==1.11 sklearn
python ppi_data.py

上面的命令会在当前目录下生成一个ppi目录,其中包含构建好的PPI图数据。

2. 模型训练

在训练集上训练一个半监督的GraphSage模型:

python -m tf_euler \
  --data_dir ppi \
  --max_id 56944 --feature_idx 1 --feature_dim 50 --label_idx 0 --label_dim 121 \
  --model graphsage_supervised --mode train

上面的命令会在当前目录下生成一个ckpt目录,其中包含训练好的TensorFlow模型。

3. 模型评估

在测试集上评估模型的效果:

python -m tf_euler \
  --data_dir ppi --id_file ppi/ppi_test.id \
  --max_id 56944 --feature_idx 1 --feature_dim 50 --label_idx 0 --label_dim 121 \
  --model graphsage_supervised --mode evaluate

使用Euler算法包默认参数训练得到的模型在测试集上的mirco-F1 score大概在0.6左右。

4. embedding输出

导出顶点的embedding:

python -m tf_euler \
  --data_dir ppi \
  --max_id 56944 --feature_idx 1 --feature_dim 50 --label_idx 0 --label_dim 121 \
  --model graphsage_supervised --mode save_embedding

上面的命令会在当下目录下的ckpt目录中生成一个embedding.npy文件和一个id.txt文件,分别表示图中所有顶点的embedding和id。

5. 导入embedding到Faiss中进行检索(可选)

Euler所生成的embedding可以根据用户的实际需要在后续流程中使用。这里给出一个利用Faiss进行相似性检索的例子:

import faiss
import numpy as np

embedding = np.load('ckpt/embedding.npy')
index = faiss.IndexFlatIP(256)
index.add(embedding)
print(index.search(embedding[:5], 4))

分布式训练

Euler算法包及其底层的图查询引擎支持分布式的模型训练,用户需要在原始的训练命令上加入四个参数--ps_hosts--worker_hosts--job_name--task_index指定分布式角色。使用分布式训练时,数据需要放置在HDFS上。这里提供一个示例脚本在本机的1998至2001端口上启动两个ps和两个worker进行分布式训练。

bash tf_euler/scripts/dist_tf_euler.sh \
  --data_dir hdfs://host:port/data \
  --euler_zk_addr zk.host.com:port --euler_zk_path /path/for/euler \
  --max_id 56944 --feature_idx 1 --feature_dim 50 --label_idx 0 --label_dim 121 \
  --model graphsage_supervised --mode train

上面的命令会在/tmp/log.{woker,ps}.{0,1}文件中打印日志。

Clone this wiki locally