GSgnnEdgeDataLoader
- class graphstorm.dataloading.GSgnnEdgeDataLoader(dataset, target_idx, fanout, batch_size, device='cpu', train_task=True, reverse_edge_types_map=None, remove_target_edge_type=True, exclude_training_targets=False, construct_feat_ntype=None, construct_feat_fanout=5)
Bases:
GSgnnEdgeDataLoaderBaseThe minibatch dataloader for edge prediction
GSgnnEdgeDataLoader samples GraphStorm edge dataset into an iterable over mini-batches of samples. Both source and destination nodes are included in the batch_graph, which will be used by GraphStorm Trainers and Inferrers.
Parameters
- dataset: GSgnnEdgeData
The GraphStorm edge dataset
- target_idxdict of Tensors
The target edges for prediction
- fanout: list of int or dict of list
Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.
- batch_size: int
Batch size
- device: torch.device
the device trainer is running on.
- train_taskbool
Whether or not for training.
- reverse_edge_types_map: dict
A map for reverse edge type
- exclude_training_targets: bool
Whether to exclude training edges during neighbor sampling
- remove_target_edge_type: bool
Whether we will exclude all edges of the target edge type in message passing.
- construct_feat_ntypelist of str
The node types that requires to construct node features.
- construct_feat_fanoutint
The fanout required to construct node features.
Examples
To train a 2-layer GNN for edge prediction on a set of edges
target_idxon a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.from graphstorm.dataloading import GSgnnEdgeTrainData from graphstorm.dataloading import GSgnnEdgeDataLoader from graphstorm.trainer import GSgnnEdgePredictionTrainer ep_data = GSgnnEdgeTrainData(...) ep_dataloader = GSgnnEdgeDataLoader(ep_data, target_idx, fanout=[15, 10], batch_size=128) ep_trainer = GSgnnEdgePredictionTrainer(...) ep_trainer.fit(ep_dataloader, num_epochs=10)
- __iter__()
Returns an iterator object
- __next__()
Return a mini-batch data for the edge task.
A mini-batch comprises three objects: the input node IDs, the target edges and the subgraph blocks for message passing.
Returns
dict of Tensors : the input node IDs of the mini-batch. DGLGraph : the target edges. list of DGLGraph : the subgraph blocks for message passing.