GSgnnEdgeDataLoader

class graphstorm.dataloading.GSgnnEdgeDataLoader(dataset, target_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, decoder_edge_feats=None, train_task=True, reverse_edge_types_map=None, remove_target_edge_type=True, exclude_training_targets=False, construct_feat_ntype=None, construct_feat_fanout=5)

Bases: GSgnnEdgeDataLoaderBase

The minibatch dataloader for edge prediction

GSgnnEdgeDataLoader samples GraphStorm edge dataset into an iterable over mini-batches of samples. Both source and destination nodes are included in the batch_graph, which will be used by GraphStorm Trainers and Inferrers.

Parameters

dataset: GSgnnData

The GraphStorm edge dataset

target_idxdict of Tensors

The target edges for prediction

fanout: list of int or dict of list

Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.

batch_size: int

Batch size

label_field: str or dict of str

Label field of the edge task.

node_feats: str, or dist of list of str

Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None

edge_feats: str, or dist of list of str

Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None

decoder_edge_feats: str or dict of list of str

Edge features used in decoder. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None

train_taskbool

Whether or not for training.

reverse_edge_types_map: dict

A map for reverse edge type

exclude_training_targets: bool

Whether to exclude training edges during neighbor sampling

remove_target_edge_type: bool

Whether we will exclude all edges of the target edge type in message passing.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout required to construct node features.

Examples

To train a 2-layer GNN for edge prediction on a set of edges target_idx on a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.

from graphstorm.dataloading import GSgnnData
from graphstorm.dataloading import GSgnnEdgeDataLoader
from graphstorm.trainer import GSgnnEdgePredictionTrainer

ep_data = GSgnnData(...)
target_idx = ep_data.get_edge_train_set(...)
ep_dataloader = GSgnnEdgeDataLoader(
    ep_data, target_idx,
    fanout=[15, 10], batch_size=128,
    label_field=config.label_field)
ep_trainer = GSgnnEdgePredictionTrainer(...)
ep_trainer.fit(ep_dataloader, num_epochs=10)
__iter__()

Returns an iterator object

__next__()

Return a mini-batch data for the edge task.

A mini-batch comprises three objects: the input node IDs, the target edges and the subgraph blocks for message passing.

Returns

dict of Tensors : the input node IDs of the mini-batch. DGLGraph : the target edges. list of DGLGraph : the subgraph blocks for message passing.