{ "cells": [ { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Notebook 5: Use GraphStorm APIs for a Customized Model to Perform Graph-level Prediction\n", "\n", "Graph-level prediction, such as graph classification or graph regression, is a common task in Graph Machine Learning (GML) across various domains, including life sciences and chemistry. In graph-level prediction, the entire graph data is typically organized as a batch of subgraphs, where each subgraph's nodes have edges only within the subgraph and no edges connecting to nodes in other subgraphs. GML labels are linked to these subgraphs. Once trained, GML models can make predictions on new and unseen subgraphs.\n", "\n", "A typical operation used in graph prediction is called `Read-out`, which aggregates the representations of nodes in a subgraph to form one representation for the subgraph. The outputs of the `Read-out` can then be used to make predictions downstream, acting as a single representation of the entire subgraph.\n", "\n", "\n", "\n", "The current version of GraphStorm can not directly perform graph prediction. But as GraphStorm supports node-level prediction, we can use a method called `supernode` to perform graph-level predictions.\n", "\n", "----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `super-node` Method Explanation\n", "\n", "Instead of using the `Read-out` operation, we add a new node, called **super node**, to each subgraph, and link all original nodes of the subgraph to it, without adding reversed edges. With these inbound edges, representations of all original nodes in a subgraph could be easily aggregated to the **super node**. We can then use the **super node** as the representation of this subgraph to perform graph-level prediction tasks. The `super-node` method helps us to turn a graph prediction task into a node prediction task.\n", "\n", "\n", "\n", "----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implementation Ideas\n", "\n", "In order to use the `super-node` method in GraphStorm, we need to implement two extra functions.\n", "\n", "- **Raw Graph Data Conversion**: \n", " Add a super node to each subgraph in the original batch of subgraphs, and store all of them as one heterogeneous graph ready for GraphStorm's graph construction CLIs. After the graph is converted into GraphStorm's distributed graph format, we can use all of GraphStorm's built-in GNN models to perform the `super-node` prediction.\n", "- **Customized GNN Encoder** (Optional):\n", " Create a specialized GNN encoder to aggregate **super node** representations. This is an optional function as all built-in GraphStorm GNN encoders can aggregate and generate embeddings for **super nodes** naturally. But creating a customized GNN encoder can provide fine-grained control of aggregation methods, which can mimic the `Read-out` method.\n", "\n", "This notebook will demonstrate the `super-node` method by using GraphStorm APIs and other libraries to implement both functions. This notebook serves as an example of a Graph Classification Solution using GraphStorm APIs. Users can modify the custom GNN model and implement their own version." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "-----\n", "\n", "## Prerequisites\n", "This notebook assumes the following:\n", "\n", "- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).\n", "- Installation of supporting libraries, e.g., matplotlib.\n", "- [Jupyter web interactive server](https://jupyter.org/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "## 1. Raw Graph Data Conversion\n", "\n", "Converting the raw graph prediction dataset into `super-node` format for GraphStorm could be illustrated as the diagram below.\n", "\n", "\n", "\n", "In this notebook, we use the [OGBG Molhiv Data](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol), which is a popular molecular property, graph-level prediction dataset. \n", "In the interest of space, we will not show the actual raw graph data conversion code in this notebook. Users can find the [source code of OGBG data conversion](https://github.com/awslabs/graphstorm/blob/main/examples/graph_prediction/gen_ogbg_supernode.py) in GraphStorm's [graph prediction example folder](https://github.com/awslabs/graphstorm/blob/main/examples/graph_prediction/).\n", "\n", "
\n", "Tip: We also provide a Python script to generate synthetic supernode-based graph data for users to better understand the `super-node` graph data format, which is available here.
\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can download the source code of the OGBG conversion and generate the `super-node` format OGBG data with the following commands.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "!wget -c https://raw.githubusercontent.com/awslabs/graphstorm/main/examples/graph_prediction/gen_ogbg_supernode.py\n", "\n", "!python gen_ogbg_supernode.py --ogbg-data-name molhiv --output-path ./supernode_raw/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The converted OGBG data will be stored at `./supernode_raw/`. Then we can run GraphStorm's GConstruct command to partition the graph for model training and inference. The processed graph is stored in the `/supernode_gs_1p/` folder." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "!python -m graphstorm.gconstruct.construct_graph \\\n", " --conf-file ./supernode_raw/config.json \\\n", " --output-dir ./supernode_gs_1p/ \\\n", " --num-parts 1 \\\n", " --graph-name supernode_molhiv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Customized GNN Encoder for Graph Prediction\n", "\n", "The key component of this `super-node` based solution is the GNN model that can perform message passing and aggregation in each subgraph, and then perform a sort of `Read-out` operation in the super nodes. This component could be easily implemented as a customized GraphStorm GNN Encoder like demonstrated in [Notebook 4: Customized Models](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_4_Customized_Models.html).\n", "\n", "As shown in the diagram below, a **super node** will aggregate the representations from other nodes in each GNN layer. Built-in GNN encoders will udpate the aggregated representations with an additional trainable parameters. This operation is different from the common `Read-out` operation, hence potentially causing worse model performance in graph-level prediction.\n", "\n", "\n", "\n", "To mimic the `Read-out` operation, we can cache the aggregated representations, and clean the **super node**'s own representation to zeros after each GNN layer computation. Using this method, we can still leverage the built-in GraphStorm encoders, e.g., `RelationalGCNEncoder` and `RelationalGATEncoder`, but avoid the built-in self-update operation from one layer to another. In addition, we can design a more flexible `Read-out` function on these cached representations, other than just using the last layer's aggregation presentations.\n", "\n", "The below `GPEncoder4SupernodeOgbg` class implements this cached representations mechanism, and provides a few options for the read-out function." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch as th\n", "from graphstorm.model import (GSgnnNodeModel,\n", " GSNodeEncoderInputLayer,\n", " RelationalGCNEncoder,\n", " RelationalGATEncoder,\n", " HGTEncoder,\n", " EntityClassifier,\n", " ClassifyLossFunc)\n", "from graphstorm.model.gnn_encoder_base import GraphConvEncoder\n", "\n", "class GPEncoder4SupernodeOgbg(GraphConvEncoder):\n", " r\"\"\"A graph conv encoder for Graph Classification\n", "\n", " Unique arguments in this class:\n", " -------------------------------\n", " base_encoder: GraphStorm ConvEncoder types, options:\n", " 1. `RelationalGCNEncoder`;\n", " 2. `RelationalGATEncoder`;\n", " 3. `HGTEncoder`.\n", " read_out_opt: string in the following options:\n", " The aggregation method for the cached supernodes' representations. The current options:\n", " 1. `last_only`: only use the last layer's representations. If use this option, \n", " the read_out_ops will be ignored because there is only one layer representation\n", " is involved in the final read_out.\n", " 2. `mean`: compute the mean of all of the cached supernode representations.\n", " 3. `sum`: compute the summantion of all of the cached supernode representations.\n", " 4. `weighted_sum`: use additional weight parameters to compute the weighted summation\n", " all of the cached supernode representations.\n", " 5. `min`: compute the minimum in each dimension of the all of the cached supernode\n", " representations.\n", " 6. `max`: compute the maximum in each dimension of the all of the cached supernode\n", " representations.\n", " super_ntype: string\n", " The name of supernode type. Default is 'super'.\n", "\n", " \"\"\"\n", " def __init__(self,\n", " h_dim,\n", " out_dim,\n", " base_encoder,\n", " read_out_opt='last_only',\n", " super_ntype='super'\n", " ):\n", " assert isinstance(base_encoder, (RelationalGCNEncoder, RelationalGATEncoder, HGTEncoder)), \\\n", " 'Only support GraphStorm\\'s RelationalGCNEncoder, RelationalGATEncoder, and HGTEncoder'\n", " assert base_encoder.num_layers >= 3, 'For Graph Prediction task, at least two layers GNN' + \\\n", " f'encoder required, but got {base_encoder.num_layers - 1} ...'\n", " super(GPEncoder4SupernodeOgbg, self).__init__(h_dim, out_dim, base_encoder.num_layers)\n", "\n", " assert read_out_opt in ['last_only', 'mean', 'sum', 'weighted_sum', 'min', 'max'], + \\\n", " f'Not recognized read_out_opt {read_out_opt}. Options include ' + \\\n", " '\\'last_only\\', \\'mean\\', \\'sum\\', \\'weighted_sum\\', \\'mim\\', ' + \\\n", " 'and \\'max\\'.'\n", " self.base_encoder = base_encoder\n", " self.read_out_opt = read_out_opt\n", " self.super_ntype = super_ntype\n", " if read_out_opt=='weighted_sum':\n", " self.weighted_sum_para = th.nn.Parameter(th.Tensor(1, num_hidden_layers))\n", " else:\n", " self.weighted_sum_para = None\n", "\n", " def forward(self, blocks, h):\n", "\n", " supernode_cache = []\n", "\n", " # message passing in subgraphs and cache super-nodes representations\n", " for layer, block in zip(self.base_encoder.layers, blocks):\n", " h = layer(block, h)\n", "\n", " # 1. cache the output of supernodes in each layer\n", " supernode_cache.append(h[self.super_ntype])\n", " # 2. zero out the representations of supernodes as the next layer input\n", " h[self.super_ntype] = th.zeros_like(h[self.super_ntype])\n", "\n", " # add final read_out functions.\n", " supernode_cache = th.stack(supernode_cache)\n", " output = self._read_out_ops(supernode_cache)\n", "\n", " return {self.super_ntype: output}\n", "\n", " def _read_out_ops(self, supernode_cache):\n", " \"\"\" The supernode_cache shape L * N * D\n", " The output shape N * D\n", " \"\"\"\n", " if self.read_out_opt=='last_only':\n", " output = supernode_cache[-1]\n", " elif self.read_out_opt=='mean':\n", " output = th.mean(supernode_cache, dim=0)\n", " elif self.read_out_opt=='sum':\n", " output = th.sum(supernode_cache, dim=0)\n", " elif self.read_out_opt=='weighted_sum' and self.weighted_sum_para:\n", " output = th.einsum('ij, jkl->kl', self.weighted_sum_para, supernode_cache)\n", " else:\n", " raise NotImplementedError('Only support last_only, mean, sum, and weighted_sum '+ \\\n", " f'read_out_opt, but got {self.read_out_opt}.')\n", "\n", " return output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Training and Inference Pipeline\n", "\n", "With the customized encoder modified for graph prediction, we can reuse GraphStorm's end-to-end training and inference pipeline as the one in [Notebook 1: Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html) and \n", "[Notebook 2: Link Prediction Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_2_LP_Pipeline.html) to conduct the graph classification task on the converted `super-node` OGBG data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=logging.INFO)\n", "\n", "import graphstorm as gs\n", "gs.initialize()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [], "source": [ "ogbg_data = gs.dataloading.GSgnnData(part_config='./supernode_gs_1p/supernode_molhiv.json')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [], "source": [ "# define dataloaders for training, validation, and testing\n", "nfeats_4_modeling = {'node': ['n_feat'], 'super': ['n_feat']}\n", "\n", "train_dataloader = gs.dataloading.GSgnnNodeDataLoader(\n", " dataset=ogbg_data,\n", " target_idx=ogbg_data.get_node_train_set(ntypes=['super']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='labels',\n", " fanout=[20, 20, 20],\n", " batch_size=128,\n", " train_task=True)\n", "val_dataloader = gs.dataloading.GSgnnNodeDataLoader(\n", " dataset=ogbg_data,\n", " target_idx=ogbg_data.get_node_val_set(ntypes=['super']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='labels',\n", " fanout=[100, 100, 100],\n", " batch_size=256,\n", " train_task=False)\n", "test_dataloader = gs.dataloading.GSgnnNodeDataLoader(\n", " dataset=ogbg_data,\n", " target_idx=ogbg_data.get_node_test_set(ntypes=['super']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='labels',\n", " fanout=[100, 100, 100],\n", " batch_size=256,\n", " train_task=False)" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "In terms of GNN model, we can create a GraphStorm GNN model using nearly the same architecture as in the other notebooks, except that we replace the built-in GNN encoders, e.g., `RelationGCNEncoder` with the customized `GPEncoder4SupernodeOgbg`, which wraps a `RelationGCNEncoder` as its base encoder." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [], "source": [ "class RgcnGCModel4SuperOgbg(GSgnnNodeModel):\n", " \"\"\" A customized GNN model for graph classification using Graphstorm APIs\n", "\n", " Arguments\n", " ----------\n", " g: DistGraph\n", " A DGL DistGraph.\n", " num_hid_layers: int\n", " The number of gnn layers.\n", " node_feat_field: dict of list of strings\n", " The list features for each node type to be used in the model.\n", " hid_size: int\n", " The dimension of hidden layers.\n", " num_classes: int\n", " The target number of classes for classification.\n", " multilabel: boolean\n", " Indicator of if this is a multilabel task.\n", " \"\"\"\n", " def __init__(self,\n", " g,\n", " num_hid_layers,\n", " node_feat_field,\n", " hid_size,\n", " num_classes,\n", " multilabel=False):\n", " super().__init__(alpha_l2norm=0.)\n", "\n", " # extract feature size\n", " feat_size = gs.get_node_feat_size(g, node_feat_field)\n", "\n", " # set an input layer encoder\n", " encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)\n", " self.set_node_input_encoder(encoder)\n", "\n", " # set an RGCN encoder as the base encoder\n", " gnn_encoder = RelationalGCNEncoder(g=g,\n", " h_dim=hid_size,\n", " out_dim=hid_size,\n", " num_hidden_layers=num_hid_layers)\n", " # wrap the base RGCN encoder into GPEncoder4SupernodeOgbg\n", " gp_encoder = GPEncoder4SupernodeOgbg(hid_size,\n", " hid_size,\n", " gnn_encoder,\n", " read_out_opt='last_only',\n", " super_ntype='super')\n", " self.set_gnn_encoder(gp_encoder)\n", "\n", " # set a decoder specific to node classification task\n", " decoder = EntityClassifier(in_dim=hid_size,\n", " num_classes=num_classes,\n", " multilabel=multilabel)\n", " self.set_decoder(decoder)\n", "\n", " # classification loss function\n", " self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))\n", "\n", " # initialize model's optimizer\n", " self.init_optimizer(lr=0.001,\n", " sparse_optimizer_lr=0.001,\n", " weight_decay=0)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [], "source": [ "model = RgcnNCModel4SuperOgbg(g=ogbg_data.g,\n", " num_hid_layers=3,\n", " node_feat_field=nfeats_4_modeling,\n", " hid_size=128,\n", " num_classes=2)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [], "source": [ "# setup a classification evaluator for the trainer\n", "evaluator = gs.eval.GSgnnClassificationEvaluator(eval_frequency=100,\n", " eval_metric_list=['roc_auc'])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [], "source": [ "# create a GraphStorm node task trainer for the RGCN model\n", "trainer = gs.trainer.GSgnnNodePredictionTrainer(model)\n", "trainer.setup_evaluator(evaluator)\n", "trainer.setup_device(gs.utils.get_device())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Train the model with the trainer using fit() function\n", "trainer.fit(train_loader=train_dataloader,\n", " val_loader=val_dataloader,\n", " test_loader=test_dataloader,\n", " num_epochs=50,\n", " save_model_path='a_save_path/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we examine the model performance over the training process" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Extract accuracies from the trainer's evaluator:\n", "val_accs, test_accs = [], []\n", "for val_acc, test_acc in trainer.evaluator.history:\n", " val_accs.append(val_acc['roc_auc'])\n", " test_accs.append(test_acc['roc_auc'])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# plot the learning curves\n", "fig, ax = plt.subplots(figsize=(15, 10))\n", "ax.plot(val_accs, label='val')\n", "ax.plot(test_accs, label='test')\n", "ax.set(xlabel='Eval Times', ylabel='ROC_AUC')\n", "ax.legend(loc='best');" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [] }, "outputs": [], "source": [ "# after training, the best model is saved to disk:\n", "best_model_path = trainer.get_best_model_path()\n", "print('Best model checkpoint:', best_model_path)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "tags": [] }, "outputs": [], "source": [ "# check the saved artifacts\n", "!ls -ls {best_model_path}" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [] }, "outputs": [], "source": [ "# we can restore the model from the checkpoint:\n", "model.restore_model(best_model_path)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Setup dataloader for inference\n", "infer_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=ogbg_data,\n", " target_idx=ogbg_data.get_node_test_set(ntypes=['super']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='labels',\n", " fanout=[100, 100, 100],\n", " batch_size=256,\n", " train_task=False)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Create an Inferrer object\n", "infer = gs.inference.GSgnnNodePredictionInferrer(model)\n", "\n", "# Run inference on the inference dataset\n", "infer.infer(infer_dataloader,\n", " save_embed_path='infer/embeddings',\n", " save_prediction_path='infer/predictions',\n", " use_mini_batch_infer=True)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "tags": [] }, "outputs": [], "source": [ "# The GNN embeddings on the inference graph are saved to:\n", "!ls -lh infer/embeddings\n", "!ls -lh infer/embeddings/super/" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "tags": [] }, "outputs": [], "source": [ "!ls -lh infer/predictions\n", "!ls -lh infer/predictions/super" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "gsf", "language": "python", "name": "gsf" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }