GSgnnNodeDataLoaderBase

class graphstorm.dataloading.GSgnnNodeDataLoaderBase(dataset, target_idx, fanout, label_field, node_feats=None, edge_feats=None)

Bases: object

The base dataloader class for node tasks.

If users want to customize dataloaders for their node prediction tasks, they should extend this base class by implementing the special methods __iter__, __next__, and __len__.

Parameters

datasetGSgnnData

The GraphStorm data for node tasks.

target_idxdict of Tensors

The target node indexes for prediction.

fanoutlist of int, or dict of lists

The fanout for each GNN layer.

label_field: str, or dict of str

Label field name of the target node types.

node_feats: str, or dict of list of str

Node feature fileds in three possible formats:

  • string: All nodes have the same feature name.

  • list of string: All nodes have the same list of features.

  • dict of list of string: Each node type have different set of node features.

Default: None.

edge_feats: str, or dict of list of str

Edge feature fileds in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

__iter__()

Returns an iterator object.

__next__()

Return a mini-batch data for node tasks.

A mini-batch comprises three objects: 1) the input node IDs of the mini-batch, 2) the target nodes, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Returns

  • dict of Tensors : the input node IDs of the mini-batch.

  • dict of Tensors : the target node indexes.

  • list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

__len__()

Return the length (number of mini-batches) of the dataloader.

Returns

int: length

property data

The data of the dataloader, which is given in class initialization.

Returns

GSgnnData : The data of the dataloader.

property target_nidx

Target edge indexes for prediction , which is given in class initialization.

Returns

dict of Tensors : the target edge indexes.

property fanout

The fan out of each GNN layers , which is given in class initialization.

Returns

list or a dict of list : the fanouts for each GNN layer , which is given in class initialization.

property label_field

The label field, which is given in class initialization.

Returns

str, or dict of str: Label fields, which is given in class initialization.

property node_feat_fields

Node features fileds, which is given in class initialization.

Returns

str, or dict of list of str: Node feature fields, which is given in class initialization.

property edge_feat_fields

Edge features fields, which is given in class initialization.

Returns

str, or dict of list of str: Edge feature fields, which is given in class initialization.