GSgnnLinkPredictionDataLoader

class graphstorm.dataloading.GSgnnLinkPredictionDataLoader(dataset, target_idx, fanout, batch_size, num_negative_edges, node_feats=None, edge_feats=None, pos_graph_edge_feats=None, train_task=True, reverse_edge_types_map=None, exclude_training_targets=False, edge_mask_for_gnn_embeddings='train_mask', construct_feat_ntype=None, construct_feat_fanout=5, edge_dst_negative_field=None, num_hard_negs=None)

Bases: GSgnnLinkPredictionDataLoaderBase

Mini-batch dataloader for link prediction.

GSgnnLinkPredictionDataLoader samples GraphStorm data into an iterable over mini-batches of samples. In each batch, pos_graph and neg_graph are sampled subgraph for positive and negative edges, which will be used by GraphStorm Trainers and Inferrers.

Given a positive edge, a negative edge is composed of the source node and a random negative destination nodes according to a uniform distribution.

Argument

dataset: GSgnnData

The GraphStorm data.

target_idxdict of Tensors

The target edge indexes for prediction.

fanout: list of int, or dict of list

Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.

batch_size: int

Mini-batch size.

num_negative_edges: int

The number of negative edges per positive edge.

node_feats: str, or dict of list of str

Node feature fileds in three possible formats:

  • string: All nodes have the same feature name.

  • list of string: All nodes have the same list of features.

  • dict of list of string: Each node type have different set of node features.

Default: None.

edge_feats: str, or dict of list of str

Edge feature fileds in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

pos_graph_edge_feats: str, or dict of list of str

The edge feature fields used by positive graph in link prediction. For example edge weight. Default: None.

train_taskbool

Whether or not it is a dataloader for training.

reverse_edge_types_map: dict

A map for reverse edge type.

exclude_training_targets: bool

Whether to exclude training edges during neighbor sampling.

edge_mask_for_gnn_embeddingsstr

The mask indicates the edges used for computing GNN embeddings. By default, the dataloader uses the edges in the training graphs to compute GNN embeddings to avoid information leak for link prediction.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout used when constructing node features for feature-less nodes.

edge_dst_negative_field: str, or dict of str

The feature fields that store the hard negative edges for each edge type.

num_hard_negs: int, or dict of int

The number of hard negatives per positive edge for each edge type.

Examples

To train a 2-layer GNN for link prediction on a set of positive edges target_idx on a graph where each edge (a source and destination node pair) takes messages from 15 neighbors on the first layer and 10 neighbors on the second. We use 10 negative edges per positive in this example.

from graphstorm.dataloading import GSgnnData
from graphstorm.dataloading import GSgnnLinkPredictionDataLoader
from graphstorm.trainer import GSgnnLinkPredictionTrainer

lp_data = GSgnnData(...)
target_idx = lp_data.get_edge_train_set(...)
lp_dataloader = GSgnnLinkPredictionDataLoader(lp_data, target_idx, fanout=[15, 10],
                                            num_negative_edges=10, batch_size=128)
lp_trainer = GSgnnLinkPredictionTrainer(...)
lp_trainer.fit(lp_dataloader, num_epochs=10)
__iter__()

Returns an iterator object.

__next__()

Return a mini-batch for link prediction.

A mini-batch of link prediction contains four objects:

  • the input node IDs of the mini-batch.

  • the target positive edges for prediction.

  • the sampled negative edges for prediction.

  • the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Returns

  • Tensor or dict of Tensors: the input nodes of a mini-batch.

  • DGLGraph: positive edges.

  • DGLGraph: negative edges.

  • list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

__len__()

Follow https://github.com/dmlc/dgl/blob/1.0.x/python/dgl/distributed/dist_dataloader.py#L116. In DGL, DistDataLoader.expected_idxs is the length (number of batches) of the dataloader.

Returns:

int: The length (number of batches) of the dataloader.