GSgnnLinkPredictionDataLoaderBase
- class graphstorm.dataloading.GSgnnLinkPredictionDataLoaderBase(dataset, target_idx, fanout, node_feats=None, edge_feats=None, pos_graph_edge_feats=None)
Bases:
objectThe base class of link prediction dataloader.
If users want to customize the dataloader for link prediction tasks they should extend this base class by implementing the special methods __iter__ and __next__.
Parameters
- dataset: GSgnnData
The GraphStorm edge dataset
- target_idxdict of Tensors
The target edges for prediction
- fanout: list of int or dict of list
Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.
- node_feats: str, or dist of list of str
Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None
- edge_feats: str, or dist of list of str
Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None
- pos_graph_edge_feats: str or dist of list of str
The field of the edge features used by positive graph in link prediction. For example edge weight. Default: None
- __iter__()
Returns an iterator object
- __next__()
Return a mini-batch for link prediction.
A mini-batch of link prediction contains four objects: * the input node IDs of the mini-batch, * the target positive edges for prediction, * the negative edges for prediction, * the subgraph blocks for message passing.
Returns
Tensor or dict of Tensors : the input nodes of a mini-batch. DGLGraph : positive edges. DGLGraph : negative edges. list of DGLGraph : subgraph blocks for message passing.