GSgnnEdgeDataLoaderBase

class graphstorm.dataloading.GSgnnEdgeDataLoaderBase(dataset, target_idx, fanout, label_field, node_feats=None, edge_feats=None, decoder_edge_feats=None)

Bases: object

The base dataloader class for edge tasks.

If users want to customize dataloaders for edge prediction tasks, they should extend this base class by implementing the special methods __iter__, __next__, and __len__.

Parameters

datasetGSgnnData

The GraphStorm data for edge tasks.

target_idxdict of Tensors

The target edge indexes for prediction.

fanoutlist or dict of lists

The fanout for each GNN layer. If it’s a dict of lists, it indicates the fanout for each edge type.

label_field: str or dict of str

Label field of the edge task.

node_feats: str, or dict of list of str

Node feature fileds in three possible formats:

  • string: All nodes have the same feature name.

  • list of string: All nodes have the same list of features.

  • dict of list of string: Each node type have different set of node features.

Default: None.

edge_feats: str, or dict of list of str

Edge feature fileds in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

decoder_edge_feats: str, or dict of list of str

Edge feature fileds used in edge decoders in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

__iter__()

Returns an iterator object.

__next__()

Return a mini-batch data for the edge task.

A mini-batch comprises three objects: 1) the input node IDs, 2) the target edges, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Returns

  • dict of Tensors : the input node IDs of the mini-batch.

  • DGLGraph : the target edges.

  • list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

__len__()

Return the length (number of mini-batches) of the data loader.

Returns

int: length

property data

The dataset of this dataloader, which is given in class initialization.

Returns

GSgnnData: The dataset of the dataloader.

property target_eidx

Target edge indexes for prediction, which is given in class initialization.

Returns

dict of Tensors: the target edge IDs, which is given in class initialization.

property fanout

The fan out of each GNN layers, which is given in class initialization.

Returns

list or a dict of list: the fanouts for each GNN layer, which is given in class initialization.

property label_field

The label field, which is given in class initialization.

Returns

str: Label fields in the graph, which is given in class initialization.

property node_feat_fields

Node feature fields, which is given in class initialization.

Returns

str or dict of list of str: Node feature fields in the graph, which is given in class initialization.

property edge_feat_fields

Edge feature fields, which is given in class initialization.

Returns

str or dict of list of str: Node feature fields in the graph, which is given in class initialization.

property decoder_edge_feat_fields

Edge features for edge decoder, which is given in class initialization.

Returns

str or dict of list of str: Node feature fields in the graph, which is given in class initialization.