Notebook 6: Running Custom Model with GraphStorm CLIs
Notebook 1 to 5 provides examples about how to use GraphStorm APIs to implement various GNN components and models. These notebooks can run in the GraphStrom Standalone mode, i.e., on a single CPU or GPU of a single machine. To fully leverage GraphStorm’s distributed model training and inference capability, however, we need to convert code implemented on these notebook into Python scripts that can be launched with GraphStorm Command Line Interfaces (CLIs).
This notebook introduces the method of conversion, and explain the key components of the example Python scripts. For this notebook, we use the custom model developed in the Notebook 4: Use GraphStorm APIs for Customizing Model Components as an example.
Prerequisites
GraphStorm. Please find more details on installation of GraphStorm.
ACM data that has been created according to Notebook 0: Data Preparation, and is stored in the
./acm_gs_1p/folder.
Brief Introduction and Run CLIs on a Single Machine
In order to use GraphStorm CLIs, we need to put the custom model into a Python file, which can be called in the Task-agnostic CLI for model training and inference as an argument. We build two files for model training and inference separately.
We can reuse most of the code about the customized RGAT module in Notebook 4, , i.e., Ara_GatLayer, Ara_GatEncoder, and RgatNCModel, in the training and inference files.
For the training file, we can copy and paste the code of the 4.1 Training pipeline section in Notebook 4, and enclose them in a fit() function. Similarly, for the inference file, we can copy and paste the code of the 4.3 Inference pipeline section in Notebook 4, and enclose them in a infer() function.
We have provided the two files, named demo_run_train.py and demo_run_infer.py under the GraphStorm API documentation folder. With the two files, we can call GraphStorm’s task-agnostic CLI to run our custom model as shown below.
[ ]:
# download the example yaml configuration file
!wget -O acm_nc.yaml https://github.com/awslabs/graphstorm/raw/main/examples/use_your_own_data/acm_nc.yaml
# CLI for the custom RGAT model training
!python -m graphstorm.run.launch \
--part-config ./acm_gs_1p/acm.json \
--num-trainers 4 \
--num-servers 1 \
--num-samplers 0 \
demo_run_train.py --cf acm_nc.yaml \
--save-model-path models/ \
--node-feat-name paper:feat author:feat subject:feat \
--num-epochs 5 \
--rgat-encoder-type ara
# CLI for the custom RGAT model inference
!python -m graphstorm.run.launch \
--part-config ./acm_gs_1p/acm.json \
--num-trainers 4 \
--num-servers 1 \
--num-samplers 0 \
demo_run_infer.py --cf acm_nc.yaml \
--restore-model-path models/epoch-4 \
--save-prediction-path predictions/ \
--save-embed-path embeddings/ \
--node-feat-name paper:feat author:feat subject:feat \
--rgat-encoder-type ara
CLI argument processing explanation
Compared to the code in Notebook 4, the majority of modifications in the two Python files is related to how to collect and parse GraphStorm CLI configurations. Unlike hard-coding some variables, e.g., nfeats_4_modeling, or setting fix input values, e.g., label_field='label', or encoder_type='ara', we will need to provide these values via CLI configurations.
As shown in the above commands, there are three types of configurations passed to the GraphStorm task-agnostic command.
Launch CLI arguments, which direclty follow the
graphstom.run.launch.Model training and inference configurations, which are predefined in GraphStorm. These configurations can be put into a yaml file which will be the value of
--cfargument following the training or inference Python file name. You can also set them as arguments too, which will overwrite the same configurations set in the yaml file.Configurations specified for custom modules, which are not predefined in GraphStorm, but are used only for the custom modules should be defined as input arguments of training or inference Python files.
Below we show the main entrance function of the demo_run_train.py file.
[3]:
import argparse
from graphstorm.config import get_argument_parser
......
if __name__ == '__main__':
# Leverage GraphStorm's argument parser to accept configuratioin yaml file
arg_parser = get_argument_parser()
# parse all arguments and split GraphStorm's built-in arguments from the custom ones
gs_args, unknown_args = arg_parser.parse_known_args()
print(f'GS arguments: {gs_args}')
# create a new argument parser dedicated for custom arguments
cust_parser = argparse.ArgumentParser(description="Customized Arguments")
# add custom arguments
cust_parser.add_argument('--rgat-encoder-type', type=str, default="ara")
cust_args = cust_parser.parse_args(unknown_args)
print(f'Customized arguments: {cust_args}')
# use both argument sets in our main function
fit(gs_args, cust_args)
GraphStorm’s config module provides a get_argument_parser method, which can create a argument parser, e.g., arg_parser, dedicated to process GraphStorm launch CLI arguments and model training and inference configurations. Using the parse_known_args() method, the argument parser can extract all GraphStorm built-in configurations, and also return custom arguments, which can be processed by another argument parse, e.g., the cust_parser. We can then pass these arguments to the
corresponding methods. Please refer to get_argument_parser API document for more details about this method.
GraphStorm GSConfig object explanation
Once obtained these arguments, we can use them to create a GSConfig object and then pass the object to different modules to get related configurations. The GSConfig object checks every argument’s format and value to ensure compliance with GraphStorm specifications. Below cells show the code of creating the GSConfig object and examples of how to use it to pass configurations. For example, we can pass the IP list file, GraphStorm backend, and the local rank configurations to GraphStorm
distributed context initialization function, i.e., gs.initialize(), to start GraphStorm distributed context.
For more details of GSConfig, please refer to the GSConfig API documentation page .
[20]:
# in demo_run_train.py file
from graphstorm.config import GSConfig
......
def fit(gs_args, cust_args):
# Utilize GraphStorm's GSConfig class to accept arguments
config = GSConfig(gs_args)
# Initialize distributed training and inference context
gs.initialize(ip_config=config.ip_config, backend=config.backend, local_rank=config.local_rank)
acm_data = gs.dataloading.GSgnnData(part_config=config.part_config)
......
model = RgatNCModel(g=acm_data.g,
num_heads=config.num_heads,
num_hid_layers=config.num_layers,
node_feat_field=config.node_feat_name,
hid_size=config.hidden_size,
num_classes=config.num_classes,
encoder_type=cust_args.rgat_encoder_type) # here use the custom argument instead of GSConfig
......
[ ]:
# in demo_run_infer.py file
from graphstorm.config import GSConfig
......
def infer(gs_args, cust_args):
# Utilize GraphStorm's GSConfig class to accept arguments
config = GSConfig(gs_args)
......
model = RgatNCModel(g=acm_data.g,
num_heads=config.num_heads,
num_hid_layers=config.num_layers,
node_feat_field=config.node_feat_name,
hid_size=config.hidden_size,
num_classes=config.num_classes,
encoder_type=cust_args.rgat_encoder_type) # here use the custom argument instead of GSConfig
model.restore_model(config.restore_model_path)
......
Run CLIs on a Distributed Cluster
It is easy to modify the command in the above cell to run them on a Distributed clusters. We need conduct three additional operations:
As demonstrated in User Your Own Data tutorial, partition the ACM data in multiple partitions, e.g., 2 partitions by setting the argument
--num-parts 2, and record its JSON file path, e.g.,./acm_gs_2p/acm.json.Follow the tutorial of creating a GraphStorm cluster to prepare a cluster with 2 machines.
Prepare an IP list file, e.g.,
ip_list.txton the cluster, and record its file path, e.g.,./ip_list.txt.
Then we just add two addition CLI launch arguments, and run the CLI below on the clusters within a running docker container.
[ ]:
# CLI for the custom RGAT model training
!python -m graphstorm.run.launch \
--part-config ./acm_gs_2p/acm.json \
--num-trainers 4 \
--num-servers 1 \
--num-samplers 0 \
--ip-config ./ip_list.txt \
--ssh-port 2222 \
demo_run_train.py --cf acm_nc.yaml \
--save-model-path models/ \
--node-feat-name paper:feat author:feat subject:feat \
--num-epochs 5 \
--rgat-encoder-type ara
# CLI for the custom RGAT model inference
!python -m graphstorm.run.launch \
--part-config ./acm_gs_2p/acm.json \
--num-trainers 4 \
--num-servers 1 \
--num-samplers 0 \
--ip-config ./ip_list.txt \
--ssh-port 2222 \
demo_run_infer.py --cf acm_nc.yaml \
--restore-model-path models/epoch-4 \
--save-prediction-path predictions/ \
--save-embed-path embeddings/ \
--node-feat-name paper:feat author:feat subject:feat \
--rgat-encoder-type ara
Run CLIs on an Amazon SageMaker Cluster
In order to run the custom models on an Amazon SageMaker cluster, we need to conduct four steps:
Partition the ACM data in multiple partitions, e.g., 2 partition, and upload them to an Amazon S3 location, e.g.,
s3://<PATH_TO_DATA>/acm_gs_2p.Upload the configuration yaml file to an Amazon S3 location, e.g.,
s3://<PATH_TO_TRAINING_CONFIG>/acm_nc.yaml.Git clone GraphStorm source code, and move the
demo_run_train.pyanddemo_run_infer.pyfiles from thegraphstorm/docs/source/api/notebooks/folder to thegraphstorm/python/graphstorm/folder.Follow the Setup GraphStorm SageMaker Docker Image tutorial to create a docker image.
Then use the following SageMaker CLIs to run custom model on an Amazon SageMaker cluster. Please refer to the GraphStorm Model Training and Inference on on SageMaker for more details.
[ ]:
# SageMaker CLIs should be run under the graphstorm/sagemaker folder
!cd /<path-to-graphstorm>/sagemaker/
# SageMaker CLI for the customized RGAT model training
!python launch/launch_train.py \
--image-url <AMAZON_ECR_IMAGE_URI> \
--region <REGION> \
--entry-point run/train_entry.py \
--role <ROLE_ARN> \
--instance-count 2 \
--graph-data-s3 s3://<PATH_TO_DATA>/acm_gs_2p \
--yaml-s3 s3://<PATH_TO_TRAINING_CONFIG>/acm_nc.yaml \
--model-artifact-s3 s3://<PATH_TO_SAVE_TRAINED_MODEL> \
--graph-name acm \
--task-type node_classification \
--custom-script graphstorm/python/graphstorm/demo_run_train.py \
--node-feat-name paper:feat author:feat subject:feat \
--num-epochs 5 \
--rgat-encoder-type ara
# SageMaker CLI for the customized RGAT model inference
!python launch/launch_infer.py \
--image-url <AMAZON_ECR_IMAGE_URI> \
--region <REGION> \
--entry-point run/infer_entry.py \
--role <ROLE_ARN> \
--instance-count 2 \
--graph-data-s3 s3://<PATH_TO_DATA>/acm_gs_2p \
--yaml-s3 s3://<PATH_TO_TRAINING_CONFIG>/acm_nc.yaml \
--model-artifact-s3 s3://<PATH_TO_SAVE_BEST_TRAINED_MODEL> \
--raw-node-mappings-s3 s3://<PATH_TO_DATA>/acm_gs_2p/raw_id_mappings \
--output-emb-s3 s3://<PATH_TO_SAVE_GENERATED_NODE_EMBEDDING>/ \
--output-prediction-s3 s3://<PATH_TO_SAVE_PREDICTION_RESULTS> \
--graph-name acm \
--task-type node_classification \
--custom-script graphstorm/python/graphstorm/demo_run_infer.py \
--node-feat-name paper:feat author:feat subject:feat \
--rgat-encoder-type ara