GSgnnLinkPredictionModelBase

class graphstorm.model.GSgnnLinkPredictionModelBase(*args, **kwargs)

Bases: GSgnnModelBase, GSgnnLinkPredictionModelInterface

GraphStorm GNN model base class for link-prediction tasks.

This base class extends GraphStorm GSgnnModelBase and GSgnnLinkPredictionModelInterface. When users want to define a customized link prediction GNN model and train the model in GraphStorm, the model class needs to inherit from this base class, and implement the required methods including forward(), predict(), save_model(), restore_model() and create_optimizer().