GSgnnLinkPredictionDataLoaderBase

class graphstorm.dataloading.GSgnnLinkPredictionDataLoaderBase(dataset, target_idx, fanout)

Bases: object

The base class of link prediction dataloader.

If users want to customize the dataloader for link prediction tasks they should extend this base class by implementing the special methods __iter__ and __next__.

Parameters

dataset: GSgnnEdgeData

The GraphStorm edge dataset

target_idxdict of Tensors

The target edges for prediction

fanout: list of int or dict of list

Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.

__iter__()

Returns an iterator object

__next__()

Return a mini-batch for link prediction.

A mini-batch of link prediction contains four objects: * the input node IDs of the mini-batch, * the target positive edges for prediction, * the negative edges for prediction, * the subgraph blocks for message passing.

Returns

Tensor or dict of Tensors : the input nodes of a mini-batch. DGLGraph : positive edges. DGLGraph : negative edges. list of DGLGraph : subgraph blocks for message passing.