GSgnnNodeDataLoader
- class graphstorm.dataloading.GSgnnNodeDataLoader(dataset, target_idx, fanout, batch_size, device, train_task=True, construct_feat_ntype=None, construct_feat_fanout=5)
Bases:
GSgnnNodeDataLoaderBaseMinibatch dataloader for node tasks
GSgnnNodeDataLoader samples GraphStorm node dataset into an iterable over mini-batches of samples including target nodes and sampled neighbor nodes, which will be used by GraphStorm Trainers and Inferrers.
Parameters
- dataset: GSgnnNodeData
The GraphStorm dataset
- target_idxdict of Tensors
The target nodes for prediction
- fanout: list of int or dict of list
Neighbor sample fanout. If it’s a dict, it indicates the fanout for each edge type.
- batch_size: int
Batch size
- device: torch.device
the device trainer is running on.
- train_taskbool
Whether or not for training.
- construct_feat_ntypelist of str
The node types that requires to construct node features.
- construct_feat_fanoutint
The fanout required to construct node features.
Examples
To train a 2-layer GNN for node classification on a set of nodes
target_idxon a graph where each nodes takes messages from 15 neighbors on the first layer and 10 neighbors on the second.from graphstorm.dataloading import GSgnnNodeTrainData from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.trainer import GSgnnNodePredictionTrainer np_data = GSgnnNodeTrainData(...) np_dataloader = GSgnnNodeDataLoader(np_data, target_idx, fanout=[15, 10], batch_size=128) np_trainer = GSgnnNodePredictionTrainer(...) np_trainer.fit(np_dataloader, num_epochs=10)
- __iter__()
Returns an iterator object
- __next__()
Return a mini-batch data for the node task.
A mini-batch comprises three objects: the input node IDs of the mini-batch, the target nodes and the subgraph blocks for message passing.
Returns
dict of Tensors : the input node IDs of the mini-batch. dict of Tensors : the target node IDs. list of DGLGraph : the subgraph blocks for message passing.