GSgnnEdgeDataLoader

class graphstorm.dataloading.GSgnnEdgeDataLoader(dataset, target_idx, fanout, batch_size, device='cpu', train_task=True, reverse_edge_types_map=None, remove_target_edge_type=True, exclude_training_targets=False, construct_feat_ntype=None, construct_feat_fanout=5)

Bases: GSgnnEdgeDataLoaderBase

The minibatch dataloader for edge prediction

GSgnnEdgeDataLoader samples GraphStorm edge dataset into an iterable over mini-batches of samples. Both source and destination nodes are included in the batch_graph, which will be used by GraphStorm Trainers and Inferrers.

Parameters

dataset: GSgnnEdgeData

The GraphStorm edge dataset

target_idxdict of Tensors

The target edges for prediction

fanout: list of int or dict of list

Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.

batch_size: int

Batch size

device: torch.device

the device trainer is running on.

train_taskbool

Whether or not for training.

reverse_edge_types_map: dict

A map for reverse edge type

exclude_training_targets: bool

Whether to exclude training edges during neighbor sampling

remove_target_edge_type: bool

Whether we will exclude all edges of the target edge type in message passing.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout required to construct node features.

Examples

To train a 2-layer GNN for edge prediction on a set of edges target_idx on a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.

from graphstorm.dataloading import GSgnnEdgeTrainData
from graphstorm.dataloading import GSgnnEdgeDataLoader
from graphstorm.trainer import GSgnnEdgePredictionTrainer

ep_data = GSgnnEdgeTrainData(...)
ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx, fanout=[15, 10], batch_size=128)
ep_trainer = GSgnnEdgePredictionTrainer(...)
ep_trainer.fit(ep_dataloader, num_epochs=10)
__iter__()

Returns an iterator object

__next__()

Return a mini-batch data for the edge task.

A mini-batch comprises three objects: the input node IDs, the target edges and the subgraph blocks for message passing.

Returns

dict of Tensors : the input node IDs of the mini-batch. DGLGraph : the target edges. list of DGLGraph : the subgraph blocks for message passing.