diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 58d67c6..6c5a57b 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -2,8 +2,10 @@ import torch from torch import nn +import torch_geometric.nn as gnn -__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "Classifier", "CNN"] + +__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "Classifier", "CNN", "GCNN"] class Encoder(nn.Module): @@ -249,3 +251,18 @@ def forward(self, x): x = x.squeeze(-1) return x + + +class GCNN(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + + self.layer_1 = nn.Sequential(gnn.GraphConv(input_dim, hidden_dim), nn.ReLU()) + self.layer_2 = nn.Sequential(gnn.GraphConv(hidden_dim, output_dim), nn.ReLU()) + self.aggregation = gnn.aggr.SumAggregation() + + def forward(self, x, edge_index, batch): + x = self.layer_1(x, edge_index) + x = self.layer_2(x, edge_index) + x = self.aggregation(x, batch) + return x