LinkPredictWeightedRotatEDecoder

class graphstorm.model.LinkPredictWeightedRotatEDecoder(etypes, h_dim, gamma=12.0, edge_weight_fields=None)

Bases: LinkPredictRotatEDecoder

Link prediction decoder with the score function of RotatE with edge weight.

When computing loss, edge weights are used to adjust the loss.

Parameters

etypes: list of tuples

The canonical edge types of the graph in the format of [(src_ntype1, etype1, dst_ntype1), …]

h_dim: int

The input dimension size. It is the dimension for both source and destination node embeddings.

gamma: float

The gamma value for model weight initialization. Default: 12.

edge_weight_fields: dict of str

The edge feature field(s) storing the edge weights.

New in version 0.4.0: The LinkPredictWeightedRotatEDecoder.

forward(g, h, e_h)

Forward function.

This computes the RotatE score on every edge type.