GSgnnNodeTrainData
- class graphstorm.dataloading.GSgnnNodeTrainData(graph_name, part_config, train_ntypes, eval_ntypes=None, label_field=None, node_feat_field=None, edge_feat_field=None)
Bases:
GSgnnNodeDataTraining data for node tasks
GSgnnNodeTrainData prepares the data for training node prediction.
Parameters
- graph_namestr
The graph name
- part_configstr
The path of the partition configuration file.
- train_ntypesstr or list of str
Target node types for training
- eval_ntypesstr or list of str
Target node types for evaluation
- label_fieldstr
The field for storing labels
- node_feat_field: str or dict of list of str
Fields to extract node features. It’s a dict if different node types have different feature names.
- edge_feat_fieldstr or dict of list of str
The field of the edge features. It’s a dict if different edge types have different feature names.
Examples
from graphstorm.dataloading import GSgnnNodeTrainData from graphstorm.dataloading import GSgnnNodeDataLoader np_data = GSgnnNodeTrainData(graph_name='dummy', part_config=part_config, train_ntypes=['n1'], label_field='label', node_feat_field='feat') np_dataloader = GSgnnNodeDataLoader(np_data, target_idx={'n1':[0]}, fanout=[15, 10], batch_size=128)
- get_edge_feats(input_edges, edge_feat_field, device='cpu')
Get the edge features
Parameters
- input_edgesTensor or dict of Tensors
The input edge IDs
- edge_feat_field: str or dict of [str ..]
The edge data fields that stores the edge features to retrieve
- devicePytorch device
The device where the returned edge features are stored.
Returns
dict of Tensors : The returned edge features.
- get_labels(nids, device='cpu')
Get the node labels
Parameters
- nidsTensor or dict of Tensors
The seed nodes
- devicePytorch device
The device where the returned node labels are stored.
Returns
dict of Tensors : the returned node labels.