GSgnnMrrLPEvaluator
- class graphstorm.eval.GSgnnMrrLPEvaluator(eval_frequency, eval_metric_list=None, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy='average_increase')
Bases:
GSgnnBaseEvaluator,GSgnnLPRankingEvalInterfaceLink Prediction Evaluator using “mrr” as metric.
GS built-in evaluator for Link Prediction tasks. It uses “mrr” as the default eval metric, which implements the GSgnnLPRankingEvalInterface.
To create a customized LP evaluator that use evaluation metric other than “mrr”, users might need to 1) define a new evaluation interface if the evaluation method requires different input arguments; 2) inherite the new evaluation interface in a customized LP evaluator; 3) define a customized LP trainer/inferrer to call the customized LP evaluator.
Parameters
- eval_frequency: int
The frequency (number of iterations) of doing evaluation.
- eval_metric_list: list of string
Evaluation metric used during evaluation. Default: [‘mrr’]
- use_early_stop: bool
Set true to use early stop.
- early_stop_burnin_rounds: int
Burn-in rounds before start checking for the early stop condition.
- early_stop_rounds: int
The number of rounds for validation scores used to decide early stop.
- early_stop_strategy: str
The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase.
- compute_score(rankings, train=True)
Compute evaluation score
Parameters
- rankings: dict of tensors
Rankings of positive scores in format of {etype: ranking}
- train: boolean
If in model training.
Returns
Evaluation metric values: dict
- do_eval(total_iters, epoch_end=False)
Decide whether to do the evaluation in current iteration or epoch
Parameters
- total_iters: int
The total number of iterations has been taken.
- epoch_end: bool
Whether it is the end of an epoch
Returns
Whether do evaluation: bool
- evaluate(val_rankings, test_rankings, total_iters)
GSgnnLinkPredictionTrainer and GSgnnLinkPredictionInferrer will call this function to compute validation and test scores.
Parameters
- val_rankings: dict of tensors
Rankings of positive scores of validation edges for each edge type.
- test_rankings: dict of tensors
Rankings of positive scores of test edges for each edge type..
- total_iters: int
The current interation number.
Returns
- val_mrr: float
Validation mrr score
- test_mrr: float
Test mrr score