GSgnnNodeDataLoader

class graphstorm.dataloading.GSgnnNodeDataLoader(dataset, target_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, train_task=True, construct_feat_ntype=None, construct_feat_fanout=5)

Bases: GSgnnNodeDataLoaderBase

Minibatch dataloader for node tasks

GSgnnNodeDataLoader samples GraphStorm node dataset into an iterable over mini-batches of samples including target nodes and sampled neighbor nodes, which will be used by GraphStorm Trainers and Inferrers.

Parameters

dataset: GSgnnData

The GraphStorm dataset

target_idxdict of Tensors

The target nodes for prediction

fanout: list of int or dict of list

Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.

label_field: str

Label field of the node task. (TODO:xiangsx) Support list of str for single dataloader multiple node tasks.

node_feats: str, list of str or dist of list of str

Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None

edge_feats: str, list of str or dist of list of str

Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None

batch_size: int

Batch size

train_taskbool

Whether or not for training.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout required to construct node features.

Examples

To train a 2-layer GNN for node classification on a set of nodes target_idx on a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.

from graphstorm.dataloading import GSgnnData
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.trainer import GSgnnNodePredictionTrainer

np_data = GSgnnData(...)
target_idx = np_data.get_node_train_set(...)
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx, fanout=[15, 10],
                                    batch_size=128,
                                    label_field="label",
                                    node_feats="feat")
np_trainer = GSgnnNodePredictionTrainer(...)
np_trainer.fit(np_dataloader, num_epochs=10)
__iter__()

Returns an iterator object

__next__()

Return a mini-batch data for the node task.

A mini-batch comprises three objects: the input node IDs of the mini-batch, the target nodes and the subgraph blocks for message passing.

Returns

dict of Tensors : the input node IDs of the mini-batch. dict of Tensors : the target node IDs. list of DGLGraph : the subgraph blocks for message passing.