-
Notifications
You must be signed in to change notification settings - Fork 4
/
tsne.py
66 lines (53 loc) · 2.78 KB
/
tsne.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.manifold import TSNE
from utils.hash_functions import get_hash
def main(args):
# Init embeddings with shape [sample, embedding]
output_path = os.getenv('OUTPUT_PATH', os.path.join('output', 'embeddings'))
test_embeddings = torch.load(os.path.join(output_path, args.model, args.dataset, 'test', 'embeddings.pt'),
map_location='cpu')
# Multi-labels with shape [sample, label]
test_labels = torch.load(os.path.join(output_path, args.model, args.dataset, 'test', 'labels.pt'),
map_location='cpu')
np.random.seed(42)
binary_emb = get_hash(test_embeddings, 'trivial', length=768)
trivial_hash = get_hash(test_embeddings, 'trivial', length=32)
ls_hash = get_hash(test_embeddings, 'lsh', length=32)
# Select a subset to avoid computation costs
num_samples = 1000
indices = np.random.choice(len(test_embeddings), min(num_samples, len(test_embeddings)), replace=False)
sampled_binary_emb = binary_emb[indices]
sampled_embeddings = test_embeddings[indices]
sampled_trivial_hash = trivial_hash[indices]
sampled_ls_hash = ls_hash[indices]
sampled_labels = test_labels[indices].nonzero()[:, 1]
assert len(sampled_labels) == len(indices), "t-SNE plot is not working with multi-label datasets."
for i, (name, vector) in enumerate([('embedding', sampled_embeddings),
('binary', sampled_binary_emb),
('lsh', sampled_ls_hash),
('trivial', sampled_trivial_hash)]):
# Compute t-SNE embeddings
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(vector)
fig, ax = plt.subplots(figsize=(3, 3))
ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=sampled_labels, cmap='Dark2', alpha=0.9)
ax.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
fig.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
if os.path.isdir(args.output_dir):
output_file = os.path.join(args.output_dir, f"{args.model}_{args.dataset}_tsne_{name}.pdf")
else:
output_file = args.output_dir
plt.savefig(output_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output_dir', type=str, default='output/figures',
help='Path to output dir')
parser.add_argument('-d', '--dataset', type=str, default='ForestNet4')
parser.add_argument('-m', '--model', type=str, default='PrithviViT')
args = parser.parse_args()
main(args)