-
Notifications
You must be signed in to change notification settings - Fork 0
/
heatmaps.py
56 lines (46 loc) · 1.64 KB
/
heatmaps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import json
import matplotlib
import os
import signac
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from collections import defaultdict
all_results = []
project = signac.get_project()
for job in project:
try:
with open(job.fn('analysis_data.json'), 'r') as f:
data = json.load(f)
all_results.append(data)
except FileNotFoundError:
print(f"Missing analysis_data.json for job: {job.id}")
df = pd.DataFrame(all_results)
df.to_csv("aggregated_results.csv", index=False)
statepoint_values = defaultdict(set)
for job in project:
for key, value in job.sp.items():
if key != "seed":
statepoint_values[key].add(value)
keys_of_interest = []
for key, values in statepoint_values.items():
if len(values) > 1:
keys_of_interest.append(key)
data_values = ["MSD_correlation", "eccentricity"]
df = pd.read_csv("aggregated_results.csv")
if not os.path.exists('heatmaps'):
os.makedirs('heatmaps')
# Iterate through pairs of keys of interest
for i, key1 in enumerate(keys_of_interest):
for key2 in keys_of_interest[i + 1:]:
for data_value in data_values:
pivot_table = df.pivot_table(
index=key1, columns=key2, values=data_value, aggfunc="mean")
plt.figure(figsize=(10, 8))
sns.heatmap(pivot_table, annot=True, cmap="viridis", cbar_kws={'label': data_value})
plt.title(f"{data_value} as a function of {key1} and {key2}")
plt.xlabel(key2)
plt.ylabel(key1)
plt.savefig(f"heatmaps/heatmap_{data_value}_{key1}_vs_{key2}.png")
plt.close()