GSgnnInstanceEvaluator

class graphstorm.eval.GSgnnInstanceEvaluator(eval_frequency, eval_metric, use_early_stop=False, early_stop_burnin_rounds=0, early_stop_rounds=3, early_stop_strategy='average_increase')

Bases: object

Template class for user defined evaluator.

Parameters

eval_frequency: int

The frequency (# of iterations) of doing evaluation.

eval_metric: list of string

Evaluation metric used during evaluation.

use_early_stop: bool

Set true to use early stop.

early_stop_burnin_rounds: int

Burn-in rounds before start checking for the early stop condition.

early_stop_rounds: int

The number of rounds for validation scores used to decide early stop.

early_stop_strategy: str

The early stop strategy. GraphStorm supports two strategies: 1) consecutive_increase and 2) average_increase.

abstract compute_score(pred, labels)

Compute evaluation score

Parameters

pred:

Rediction result

labels:

Label

do_eval(total_iters, epoch_end=False)

Decide whether to do the evaluation in current iteration or epoch

Parameters

total_iters: int

The total number of iterations has been taken.

epoch_end: bool

Whether it is the end of an epoch

Returns

Whether do evaluation: bool

abstract evaluate(val_pred, test_pred, val_labels, test_labels, total_iters)

GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution.

Parameters

val_predtensor

The tensor stores the prediction results on the validation nodes.

test_predtensor

The tensor stores the prediction results on the test nodes.

val_labelstensor

The tensor stores the labels of the validation nodes.

test_labelstensor

The tensor stores the labels of the test nodes.

total_iters: int

The current interation number.

Returns

eval_score: float

Validation score

test_score: float

Test score