Coverage for /builds/kinetik161/ase/ase/spectrum/dosdata.py: 100.00%
152 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1# Refactor of DOS-like data objects
2# towards replacing ase.dft.dos and ase.dft.pdos
3import warnings
4from abc import ABCMeta, abstractmethod
5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union
7import numpy as np
9from ase.utils.plotting import SimplePlottingAxes
11# This import is for the benefit of type-checking / mypy
12if False:
13 import matplotlib.axes
15# For now we will be strict about Info and say it has to be str->str. Perhaps
16# later we will allow other types that have reliable comparison operations.
17Info = Dict[str, str]
19# Still no good solution to type checking with arrays.
20Floats = Union[Sequence[float], np.ndarray]
23class DOSData(metaclass=ABCMeta):
24 """Abstract base class for a single series of DOS-like data
26 Only the 'info' is a mutable attribute; DOS data is set at init"""
28 def __init__(self,
29 info: Info = None) -> None:
30 if info is None:
31 self.info = {}
32 elif isinstance(info, dict):
33 self.info = info
34 else:
35 raise TypeError("Info must be a dict or None")
37 @abstractmethod
38 def get_energies(self) -> Floats:
39 """Get energy data stored in this object"""
41 @abstractmethod
42 def get_weights(self) -> Floats:
43 """Get DOS weights stored in this object"""
45 @abstractmethod
46 def copy(self) -> 'DOSData':
47 """Returns a copy in which info dict can be safely mutated"""
49 def _sample(self,
50 energies: Floats,
51 width: float = 0.1,
52 smearing: str = 'Gauss') -> np.ndarray:
53 """Sample the DOS data at chosen points, with broadening
55 Note that no correction is made here for the sampling bin width; total
56 intensity will vary with sampling density.
58 Args:
59 energies: energy values for sampling
60 width: Width of broadening kernel
61 smearing: selection of broadening kernel (only "Gauss" is currently
62 supported)
64 Returns:
65 Weights sampled from a broadened DOS at values corresponding to x
66 """
68 self._check_positive_width(width)
69 weights_grid = np.zeros(len(energies), float)
70 weights = self.get_weights()
71 energies = np.asarray(energies, float)
73 for i, raw_energy in enumerate(self.get_energies()):
74 delta = self._delta(energies, raw_energy, width, smearing=smearing)
75 weights_grid += weights[i] * delta
76 return weights_grid
78 def _almost_equals(self, other: Any) -> bool:
79 """Compare with another DOSData for testing purposes"""
80 if not isinstance(other, type(self)):
81 return False
82 if self.info != other.info:
83 return False
84 if not np.allclose(self.get_weights(), other.get_weights()):
85 return False
86 return np.allclose(self.get_energies(), other.get_energies())
88 @staticmethod
89 def _delta(x: np.ndarray,
90 x0: float,
91 width: float,
92 smearing: str = 'Gauss') -> np.ndarray:
93 """Return a delta-function centered at 'x0'.
95 This function is used with numpy broadcasting; if x is a row and x0 is
96 a column vector, the returned data will be a 2D array with each row
97 corresponding to a different delta center.
98 """
99 if smearing.lower() == 'gauss':
100 x1 = -0.5 * ((x - x0) / width)**2
101 return np.exp(x1) / (np.sqrt(2 * np.pi) * width)
102 else:
103 msg = 'Requested smearing type not recognized. Got {}'.format(
104 smearing)
105 raise ValueError(msg)
107 @staticmethod
108 def _check_positive_width(width):
109 if width <= 0.0:
110 msg = 'Cannot add 0 or negative width smearing'
111 raise ValueError(msg)
113 def sample_grid(self,
114 npts: int,
115 xmin: float = None,
116 xmax: float = None,
117 padding: float = 3,
118 width: float = 0.1,
119 smearing: str = 'Gauss',
120 ) -> 'GridDOSData':
121 """Sample the DOS data on an evenly-spaced energy grid
123 Args:
124 npts: Number of sampled points
125 xmin: Minimum sampled x value; if unspecified, a default is chosen
126 xmax: Maximum sampled x value; if unspecified, a default is chosen
127 padding: If xmin/xmax is unspecified, default value will be padded
128 by padding * width to avoid cutting off peaks.
129 width: Width of broadening kernel
130 smearing: selection of broadening kernel (only 'Gauss' is
131 implemented)
133 Returns:
134 (energy values, sampled DOS)
135 """
137 if xmin is None:
138 xmin = min(self.get_energies()) - (padding * width)
139 if xmax is None:
140 xmax = max(self.get_energies()) + (padding * width)
141 energies_grid = np.linspace(xmin, xmax, npts)
142 weights_grid = self._sample(energies_grid, width=width,
143 smearing=smearing)
145 return GridDOSData(energies_grid, weights_grid, info=self.info.copy())
147 def plot(self,
148 npts: int = 1000,
149 xmin: float = None,
150 xmax: float = None,
151 width: float = 0.1,
152 smearing: str = 'Gauss',
153 ax: 'matplotlib.axes.Axes' = None,
154 show: bool = False,
155 filename: str = None,
156 mplargs: dict = None) -> 'matplotlib.axes.Axes':
157 """Simple 1-D plot of DOS data, resampled onto a grid
159 If the special key 'label' is present in self.info, this will be set
160 as the label for the plotted line (unless overruled in mplargs). The
161 label is only seen if a legend is added to the plot (i.e. by calling
162 ``ax.legend()``).
164 Args:
165 npts, xmin, xmax: output data range, as passed to self.sample_grid
166 width: Width of broadening kernel for self.sample_grid()
167 smearing: selection of broadening kernel for self.sample_grid()
168 ax: existing Matplotlib axes object. If not provided, a new figure
169 with one set of axes will be created using Pyplot
170 show: show the figure on-screen
171 filename: if a path is given, save the figure to this file
172 mplargs: additional arguments to pass to matplotlib plot command
173 (e.g. {'linewidth': 2} for a thicker line).
176 Returns:
177 Plotting axes. If "ax" was set, this is the same object.
178 """
180 if mplargs is None:
181 mplargs = {}
182 if 'label' not in mplargs:
183 mplargs.update({'label': self.label_from_info(self.info)})
185 return self.sample_grid(npts, xmin=xmin, xmax=xmax,
186 width=width,
187 smearing=smearing
188 ).plot(ax=ax, xmin=xmin, xmax=xmax,
189 show=show, filename=filename,
190 mplargs=mplargs)
192 @staticmethod
193 def label_from_info(info: Dict[str, str]):
194 """Generate an automatic legend label from info dict"""
195 if 'label' in info:
196 return info['label']
197 else:
198 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}',
199 info.items()))
202class GeneralDOSData(DOSData):
203 """Base class for a single series of DOS-like data
205 Only the 'info' is a mutable attribute; DOS data is set at init
207 This is the base class for DOSData objects that accept/set seperate
208 "energies" and "weights" sequences of equal length at init.
210 """
212 def __init__(self,
213 energies: Floats,
214 weights: Floats,
215 info: Info = None) -> None:
216 super().__init__(info=info)
218 n_entries = len(energies)
219 if len(weights) != n_entries:
220 raise ValueError("Energies and weights must be the same length")
222 # Internally store the data as a np array with two rows; energy, weight
223 self._data = np.empty((2, n_entries), dtype=float, order='C')
224 self._data[0, :] = energies
225 self._data[1, :] = weights
227 def get_energies(self) -> np.ndarray:
228 return self._data[0, :].copy()
230 def get_weights(self) -> np.ndarray:
231 return self._data[1, :].copy()
233 D = TypeVar('D', bound='GeneralDOSData')
235 def copy(self: D) -> D: # noqa F821
236 return type(self)(self.get_energies(), self.get_weights(),
237 info=self.info.copy())
240class RawDOSData(GeneralDOSData):
241 """A collection of weighted delta functions which sum to form a DOS
243 This is an appropriate data container for density-of-states (DOS) or
244 spectral data where the energy data values not form a known regular
245 grid. The data may be plotted or resampled for further analysis using the
246 sample_grid() and plot() methods. Multiple weights at the same
247 energy value will *only* be combined in output data, and data stored in
248 RawDOSData is never resampled. A plot_deltas() function is also provided
249 which plots the raw data.
251 Metadata may be stored in the info dict, in which keys and values must be
252 strings. This data is used for selecting and combining multiple DOSData
253 objects in a DOSCollection object.
255 When RawDOSData objects are combined with the addition operator::
257 big_dos = raw_dos_1 + raw_dos_2
259 the energy and weights data is *concatenated* (i.e. combined without
260 sorting or replacement) and the new info dictionary consists of the
261 *intersection* of the inputs: only key-value pairs that were common to both
262 of the input objects will be retained in the new combined object. For
263 example::
265 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'})
266 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'}))
268 will yield the equivalent of::
270 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'})
272 """
274 def __add__(self, other: 'RawDOSData') -> 'RawDOSData':
275 if not isinstance(other, RawDOSData):
276 raise TypeError("RawDOSData can only be combined with other "
277 "RawDOSData objects")
279 # Take intersection of metadata (i.e. only common entries are retained)
280 new_info = dict(set(self.info.items()) & set(other.info.items()))
282 # Concatenate the energy/weight data
283 new_data = np.concatenate((self._data, other._data), axis=1)
285 new_object = RawDOSData([], [], info=new_info)
286 new_object._data = new_data
288 return new_object
290 def plot_deltas(self,
291 ax: 'matplotlib.axes.Axes' = None,
292 show: bool = False,
293 filename: str = None,
294 mplargs: dict = None) -> 'matplotlib.axes.Axes':
295 """Simple plot of sparse DOS data as a set of delta functions
297 Items at the same x-value can overlap and will not be summed together
299 Args:
300 ax: existing Matplotlib axes object. If not provided, a new figure
301 with one set of axes will be created using Pyplot
302 show: show the figure on-screen
303 filename: if a path is given, save the figure to this file
304 mplargs: additional arguments to pass to matplotlib Axes.vlines
305 command (e.g. {'linewidth': 2} for a thicker line).
307 Returns:
308 Plotting axes. If "ax" was set, this is the same object.
309 """
311 if mplargs is None:
312 mplargs = {}
314 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
315 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs)
317 return ax
320class GridDOSData(GeneralDOSData):
321 """A collection of regularly-sampled data which represents a DOS
323 This is an appropriate data container for density-of-states (DOS) or
324 spectral data where the intensity values form a regular grid. This
325 is generally the result of sampling or integrating into discrete
326 bins, rather than a collection of unique states. The data may be
327 plotted or resampled for further analysis using the sample_grid()
328 and plot() methods.
330 Metadata may be stored in the info dict, in which keys and values must be
331 strings. This data is used for selecting and combining multiple DOSData
332 objects in a DOSCollection object.
334 When RawDOSData objects are combined with the addition operator::
336 big_dos = raw_dos_1 + raw_dos_2
338 the weights data is *summed* (requiring a consistent energy grid) and the
339 new info dictionary consists of the *intersection* of the inputs: only
340 key-value pairs that were common to both of the input objects will be
341 retained in the new combined object. For example::
343 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3],
344 info={'symbol': 'O', 'index': '1'})
345 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6],
346 info={'symbol': 'O', 'index': '2'}))
348 will yield the equivalent of::
350 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'})
352 """
354 def __init__(self,
355 energies: Floats,
356 weights: Floats,
357 info: Info = None) -> None:
358 n_entries = len(energies)
359 if not np.allclose(energies,
360 np.linspace(energies[0], energies[-1], n_entries)):
361 raise ValueError("Energies must be an evenly-spaced 1-D grid")
363 if len(weights) != n_entries:
364 raise ValueError("Energies and weights must be the same length")
366 super().__init__(energies, weights, info=info)
367 self.sigma_cutoff = 3
369 def _check_spacing(self, width) -> float:
370 current_spacing = self._data[0, 1] - self._data[0, 0]
371 if width < (2 * current_spacing):
372 warnings.warn(
373 "The broadening width is small compared to the original "
374 "sampling density. The results are unlikely to be smooth.")
375 return current_spacing
377 def _sample(self,
378 energies: Floats,
379 width: float = 0.1,
380 smearing: str = 'Gauss') -> np.ndarray:
381 current_spacing = self._check_spacing(width)
382 return super()._sample(energies=energies,
383 width=width, smearing=smearing
384 ) * current_spacing
386 def __add__(self, other: 'GridDOSData') -> 'GridDOSData':
387 # This method uses direct access to the mutable energy and weights data
388 # (self._data) to avoid redundant copying operations. The __init__
389 # method of GridDOSData will write this to a new array, so on this
390 # occasion it is safe to pass references to the mutable data.
392 if not isinstance(other, GridDOSData):
393 raise TypeError("GridDOSData can only be combined with other "
394 "GridDOSData objects")
395 if len(self._data[0, :]) != len(other.get_energies()):
396 raise ValueError("Cannot add GridDOSData objects with different-"
397 "length energy grids.")
399 if not np.allclose(self._data[0, :], other.get_energies()):
400 raise ValueError("Cannot add GridDOSData objects with different "
401 "energy grids.")
403 # Take intersection of metadata (i.e. only common entries are retained)
404 new_info = dict(set(self.info.items()) & set(other.info.items()))
406 # Sum the energy/weight data
407 new_weights = self._data[1, :] + other.get_weights()
409 new_object = GridDOSData(self._data[0, :], new_weights,
410 info=new_info)
411 return new_object
413 @staticmethod
414 def _interpret_smearing_args(npts: int,
415 width: float = None,
416 default_npts: int = 1000,
417 default_width: float = 0.1
418 ) -> Tuple[int, Union[float, None]]:
419 """Figure out what the user intended: resample if width provided"""
420 if width is not None:
421 if npts:
422 return (npts, float(width))
423 else:
424 return (default_npts, float(width))
425 else:
426 if npts:
427 return (npts, default_width)
428 else:
429 return (0, None)
431 def plot(self,
432 npts: int = 0,
433 xmin: float = None,
434 xmax: float = None,
435 width: float = None,
436 smearing: str = 'Gauss',
437 ax: 'matplotlib.axes.Axes' = None,
438 show: bool = False,
439 filename: str = None,
440 mplargs: dict = None) -> 'matplotlib.axes.Axes':
441 """Simple 1-D plot of DOS data
443 Data will be resampled onto a grid with `npts` points unless `npts` is
444 set to zero, in which case:
446 - no resampling takes place
447 - `width` and `smearing` are ignored
448 - `xmin` and `xmax` affect the axis limits of the plot, not the
449 underlying data.
451 If the special key 'label' is present in self.info, this will be set
452 as the label for the plotted line (unless overruled in mplargs). The
453 label is only seen if a legend is added to the plot (i.e. by calling
454 ``ax.legend()``).
456 Args:
457 npts, xmin, xmax: output data range, as passed to self.sample_grid
458 width: Width of broadening kernel, passed to self.sample_grid().
459 If no npts was set but width is set, npts will be set to 1000.
460 smearing: selection of broadening kernel for self.sample_grid()
461 ax: existing Matplotlib axes object. If not provided, a new figure
462 with one set of axes will be created using Pyplot
463 show: show the figure on-screen
464 filename: if a path is given, save the figure to this file
465 mplargs: additional arguments to pass to matplotlib plot command
466 (e.g. {'linewidth': 2} for a thicker line).
468 Returns:
469 Plotting axes. If "ax" was set, this is the same object.
470 """
472 npts, width = self._interpret_smearing_args(npts, width)
474 if mplargs is None:
475 mplargs = {}
476 if 'label' not in mplargs:
477 mplargs.update({'label': self.label_from_info(self.info)})
479 if npts:
480 assert isinstance(width, float)
481 dos = self.sample_grid(npts, xmin=xmin,
482 xmax=xmax, width=width,
483 smearing=smearing)
484 else:
485 dos = self
487 energies, intensity = dos.get_energies(), dos.get_weights()
489 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
490 ax.plot(energies, intensity, **mplargs)
491 ax.set_xlim(left=xmin, right=xmax)
493 return ax