GSgnnNodeDataLoader

class graphstorm.dataloading.GSgnnNodeDataLoader(dataset, target_idx, fanout, batch_size, device, train_task=True, construct_feat_ntype=None, construct_feat_fanout=5)

Bases: GSgnnNodeDataLoaderBase

Minibatch dataloader for node tasks

GSgnnNodeDataLoader samples GraphStorm node dataset into an iterable over mini-batches of samples including target nodes and sampled neighbor nodes, which will be used by GraphStorm Trainers and Inferrers.

Parameters

dataset: GSgnnNodeData

The GraphStorm dataset

target_idxdict of Tensors

The target nodes for prediction

fanout: list of int or dict of list

Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.

batch_size: int

Batch size

device: torch.device

the device trainer is running on.

train_taskbool

Whether or not for training.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout required to construct node features.

Examples

To train a 2-layer GNN for node classification on a set of nodes target_idx on a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.

from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.trainer import GSgnnNodePredictionTrainer

np_data = GSgnnNodeTrainData(...)
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx, fanout=[15, 10], batch_size=128)
np_trainer = GSgnnNodePredictionTrainer(...)
np_trainer.fit(np_dataloader, num_epochs=10)
__iter__()

Returns an iterator object

__next__()

Return a mini-batch data for the node task.

A mini-batch comprises three objects: the input node IDs of the mini-batch, the target nodes and the subgraph blocks for message passing.

Returns

dict of Tensors : the input node IDs of the mini-batch. dict of Tensors : the target node IDs. list of DGLGraph : the subgraph blocks for message passing.