GSgnnLinkPredictionDataLoaderBase
- class graphstorm.dataloading.GSgnnLinkPredictionDataLoaderBase(dataset, target_idx, fanout, node_feats=None, edge_feats=None, pos_graph_edge_feats=None)
Bases:
objectThe base dataloader class for link prediction tasks.
If users want to customize dataloaders for link prediction tasks, they should extend this base class by implementing the special methods
__iter__,__next__, and__len__.Parameters
- dataset: GSgnnData
The GraphStorm data for link prediction tasks.
- target_idxdict of Tensors
The target edge indexes for link prediction.
- fanout: list of int, or dict of list
Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.
- node_feats: str, or dict of list of str
Node feature fileds in three possible formats:
string: All nodes have the same feature name.
list of string: All nodes have the same list of features.
dict of list of string: Each node type have different set of node features.
Default: None.
- edge_feats: str, or dict of list of str
Edge feature fileds in three possible formats:
string: All edges have the same feature name.
list of string: All edges have the same list of features.
dict of list of string: Each edge type have different set of edge features.
Default: None.
- pos_graph_edge_feats: str, or dict of list of str
The field of the edge features used by positive graph in link prediction. For example edge weights. Default: None.
- __iter__()
Returns an iterator object.
- __next__()
Return a mini-batch for link prediction.
A mini-batch of link prediction contains four objects:
the input node IDs of the mini-batch.
the target positive edges for prediction.
the sampled negative edges for prediction.
the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Returns
Tensor or dict of Tensors: the input nodes of a mini-batch.
DGLGraph: positive edges.
DGLGraph: negative edges.
list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- property data
The dataset of this dataloader, which is given in class initialization.
Returns
GSgnnData : The dataset of the dataloader.
- property fanout
The fan out of each GNN layers, which is given in class initialization.
Returns
list or a dict of list : the fanouts for each GNN layer.
- property target_eidx
The target edge indexes for prediction, which is given in class initialization.
Returns
dict of Tensors : the target edge IDs.
- property node_feat_fields
Node feature fields, which is given in class initialization.
Returns
str or dict of list of str: Node feature fields in the graph.