GSgnnPerEtypeMrrLPEvaluator

class graphstorm.eval.GSgnnPerEtypeMrrLPEvaluator(eval_frequency, data, num_negative_edges_eval, lp_decoder_type, major_etype='ALL', use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy='average_increase')

Bases: GSgnnMrrLPEvaluator

The class for link prediction evaluation using Mrr metric and

return a Per etype mrr score.

Parameters

eval_frequency: int

The frequency (# of iterations) of doing evaluation.

data: GSgnnEdgeData

The processed dataset

num_negative_edges_eval: int

Number of negative edges sampled for each positive edge in evalation.

lp_decoder_type: str

Link prediction decoder type.

major_etype: tuple

Canonical etype used for selecting the best model. If None, use the general mrr.

use_early_stop: bool

Set true to use early stop.

early_stop_burnin_rounds: int

Burn-in rounds before start checking for the early stop condition.

early_stop_rounds: int

The number of rounds for validation scores used to decide early stop.

early_stop_strategy: str

The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase.

compute_score(rankings, train=False)

Compute evaluation score

Parameters

rankings: dict of tensors

Rankings of positive scores in format of {etype: ranking}

train: bool

TODO: Reversed for future use cases when we want to use different way to generate scores for train (more efficient but less accurate) and test.

Returns

Evaluation metric values: dict

do_eval(total_iters, epoch_end=False)

Decide whether to do the evaluation in current iteration or epoch

Parameters

epoch: int

The epoch number

total_iters: int

The total number of iterations has been taken.

epoch_end: bool

Whether it is the end of an epoch

Returns

Whether do evaluation: bool

evaluate(val_scores, test_scores, total_iters)

GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution.

Parameters

val_scores: dict of tensors

Rankings of positive scores of validation edges for each edge type.

test_scores: dict of tensors

Rankings of positive scores of test edges for each edge type..

total_iters: int

The current interation number.

Returns

val_mrr: float

Validation mrr

test_mrr: float

Test mrr