GSgnnPerEtypeLPEvaluator

class graphstorm.eval.GSgnnPerEtypeLPEvaluator(eval_frequency, eval_metric_list=None, major_etype='ALL', use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy='average_increase')

Bases: GSgnnBaseEvaluator, GSgnnLPRankingEvalInterface

Evaluator for Link Prediction tasks using mrr and/or hit@k as metric(s), and return per edge type scores.

Parameters

eval_frequency: int

The frequency (number of iterations) of doing evaluation.

eval_metric_list: list of string

Evaluation metrics used during evaluation, for example, [“mrr”, “hit_at_10”]. Default: [“mrr”].

major_etype: tuple

A canonical edge type used for selecting the best model. Default: will use the summation of metric scores of all edge types.

use_early_stop: bool

Set true to use early stop. Default: False.

early_stop_burnin_rounds: int

Burn-in rounds (number of evaluations) before starting to check for the early stop condition. Default: 0.

early_stop_rounds: int

The number of rounds (number of evaluations) for validation scores used to decide early stop. Default: 3.

early_stop_strategy: str

The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase, and 2) average_increase. Default: average_increase.

New in version 0.4.0: The GSgnnPerEtypeLPEvaluator.

evaluate(val_rankings, test_rankings, total_iters, **kwargs)

GSgnnLinkPredictionTrainer and GSgnnLinkPredictionInferrer will call this function to compute validation and test scores.

Parameters

val_rankings: dict of tensors

Rankings of positive scores of validation edges for each edge type in the format of {etype: ranking}.

test_rankings: dict of tensors

Rankings of positive scores of test edges for each edge type in the format of {etype: ranking}.

total_iters: int

The current iteration number.

kwargs: dict

Keyword arguments to pass downstream to metric calculation functions.

Currently we support:

val_candidate_sizesdict of tensors

The size of each candidate list (positive + negative pairs) in the validation set, in the format of {etype: size_tensor}. If all candidate lists have the same size this will be a single-value tensor per etype.

test_candidate_sizesdict of tensors

The size of each candidate list (positive + negative pairs) in the test set, in the format of {etype: size_tensor}. If all candidate lists have the same size this will be a single-value tensor per etype.

..versionadded:: 0.4.0

Returns

val_score: dict of dict of float

Validation score in the format of {metric: {etype: val_score}}. If the val_ranking is None, return {metric: “N/A”}.

test_score: dict of dict of float

Test score in the format of {metric: {etype: test_score}}. If the test_ranking is None, return {metric: “N/A”}.

compute_score(rankings, train=True, **kwargs)

Compute per edge type evaluation score.

Parameters

rankings: dict of tensors

Rankings of positive scores in the format of {etype: ranking}.

train: boolean

If in model training.

kwargs: dict

Keyword arguments to pass downstream to the metric computation.

Currently we support:

candidate_sizes: dict of tensors, optional

The size of each candidate list corresponding to each value in the rankings, in the format of {etype: sizes}. If a tensor for an edge type has a single element we use that as the size of all lists.

..versionadded:: 0.4.0

Returns

return_metrics: dict of dict of float

Per edge type evaluation score in the format of {metric: {etype: score}}.

get_val_score_rank(val_score)

Get the rank of the validation score of the major_etype initialized in class initialization by comparing its value to the existing historical values. If using the default major_etype, it will compute the rank as the summation of validation values of all edge types.

Parameters

val_score: dict of dict

A dict in the format of {metric: {etype: score}}.

Returns

rank: int

The rank of the validation score of the given major_etype initialized in class initialization. If using the default major_etype, the rank will be computed based on the summation of validation scores for all edge types.