Coverage for /builds/kinetik161/ase/ase/spectrum/band_structure.py: 79.10%
177 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import numpy as np
3import ase # Annotations
4from ase.calculators.calculator import PropertyNotImplementedError
5from ase.utils import jsonable
8def calculate_band_structure(atoms, path=None, scf_kwargs=None,
9 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6):
10 """Calculate band structure.
12 The purpose of this function is to abstract a band structure calculation
13 so the workflow does not depend on the calculator.
15 First trigger SCF calculation if necessary, then set arguments
16 on the calculator for band structure calculation, then return
17 calculated band structure.
19 The difference from get_band_structure() is that the latter
20 expects the calculation to already have been done."""
21 if path is None:
22 path = atoms.cell.bandpath()
24 from ase.lattice import celldiff # Should this be a method on cell?
25 if any(path.cell.any(1) != atoms.pbc):
26 raise ValueError('The band path\'s cell, {}, does not match the '
27 'periodicity {} of the atoms'
28 .format(path.cell, atoms.pbc))
29 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc))
30 if cell_err > cell_tol:
31 raise ValueError('Atoms and band path have different unit cells. '
32 'Please reduce atoms to standard form. '
33 'Cell lengths and angles are {} vs {}'
34 .format(atoms.cell.cellpar(), path.cell.cellpar()))
36 calc = atoms.calc
37 if calc is None:
38 raise ValueError('Atoms have no calculator')
40 if scf_kwargs is not None:
41 calc.set(**scf_kwargs)
43 # Proposed standard mechanism for calculators to advertise that they
44 # use the bandpath keyword to handle band structures rather than
45 # a double (SCF + BS) run.
46 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False)
47 if use_bandpath_kw:
48 calc.set(bandpath=path)
49 atoms.get_potential_energy()
50 return calc.band_structure()
52 atoms.get_potential_energy()
54 if hasattr(calc, 'get_fermi_level'):
55 # What is the protocol for a calculator to tell whether
56 # it has fermi_energy?
57 eref = calc.get_fermi_level()
58 else:
59 eref = 0.0
61 if bs_kwargs is None:
62 bs_kwargs = {}
64 calc.set(kpts=path, **bs_kwargs)
65 calc.results.clear() # XXX get rid of me
67 # Calculators are too inconsistent here:
68 # * atoms.get_potential_energy() will fail when total energy is
69 # not in results after BS calculation (Espresso)
70 # * calc.calculate(atoms) doesn't ask for any quantity, so some
71 # calculators may not calculate anything at all
72 # * 'bandstructure' is not a recognized property we can ask for
73 try:
74 atoms.get_potential_energy()
75 except PropertyNotImplementedError:
76 pass
78 ibzkpts = calc.get_ibz_k_points()
79 kpts_err = np.abs(path.kpts - ibzkpts).max()
80 if kpts_err > kpts_tol:
81 raise RuntimeError('Kpoints of calculator differ from those '
82 'of the band path we just used; '
83 'err={} > tol={}'.format(kpts_err, kpts_tol))
85 bs = get_band_structure(atoms, path=path, reference=eref)
86 return bs
89def get_band_structure(atoms=None, calc=None, path=None, reference=None):
90 """Create band structure object from Atoms or calculator."""
91 # path and reference are used internally at the moment, but
92 # the exact implementation will probably change. WIP.
93 #
94 # XXX We throw away info about the bandpath when we create the calculator.
95 # If we have kept the bandpath, we can provide it as an argument here.
96 # It would be wise to check that the bandpath kpoints are the same as
97 # those stored in the calculator.
98 atoms = atoms if atoms is not None else calc.atoms
99 calc = calc if calc is not None else atoms.calc
101 kpts = calc.get_ibz_k_points()
103 energies = []
104 for s in range(calc.get_number_of_spins()):
105 energies.append([calc.get_eigenvalues(kpt=k, spin=s)
106 for k in range(len(kpts))])
107 energies = np.array(energies)
109 if path is None:
110 from ase.dft.kpoints import (BandPath, find_bandpath_kinks,
111 resolve_custom_points)
112 standard_path = atoms.cell.bandpath(npoints=0)
113 # Kpoints are already evaluated, we just need to put them into
114 # the path (whether they fit our idea of what the path is, or not).
115 #
116 # Depending on how the path was established, the kpoints might
117 # be valid high-symmetry points, but since there are multiple
118 # high-symmetry points of each type, they may not coincide
119 # with ours if the bandpath was generated by another code.
120 #
121 # Here we hack it so the BandPath has proper points even if they
122 # come from some weird source.
123 #
124 # This operation (manually hacking the bandpath) is liable to break.
125 # TODO: Make it available as a proper (documented) bandpath method.
126 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5)
127 pathspec, special_points = resolve_custom_points(
128 kpts[kinks], standard_path.special_points, eps=1e-5)
129 path = BandPath(standard_path.cell,
130 kpts=kpts,
131 path=pathspec,
132 special_points=special_points)
134 # XXX If we *did* get the path, now would be a good time to check
135 # that it matches the cell! Although the path can only be passed
136 # because we internally want to not re-evaluate the Bravais
137 # lattice type. (We actually need an eps parameter, too.)
139 if reference is None:
140 # Fermi level should come from the GS calculation, not the BS one!
141 reference = calc.get_fermi_level()
143 if reference is None:
144 # Fermi level may not be available, e.g., with non-Fermi smearing.
145 # XXX Actually get_fermi_level() should raise an error when Fermi
146 # level wasn't available, so we should fix that.
147 reference = 0.0
149 return BandStructure(path=path,
150 energies=energies,
151 reference=reference)
154class BandStructurePlot:
155 def __init__(self, bs):
156 self.bs = bs
157 self.ax = None
158 self.xcoords = None
159 self.show_legend = False
161 def plot(self, ax=None, spin=None, emin=-10, emax=5, filename=None,
162 show=False, ylabel=None, colors=None, label=None,
163 spin_labels=['spin up', 'spin down'], loc=None, **plotkwargs):
164 """Plot band-structure.
166 spin: int or None
167 Spin channel. Default behaviour is to plot both spin up and down
168 for spin-polarized calculations.
169 emin,emax: float
170 Maximum energy above reference.
171 filename: str
172 Write image to a file.
173 ax: Axes
174 MatPlotLib Axes object. Will be created if not supplied.
175 show: bool
176 Show the image.
177 """
179 if self.ax is None:
180 ax = self.prepare_plot(ax, emin, emax, ylabel)
182 if spin is None:
183 e_skn = self.bs.energies
184 else:
185 e_skn = self.bs.energies[spin, np.newaxis]
187 if colors is None:
188 if len(e_skn) == 1:
189 colors = 'g'
190 else:
191 colors = 'yb'
193 nspins = len(e_skn)
195 for spin, e_kn in enumerate(e_skn):
196 color = colors[spin]
197 kwargs = dict(color=color)
198 kwargs.update(plotkwargs)
199 if nspins == 2:
200 if label:
201 lbl = label + ' ' + spin_labels[spin]
202 else:
203 lbl = spin_labels[spin]
204 else:
205 lbl = label
206 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs)
208 for e_k in e_kn.T[1:]:
209 ax.plot(self.xcoords, e_k, **kwargs)
211 self.show_legend = label is not None or nspins == 2
212 self.finish_plot(filename, show, loc)
214 return ax
216 def plot_with_colors(self, ax=None, emin=-10, emax=5, filename=None,
217 show=False, energies=None, colors=None,
218 ylabel=None, clabel='$s_z$', cmin=-1.0, cmax=1.0,
219 sortcolors=False, loc=None, s=2):
220 """Plot band-structure with colors."""
222 import matplotlib.pyplot as plt
224 if self.ax is None:
225 ax = self.prepare_plot(ax, emin, emax, ylabel)
227 shape = energies.shape
228 xcoords = np.vstack([self.xcoords] * shape[1])
229 if sortcolors:
230 perm = colors.argsort(axis=None)
231 energies = energies.ravel()[perm].reshape(shape)
232 colors = colors.ravel()[perm].reshape(shape)
233 xcoords = xcoords.ravel()[perm].reshape(shape)
235 for e_k, c_k, x_k in zip(energies, colors, xcoords):
236 things = ax.scatter(x_k, e_k, c=c_k, s=s,
237 vmin=cmin, vmax=cmax)
239 cbar = plt.colorbar(things)
240 cbar.set_label(clabel)
242 self.finish_plot(filename, show, loc)
244 return ax
246 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None):
247 import matplotlib.pyplot as plt
248 if ax is None:
249 ax = plt.figure().add_subplot(111)
251 def pretty(kpt):
252 if kpt == 'G':
253 kpt = r'$\Gamma$'
254 elif len(kpt) == 2:
255 kpt = kpt[0] + '$_' + kpt[1] + '$'
256 return kpt
258 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels()
259 label_xcoords = list(label_xcoords)
260 labels = [pretty(name) for name in orig_labels]
262 i = 1
263 while i < len(labels):
264 if label_xcoords[i - 1] == label_xcoords[i]:
265 labels[i - 1] = labels[i - 1] + ',' + labels[i]
266 labels.pop(i)
267 label_xcoords.pop(i)
268 else:
269 i += 1
271 for x in label_xcoords[1:-1]:
272 ax.axvline(x, color='0.5')
274 ylabel = ylabel if ylabel is not None else 'energies [eV]'
276 ax.set_xticks(label_xcoords)
277 ax.set_xticklabels(labels)
278 ax.set_ylabel(ylabel)
279 ax.axhline(self.bs.reference, color='k', ls=':')
280 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax)
281 self.ax = ax
282 return ax
284 def finish_plot(self, filename, show, loc):
285 import matplotlib.pyplot as plt
287 if self.show_legend:
288 leg = plt.legend(loc=loc)
289 leg.get_frame().set_alpha(1)
291 if filename:
292 plt.savefig(filename)
294 if show:
295 plt.show()
298@jsonable('bandstructure')
299class BandStructure:
300 """A band structure consists of an array of eigenvalues and a bandpath.
302 BandStructure objects support JSON I/O.
303 """
305 def __init__(self, path, energies, reference=0.0):
306 self._path = path
307 self._energies = np.asarray(energies)
308 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands
309 assert self.energies.shape[1] == len(path.kpts)
310 assert np.isscalar(reference)
311 self._reference = reference
313 @property
314 def energies(self) -> np.ndarray:
315 """The energies of this band structure.
317 This is a numpy array of shape (nspins, nkpoints, nbands)."""
318 return self._energies
320 @property
321 def path(self) -> 'ase.dft.kpoints.BandPath':
322 """The :class:`~ase.dft.kpoints.BandPath` of this band structure."""
323 return self._path
325 @property
326 def reference(self) -> float:
327 """The reference energy.
329 Semantics may vary; typically a Fermi energy or zero,
330 depending on how the band structure was created."""
331 return self._reference
333 def subtract_reference(self) -> 'BandStructure':
334 """Return new band structure with reference energy subtracted."""
335 return BandStructure(self.path, self.energies - self.reference,
336 reference=0.0)
338 def todict(self):
339 return dict(path=self.path,
340 energies=self.energies,
341 reference=self.reference)
343 def get_labels(self, eps=1e-5):
344 """"See :func:`ase.dft.kpoints.labels_from_kpts`."""
345 return self.path.get_linear_kpoint_axis(eps=eps)
347 def plot(self, *args, **kwargs):
348 """Plot this band structure."""
349 bsp = BandStructurePlot(self)
350 return bsp.plot(*args, **kwargs)
352 def __repr__(self):
353 return ('{}(path={!r}, energies=[{} values], reference={})'
354 .format(self.__class__.__name__, self.path,
355 '{}x{}x{}'.format(*self.energies.shape),
356 self.reference))