From 8cd8f1220e98415df78c1d1e171f882c0baf73dc Mon Sep 17 00:00:00 2001 From: Taras Savchyn <30748114+trsvchn@users.noreply.github.com> Date: Mon, 18 Dec 2023 01:48:38 +0100 Subject: [PATCH] Add a GCNN module (#11) * Replace GIN with GraphConv * Improve GCNN --- flexynesis/modules.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 58d67c6..6c5a57b 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -2,8 +2,10 @@ import torch from torch import nn +import torch_geometric.nn as gnn -__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "Classifier", "CNN"] + +__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "Classifier", "CNN", "GCNN"] class Encoder(nn.Module): @@ -249,3 +251,18 @@ def forward(self, x): x = x.squeeze(-1) return x + + +class GCNN(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + + self.layer_1 = nn.Sequential(gnn.GraphConv(input_dim, hidden_dim), nn.ReLU()) + self.layer_2 = nn.Sequential(gnn.GraphConv(hidden_dim, output_dim), nn.ReLU()) + self.aggregation = gnn.aggr.SumAggregation() + + def forward(self, x, edge_index, batch): + x = self.layer_1(x, edge_index) + x = self.layer_2(x, edge_index) + x = self.aggregation(x, batch) + return x