GSgnnEdgeDataLoader

class graphstorm.dataloading.GSgnnEdgeDataLoader(dataset, target_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, decoder_edge_feats=None, train_task=True, reverse_edge_types_map=None, remove_target_edge_type=True, exclude_training_targets=False, construct_feat_ntype=None, construct_feat_fanout=5)

Bases: GSgnnEdgeDataLoaderBase

The mini-batch dataloader for edge prediction tasks.

GSgnnEdgeDataLoader samples target edges into an iterable over mini-batches of samples. Both source and destination nodes are included in the batch_graph, which will be used by GraphStorm Trainers and Inferrers.

Parameters

dataset: GSgnnData

The GraphStorm data.

target_idxdict of Tensors

The target edge indexes for prediction.

fanout: list of int, or dict of list

Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.

batch_size: int

Mini-batch size.

label_field: str or dict of str

Label field of the edge task.

node_feats: str, or dict of list of str

Node feature fileds in three possible formats:

  • string: All nodes have the same feature name.

  • list of string: All nodes have the same list of features.

  • dict of list of string: Each node type have different set of node features.

Default: None.

edge_feats: str, or dict of list of str

Edge features fileds in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

decoder_edge_feats: str, or dict of list of str

Edge features used in edge decoders in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None.

train_taskbool

Whether or not is the dataloader for training.

reverse_edge_types_map: dict

A map for reverse edge type.

exclude_training_targets: bool

Whether to exclude training edges during neighbor sampling.

remove_target_edge_type: bool

Whether to exclude all edges of the target edge type in message passing.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout used when constructing node features for feature-less nodes.

Examples

To train a 2-layer GNN for edge prediction on a set of edges target_idx on a graph where each edge (source and destination node pair) takes messages from 15 neighbors on the first layer and 10 neighbors on the second.

from graphstorm.dataloading import GSgnnData
from graphstorm.dataloading import GSgnnEdgeDataLoader
from graphstorm.trainer import GSgnnEdgePredictionTrainer

ep_data = GSgnnData(...)
target_idx = ep_data.get_edge_train_set(...)
ep_dataloader = GSgnnEdgeDataLoader(
    ep_data, target_idx,
    fanout=[15, 10], batch_size=128,
    label_field=config.label_field)
ep_trainer = GSgnnEdgePredictionTrainer(...)
ep_trainer.fit(ep_dataloader, num_epochs=10)
__iter__()

Returns an iterator object.

__next__()

Return a mini-batch data for the edge task.

A mini-batch comprises three objects: 1) the input node IDs, 2) the target edges, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Returns

  • dict of Tensors : the input node IDs of the mini-batch.

  • DGLGraph : the target edges.

  • list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

__len__()

Follow https://github.com/dmlc/dgl/blob/1.0.x/python/dgl/distributed/dist_dataloader.py#L116. In DGL, DistDataLoader.expected_idxs is the length (number of batches) of the dataloader.

Returns:

int: The length (number of batches) of the dataloader.