GSgnnLinkPredictionTestDataLoader
- class graphstorm.dataloading.GSgnnLinkPredictionTestDataLoader(dataset, target_idx, batch_size, num_negative_edges, fanout=None, fixed_test_size=None, node_feats=None, edge_feats=None, pos_graph_edge_feats=None)
Bases:
GSgnnLinkPredictionDataLoaderBaseMini-batch dataloader for link prediction validation and test. In order to efficiently compute positive and negative scores for link prediction tasks,
GSgnnLinkPredictionTestDataLoaderis designed to only generates edges, i.e., source and destination node pairs.The negative edges are sampled uniformly.
Parameters
- dataset: GSgnnData
The GraphStorm data.
- target_idxdict of Tensors
The target edge indexes for link prediction.
- batch_size: int
Mini-batch size.
- num_negative_edges: int
The number of negative edges per positive edge.
- fanout: list of int, or dict of list
Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.
- fixed_test_size: int
Fixed number of test data used in evaluation. If it is none, use the whole testset. When test is huge, using fixed_test_size can save validation and test time. Default: None.
- node_feats: str, or dict of list of str
Node feature fileds in three possible formats:
string: All nodes have the same feature name.
list of string: All nodes have the same list of features.
dict of list of string: Each node type have different set of node features.
Default: None.
- edge_feats: str, or dict of list of str
Edge feature fileds in three possible formats:
string: All edges have the same feature name.
list of string: All edges have the same list of features.
dict of list of string: Each edge type have different set of edge features.
Default: None.
- pos_graph_edge_feats: str or dict of list of str
The edge feature fields used by positive graph in link prediction. For example edge weight. Default: None.
- __iter__()
Returns an iterator object.
- __next__()
Return a mini-batch for link prediction.
A mini-batch of link prediction contains four objects:
the input node IDs of the mini-batch.
the target positive edges for prediction.
the sampled negative edges for prediction.
the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Returns
Tensor or dict of Tensors: the input nodes of a mini-batch.
DGLGraph: positive edges.
DGLGraph: negative edges.
list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.