GSgnnEdgeDataLoaderBase
- class graphstorm.dataloading.GSgnnEdgeDataLoaderBase(dataset, target_idx, fanout, label_field, node_feats=None, edge_feats=None, decoder_edge_feats=None)
Bases:
objectThe base dataloader class for edge tasks.
If users want to customize the dataloader for edge prediction tasks they should extend this base class by implementing the special methods __iter__ and __next__.
Parameters
- datasetGSgnnData
The dataset for the edge task.
- target_idxdict of Tensors
The target edge IDs.
- fanoutlist or dict of lists
The fanout for each GNN layer.
- label_field: str or dict of str
Label field of the edge task.
- node_feats: str, or dist of list of str
Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None
- edge_feats: str, or dist of list of str
Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None
- decoder_edge_feats: str or dict of list of str
Edge features used in decoder. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None
- __iter__()
Returns an iterator object
- __next__()
Return a mini-batch data for the edge task.
A mini-batch comprises three objects: the input node IDs, the target edges and the subgraph blocks for message passing.
Returns
dict of Tensors : the input node IDs of the mini-batch. DGLGraph : the target edges. list of DGLGraph : the subgraph blocks for message passing.