Coverage for /builds/kinetik161/ase/ase/dft/bz.py: 93.96%
149 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1from math import cos, pi, sin
2from typing import Any, Dict
4import numpy as np
7def bz_vertices(icell, dim=3):
8 """See https://xkcd.com/1421 ..."""
9 from scipy.spatial import Voronoi
10 icell = icell.copy()
11 if dim < 3:
12 icell[2, 2] = 1e-3
13 if dim < 2:
14 icell[1, 1] = 1e-3
16 indices = (np.indices((3, 3, 3)) - 1).reshape((3, 27))
17 G = np.dot(icell.T, indices).T
18 vor = Voronoi(G)
19 bz1 = []
20 for vertices, points in zip(vor.ridge_vertices, vor.ridge_points):
21 if -1 not in vertices and 13 in points:
22 normal = G[points].sum(0)
23 normal /= (normal**2).sum()**0.5
24 bz1.append((vor.vertices[vertices], normal))
25 return bz1
28class FlatPlot:
29 """Helper class for 1D/2D Brillouin zone plots."""
31 axis_dim = 2 # Dimension of the plotting surface (2 even if it's 1D BZ).
32 point_options = {'zorder': 5}
34 def new_axes(self, fig):
35 return fig.gca()
37 def adjust_view(self, ax, minp, maxp):
38 ax.autoscale_view(tight=True)
39 s = maxp * 1.05
40 ax.set_xlim(-s, s)
41 ax.set_ylim(-s, s)
42 ax.set_aspect('equal')
44 def draw_arrow(self, ax, vector, **kwargs):
45 ax.arrow(0, 0, vector[0], vector[1],
46 lw=1,
47 length_includes_head=True,
48 head_width=0.03,
49 head_length=0.05,
50 **kwargs)
52 def label_options(self, point):
53 ha_s = ['right', 'left', 'right']
54 va_s = ['bottom', 'bottom', 'top']
56 x, y = point
57 ha = ha_s[int(np.sign(x))]
58 va = va_s[int(np.sign(y))]
59 return {'ha': ha, 'va': va, 'zorder': 4}
62class SpacePlot:
63 """Helper class for ordinary (3D) Brillouin zone plots."""
64 axis_dim = 3
65 point_options: Dict[str, Any] = {}
67 def __init__(self, *, elev=None):
68 from matplotlib.patches import FancyArrowPatch
69 from mpl_toolkits.mplot3d import Axes3D, proj3d
70 Axes3D # silence pyflakes
72 class Arrow3D(FancyArrowPatch):
73 def __init__(self, ax, xs, ys, zs, *args, **kwargs):
74 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
75 self._verts3d = xs, ys, zs
76 self.ax = ax
78 def draw(self, renderer):
79 xs3d, ys3d, zs3d = self._verts3d
80 xs, ys, zs = proj3d.proj_transform(xs3d, ys3d,
81 zs3d, self.ax.axes.M)
82 self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
83 FancyArrowPatch.draw(self, renderer)
85 # FIXME: Compatibility fix for matplotlib 3.5.0: Handling of 3D
86 # artists have changed and all 3D artists now need
87 # "do_3d_projection". Since this class is a hack that manually
88 # projects onto the 3D axes we don't need to do anything in this
89 # method. Ideally we shouldn't resort to a hack like this.
90 def do_3d_projection(self, *_, **__):
91 return 0
93 self.arrow3d = Arrow3D
94 self.azim = pi / 5
95 if elev is None:
96 elev = pi / 6
97 self.elev = elev
98 x = sin(self.azim)
99 y = cos(self.azim)
100 self.view = [x * cos(elev), y * cos(elev), sin(elev)]
102 def new_axes(self, fig):
103 return fig.add_subplot(projection='3d')
105 def draw_arrow(self, ax, vector, **kwargs):
106 ax.add_artist(self.arrow3d(
107 ax,
108 [0, vector[0]],
109 [0, vector[1]],
110 [0, vector[2]],
111 mutation_scale=20,
112 arrowstyle='-|>',
113 **kwargs))
115 def adjust_view(self, ax, minp, maxp):
116 import matplotlib.pyplot as plt
118 # ax.set_aspect('equal') <-- won't work anymore in 3.1.0
119 ax.view_init(azim=self.azim / pi * 180, elev=self.elev / pi * 180)
120 # We want aspect 'equal', but apparently there was a bug in
121 # matplotlib causing wrong behaviour. Matplotlib raises
122 # NotImplementedError as of v3.1.0. This is a bit unfortunate
123 # because the workarounds known to StackOverflow and elsewhere
124 # all involve using set_aspect('equal') and then doing
125 # something more.
126 #
127 # We try to get square axes here by setting a square figure,
128 # but this is probably rather inexact.
129 fig = ax.get_figure()
130 xx = plt.figaspect(1.0)
131 fig.set_figheight(xx[1])
132 fig.set_figwidth(xx[0])
134 ax.set_proj_type('ortho')
136 minp0 = 0.9 * minp # Here we cheat a bit to trim spacings
137 maxp0 = 0.9 * maxp
138 ax.set_xlim3d(minp0, maxp0)
139 ax.set_ylim3d(minp0, maxp0)
140 ax.set_zlim3d(minp0, maxp0)
142 ax.set_box_aspect([1, 1, 1])
144 def label_options(self, point):
145 return dict(ha='center', va='bottom')
148def normalize_name(name):
149 if name == 'G':
150 return '\\Gamma'
152 if len(name) > 1:
153 import re
154 m = re.match(r'^(\D+?)(\d*)$', name)
155 if m is None:
156 raise ValueError(f'Bad label: {name}')
157 name, num = m.group(1, 2)
158 if num:
159 name = f'{name}_{{{num}}}'
160 return name
163def bz_plot(cell, vectors=False, paths=None, points=None,
164 elev=None, scale=1, interactive=False,
165 pointstyle=None, ax=None, show=False):
166 import matplotlib.pyplot as plt
168 if pointstyle is None:
169 pointstyle = {}
171 cell = cell.copy()
173 dimensions = cell.rank
174 if dimensions == 3:
175 plotter = SpacePlot()
176 else:
177 plotter = FlatPlot()
178 assert dimensions > 0, 'No BZ for 0D!'
180 if ax is None:
181 ax = plotter.new_axes(plt.gcf())
183 assert not cell[dimensions:, :].any()
184 assert not cell[:, dimensions:].any()
186 icell = cell.reciprocal()
187 kpoints = points
188 bz1 = bz_vertices(icell, dim=dimensions)
190 maxp = 0.0
191 minp = 0.0
192 for points, normal in bz1:
193 ls = '-'
194 xyz = np.concatenate([points, points[:1]]).T
195 if dimensions == 3:
196 if normal @ plotter.view < 0 and not interactive:
197 ls = ':'
199 ax.plot(*xyz[:plotter.axis_dim], c='k', ls=ls)
200 maxp = max(maxp, points.max())
201 minp = min(minp, points.min())
203 if vectors:
204 for i in range(dimensions):
205 plotter.draw_arrow(ax, icell[i], color='k')
207 # XXX Can this be removed?
208 if dimensions == 3:
209 maxp = max(maxp, 0.6 * icell.max())
210 else:
211 maxp = max(maxp, icell.max())
213 if paths is not None:
214 for names, points in paths:
215 coords = np.array(points).T[:plotter.axis_dim, :]
216 ax.plot(*coords, c='r', ls='-')
218 for name, point in zip(names, points):
219 name = normalize_name(name)
220 point = point[:plotter.axis_dim]
221 ax.text(*point, rf'$\mathrm{{{name}}}$',
222 color='g', **plotter.label_options(point))
224 if kpoints is not None:
225 kw = {'c': 'b', **plotter.point_options, **pointstyle}
226 ax.scatter(*kpoints[:, :plotter.axis_dim].T, **kw)
228 ax.set_axis_off()
230 plotter.adjust_view(ax, minp, maxp)
232 if show:
233 plt.show()
235 return ax