GSgnnLinkPredictionDataLoaderBase

class graphstorm.dataloading.GSgnnLinkPredictionDataLoaderBase(dataset, target_idx, fanout, node_feats=None, edge_feats=None, pos_graph_edge_feats=None)

Bases: object

The base dataloader class for link prediction tasks.

If users want to customize dataloaders for link prediction tasks, they should extend this base class by implementing the special methods __iter__, __next__, and __len__.

Parameters

dataset: GSgnnData

The GraphStorm data for link prediction tasks.

target_idxdict of Tensors

The target edge indexes for link prediction.

fanout: list of int, or dict of list

Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.

node_feats: str, or dict of list of str

Node feature fileds in three possible formats:

  • string: All nodes have the same feature name.

  • list of string: All nodes have the same list of features.

  • dict of list of string: Each node type have different set of node features.

Default: None.

edge_feats: str, or dict of list of str

Edge feature fileds in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

pos_graph_edge_feats: str, or dict of list of str

The field of the edge features used by positive graph in link prediction. For example edge weights. Default: None.

__iter__()

Returns an iterator object.

__next__()

Return a mini-batch for link prediction.

A mini-batch of link prediction contains four objects:

  • the input node IDs of the mini-batch.

  • the target positive edges for prediction.

  • the sampled negative edges for prediction.

  • the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Returns

  • Tensor or dict of Tensors: the input nodes of a mini-batch.

  • DGLGraph: positive edges.

  • DGLGraph: negative edges.

  • list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

__len__()

Return the length (number of mini-batches) of the data loader.

Returns

int: length

property data

The dataset of this dataloader, which is given in class initialization.

Returns

GSgnnData : The dataset of the dataloader.

property fanout

The fan out of each GNN layers, which is given in class initialization.

Returns

list or a dict of list : the fanouts for each GNN layer.

property target_eidx

The target edge indexes for prediction, which is given in class initialization.

Returns

dict of Tensors : the target edge IDs.

property node_feat_fields

Node feature fields, which is given in class initialization.

Returns

str or dict of list of str: Node feature fields in the graph.

property edge_feat_fields

Edge feature fields, which is given in class initialization.

Returns

str or dict of list of str: Edge feature fields in the graph.

property pos_graph_edge_feat_fields

Get edge feature fields of positive graphs, which is given in class initialization.

Returns

str or dict of list of str: Edge feature fields in the positive graph.