Use GraphStorm CLIs for Multi-task Learning

This notebook demonstrates how to use GraphStorm Command Line Interfaces (CLIs) to run multi-task GNN model training and inference. By playing with this nodebook, users will be able to get familiar with GraphStom CLIs, hence furhter using them on their own tasks and models.

In this notebook, we will train a RGCN model on the ACM dataset with two training supervisions, i.e., link prediction and node feature reconstruction.

Note: For more details about multi-task learning please refer to Multi-task Learning in GraphStorm

0. Setup environment

First let’s install GraphStorm and its dependencies, PyTorch and DGL.

[1]:
!pip install scikit-learn==1.4.2
!pip install scipy==1.13.0
!pip install pandas==1.3.5
!pip install pyarrow==14.0.0
!pip install graphstorm
!pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install dgl==1.1.3 -f https://data.dgl.ai/wheels-internal/repo.html

1. Create the example ACM graph data

This notebook uses the ACM graph as an example. We use the following script to create the ACM graph data.

[2]:
!mkdir example
!wget -O /example/acm_data.py https://github.com/awslabs/graphstorm/raw/main/examples/acm_data.py
!python /example/acm_data.py --output-path /example/acm_raw

The ACM graph data includes node files and edge files. It also includes a JSON configuration file describing how to construct a graph for model training. More details can be found in Use Your Own Data (ACM data example).

[3]:
!ls -al /example/acm_raw/

2. Construct and Partition ACM Graph

Since GraphStorm is designed naturally for distributed GNN training, we need to construct a graph and split it into multiple partitions. In this example, for simplicity, we create a graph with one partition (no actual splitting).

[4]:
!python -m graphstorm.gconstruct.construct_graph \
           --conf-file /example/acm_raw/config.json \
           --output-dir /example/acm_gs \
           --num-parts 1 \
           --graph-name acm

The generated ACM graph contains all the information required for GNN model training. For more details of preparing data for multi-task learning, please refer to Preparing multi-task learning data.

3. GNN Model Training

Once the graph constucted, we can call the GraphStorm multi-task learning CLI to run model training. Before kicking off the model training, we need to create a YAML configuration file for the CLI.

[5]:
!wget -O /example/acm_mt.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_mt.yaml
[6]:
!cat /example/acm_mt.yaml

The YAML configuration file defines two training tasks:

  • A link prediction task on the <paper, citing, paper> edges. The task specific settings are under thegsf::multi_task_learning::link_prediction configuration block.

  • A node feature reconstruction task on the paper nodes with the node feature label to be reconstructed. The task specific settings are under thegsf::multi_task_learning::reconstruct_node_feat configuration block.

For more details of multi-task YAML configuration, please refer to Define Multi-task for training.

Launch the training job

[7]:
!python -m graphstorm.run.gs_multi_task_learning \
           --workspace /example \
           --part-config /example/acm_gs/acm.json \
           --num-trainers 1 \
           --cf /example/acm_mt.yaml \
           --num-epochs 4

The saved model is under /example/acm_lp/models/.

[8]:
!ls -a /example/acm_lp/models/

4. GNN Model Inference

Once the model is trained, we can do model inference with the trained model artifacts by using the GraphStorm multi-task learning CLI. We can use the same YAML configuration file for model inference.

Launch the inference job

We load the model checkpoint of epoch-2 in the example to do inference. The inference command will report the test scores for both link prediction task and node feature reconstruction task.

[9]:
!python -m graphstorm.run.gs_multi_task_learning \
           --inference \
           --workspace /example \
           --part-config /example/acm_gs/acm.json \
           --restore-model-path /example/acm_lp/models/epoch-2 \
           --num-trainers 1 \
           --cf /example/acm_mt.yaml

Launch the embedding generation inference job

You can also use the GraphStorm gs_gen_node_embedding CLI to generate node embeddings with the trained GNN model on the ACM graph. The saved node embeddings are under /example/acm_lp/emb/.

[10]:
!python -m graphstorm.run.gs_gen_node_embedding \
           --inference \
           --workspace /example \
           --part-config /example/acm_gs/acm.json \
           --restore-model-path /example/acm_lp/models/epoch-2 \
           --save-embed-path /example/acm_lp/emb/ \
           --restore-model-layers "embed,gnn" \
           --num-trainers 1 \
           --cf /example/acm_mt.yaml
[11]:
!ls -al /example/acm_lp/emb/