-
Notifications
You must be signed in to change notification settings - Fork 5
/
mtcl.py
executable file
·156 lines (131 loc) · 5.31 KB
/
mtcl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from layer import *
class graph_constructor(nn.Module):
'''
Graph Constructor is the Multivariate Time Series Correlation Layer
(MTCL) in the paper.
'''
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super(graph_constructor, self).__init__()
self.nnodes = nnodes
if static_feat is not None:
xd = static_feat.shape[1]
self.lin1 = nn.Linear(xd, dim)
self.lin2 = nn.Linear(xd, dim)
else:
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1 = nn.Linear(dim,dim)
self.lin2 = nn.Linear(dim,dim)
self.device = device
self.k = k
self.dim = dim
self.alpha = alpha
self.static_feat = static_feat
def forward(self, idx):
if self.static_feat is None:
nodevec1 = self.emb1(idx)
nodevec2 = self.emb2(idx)
else:
nodevec1 = self.static_feat[idx,:]
nodevec2 = nodevec1
nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
adj = F.relu(torch.tanh(self.alpha*a))
mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
mask.fill_(float('0'))
s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1) #rand for numerical stability
mask.scatter_(1,t1,s1.fill_(1))
adj = adj*mask
return adj
def fullA(self, idx):
if self.static_feat is None:
nodevec1 = self.emb1(idx)
nodevec2 = self.emb2(idx)
else:
nodevec1 = self.static_feat[idx,:]
nodevec2 = nodevec1
nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
adj = F.relu(torch.tanh(self.alpha*a))
return adj
################## ADDITIONAL GRAPH LEARNING LAYER TO BE POTENTIALLY APPLIED ##################
class graph_global(nn.Module):
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super(graph_global, self).__init__()
self.nnodes = nnodes
self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device)
def forward(self, idx):
return F.relu(self.A)
class graph_undirected(nn.Module):
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super(graph_undirected, self).__init__()
self.nnodes = nnodes
if static_feat is not None:
xd = static_feat.shape[1]
self.lin1 = nn.Linear(xd, dim)
else:
self.emb1 = nn.Embedding(nnodes, dim)
self.lin1 = nn.Linear(dim,dim)
self.device = device
self.k = k
self.dim = dim
self.alpha = alpha
self.static_feat = static_feat
def forward(self, idx):
if self.static_feat is None:
nodevec1 = self.emb1(idx)
nodevec2 = self.emb1(idx)
else:
nodevec1 = self.static_feat[idx,:]
nodevec2 = nodevec1
nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))
adj = F.relu(torch.tanh(self.alpha*a))
mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
mask.fill_(float('0'))
s1,t1 = adj.topk(self.k,1)
mask.scatter_(1,t1,s1.fill_(1))
adj = adj*mask
return adj
class graph_directed(nn.Module):
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super(graph_directed, self).__init__()
self.nnodes = nnodes
if static_feat is not None:
xd = static_feat.shape[1]
self.lin1 = nn.Linear(xd, dim)
self.lin2 = nn.Linear(xd, dim)
else:
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1 = nn.Linear(dim,dim)
self.lin2 = nn.Linear(dim,dim)
self.device = device
self.k = k
self.dim = dim
self.alpha = alpha
self.static_feat = static_feat
def forward(self, idx):
if self.static_feat is None:
nodevec1 = self.emb1(idx)
nodevec2 = self.emb2(idx)
else:
nodevec1 = self.static_feat[idx,:]
nodevec2 = nodevec1
nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))
a = torch.mm(nodevec1, nodevec2.transpose(1,0))
adj = F.relu(torch.tanh(self.alpha*a))
mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
mask.fill_(float('0'))
s1,t1 = adj.topk(self.k,1)
mask.scatter_(1,t1,s1.fill_(1))
adj = adj*mask
return adj