GSgnnEdgeDataLoaderBase
- class graphstorm.dataloading.GSgnnEdgeDataLoaderBase(dataset, target_idx, fanout, label_field, node_feats=None, edge_feats=None, decoder_edge_feats=None)
Bases:
objectThe base dataloader class for edge tasks.
If users want to customize dataloaders for edge prediction tasks, they should extend this base class by implementing the special methods
__iter__,__next__, and__len__.Parameters
- datasetGSgnnData
The GraphStorm data for edge tasks.
- target_idxdict of Tensors
The target edge indexes for prediction.
- fanoutlist or dict of lists
The fanout for each GNN layer. If it’s a dict of lists, it indicates the fanout for each edge type.
- label_field: str or dict of str
Label field of the edge task.
- node_feats: str, or dict of list of str
Node feature fileds in three possible formats:
string: All nodes have the same feature name.
list of string: All nodes have the same list of features.
dict of list of string: Each node type have different set of node features.
Default: None.
- edge_feats: str, or dict of list of str
Edge feature fileds in three possible formats:
string: All edges have the same feature name.
list of string: All edges have the same list of features.
dict of list of string: Each edge type have different set of edge features.
Default: None.
- decoder_edge_feats: str, or dict of list of str
Edge feature fileds used in edge decoders in three possible formats:
string: All edges have the same feature name.
list of string: All edges have the same list of features.
dict of list of string: Each edge type have different set of edge features.
Default: None.
- __iter__()
Returns an iterator object.
- __next__()
Return a mini-batch data for the edge task.
A mini-batch comprises three objects: 1) the input node IDs, 2) the target edges, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Returns
dict of Tensors : the input node IDs of the mini-batch.
DGLGraph : the target edges.
list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- property data
The dataset of this dataloader, which is given in class initialization.
Returns
GSgnnData: The dataset of the dataloader.
- property target_eidx
Target edge indexes for prediction, which is given in class initialization.
Returns
dict of Tensors: the target edge IDs, which is given in class initialization.
- property fanout
The fan out of each GNN layers, which is given in class initialization.
Returns
list or a dict of list: the fanouts for each GNN layer, which is given in class initialization.
- property label_field
The label field, which is given in class initialization.
Returns
str: Label fields in the graph, which is given in class initialization.
- property node_feat_fields
Node feature fields, which is given in class initialization.
Returns
str or dict of list of str: Node feature fields in the graph, which is given in class initialization.