{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Notebook 4: Use GraphStorm APIs for Customizing Model Components\n", "\n", "This notebook provides an example about how to customize components of GML models to fit specific requirements. The customized models should extend GraphStorm higher-level APIs, which enable them to not only implement their own functionalities, but also to easily integrate into GraphStorm training and inference pipelines.\n", "\n", "----\n", "\n", "### An Example of a Customized Model\n", "\n", "A widely used GNN model is the `RGAT` model, proposed by [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811). The original `RGAT` model considers the different importance of neighbors for a node and leverages attention mechanism to aggregate messages from neighbors within the same relation type. Then, aggregations of neighbors from different relation types are added together as the output representations of a node,\n", "\n", "$$h_i = \\sum_{r\\in \\mathscr{R}}\\sum_{j\\in \\mathcal{N}^{r}_{(i)}} \\alpha^{r}_{i,j} W^{r} h_j^{r}.$$\n", "\n", "An alternative way to aggregate representations across different relation types is to use attention instead of summation. We can use an additional weight set to compute the attention coefficients for different relation types,\n", "$$\\beta^r_i = \\dfrac{exp(h^{r}_i \\cdot \\phi)}{\\sum_{k \\in \\{1, \\dots, \\mathscr{R}\\}} exp(h^{k}_i \\cdot \\phi)},$$\n", "and then compute the weighted sum of aggregations across relation types,\n", "$$h_i = \\sum_{r \\in \\mathscr{R}}{\\beta^r_i \\times h_i^{r}},$$\n", "$$h_i^{r} = LeakyReLU(\\sum_{j\\in \\mathcal{N}^{r}_{(i)}} \\alpha^{r}_{i,j} W^{r} h_j^{r}).$$\n", "\n", "In this notebook, we will implement this `Across Relation Attention GAT (ARA_GAT)`, fit it into GraphStorm model architecture, and run training and inference using existing pipelines.\n", "\n", "### Prerequisites\n", "\n", "- GraphStorm. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).\n", "- ACM data that has been created according to **[Notebook 0: Data Preparation](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_0_Data_Prepare.html)**, and is stored in the `./acm_gs_1p/` folder.\n", "- Installation of supporting libraries, e.g., matplotlib." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Recap GraphStorm Model Architecture\n", "\n", "As explained in **[Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_3_Model_Variants.html)**, a GraphStorm model normally contains four modules:\n", "\n", "- An input encoder that converts input node features to the embeddings with hidden dimensions.\n", "- A GNN encoder that takes the embeddings from the input encoder and performs message passing computation.\n", "- A decoder that is task sepcific, e.g., the `EntityClassifier` for classification tasks.\n", "- A loss function that matches specific tasks, e.g., the `ClassifyLossFunc`.\n", "\n", "Given this architecture, it is clear that we only need to build a **customized GNN encoder** that implements the `ARA_GAT` variant and leave the other modules untouched." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. ARA_GAT Variant Encoder Implementation\n", "\n", "To build a customized GNN encoder, we can refer to implementation of GraphStorm's GNN encoders, such as `graphstorm.model.RelationalGATEncoder`, which extends the `graphstorm.model.GraphConvEncoder` and implements the required method.\n", "\n", "The code in the cells below includes a layer module named `Ara_GatLayer`, which fulfills the ARA_GAT functions in one layer of GNN, and an encoder module named `Ara_GatEncoder`, which extends from `graphstorm.model.GSgnnNodeModel`. \n", "\n", "### 2.1 `Ara_GatLaye` Implementation" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [], "source": [ "import dgl\n", "import torch as th\n", "import torch.nn as nn\n", "\n", "class Ara_GatLayer(nn.Module):\n", " \"\"\" One layer of ARA_GAT\n", " \"\"\"\n", " def __init__(self, in_dim, out_dim, num_heads, rel_names, bias=True,\n", " activation=None, self_loop=False, dropout=0.0, norm=None):\n", " super(Ara_GatLayer, self).__init__()\n", " self.bias = bias\n", " self.activation = activation\n", " self.self_loop = self_loop\n", " self.leaky_relu = nn.LeakyReLU(0.2)\n", "\n", " # GAT module for each relation type\n", " self.rel_gats = nn.ModuleDict()\n", " for rel in rel_names:\n", " self.rel_gats[str(rel)] = dgl.nn.GATConv(in_dim, out_dim//num_heads, # should be divible\n", " num_heads, allow_zero_in_degree=True)\n", "\n", " # across-relation attention weight set\n", " self.acr_attn_weights = nn.Parameter(th.Tensor(out_dim, 1))\n", " nn.init.normal_(self.acr_attn_weights)\n", "\n", " # bias\n", " if bias:\n", " self.h_bias = nn.Parameter(th.Tensor(out_dim))\n", " nn.init.zeros_(self.h_bias)\n", "\n", " # weight for self loop\n", " if self.self_loop:\n", " self.loop_weight = nn.Parameter(th.Tensor(in_dim, out_dim))\n", " nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))\n", "\n", " # dropout\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " # normalization for each node type\n", " ntypes = set()\n", " for rel in rel_names:\n", " ntypes.add(rel[0])\n", " ntypes.add(rel[2])\n", "\n", " if norm == \"batch\":\n", " self.norm = nn.ParameterDict({ntype:nn.BatchNorm1d(out_dim) for ntype in ntypes})\n", " elif norm == \"layer\":\n", " self.norm = nn.ParameterDict({ntype:nn.LayerNorm(out_dim) for ntype in ntypes})\n", " else:\n", " self.norm = None\n", "\n", " def forward(self, g, inputs):\n", " \"\"\"\n", " g: DGL.block\n", " A DGL block\n", " inputs : dict[str, torch.Tensor]\n", " Node feature for each node type.\n", "\n", " Returns\n", " -------\n", " dict[str, torch.Tensor]\n", " New node features for each node type.\n", " \"\"\"\n", " g = g.local_var()\n", "\n", " # loop each edge type to fulfill GAT computation within each edge type\n", " for src_type, e_type, dst_type in g.canonical_etypes:\n", "\n", " # extract subgraph of each edge type\n", " sub_graph = g[src_type, e_type, dst_type]\n", "\n", " # check if no edges exist for this edge type\n", " if sub_graph.num_edges() == 0:\n", " continue\n", "\n", " # extract source and destination node features\n", " src_feat = inputs[src_type]\n", " dst_feat = inputs[dst_type][ :sub_graph.num_dst_nodes()]\n", "\n", " # GAT in one relation type\n", " agg = self.rel_gats[str((src_type, e_type, dst_type))](sub_graph, (src_feat, dst_feat))\n", " agg = agg.view(agg.shape[0], -1)\n", "\n", " # store aggregations in destination nodes\n", " sub_graph.dstdata['agg_' + str((src_type, e_type, dst_type))] = self.leaky_relu(agg)\n", "\n", " h = {}\n", " for n_type in g.dsttypes:\n", " if g.num_dst_nodes(n_type) == 0:\n", " continue\n", "\n", " # cross relation attention enhancement as outputs\n", " agg_list = []\n", " for k, v in g.dstnodes[n_type].data.items():\n", " if k.startswith('agg_'):\n", " agg_list.append(v)\n", "\n", " # cross-relation attention\n", " if agg_list:\n", " acr_agg = th.stack(agg_list, dim=1)\n", "\n", " acr_att = th.matmul(acr_agg, self.acr_attn_weights)\n", " acr_sfm = th.softmax(acr_att, dim=1)\n", "\n", " # cross-relation weighted aggregation\n", " acr_sum = (acr_agg * acr_sfm).sum(dim=1)\n", " elif not self.self_loop:\n", " raise ValueError(f'Some nodes in the {n_type} type have no in-degree.' + \\\n", " 'Please check the data or set \\\"self_loop=True\\\"')\n", "\n", " # process new features\n", " if self.self_loop:\n", " if agg_list:\n", " h_n = acr_sum + th.matmul(inputs[n_type][ :g.num_dst_nodes(n_type)], self.loop_weight)\n", " else:\n", " h_n = th.matmul(inputs[n_type][ :g.num_dst_nodes(n_type)], self.loop_weight)\n", " if self.bias:\n", " h_n = h_n + self.h_bias\n", " if self.activation:\n", " h_n = self.activation(h_n)\n", " if self.norm:\n", " h_n = self.norm[n_type](h_n)\n", " h_n = self.dropout(h_n)\n", "\n", " h[n_type] = h_n\n", "\n", " return h\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 `Ara_GatEncoder` Implementation\n", "\n", "Here, we implement the `Ara_GatEncoder` by extending the base GraphStorm encoder, `graphstorm.model.gnn_encoder_base.GraphConvEncoder`, and implementing the `forward(self, blocks, h)` funciton to make this class compatible with GraphStorm model architecture. The forward() function takes a DGL block list and a dictionary of node representations as input arguments, and returns a dictionary that contains the new node representations. This forward() function will be called by GraphStorm model classes within their own forward() function." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [], "source": [ "from graphstorm.model.gnn_encoder_base import GraphConvEncoder\n", "import torch.nn.functional as F\n", "\n", "class Ara_GatEncoder(GraphConvEncoder):\n", " \"\"\" Across Relation Attention GAT Encoder by extending Graphstorm APIs\n", " \"\"\"\n", " def __init__(self, g, h_dim, out_dim, num_heads, num_hidden_layers=1,\n", " dropout=0, use_self_loop=True, norm='batch'):\n", " super(Ara_GatEncoder, self).__init__(h_dim, out_dim, num_hidden_layers)\n", "\n", " # h2h\n", " for _ in range(num_hidden_layers):\n", " self.layers.append(Ara_GatLayer(h_dim, h_dim, num_heads, g.canonical_etypes,\n", " activation=F.relu, self_loop=use_self_loop, dropout=dropout, norm=norm))\n", " # h2o\n", " self.layers.append(Ara_GatLayer(h_dim, out_dim, num_heads, g.canonical_etypes,\n", " activation=F.relu, self_loop=use_self_loop, norm=norm))\n", "\n", " def forward(self, blocks, h):\n", " \"\"\" accept block list and feature dictionary as inputs\n", " \"\"\"\n", " for layer, block in zip(self.layers, blocks):\n", " h = layer(block, h)\n", " return h" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Build a Node Classification Model based on the `Ara_GatEncoder`\n", "\n", "The `RgatNCModel` below follows the same node classification model architecture used in **[Notebook 1: Use GraphStorm APIs for Building a Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html)**. For the GNN encoder components, this model provides the option to use either the `Ara_GatEncoder` or the built-in `RelationalGATEncoder` from GraphStorm." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "tags": [] }, "outputs": [], "source": [ "from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGATEncoder, EntityClassifier, ClassifyLossFunc\n", "\n", "class RgatNCModel(GSgnnNodeModel):\n", " \"\"\" A customized RGAT model for node classification using Graphstorm APIs\n", " \"\"\"\n", " def __init__(self, g, num_heads, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False,\n", " encoder_type='ara' # option for different rgat encoders\n", " ):\n", " super(RgatNCModel, self).__init__(alpha_l2norm=0.)\n", "\n", " # extract feature size\n", " feat_size = gs.get_node_feat_size(g, node_feat_field)\n", "\n", " # set an input layer encoder\n", " encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)\n", " self.set_node_input_encoder(encoder)\n", "\n", " # set the option of using either customized RGAT or built-in RGAT encoder\n", " if encoder_type == 'ara':\n", " gnn_encoder = Ara_GatEncoder(g, hid_size, hid_size, num_heads,\n", " num_hidden_layers=num_hid_layers-1)\n", " elif encoder_type == 'rgat':\n", " gnn_encoder = RelationalGATEncoder(g, hid_size, hid_size, num_heads,\n", " num_hidden_layers=num_hid_layers-1)\n", " else:\n", " raise Exception(f'Not supported encoders \\\"{encoder_type}\\\".')\n", " self.set_gnn_encoder(gnn_encoder)\n", "\n", " # set a decoder specific to node classification task\n", " decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)\n", " self.set_decoder(decoder)\n", "\n", " # classification loss function\n", " self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))\n", "\n", " # initialize model's optimizer\n", " self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Node Classification Pipeline Using the `Ara_GatNCModel` Model \n", "\n", "The overall pipeline for using the customized model for node classification tasks is identical to those in **[Notebook 1: Use GraphStorm APIs for Building a Node Classification Pipeline](https://graphstorm.readthedocs.io/en/latest/api/notebooks/Notebook_1_NC_Pipeline.html)**.\n", "\n", "### 4.1 Training pipeline" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=20)\n", "import graphstorm as gs\n", "gs.initialize()\n", "\n", "acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json')\n", "\n", "nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}\n", "\n", "train_dataloader = gs.dataloading.GSgnnNodeDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_node_train_set(ntypes=['paper']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='label',\n", " fanout=[20, 20],\n", " batch_size=64,\n", " train_task=True)\n", "val_dataloader = gs.dataloading.GSgnnNodeDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_node_val_set(ntypes=['paper']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='label',\n", " fanout=[100, 100],\n", " batch_size=256,\n", " train_task=False)\n", "test_dataloader = gs.dataloading.GSgnnNodeDataLoader(\n", " dataset=acm_data,\n", " target_idx=acm_data.get_node_test_set(ntypes=['paper']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='label',\n", " fanout=[100, 100],\n", " batch_size=256,\n", " train_task=False)\n", "\n", "model = RgatNCModel(g=acm_data.g, num_heads=8, num_hid_layers=2, node_feat_field=nfeats_4_modeling,\n", " hid_size=128, num_classes=14, encoder_type='ara')\n", "\n", "evaluator = gs.eval.GSgnnClassificationEvaluator(eval_frequency=100)\n", "\n", "trainer = gs.trainer.GSgnnNodePredictionTrainer(model)\n", "trainer.setup_evaluator(evaluator)\n", "trainer.setup_device(gs.utils.get_device())\n", "\n", "trainer.fit(train_loader=train_dataloader,\n", " val_loader=val_dataloader,\n", " test_loader=test_dataloader,\n", " num_epochs=50,\n", " save_model_path='a_save_path/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.2 Visualize Model Performance History" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "val_metrics, test_metrics = [], []\n", "for val_metric, test_metric in trainer.evaluator.history:\n", " val_metrics.append(val_metric['accuracy'])\n", " test_metrics.append(test_metric['accuracy'])\n", "\n", "fig, ax = plt.subplots()\n", "ax.plot(val_metrics, label='val')\n", "ax.plot(test_metrics, label='test')\n", "ax.set(xlabel='Epoch', ylabel='Accuracy')\n", "ax.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.3 Inference pipeline" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "best_model_path = trainer.get_best_model_path()\n", "print('Best model path:', best_model_path)\n", "\n", "model.restore_model(best_model_path)\n", "\n", "infer_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=acm_data,\n", " target_idx=acm_data.get_node_test_set(ntypes=['paper']),\n", " node_feats=nfeats_4_modeling,\n", " label_field='label',\n", " fanout=[100, 100],\n", " batch_size=256,\n", " train_task=False)\n", "\n", "infer = gs.inference.GSgnnNodePredictionInferrer(model)\n", "\n", "infer.infer(infer_dataloader,\n", " save_embed_path='infer/embeddings',\n", " save_prediction_path='infer/predictions',\n", " use_mini_batch_infer=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }