GSgnnNodeDataLoader
- class graphstorm.dataloading.GSgnnNodeDataLoader(dataset, target_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, train_task=True, construct_feat_ntype=None, construct_feat_fanout=5)
Bases:
GSgnnNodeDataLoaderBaseMinibatch dataloader for node tasks
GSgnnNodeDataLoader samples GraphStorm node dataset into an iterable over mini-batches of samples including target nodes and sampled neighbor nodes, which will be used by GraphStorm Trainers and Inferrers.
Parameters
- dataset: GSgnnData
The GraphStorm dataset
- target_idxdict of Tensors
The target nodes for prediction
- fanout: list of int or dict of list
Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.
- label_field: str
Label field of the node task. (TODO:xiangsx) Support list of str for single dataloader multiple node tasks.
- node_feats: str, list of str or dist of list of str
Node features. str: All the nodes have the same feature name. list of string: All the nodes have the same list of features. dist of list of string: Each node type have different set of node features. Default: None
- edge_feats: str, list of str or dist of list of str
Edge features. str: All the edges have the same feature name. list of string: All the edges have the same list of features. dist of list of string: Each edge type have different set of edge features. Default: None
- batch_size: int
Batch size
- train_taskbool
Whether or not for training.
- construct_feat_ntypelist of str
The node types that requires to construct node features.
- construct_feat_fanoutint
The fanout required to construct node features.
Examples
To train a 2-layer GNN for node classification on a set of nodes
target_idxon a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.trainer import GSgnnNodePredictionTrainer np_data = GSgnnData(...) target_idx = np_data.get_node_train_set(...) np_dataloader = GSgnnNodeDataLoader(np_data, target_idx, fanout=[15, 10], batch_size=128, label_field="label", node_feats="feat") np_trainer = GSgnnNodePredictionTrainer(...) np_trainer.fit(np_dataloader, num_epochs=10)
- __iter__()
Returns an iterator object
- __next__()
Return a mini-batch data for the node task.
A mini-batch comprises three objects: the input node IDs of the mini-batch, the target nodes and the subgraph blocks for message passing.
Returns
dict of Tensors : the input node IDs of the mini-batch. dict of Tensors : the target node IDs. list of DGLGraph : the subgraph blocks for message passing.