-
Notifications
You must be signed in to change notification settings - Fork 0
/
xpm3d.py
69 lines (54 loc) · 2.04 KB
/
xpm3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps
from scipy.interpolate import griddata
import os
def plot_pm3d(
x,y,z,bins=None,xlabel=None,ylabel=None,title=None,fout=None,
xmin=None,xmax=None,ymin=None,ymax=None,colormap=None,interpolate=None
):
"""
works like the contour map `gnuplot pm3d`, http://www.gnuplot.info/docs/loc6259.html
Args:
x | y | z (np.array)) : 1D correspondent array, has the relation `z = f(x,y)`
"""
assert len(x) == len(y)
assert len(y) == len(z)
if xmin is None: xmin = x.min()
if xmax is None: xmax = x.max()
if ymin is None: ymin = y.min()
if ymax is None: ymax = y.max()
if xlabel is None: xlabel = 'x'
if ylabel is None: ylabel = 'y'
if fout is None or not isinstance(fout,str): fout = 'plot_im3d.png'
if colormap is None or colormap not in ['jet','rainbow','gist_rainbow', 'magma','brg']:
colormap = 'rainbow'
if interpolate is None or interpolate not in ['nearest','linear','cubic']:
interpolate = 'linear'
if bins is None: bins = int(len(x)/2*0.8)
xx = np.linspace(xmin, xmax, bins)
yy = np.linspace(ymin, ymax, bins)
grid = np.array(np.meshgrid(xx, yy.T))
grid = grid.reshape(2, grid.shape[1]*grid.shape[2]).T
points = np.array([x, y]).T # transpose for `griddata`
zz = griddata(points, z, grid, method=interpolate)
zz = zz.reshape(xx.shape[0], yy.shape[0])
fig, ax = plt.subplots(1, 1, figsize=(30, 30))
cax = ax.imshow(zz, extent=[xmin,xmax,ymin,ymax], origin='lower', cmap=colormaps[colormap])
if title is not None: ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
plt.colorbar(cax)
plt.show()
#fig.savefig(fout)
file = 'fes_2D.dat'
# format: # phi psi fes *
if os.path.isfile(file):
print(f'Fatal: not a file: {file}')
exit()
data = np.loadtxt(file,comments='#',dtype=np.float_)
x = data[:,0]
y = data[:,1]
z = data[:,2]
plot_pm3d(x,y,z,xmin=-3.14,xmax=3.14,ymin=-3.14,ymax=3.14)