GSgnnLinkPredictionModelBase
- class graphstorm.model.GSgnnLinkPredictionModelBase(*args, **kwargs)
Bases:
GSgnnModelBase,GSgnnLinkPredictionModelInterfaceGraphStorm GNN model base class for link-prediction tasks.
This base class extends GraphStorm
GSgnnModelBaseandGSgnnLinkPredictionModelInterface. When users want to define a customized link prediction GNN model and train the model in GraphStorm, the model class needs to inherit from this base class, and implement the required methods includingforward(),predict(),save_model(),restore_model()andcreate_optimizer().