GSgnnNodeTrainData

class graphstorm.dataloading.GSgnnNodeTrainData(graph_name, part_config, train_ntypes, eval_ntypes=None, label_field=None, node_feat_field=None, edge_feat_field=None)

Bases: GSgnnNodeData

Training data for node tasks

GSgnnNodeTrainData prepares the data for training node prediction.

Parameters

graph_namestr

The graph name

part_configstr

The path of the partition configuration file.

train_ntypesstr or list of str

Target node types for training

eval_ntypesstr or list of str

Target node types for evaluation

label_fieldstr

The field for storing labels

node_feat_field: str or dict of list of str

Fields to extract node features. It’s a dict if different node types have different feature names.

edge_feat_fieldstr or dict of list of str

The field of the edge features. It’s a dict if different edge types have different feature names.

Examples

from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.dataloading import GSgnnNodeDataLoader

np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config,
                                train_ntypes=['n1'], label_field='label',
                                node_feat_field='feat')
np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]}, 
                                    fanout=[15, 10], batch_size=128)
get_edge_feats(input_edges, edge_feat_field, device='cpu')

Get the edge features

Parameters

input_edgesTensor or dict of Tensors

The input edge IDs

edge_feat_field: str or dict of [str ..]

The edge data fields that stores the edge features to retrieve

devicePytorch device

The device where the returned edge features are stored.

Returns

dict of Tensors : The returned edge features.

get_labels(nids, device='cpu')

Get the node labels

Parameters

nidsTensor or dict of Tensors

The seed nodes

devicePytorch device

The device where the returned node labels are stored.

Returns

dict of Tensors : the returned node labels.

get_node_feats(input_nodes, device='cpu')

Get the node features

Parameters

input_nodesTensor or dict of Tensors

The input node IDs

devicePytorch device

The device where the returned node features are stored.

Returns

dict of Tensors : The returned node features.

prepare_data(g)

Prepare the dataset.

Arguement

g: Dist DGLGraph