{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models\n", "\n", "This notebook demonstrates how to use GraphStorm APIs to implement GraphStorm built-in GNN models such as RGAT and HGT, for different tasks.\n", "\n", "In this notebook, we use different ``GSgnnEncoder`` modules, and set the corresponding arguments in a GNN model, hence reproducing several GraphStorm built-in GNN models, such as `RGAT`, and `HGT`. Using the same pipelines demonstratred in the **[Notebook 1: Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html)** and **[Notebook 2: Link Prediction Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_2_LP_Pipeline.html)**, users can easily conduct node classification and link prediction task on the ACM dataset created by the **[Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html)**. \n", "\n", "### Prerequisites\n", "\n", "- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).\n", "- ACM data that has been created according to the [Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.\n", "- Installation of supporting libraries, e.g., matplotlib." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "## 1. Revisit an `RGCN` model in the `demo_models.py`\n", "\n", "The Notebook 1 and Notebook 2 both use `RGCN` models that share the same GNN model architecture defined by GraphStorm. To modify a GraphStorm GNN model, let's first revisit an RGCN model in the `demo_models.py` file. For simplicity, some document strings are removed, and code are restructured to fit in notebook cells." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "tags": [] }, "outputs": [], "source": [ "import graphstorm as gs\n", "from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGCNEncoder, EntityClassifier, ClassifyLossFunc\n", "\n", "class RgcnNCModel(GSgnnNodeModel):\n", " \"\"\" A simple RGCN model for node classification using Graphstorm APIs\n", " \"\"\"\n", " def __init__(self, g, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):\n", " super(RgcnNCModel, self).__init__(alpha_l2norm=0.)\n", "\n", " # extract node feature dimensions\n", " feat_size = gs.get_node_feat_size(g, node_feat_field)\n", "\n", " # set an input encoder\n", " encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)\n", " self.set_node_input_encoder(encoder)\n", "\n", " # set a GNN encoder\n", " gnn_encoder = RelationalGCNEncoder(g=g, h_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)\n", " self.set_gnn_encoder(gnn_encoder)\n", "\n", " # set a decoder specific to node classification task\n", " decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)\n", " self.set_decoder(decoder)\n", "\n", " # classification loss function\n", " self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))\n", "\n", " # initialize model's optimizer\n", " self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 GraphStorm built-in model architecture\n", "\n", "A GraphStorm built-in model normally contains four modules:\n", "\n", "- An input encoder that converts input node features to the embeddings with hidden dimensions.\n", "- A GNN encoder that takes the embeddings from the input layer and performs message passing computation.\n", "- A decoder that is task sepcific, e.g., the `EntityClassifier` for classification tasks.\n", "- A loss function that matches specific tasks, e.g., the `ClassifyLossFunc`.\n", "\n", "Besides the four modules, a GraphStorm GNN model also need to initialize its own optimizer object." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 Model arguments\n", "\n", "Each specific GNN model may has its own model arguments. Some arguments could be common for other models, like the input and output dimensions, while others may be model specific. For example, `RGCN` model asks for the number of bases to reduce the number of learnable parameters, and attention-based models may need to set the number of attention heads. Not only GNN models ask for arguments, GML tasks need specific arguments. For example, classification tasks may have multiple labels.\n", "\n", "GraphStorm APIs have given default values to many arguments. For better flexibility, we can add some arguments into model initialization, such as `num_hid_layers` and `hid_size`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 GML task modules\n", "\n", "Besides model-related modules, a GNN model also contains task-specific modules, including task specific decoders and loss functions. For example, to perform a node classification task, the above `RgcnNCModel` model chooses the `EntityClassifier` as its decoder and use the `ClassifyLossFunc` as its loss function." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----\n", "## 2 Reproduce GraphStorm Built-in GNN Model Variants\n", "\n", "With knowing the common architecture and arguments, it is easy to reproduce GraphStorm built-in GNN model variants.\n", "\n", "### 2.1 Reproduce an `RGAT` Model for Node Classification\n", "\n", "To turn the demo `RgcnNCModel` code into an `RgatNCModel` model, only need two modifications:\n", "\n", "1. For the GNN encoder, replace the `RelationalGCNEncoder` with the `RelationalGATEncoder`.\n", "2. Add some `RelationalGATEncoder` specific arguments in initialization.\n", "\n", "Below is the simplified code of the `RgatNCModel` model. The complete code can be found in the `demo_models.py` file." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "from graphstorm.model import RelationalGATEncoder\n", "\n", "class RgatNCModel(GSgnnNodeModel):\n", " \"\"\" A simple Rgat model for node classification using Graphstorm APIs\n", " \"\"\"\n", " def __init__(self, g,\n", " num_heads, # an argument specific to RelationalGATEncoder\n", " num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):\n", " super(RgatNCModel, self).__init__(alpha_l2norm=0.)\n", "\n", " # input encoder remains the same ......\n", "\n", " # set a GNN encoder\n", " gnn_encoder = RelationalGATEncoder(g=g, h_dim=hid_size, out_dim=hid_size,\n", " num_heads=num_heads, # pass the num_heads to the RelationalGATEncoder\n", " num_hidden_layers=num_hid_layers-1)\n", " self.set_gnn_encoder(gnn_encoder)\n", "\n", " # decoder, loss function, and optimizer initialization remain the same ......" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Reproduce an `HGT` Model with `DistMult` Decoder for Link Prediction\n", "\n", "Similar as the `RGAT` variant, replacement of the `RelationalGCNEncoder` with the `HGTEncoder` and setting up corresponding arguments can reproduce an `HGT` model. In addition, this example also replaces the `LinkPredictDotDecoder` decoder with the `LinkPredictDistMultDecoder`, and sets its own arguments. Below is the simplified code of the `HgtLPModel` model. The complete code can be found in the `demo_models.py` file." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "from graphstorm.model import GSgnnLinkPredictionModel, HGTEncoder, LinkPredictDistMultDecoder\n", "\n", "class HgtLPModel(GSgnnLinkPredictionModel):\n", " \"\"\" A simple HGT model for link prediction using Graphstorm APIs\n", " \"\"\"\n", " def __init__(self, g,\n", " num_heads, # an argument specific to HGTEncoder\n", " num_hid_layers, node_feat_field, hid_size):\n", " super(HgtLPModel, self).__init__(alpha_l2norm=0.)\n", "\n", " # input encoder remains the same ......\n", "\n", " # set a GNN encoder\n", " gnn_encoder = HGTEncoder(g=g,\n", " num_heads=num_heads, # pass the num_heads to the HGTEncoder\n", " hid_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)\n", " self.set_gnn_encoder(gnn_encoder)\n", "\n", " # set a decoder specific to link prediction task\n", " decoder = LinkPredictDistMultDecoder(etypes=g.canonical_etypes, # specificly added to the LinkPredictDistMultDecoder\n", " h_dim=hid_size)\n", " self.set_decoder(decoder)\n", "\n", " # loss function, and optimizer initialization remain the same ......" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----\n", "\n", "## 3. Link Prediciton Pipeline by Using the `HGT` Model \n", "\n", "To use the above mentioned GNN model variant, the overall GML pipeline only needs very few modifications that adapt to model specific arugments. Below example reuses the link prediction pipeline of the **Notebook 2**. For simplisity, this example combines multiple cells, and comments.\n", "\n", "### 3.1 Training pipeline" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [] }, "outputs": [], "source": [ "import logging\n", "import graphstorm as gs\n", "\n", "logging.basicConfig(level=20)\n", "gs.initialize()\n", "\n", "nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}\n", "\n", "acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)\n", "\n", "train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),\n", " fanout=[20, 20],\n", " num_negative_edges=10,\n", " node_feats=nfeats_4_modeling,\n", " batch_size=64,\n", " exclude_training_targets=True,\n", " reverse_edge_types_map={(\"paper\", \"citing\", \"paper\"):(\"paper\",\"cited\",\"paper\")},\n", " train_task=True)\n", "val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),\n", " fanout=[100, 100],\n", " num_negative_edges=100,\n", " node_feats=nfeats_4_modeling,\n", " batch_size=256)\n", "test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),\n", " fanout=[100, 100],\n", " num_negative_edges=100,\n", " node_feats=nfeats_4_modeling,\n", " batch_size=256)\n", "\n", "from demo_models import HgtLPModel # Import the HGT model variant\n", "\n", "model = HgtLPModel(g=acm_data.g,\n", " num_heads=8,\n", " num_hid_layers=2,\n", " node_feat_field=nfeats_4_modeling,\n", " hid_size=128)\n", "\n", "evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)\n", "\n", "trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)\n", "trainer.setup_evaluator(evaluator)\n", "trainer.setup_device(gs.utils.get_device())\n", "\n", "trainer.fit(train_loader=train_dataloader,\n", " val_loader=val_dataloader,\n", " test_loader=test_dataloader,\n", " num_epochs=50,\n", " save_model_path='a_save_path/',\n", " save_model_frequency=1000,\n", " use_mini_batch_infer=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2 Visualize Model Performance History" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "val_metrics, test_metrics = [], []\n", "for val_metric, test_metric in trainer.evaluator.history:\n", " val_metrics.append(val_metric['mrr'])\n", " test_metrics.append(test_metric['mrr'])\n", "\n", "fig, ax = plt.subplots()\n", "ax.plot(val_metrics, label='val')\n", "ax.plot(test_metrics, label='test')\n", "ax.set(xlabel='Epoch', ylabel='Mrr')\n", "ax.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.3 Inference pipeline" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "best_model_path = trainer.get_best_model_path()\n", "print('Best model path:', best_model_path)\n", "\n", "model.restore_model(best_model_path)\n", "\n", "infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),\n", " fanout=[100, 100],\n", " num_negative_edges=100,\n", " node_feats=nfeats_4_modeling,\n", " batch_size=256)\n", "\n", "infer = gs.inference.GSgnnLinkPredictionInferrer(model)\n", "\n", "infer.infer(acm_data,\n", " infer_dataloader,\n", " save_embed_path='infer/embeddings',\n", " use_mini_batch_infer=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_gsf", "language": "python", "name": "conda_gsf" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }