GSgnnInstanceEvaluator
- class graphstorm.eval.GSgnnInstanceEvaluator(eval_frequency, eval_metric, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy='average_increase')
Bases:
objectTemplate class for user defined evaluator.
Parameters
- eval_frequency: int
The frequency (# of iterations) of doing evaluation.
- eval_metric: list of string
Evaluation metric used during evaluation.
- use_early_stop: bool
Set true to use early stop.
- early_stop_burnin_rounds: int
Burn-in rounds before start checking for the early stop condition.
- early_stop_rounds: int
The number of rounds for validation scores used to decide early stop.
- early_stop_strategy: str
The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase.
- abstract compute_score(pred, labels)
Compute evaluation score
Parameters
- pred:
Rediction result
- labels:
Label
- do_eval(total_iters, epoch_end=False)
Decide whether to do the evaluation in current iteration or epoch
Parameters
- total_iters: int
The total number of iterations has been taken.
- epoch_end: bool
Whether it is the end of an epoch
Returns
Whether do evaluation: bool
- abstract evaluate(val_pred, test_pred, val_labels, test_labels, total_iters)
GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution.
Parameters
- val_predtensor
The tensor stores the prediction results on the validation nodes.
- test_predtensor
The tensor stores the prediction results on the test nodes.
- val_labelstensor
The tensor stores the labels of the validation nodes.
- test_labelstensor
The tensor stores the labels of the test nodes.
- total_iters: int
The current interation number.
Returns
- eval_score: float
Validation score
- test_score: float
Test score