Coverage for /builds/kinetik161/ase/ase/spectrum/band_structure.py: 79.10%

177 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import numpy as np 

2 

3import ase # Annotations 

4from ase.calculators.calculator import PropertyNotImplementedError 

5from ase.utils import jsonable 

6 

7 

8def calculate_band_structure(atoms, path=None, scf_kwargs=None, 

9 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6): 

10 """Calculate band structure. 

11 

12 The purpose of this function is to abstract a band structure calculation 

13 so the workflow does not depend on the calculator. 

14 

15 First trigger SCF calculation if necessary, then set arguments 

16 on the calculator for band structure calculation, then return 

17 calculated band structure. 

18 

19 The difference from get_band_structure() is that the latter 

20 expects the calculation to already have been done.""" 

21 if path is None: 

22 path = atoms.cell.bandpath() 

23 

24 from ase.lattice import celldiff # Should this be a method on cell? 

25 if any(path.cell.any(1) != atoms.pbc): 

26 raise ValueError('The band path\'s cell, {}, does not match the ' 

27 'periodicity {} of the atoms' 

28 .format(path.cell, atoms.pbc)) 

29 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc)) 

30 if cell_err > cell_tol: 

31 raise ValueError('Atoms and band path have different unit cells. ' 

32 'Please reduce atoms to standard form. ' 

33 'Cell lengths and angles are {} vs {}' 

34 .format(atoms.cell.cellpar(), path.cell.cellpar())) 

35 

36 calc = atoms.calc 

37 if calc is None: 

38 raise ValueError('Atoms have no calculator') 

39 

40 if scf_kwargs is not None: 

41 calc.set(**scf_kwargs) 

42 

43 # Proposed standard mechanism for calculators to advertise that they 

44 # use the bandpath keyword to handle band structures rather than 

45 # a double (SCF + BS) run. 

46 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False) 

47 if use_bandpath_kw: 

48 calc.set(bandpath=path) 

49 atoms.get_potential_energy() 

50 return calc.band_structure() 

51 

52 atoms.get_potential_energy() 

53 

54 if hasattr(calc, 'get_fermi_level'): 

55 # What is the protocol for a calculator to tell whether 

56 # it has fermi_energy? 

57 eref = calc.get_fermi_level() 

58 else: 

59 eref = 0.0 

60 

61 if bs_kwargs is None: 

62 bs_kwargs = {} 

63 

64 calc.set(kpts=path, **bs_kwargs) 

65 calc.results.clear() # XXX get rid of me 

66 

67 # Calculators are too inconsistent here: 

68 # * atoms.get_potential_energy() will fail when total energy is 

69 # not in results after BS calculation (Espresso) 

70 # * calc.calculate(atoms) doesn't ask for any quantity, so some 

71 # calculators may not calculate anything at all 

72 # * 'bandstructure' is not a recognized property we can ask for 

73 try: 

74 atoms.get_potential_energy() 

75 except PropertyNotImplementedError: 

76 pass 

77 

78 ibzkpts = calc.get_ibz_k_points() 

79 kpts_err = np.abs(path.kpts - ibzkpts).max() 

80 if kpts_err > kpts_tol: 

81 raise RuntimeError('Kpoints of calculator differ from those ' 

82 'of the band path we just used; ' 

83 'err={} > tol={}'.format(kpts_err, kpts_tol)) 

84 

85 bs = get_band_structure(atoms, path=path, reference=eref) 

86 return bs 

87 

88 

89def get_band_structure(atoms=None, calc=None, path=None, reference=None): 

90 """Create band structure object from Atoms or calculator.""" 

91 # path and reference are used internally at the moment, but 

92 # the exact implementation will probably change. WIP. 

93 # 

94 # XXX We throw away info about the bandpath when we create the calculator. 

95 # If we have kept the bandpath, we can provide it as an argument here. 

96 # It would be wise to check that the bandpath kpoints are the same as 

97 # those stored in the calculator. 

98 atoms = atoms if atoms is not None else calc.atoms 

99 calc = calc if calc is not None else atoms.calc 

100 

101 kpts = calc.get_ibz_k_points() 

102 

103 energies = [] 

104 for s in range(calc.get_number_of_spins()): 

105 energies.append([calc.get_eigenvalues(kpt=k, spin=s) 

106 for k in range(len(kpts))]) 

107 energies = np.array(energies) 

108 

109 if path is None: 

110 from ase.dft.kpoints import (BandPath, find_bandpath_kinks, 

111 resolve_custom_points) 

112 standard_path = atoms.cell.bandpath(npoints=0) 

113 # Kpoints are already evaluated, we just need to put them into 

114 # the path (whether they fit our idea of what the path is, or not). 

115 # 

116 # Depending on how the path was established, the kpoints might 

117 # be valid high-symmetry points, but since there are multiple 

118 # high-symmetry points of each type, they may not coincide 

119 # with ours if the bandpath was generated by another code. 

120 # 

121 # Here we hack it so the BandPath has proper points even if they 

122 # come from some weird source. 

123 # 

124 # This operation (manually hacking the bandpath) is liable to break. 

125 # TODO: Make it available as a proper (documented) bandpath method. 

126 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5) 

127 pathspec, special_points = resolve_custom_points( 

128 kpts[kinks], standard_path.special_points, eps=1e-5) 

129 path = BandPath(standard_path.cell, 

130 kpts=kpts, 

131 path=pathspec, 

132 special_points=special_points) 

133 

134 # XXX If we *did* get the path, now would be a good time to check 

135 # that it matches the cell! Although the path can only be passed 

136 # because we internally want to not re-evaluate the Bravais 

137 # lattice type. (We actually need an eps parameter, too.) 

138 

139 if reference is None: 

140 # Fermi level should come from the GS calculation, not the BS one! 

141 reference = calc.get_fermi_level() 

142 

143 if reference is None: 

144 # Fermi level may not be available, e.g., with non-Fermi smearing. 

145 # XXX Actually get_fermi_level() should raise an error when Fermi 

146 # level wasn't available, so we should fix that. 

147 reference = 0.0 

148 

149 return BandStructure(path=path, 

150 energies=energies, 

151 reference=reference) 

152 

153 

154class BandStructurePlot: 

155 def __init__(self, bs): 

156 self.bs = bs 

157 self.ax = None 

158 self.xcoords = None 

159 self.show_legend = False 

160 

161 def plot(self, ax=None, spin=None, emin=-10, emax=5, filename=None, 

162 show=False, ylabel=None, colors=None, label=None, 

163 spin_labels=['spin up', 'spin down'], loc=None, **plotkwargs): 

164 """Plot band-structure. 

165 

166 spin: int or None 

167 Spin channel. Default behaviour is to plot both spin up and down 

168 for spin-polarized calculations. 

169 emin,emax: float 

170 Maximum energy above reference. 

171 filename: str 

172 Write image to a file. 

173 ax: Axes 

174 MatPlotLib Axes object. Will be created if not supplied. 

175 show: bool 

176 Show the image. 

177 """ 

178 

179 if self.ax is None: 

180 ax = self.prepare_plot(ax, emin, emax, ylabel) 

181 

182 if spin is None: 

183 e_skn = self.bs.energies 

184 else: 

185 e_skn = self.bs.energies[spin, np.newaxis] 

186 

187 if colors is None: 

188 if len(e_skn) == 1: 

189 colors = 'g' 

190 else: 

191 colors = 'yb' 

192 

193 nspins = len(e_skn) 

194 

195 for spin, e_kn in enumerate(e_skn): 

196 color = colors[spin] 

197 kwargs = dict(color=color) 

198 kwargs.update(plotkwargs) 

199 if nspins == 2: 

200 if label: 

201 lbl = label + ' ' + spin_labels[spin] 

202 else: 

203 lbl = spin_labels[spin] 

204 else: 

205 lbl = label 

206 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs) 

207 

208 for e_k in e_kn.T[1:]: 

209 ax.plot(self.xcoords, e_k, **kwargs) 

210 

211 self.show_legend = label is not None or nspins == 2 

212 self.finish_plot(filename, show, loc) 

213 

214 return ax 

215 

216 def plot_with_colors(self, ax=None, emin=-10, emax=5, filename=None, 

217 show=False, energies=None, colors=None, 

218 ylabel=None, clabel='$s_z$', cmin=-1.0, cmax=1.0, 

219 sortcolors=False, loc=None, s=2): 

220 """Plot band-structure with colors.""" 

221 

222 import matplotlib.pyplot as plt 

223 

224 if self.ax is None: 

225 ax = self.prepare_plot(ax, emin, emax, ylabel) 

226 

227 shape = energies.shape 

228 xcoords = np.vstack([self.xcoords] * shape[1]) 

229 if sortcolors: 

230 perm = colors.argsort(axis=None) 

231 energies = energies.ravel()[perm].reshape(shape) 

232 colors = colors.ravel()[perm].reshape(shape) 

233 xcoords = xcoords.ravel()[perm].reshape(shape) 

234 

235 for e_k, c_k, x_k in zip(energies, colors, xcoords): 

236 things = ax.scatter(x_k, e_k, c=c_k, s=s, 

237 vmin=cmin, vmax=cmax) 

238 

239 cbar = plt.colorbar(things) 

240 cbar.set_label(clabel) 

241 

242 self.finish_plot(filename, show, loc) 

243 

244 return ax 

245 

246 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None): 

247 import matplotlib.pyplot as plt 

248 if ax is None: 

249 ax = plt.figure().add_subplot(111) 

250 

251 def pretty(kpt): 

252 if kpt == 'G': 

253 kpt = r'$\Gamma$' 

254 elif len(kpt) == 2: 

255 kpt = kpt[0] + '$_' + kpt[1] + '$' 

256 return kpt 

257 

258 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels() 

259 label_xcoords = list(label_xcoords) 

260 labels = [pretty(name) for name in orig_labels] 

261 

262 i = 1 

263 while i < len(labels): 

264 if label_xcoords[i - 1] == label_xcoords[i]: 

265 labels[i - 1] = labels[i - 1] + ',' + labels[i] 

266 labels.pop(i) 

267 label_xcoords.pop(i) 

268 else: 

269 i += 1 

270 

271 for x in label_xcoords[1:-1]: 

272 ax.axvline(x, color='0.5') 

273 

274 ylabel = ylabel if ylabel is not None else 'energies [eV]' 

275 

276 ax.set_xticks(label_xcoords) 

277 ax.set_xticklabels(labels) 

278 ax.set_ylabel(ylabel) 

279 ax.axhline(self.bs.reference, color='k', ls=':') 

280 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax) 

281 self.ax = ax 

282 return ax 

283 

284 def finish_plot(self, filename, show, loc): 

285 import matplotlib.pyplot as plt 

286 

287 if self.show_legend: 

288 leg = plt.legend(loc=loc) 

289 leg.get_frame().set_alpha(1) 

290 

291 if filename: 

292 plt.savefig(filename) 

293 

294 if show: 

295 plt.show() 

296 

297 

298@jsonable('bandstructure') 

299class BandStructure: 

300 """A band structure consists of an array of eigenvalues and a bandpath. 

301 

302 BandStructure objects support JSON I/O. 

303 """ 

304 

305 def __init__(self, path, energies, reference=0.0): 

306 self._path = path 

307 self._energies = np.asarray(energies) 

308 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands 

309 assert self.energies.shape[1] == len(path.kpts) 

310 assert np.isscalar(reference) 

311 self._reference = reference 

312 

313 @property 

314 def energies(self) -> np.ndarray: 

315 """The energies of this band structure. 

316 

317 This is a numpy array of shape (nspins, nkpoints, nbands).""" 

318 return self._energies 

319 

320 @property 

321 def path(self) -> 'ase.dft.kpoints.BandPath': 

322 """The :class:`~ase.dft.kpoints.BandPath` of this band structure.""" 

323 return self._path 

324 

325 @property 

326 def reference(self) -> float: 

327 """The reference energy. 

328 

329 Semantics may vary; typically a Fermi energy or zero, 

330 depending on how the band structure was created.""" 

331 return self._reference 

332 

333 def subtract_reference(self) -> 'BandStructure': 

334 """Return new band structure with reference energy subtracted.""" 

335 return BandStructure(self.path, self.energies - self.reference, 

336 reference=0.0) 

337 

338 def todict(self): 

339 return dict(path=self.path, 

340 energies=self.energies, 

341 reference=self.reference) 

342 

343 def get_labels(self, eps=1e-5): 

344 """"See :func:`ase.dft.kpoints.labels_from_kpts`.""" 

345 return self.path.get_linear_kpoint_axis(eps=eps) 

346 

347 def plot(self, *args, **kwargs): 

348 """Plot this band structure.""" 

349 bsp = BandStructurePlot(self) 

350 return bsp.plot(*args, **kwargs) 

351 

352 def __repr__(self): 

353 return ('{}(path={!r}, energies=[{} values], reference={})' 

354 .format(self.__class__.__name__, self.path, 

355 '{}x{}x{}'.format(*self.energies.shape), 

356 self.reference))