GSgnnNodeSemiSupDataLoader

class graphstorm.dataloading.GSgnnNodeSemiSupDataLoader(dataset, target_idx, unlabeled_idx, fanout, batch_size, label_field, node_feats=None, edge_feats=None, train_task=True, construct_feat_ntype=None, construct_feat_fanout=5)

Bases: GSgnnNodeDataLoader

Semi-supervised mini-batch dataloader for node tasks.

Parameters

dataset: GSgnnData

The GraphStorm data.

target_idxdict of Tensors

The target node indexes for prediction.

unlabeled_idxdict of Tensors

The unlabeled node indexes for semi-supervised training.

fanout: list of int, or dict of list

Neighbor sampling fanout. If it’s a dict of list, it indicates the fanout for each edge type.

batch_size: int

Mini-batch size, the sum of labeled and unlabeled nodes

label_field: str

Label field of the node task.

node_feats: str, list of str, or dict of list of str

Node feature fileds in three possible formats:

  • string: All nodes have the same feature name.

  • list of string: All nodes have the same list of features.

  • dict of list of string: Each node type have different set of node features.

Default: None

edge_feats: str, list of str, or dict of list of str

Edge feature fileds in three possible formats:

  • string: All edges have the same feature name.

  • list of string: All edges have the same list of features.

  • dict of list of string: Each edge type have different set of edge features.

Default: None

train_taskbool

Whether or not it is the dataloader for training.

construct_feat_ntypelist of str

The node types that requires to construct node features.

construct_feat_fanoutint

The fanout used when constructing node features for feature-less nodes.

__iter__()

Returns an iterator object.

__next__()

Return a mini-batch data for node tasks.

A mini-batch comprises three objects: 1) the input node IDs of the mini-batch, 2) the target nodes, and 3) the sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Returns

  • dict of Tensors : the input node IDs of the mini-batch.

  • dict of Tensors : the target node indexes.

  • list of DGL MFGs : the list of DGL message flow graphs (MFGs) for message passing. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

__len__()

Follow the https://github.com/dmlc/dgl/blob/1.0.x/python/dgl/distributed/dist_dataloader.py#L116. In DGL, DistDataLoader.expected_idxs is the length (number of batches) of the dataloader. As it uses two dataloader, either one throws an End of Iter error will stop the dataloader.

Returns:

int: The length (number of batches) of the dataloader.