GSgnnNodeDataLoader
- class graphstorm.dataloading.GSgnnNodeDataLoader(dataset, target_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, train_task=True, construct_feat_ntype=None, construct_feat_fanout=5)
Bases:
GSgnnNodeDataLoaderBaseMini-batch dataloader for node tasks.
GSgnnNodeDataLoadersamples GraphStorm data into an iterable over mini-batches of samples, including target nodes and sampled neighbor nodes, which will be used by GraphStorm Trainers and Inferrers.Parameters
- dataset: GSgnnData
The GraphStorm data.
- target_idxdict of Tensors
The target node indexes for prediction.
- fanout: list of int, or dict of list
Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.
- label_field: str
Label field of the node task.
- node_feats: str, list of str or dict of list of str
Node feature fileds in three possible formats:
string: All nodes have the same feature name.
list of string: All nodes have the same list of features.
dict of list of string: Each node type have different set of node features.
Default: None.
- edge_feats: str, list of str or dict of list of str
Edge feature fileds in three possible formats:
string: All edges have the same feature name.
list of string: All edges have the same list of features.
dict of list of string: Each edge type have different set of edge features.
Default: None.
- batch_size: int
Mini-batch size.
- train_taskbool
Whether or not it is the dataloader for training.
- construct_feat_ntypelist of str
The node types that requires to construct node features.
- construct_feat_fanoutint
The fanout used when constructing node features for feature-less nodes.
Examples
To train a 2-layer GNN for node classification on a set of nodes
target_idxon a graph where each node takes messages from 15 neighbors on the first layer and 10 neighbors on the second.from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.trainer import GSgnnNodePredictionTrainer np_data = GSgnnData(...) target_idx = np_data.get_node_train_set(...) np_dataloader = GSgnnNodeDataLoader(np_data, target_idx, fanout=[15, 10], batch_size=128, label_field="label", node_feats="feat") np_trainer = GSgnnNodePredictionTrainer(...) np_trainer.fit(np_dataloader, num_epochs=10)
- __iter__()
Returns an iterator object.
- __next__()
Return a mini-batch data for node tasks.
A mini-batch comprises three objects: 1) the input node IDs of the mini-batch, 2) the target nodes, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Returns
dict of Tensors : the input node IDs of the mini-batch.
dict of Tensors : the target node indexes.
list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- __len__()
Follow the https://github.com/dmlc/dgl/blob/1.0.x/python/dgl/distributed/dist_dataloader.py#L116. In DGL,
DistDataLoader.expected_idxsis the length (number of batches) of the dataloader.Returns:
int: The length (number of batches) of the dataloader.