GSgnnNodeDataLoaderBase

class graphstorm.dataloading.GSgnnNodeDataLoaderBase(dataset, target_idx, fanout, label_field, node_feats=None, edge_feats=None)

Bases: object

The base dataloader class for node tasks.

If users want to customize the dataloader for node prediction tasks they should extend this base class by implementing the special methods __iter__ and __next__.

Parameters

datasetGSgnnData

The dataset for the node task.

target_idxdict of Tensors

The target node IDs.

fanoutlist or dict of lists

The fanout for each GNN layer.

label_field: str or dict of str

Label field of the node task.

node_feats: str, or dist of list of str

Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None

edge_feats: str, or dist of list of str

Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None

__iter__()

Returns an iterator object

__next__()

Return a mini-batch data for the node task.

A mini-batch comprises three objects: the input node IDs of the mini-batch, the target nodes and the subgraph blocks for message passing.

Returns

dict of Tensors : the input node IDs of the mini-batch. dict of Tensors : the target node IDs. list of DGLGraph : the subgraph blocks for message passing.