Coverage for /builds/kinetik161/ase/ase/spectrum/doscollection.py: 97.84%
185 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import collections
2from functools import reduce, singledispatch
3from typing import (Any, Dict, Iterable, List, Optional, Sequence, TypeVar,
4 Union, overload)
6import numpy as np
8from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData
9from ase.utils.plotting import SimplePlottingAxes
11# This import is for the benefit of type-checking / mypy
12if False:
13 import matplotlib.axes
16class DOSCollection(collections.abc.Sequence):
17 """Base class for a collection of DOSData objects"""
19 def __init__(self, dos_series: Iterable[DOSData]) -> None:
20 self._data = list(dos_series)
22 def _sample(self,
23 energies: Floats,
24 width: float = 0.1,
25 smearing: str = 'Gauss') -> np.ndarray:
26 """Sample the DOS data at chosen points, with broadening
28 This samples the underlying DOS data in the same way as the .sample()
29 method of those DOSData items, returning a 2-D array with columns
30 corresponding to x and rows corresponding to the collected data series.
32 Args:
33 energies: energy values for sampling
34 width: Width of broadening kernel
35 smearing: selection of broadening kernel (only "Gauss" is currently
36 supported)
38 Returns:
39 Weights sampled from a broadened DOS at values corresponding to x,
40 in rows corresponding to DOSData entries contained in this object
41 """
43 if len(self) == 0:
44 raise IndexError("No data to sample")
46 return np.asarray(
47 [data._sample(energies, width=width, smearing=smearing)
48 for data in self])
50 def plot(self,
51 npts: int = 1000,
52 xmin: float = None,
53 xmax: float = None,
54 width: float = 0.1,
55 smearing: str = 'Gauss',
56 ax: 'matplotlib.axes.Axes' = None,
57 show: bool = False,
58 filename: str = None,
59 mplargs: dict = None) -> 'matplotlib.axes.Axes':
60 """Simple plot of collected DOS data, resampled onto a grid
62 If the special key 'label' is present in self.info, this will be set
63 as the label for the plotted line (unless overruled in mplargs). The
64 label is only seen if a legend is added to the plot (i.e. by calling
65 `ax.legend()`).
67 Args:
68 npts, xmin, xmax: output data range, as passed to self.sample_grid
69 width: Width of broadening kernel, passed to self.sample_grid()
70 smearing: selection of broadening kernel for self.sample_grid()
71 ax: existing Matplotlib axes object. If not provided, a new figure
72 with one set of axes will be created using Pyplot
73 show: show the figure on-screen
74 filename: if a path is given, save the figure to this file
75 mplargs: additional arguments to pass to matplotlib plot command
76 (e.g. {'linewidth': 2} for a thicker line).
78 Returns:
79 Plotting axes. If "ax" was set, this is the same object.
80 """
81 return self.sample_grid(npts,
82 xmin=xmin, xmax=xmax,
83 width=width, smearing=smearing
84 ).plot(npts=npts,
85 xmin=xmin, xmax=xmax,
86 width=width, smearing=smearing,
87 ax=ax, show=show, filename=filename,
88 mplargs=mplargs)
90 def sample_grid(self,
91 npts: int,
92 xmin: float = None,
93 xmax: float = None,
94 padding: float = 3,
95 width: float = 0.1,
96 smearing: str = 'Gauss',
97 ) -> 'GridDOSCollection':
98 """Sample the DOS data on an evenly-spaced energy grid
100 Args:
101 npts: Number of sampled points
102 xmin: Minimum sampled energy value; if unspecified, a default is
103 chosen
104 xmax: Maximum sampled energy value; if unspecified, a default is
105 chosen
106 padding: If xmin/xmax is unspecified, default value will be padded
107 by padding * width to avoid cutting off peaks.
108 width: Width of broadening kernel, passed to self.sample_grid()
109 smearing: selection of broadening kernel, for self.sample_grid()
111 Returns:
112 (energy values, sampled DOS)
113 """
114 if len(self) == 0:
115 raise IndexError("No data to sample")
117 if xmin is None:
118 xmin = (min(min(data.get_energies()) for data in self)
119 - (padding * width))
120 if xmax is None:
121 xmax = (max(max(data.get_energies()) for data in self)
122 + (padding * width))
124 return GridDOSCollection(
125 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width,
126 smearing=smearing)
127 for data in self])
129 @classmethod
130 def from_data(cls,
131 energies: Floats,
132 weights: Sequence[Floats],
133 info: Sequence[Info] = None) -> 'DOSCollection':
134 """Create a DOSCollection from data sharing a common set of energies
136 This is a convenience method to be used when all the DOS data in the
137 collection has a common energy axis. There is no performance advantage
138 in using this method for the generic DOSCollection, but for
139 GridDOSCollection it is more efficient.
141 Args:
142 energy: common set of energy values for input data
143 weights: array of DOS weights with rows corresponding to different
144 datasets
145 info: sequence of info dicts corresponding to weights rows.
147 Returns:
148 Collection of DOS data (in RawDOSData format)
149 """
151 info = cls._check_weights_and_info(weights, info)
153 return cls(RawDOSData(energies, row_weights, row_info)
154 for row_weights, row_info in zip(weights, info))
156 @staticmethod
157 def _check_weights_and_info(weights: Sequence[Floats],
158 info: Optional[Sequence[Info]],
159 ) -> Sequence[Info]:
160 if info is None:
161 info = [{} for _ in range(len(weights))]
162 else:
163 if len(info) != len(weights):
164 raise ValueError("Length of info must match number of rows in "
165 "weights")
166 return info
168 @overload
169 def __getitem__(self, item: int) -> DOSData:
170 ...
172 @overload # noqa F811
173 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811
174 ...
176 def __getitem__(self, item): # noqa F811
177 if isinstance(item, int):
178 return self._data[item]
179 elif isinstance(item, slice):
180 return type(self)(self._data[item])
181 else:
182 raise TypeError("index in DOSCollection must be an integer or "
183 "slice")
185 def __len__(self) -> int:
186 return len(self._data)
188 def _almost_equals(self, other: Any) -> bool:
189 """Compare with another DOSCollection for testing purposes"""
190 if not isinstance(other, type(self)):
191 return False
192 elif not len(self) == len(other):
193 return False
194 else:
195 return all(a._almost_equals(b) for a, b in zip(self, other))
197 def total(self) -> DOSData:
198 """Sum all the DOSData in this Collection and label it as 'Total'"""
199 data = self.sum_all()
200 data.info.update({'label': 'Total'})
201 return data
203 def sum_all(self) -> DOSData:
204 """Sum all the DOSData contained in this Collection"""
205 if len(self) == 0:
206 raise IndexError("No data to sum")
207 elif len(self) == 1:
208 data = self[0].copy()
209 else:
210 data = reduce(lambda x, y: x + y, self)
211 return data
213 D = TypeVar('D', bound=DOSData)
215 @staticmethod
216 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes
217 info_selection: Dict[str, str], # misses 'D' def
218 negative: bool = False) -> List[D]: # noqa: F821
219 query = set(info_selection.items())
221 if negative:
222 return [data for data in dos_collection
223 if not query.issubset(set(data.info.items()))]
224 else:
225 return [data for data in dos_collection
226 if query.issubset(set(data.info.items()))]
228 def select(self, **info_selection: str) -> 'DOSCollection':
229 """Narrow DOSCollection to items with specified info
231 For example, if ::
233 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
234 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
236 then ::
238 dc.select(b='1')
240 will return an identical object to dc, while ::
242 dc.select(a='1')
244 will return a DOSCollection with only the first item and ::
246 dc.select(a='2', b='1')
248 will return a DOSCollection with only the second item.
250 """
252 matches = self._select_to_list(self, info_selection)
253 return type(self)(matches)
255 def select_not(self, **info_selection: str) -> 'DOSCollection':
256 """Narrow DOSCollection to items without specified info
258 For example, if ::
260 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
261 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
263 then ::
265 dc.select_not(b='2')
267 will return an identical object to dc, while ::
269 dc.select_not(a='2')
271 will return a DOSCollection with only the first item and ::
273 dc.select_not(a='1', b='1')
275 will return a DOSCollection with only the second item.
277 """
278 matches = self._select_to_list(self, info_selection, negative=True)
279 return type(self)(matches)
281 def sum_by(self, *info_keys: str) -> 'DOSCollection':
282 """Return a DOSCollection with some data summed by common attributes
284 For example, if ::
286 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
287 DOSData(x2, y2, info={'a': '2', 'b': '1'}),
288 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
290 then ::
292 dc.sum_by('b')
294 will return a collection equivalent to ::
296 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'})
297 + DOSData(x2, y2, info={'a': '2', 'b': '1'}),
298 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
300 where the resulting contained DOSData have info attributes of
301 {'b': '1'} and {'b': '2'} respectively.
303 dc.sum_by('a', 'b') on the other hand would return the full three-entry
304 collection, as none of the entries have common 'a' *and* 'b' info.
306 """
308 def _matching_info_tuples(data: DOSData):
309 """Get relevent dict entries in tuple form
311 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3}
312 and info_keys = ('a', 'c')
314 then return (('a', 1), ('c': 3))
315 """
316 matched_keys = set(info_keys) & set(data.info)
317 return tuple(sorted([(key, data.info[key])
318 for key in matched_keys]))
320 # Sorting inside info matching helps set() to remove redundant matches;
321 # combos are then sorted() to ensure consistent output across sessions.
322 all_combos = map(_matching_info_tuples, self)
323 unique_combos = sorted(set(all_combos))
325 # For each key/value combination, perform a select() to obtain all
326 # the matching entries and sum them together.
327 collection_data = [self.select(**dict(combo)).sum_all()
328 for combo in unique_combos]
329 return type(self)(collection_data)
331 def __add__(self, other: Union['DOSCollection', DOSData]
332 ) -> 'DOSCollection':
333 """Join entries between two DOSCollection objects of the same type
335 It is also possible to add a single DOSData object without wrapping it
336 in a new collection: i.e. ::
338 DOSCollection([dosdata1]) + DOSCollection([dosdata2])
340 or ::
342 DOSCollection([dosdata1]) + dosdata2
344 will return ::
346 DOSCollection([dosdata1, dosdata2])
348 """
349 return _add_to_collection(other, self)
352@singledispatch
353def _add_to_collection(other: Union[DOSData, DOSCollection],
354 collection: DOSCollection) -> DOSCollection:
355 if isinstance(other, type(collection)):
356 return type(collection)(list(collection) + list(other))
357 elif isinstance(other, DOSCollection):
358 raise TypeError("Only DOSCollection objects of the same type may "
359 "be joined with '+'.")
360 else:
361 raise TypeError("DOSCollection may only be joined to DOSData or "
362 "DOSCollection objects with '+'.")
365@_add_to_collection.register(DOSData)
366def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection:
367 """Return a new DOSCollection with an additional DOSData item"""
368 return type(collection)(list(collection) + [other])
371class RawDOSCollection(DOSCollection):
372 def __init__(self, dos_series: Iterable[RawDOSData]) -> None:
373 super().__init__(dos_series)
374 for dos_data in self:
375 if not isinstance(dos_data, RawDOSData):
376 raise TypeError("RawDOSCollection can only store "
377 "RawDOSData objects.")
380class GridDOSCollection(DOSCollection):
381 def __init__(self, dos_series: Iterable[GridDOSData],
382 energies: Optional[Floats] = None) -> None:
383 dos_list = list(dos_series)
384 if energies is None:
385 if len(dos_list) == 0:
386 raise ValueError("Must provide energies to create a "
387 "GridDOSCollection without any DOS data.")
388 self._energies = dos_list[0].get_energies()
389 else:
390 self._energies = np.asarray(energies)
392 self._weights = np.empty((len(dos_list), len(self._energies)), float)
393 self._info = []
395 for i, dos_data in enumerate(dos_list):
396 if not isinstance(dos_data, GridDOSData):
397 raise TypeError("GridDOSCollection can only store "
398 "GridDOSData objects.")
399 if (dos_data.get_energies().shape != self._energies.shape
400 or not np.allclose(dos_data.get_energies(),
401 self._energies)):
402 raise ValueError("All GridDOSData objects in GridDOSCollection"
403 " must have the same energy axis.")
404 self._weights[i, :] = dos_data.get_weights()
405 self._info.append(dos_data.info)
407 def get_energies(self) -> Floats:
408 return self._energies.copy()
410 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]:
411 return self._weights.copy()
413 def __len__(self) -> int:
414 return self._weights.shape[0]
416 @overload # noqa F811
417 def __getitem__(self, item: int) -> DOSData:
418 ...
420 @overload # noqa F811
421 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811
422 ...
424 def __getitem__(self, item): # noqa F811
425 if isinstance(item, int):
426 return GridDOSData(self._energies, self._weights[item, :],
427 info=self._info[item])
428 elif isinstance(item, slice):
429 return type(self)([self[i] for i in range(len(self))[item]])
430 else:
431 raise TypeError("index in DOSCollection must be an integer or "
432 "slice")
434 @classmethod
435 def from_data(cls,
436 energies: Floats,
437 weights: Sequence[Floats],
438 info: Sequence[Info] = None) -> 'GridDOSCollection':
439 """Create a GridDOSCollection from data with a common set of energies
441 This convenience method may also be more efficient as it limits
442 redundant copying/checking of the data.
444 Args:
445 energies: common set of energy values for input data
446 weights: array of DOS weights with rows corresponding to different
447 datasets
448 info: sequence of info dicts corresponding to weights rows.
450 Returns:
451 Collection of DOS data (in RawDOSData format)
452 """
454 weights_array = np.asarray(weights, dtype=float)
455 if len(weights_array.shape) != 2:
456 raise IndexError("Weights must be a 2-D array or nested sequence")
457 if weights_array.shape[0] < 1:
458 raise IndexError("Weights cannot be empty")
459 if weights_array.shape[1] != len(energies):
460 raise IndexError("Length of weights rows must equal size of x")
462 info = cls._check_weights_and_info(weights, info)
464 dos_collection = cls([GridDOSData(energies, weights_array[0])])
465 dos_collection._weights = weights_array
466 dos_collection._info = list(info)
468 return dos_collection
470 def select(self, **info_selection: str) -> 'DOSCollection':
471 """Narrow GridDOSCollection to items with specified info
473 For example, if ::
475 dc = GridDOSCollection([GridDOSData(x, y1,
476 info={'a': '1', 'b': '1'}),
477 GridDOSData(x, y2,
478 info={'a': '2', 'b': '1'})])
480 then ::
482 dc.select(b='1')
484 will return an identical object to dc, while ::
486 dc.select(a='1')
488 will return a DOSCollection with only the first item and ::
490 dc.select(a='2', b='1')
492 will return a DOSCollection with only the second item.
494 """
496 matches = self._select_to_list(self, info_selection)
497 if len(matches) == 0:
498 return type(self)([], energies=self._energies)
499 else:
500 return type(self)(matches)
502 def select_not(self, **info_selection: str) -> 'DOSCollection':
503 """Narrow GridDOSCollection to items without specified info
505 For example, if ::
507 dc = GridDOSCollection([GridDOSData(x, y1,
508 info={'a': '1', 'b': '1'}),
509 GridDOSData(x, y2,
510 info={'a': '2', 'b': '1'})])
512 then ::
514 dc.select_not(b='2')
516 will return an identical object to dc, while ::
518 dc.select_not(a='2')
520 will return a DOSCollection with only the first item and ::
522 dc.select_not(a='1', b='1')
524 will return a DOSCollection with only the second item.
526 """
527 matches = self._select_to_list(self, info_selection, negative=True)
528 if len(matches) == 0:
529 return type(self)([], energies=self._energies)
530 else:
531 return type(self)(matches)
533 def plot(self,
534 npts: int = 0,
535 xmin: float = None,
536 xmax: float = None,
537 width: float = None,
538 smearing: str = 'Gauss',
539 ax: 'matplotlib.axes.Axes' = None,
540 show: bool = False,
541 filename: str = None,
542 mplargs: dict = None) -> 'matplotlib.axes.Axes':
543 """Simple plot of collected DOS data, resampled onto a grid
545 If the special key 'label' is present in self.info, this will be set
546 as the label for the plotted line (unless overruled in mplargs). The
547 label is only seen if a legend is added to the plot (i.e. by calling
548 `ax.legend()`).
550 Args:
551 npts:
552 Number of points in resampled x-axis. If set to zero (default),
553 no resampling is performed and the stored data is plotted
554 directly.
555 xmin, xmax:
556 output data range; this limits the resampling range as well as
557 the plotting output
558 width: Width of broadening kernel, passed to self.sample()
559 smearing: selection of broadening kernel, passed to self.sample()
560 ax: existing Matplotlib axes object. If not provided, a new figure
561 with one set of axes will be created using Pyplot
562 show: show the figure on-screen
563 filename: if a path is given, save the figure to this file
564 mplargs: additional arguments to pass to matplotlib plot command
565 (e.g. {'linewidth': 2} for a thicker line).
567 Returns:
568 Plotting axes. If "ax" was set, this is the same object.
569 """
571 # Apply defaults if necessary
572 npts, width = GridDOSData._interpret_smearing_args(npts, width)
574 if npts:
575 assert isinstance(width, float)
576 dos = self.sample_grid(npts,
577 xmin=xmin, xmax=xmax,
578 width=width, smearing=smearing)
579 else:
580 dos = self
582 energies, all_y = dos._energies, dos._weights
584 all_labels = [DOSData.label_from_info(data.info) for data in self]
586 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
587 self._plot_broadened(ax, energies, all_y, all_labels, mplargs)
589 return ax
591 @staticmethod
592 def _plot_broadened(ax: 'matplotlib.axes.Axes',
593 energies: Floats,
594 all_y: np.ndarray,
595 all_labels: Sequence[str],
596 mplargs: Optional[Dict]):
597 """Plot DOS data with labels to axes
599 This is separated into another function so that subclasses can
600 manipulate broadening, labels etc in their plot() method."""
601 if mplargs is None:
602 mplargs = {}
604 all_lines = ax.plot(energies, all_y.T, **mplargs)
605 for line, label in zip(all_lines, all_labels):
606 line.set_label(label)
607 ax.legend()
609 ax.set_xlim(left=min(energies), right=max(energies))
610 ax.set_ylim(bottom=0)