EntityClassifier
- class graphstorm.model.EntityClassifier(in_dim, num_classes, multilabel, dropout=0, norm=None)
Bases:
GSLayerDecoder for node classification tasks.
Parameters
- in_dim: int
The input dimension size.
- num_classes: int
The number of classes to predict.
- multilabel: bool
Whether this is a multi-label classification decoder.
- dropout: float
Dropout rate. Default: 0.
- norm: str
Normalization methods. Not used, but reserved for complex node classifier implementation. Default: None.
- forward(inputs)
Node classification decoder forward computation.
Parameters
- inputs: Tensor
The input embeddings.
Returns
Tensor: the prediction logits.
- predict(inputs)
Node classification prediction computation.
Parameters
- inputs: Tensor
The input embeddings.
Returns
Tensor: argmax of the prediction results, or the maximum of the prediction results if
multilabelisTrue.
- predict_proba(inputs)
Node classification prediction computation and return normalized prediction results.
Parameters
- inputs: Tensor
The input embeddings.
Returns
Tensor: normalized prediction results.
- property in_dims
Return the input dimension size, which is given in class initialization.
- property out_dims
Return the output dimensions size, which is given in class initialization.