GSgnnNodeDataLoaderBase
- class graphstorm.dataloading.GSgnnNodeDataLoaderBase(dataset, target_idx, fanout, label_field, node_feats=None, edge_feats=None)
Bases:
objectThe base dataloader class for node tasks.
If users want to customize dataloaders for their node prediction tasks, they should extend this base class by implementing the special methods
__iter__,__next__, and__len__.Parameters
- datasetGSgnnData
The GraphStorm data for node tasks.
- target_idxdict of Tensors
The target node indexes for prediction.
- fanoutlist of int, or dict of lists
The fanout for each GNN layer.
- label_field: str, or dict of str
Label field name of the target node types.
- node_feats: str, or dict of list of str
Node feature fileds in three possible formats:
string: All nodes have the same feature name.
list of string: All nodes have the same list of features.
dict of list of string: Each node type have different set of node features.
Default: None.
- edge_feats: str, or dict of list of str
Edge feature fileds in three possible formats:
string: All edges have the same feature name.
list of string: All edges have the same list of features.
dict of list of string: Each edge type have different set of edge features.
Default: None.
- __iter__()
Returns an iterator object.
- __next__()
Return a mini-batch data for node tasks.
A mini-batch comprises three objects: 1) the input node IDs of the mini-batch, 2) the target nodes, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Returns
dict of Tensors : the input node IDs of the mini-batch.
dict of Tensors : the target node indexes.
list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- property data
The data of the dataloader, which is given in class initialization.
Returns
GSgnnData : The data of the dataloader.
- property target_nidx
Target edge indexes for prediction , which is given in class initialization.
Returns
dict of Tensors : the target edge indexes.
- property fanout
The fan out of each GNN layers , which is given in class initialization.
Returns
list or a dict of list : the fanouts for each GNN layer , which is given in class initialization.
- property label_field
The label field, which is given in class initialization.
Returns
str, or dict of str: Label fields, which is given in class initialization.