GSgnnEdgeDataLoader
- class graphstorm.dataloading.GSgnnEdgeDataLoader(dataset, target_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, decoder_edge_feats=None, train_task=True, reverse_edge_types_map=None, remove_target_edge_type=True, exclude_training_targets=False, construct_feat_ntype=None, construct_feat_fanout=5)
Bases:
GSgnnEdgeDataLoaderBaseThe minibatch dataloader for edge prediction
GSgnnEdgeDataLoader samples GraphStorm edge dataset into an iterable over mini-batches of samples. Both source and destination nodes are included in the batch_graph, which will be used by GraphStorm Trainers and Inferrers.
Parameters
- dataset: GSgnnData
The GraphStorm edge dataset
- target_idxdict of Tensors
The target edges for prediction
- fanout: list of int or dict of list
Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.
- batch_size: int
Batch size
- label_field: str or dict of str
Label field of the edge task.
- node_feats: str, or dist of list of str
Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None
- edge_feats: str, or dist of list of str
Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None
- decoder_edge_feats: str or dict of list of str
Edge features used in decoder. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None
- train_taskbool
Whether or not for training.
- reverse_edge_types_map: dict
A map for reverse edge type
- exclude_training_targets: bool
Whether to exclude training edges during neighbor sampling
- remove_target_edge_type: bool
Whether we will exclude all edges of the target edge type in message passing.
- construct_feat_ntypelist of str
The node types that requires to construct node features.
- construct_feat_fanoutint
The fanout required to construct node features.
Examples
To train a 2-layer GNN for edge prediction on a set of edges
target_idxon a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import GSgnnEdgeDataLoader from graphstorm.trainer import GSgnnEdgePredictionTrainer ep_data = GSgnnData(...) target_idx = ep_data.get_edge_train_set(...) ep_dataloader = GSgnnEdgeDataLoader( ep_data, target_idx, fanout=[15, 10], batch_size=128, label_field=config.label_field) ep_trainer = GSgnnEdgePredictionTrainer(...) ep_trainer.fit(ep_dataloader, num_epochs=10)
- __iter__()
Returns an iterator object
- __next__()
Return a mini-batch data for the edge task.
A mini-batch comprises three objects: the input node IDs, the target edges and the subgraph blocks for message passing.
Returns
dict of Tensors : the input node IDs of the mini-batch. DGLGraph : the target edges. list of DGLGraph : the subgraph blocks for message passing.