Coverage for /builds/kinetik161/ase/ase/dft/bz.py: 93.96%

149 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1from math import cos, pi, sin 

2from typing import Any, Dict 

3 

4import numpy as np 

5 

6 

7def bz_vertices(icell, dim=3): 

8 """See https://xkcd.com/1421 ...""" 

9 from scipy.spatial import Voronoi 

10 icell = icell.copy() 

11 if dim < 3: 

12 icell[2, 2] = 1e-3 

13 if dim < 2: 

14 icell[1, 1] = 1e-3 

15 

16 indices = (np.indices((3, 3, 3)) - 1).reshape((3, 27)) 

17 G = np.dot(icell.T, indices).T 

18 vor = Voronoi(G) 

19 bz1 = [] 

20 for vertices, points in zip(vor.ridge_vertices, vor.ridge_points): 

21 if -1 not in vertices and 13 in points: 

22 normal = G[points].sum(0) 

23 normal /= (normal**2).sum()**0.5 

24 bz1.append((vor.vertices[vertices], normal)) 

25 return bz1 

26 

27 

28class FlatPlot: 

29 """Helper class for 1D/2D Brillouin zone plots.""" 

30 

31 axis_dim = 2 # Dimension of the plotting surface (2 even if it's 1D BZ). 

32 point_options = {'zorder': 5} 

33 

34 def new_axes(self, fig): 

35 return fig.gca() 

36 

37 def adjust_view(self, ax, minp, maxp): 

38 ax.autoscale_view(tight=True) 

39 s = maxp * 1.05 

40 ax.set_xlim(-s, s) 

41 ax.set_ylim(-s, s) 

42 ax.set_aspect('equal') 

43 

44 def draw_arrow(self, ax, vector, **kwargs): 

45 ax.arrow(0, 0, vector[0], vector[1], 

46 lw=1, 

47 length_includes_head=True, 

48 head_width=0.03, 

49 head_length=0.05, 

50 **kwargs) 

51 

52 def label_options(self, point): 

53 ha_s = ['right', 'left', 'right'] 

54 va_s = ['bottom', 'bottom', 'top'] 

55 

56 x, y = point 

57 ha = ha_s[int(np.sign(x))] 

58 va = va_s[int(np.sign(y))] 

59 return {'ha': ha, 'va': va, 'zorder': 4} 

60 

61 

62class SpacePlot: 

63 """Helper class for ordinary (3D) Brillouin zone plots.""" 

64 axis_dim = 3 

65 point_options: Dict[str, Any] = {} 

66 

67 def __init__(self, *, elev=None): 

68 from matplotlib.patches import FancyArrowPatch 

69 from mpl_toolkits.mplot3d import Axes3D, proj3d 

70 Axes3D # silence pyflakes 

71 

72 class Arrow3D(FancyArrowPatch): 

73 def __init__(self, ax, xs, ys, zs, *args, **kwargs): 

74 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) 

75 self._verts3d = xs, ys, zs 

76 self.ax = ax 

77 

78 def draw(self, renderer): 

79 xs3d, ys3d, zs3d = self._verts3d 

80 xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, 

81 zs3d, self.ax.axes.M) 

82 self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) 

83 FancyArrowPatch.draw(self, renderer) 

84 

85 # FIXME: Compatibility fix for matplotlib 3.5.0: Handling of 3D 

86 # artists have changed and all 3D artists now need 

87 # "do_3d_projection". Since this class is a hack that manually 

88 # projects onto the 3D axes we don't need to do anything in this 

89 # method. Ideally we shouldn't resort to a hack like this. 

90 def do_3d_projection(self, *_, **__): 

91 return 0 

92 

93 self.arrow3d = Arrow3D 

94 self.azim = pi / 5 

95 if elev is None: 

96 elev = pi / 6 

97 self.elev = elev 

98 x = sin(self.azim) 

99 y = cos(self.azim) 

100 self.view = [x * cos(elev), y * cos(elev), sin(elev)] 

101 

102 def new_axes(self, fig): 

103 return fig.add_subplot(projection='3d') 

104 

105 def draw_arrow(self, ax, vector, **kwargs): 

106 ax.add_artist(self.arrow3d( 

107 ax, 

108 [0, vector[0]], 

109 [0, vector[1]], 

110 [0, vector[2]], 

111 mutation_scale=20, 

112 arrowstyle='-|>', 

113 **kwargs)) 

114 

115 def adjust_view(self, ax, minp, maxp): 

116 import matplotlib.pyplot as plt 

117 

118 # ax.set_aspect('equal') <-- won't work anymore in 3.1.0 

119 ax.view_init(azim=self.azim / pi * 180, elev=self.elev / pi * 180) 

120 # We want aspect 'equal', but apparently there was a bug in 

121 # matplotlib causing wrong behaviour. Matplotlib raises 

122 # NotImplementedError as of v3.1.0. This is a bit unfortunate 

123 # because the workarounds known to StackOverflow and elsewhere 

124 # all involve using set_aspect('equal') and then doing 

125 # something more. 

126 # 

127 # We try to get square axes here by setting a square figure, 

128 # but this is probably rather inexact. 

129 fig = ax.get_figure() 

130 xx = plt.figaspect(1.0) 

131 fig.set_figheight(xx[1]) 

132 fig.set_figwidth(xx[0]) 

133 

134 ax.set_proj_type('ortho') 

135 

136 minp0 = 0.9 * minp # Here we cheat a bit to trim spacings 

137 maxp0 = 0.9 * maxp 

138 ax.set_xlim3d(minp0, maxp0) 

139 ax.set_ylim3d(minp0, maxp0) 

140 ax.set_zlim3d(minp0, maxp0) 

141 

142 ax.set_box_aspect([1, 1, 1]) 

143 

144 def label_options(self, point): 

145 return dict(ha='center', va='bottom') 

146 

147 

148def normalize_name(name): 

149 if name == 'G': 

150 return '\\Gamma' 

151 

152 if len(name) > 1: 

153 import re 

154 m = re.match(r'^(\D+?)(\d*)$', name) 

155 if m is None: 

156 raise ValueError(f'Bad label: {name}') 

157 name, num = m.group(1, 2) 

158 if num: 

159 name = f'{name}_{{{num}}}' 

160 return name 

161 

162 

163def bz_plot(cell, vectors=False, paths=None, points=None, 

164 elev=None, scale=1, interactive=False, 

165 pointstyle=None, ax=None, show=False): 

166 import matplotlib.pyplot as plt 

167 

168 if pointstyle is None: 

169 pointstyle = {} 

170 

171 cell = cell.copy() 

172 

173 dimensions = cell.rank 

174 if dimensions == 3: 

175 plotter = SpacePlot() 

176 else: 

177 plotter = FlatPlot() 

178 assert dimensions > 0, 'No BZ for 0D!' 

179 

180 if ax is None: 

181 ax = plotter.new_axes(plt.gcf()) 

182 

183 assert not cell[dimensions:, :].any() 

184 assert not cell[:, dimensions:].any() 

185 

186 icell = cell.reciprocal() 

187 kpoints = points 

188 bz1 = bz_vertices(icell, dim=dimensions) 

189 

190 maxp = 0.0 

191 minp = 0.0 

192 for points, normal in bz1: 

193 ls = '-' 

194 xyz = np.concatenate([points, points[:1]]).T 

195 if dimensions == 3: 

196 if normal @ plotter.view < 0 and not interactive: 

197 ls = ':' 

198 

199 ax.plot(*xyz[:plotter.axis_dim], c='k', ls=ls) 

200 maxp = max(maxp, points.max()) 

201 minp = min(minp, points.min()) 

202 

203 if vectors: 

204 for i in range(dimensions): 

205 plotter.draw_arrow(ax, icell[i], color='k') 

206 

207 # XXX Can this be removed? 

208 if dimensions == 3: 

209 maxp = max(maxp, 0.6 * icell.max()) 

210 else: 

211 maxp = max(maxp, icell.max()) 

212 

213 if paths is not None: 

214 for names, points in paths: 

215 coords = np.array(points).T[:plotter.axis_dim, :] 

216 ax.plot(*coords, c='r', ls='-') 

217 

218 for name, point in zip(names, points): 

219 name = normalize_name(name) 

220 point = point[:plotter.axis_dim] 

221 ax.text(*point, rf'$\mathrm{{{name}}}$', 

222 color='g', **plotter.label_options(point)) 

223 

224 if kpoints is not None: 

225 kw = {'c': 'b', **plotter.point_options, **pointstyle} 

226 ax.scatter(*kpoints[:, :plotter.axis_dim].T, **kw) 

227 

228 ax.set_axis_off() 

229 

230 plotter.adjust_view(ax, minp, maxp) 

231 

232 if show: 

233 plt.show() 

234 

235 return ax