Skip to content

Commit

Permalink
Add a GCNN module (#11)
Browse files Browse the repository at this point in the history
* Replace GIN with GraphConv

* Improve GCNN
  • Loading branch information
trsvchn authored Dec 18, 2023
1 parent 0f5a24b commit 8cd8f12
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion flexynesis/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import torch
from torch import nn
import torch_geometric.nn as gnn

__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "Classifier", "CNN"]

__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "Classifier", "CNN", "GCNN"]


class Encoder(nn.Module):
Expand Down Expand Up @@ -249,3 +251,18 @@ def forward(self, x):

x = x.squeeze(-1)
return x


class GCNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()

self.layer_1 = nn.Sequential(gnn.GraphConv(input_dim, hidden_dim), nn.ReLU())
self.layer_2 = nn.Sequential(gnn.GraphConv(hidden_dim, output_dim), nn.ReLU())
self.aggregation = gnn.aggr.SumAggregation()

def forward(self, x, edge_index, batch):
x = self.layer_1(x, edge_index)
x = self.layer_2(x, edge_index)
x = self.aggregation(x, batch)
return x

0 comments on commit 8cd8f12

Please sign in to comment.