Coverage for /builds/kinetik161/ase/ase/spectrum/doscollection.py: 97.84%

185 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import collections 

2from functools import reduce, singledispatch 

3from typing import (Any, Dict, Iterable, List, Optional, Sequence, TypeVar, 

4 Union, overload) 

5 

6import numpy as np 

7 

8from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData 

9from ase.utils.plotting import SimplePlottingAxes 

10 

11# This import is for the benefit of type-checking / mypy 

12if False: 

13 import matplotlib.axes 

14 

15 

16class DOSCollection(collections.abc.Sequence): 

17 """Base class for a collection of DOSData objects""" 

18 

19 def __init__(self, dos_series: Iterable[DOSData]) -> None: 

20 self._data = list(dos_series) 

21 

22 def _sample(self, 

23 energies: Floats, 

24 width: float = 0.1, 

25 smearing: str = 'Gauss') -> np.ndarray: 

26 """Sample the DOS data at chosen points, with broadening 

27 

28 This samples the underlying DOS data in the same way as the .sample() 

29 method of those DOSData items, returning a 2-D array with columns 

30 corresponding to x and rows corresponding to the collected data series. 

31 

32 Args: 

33 energies: energy values for sampling 

34 width: Width of broadening kernel 

35 smearing: selection of broadening kernel (only "Gauss" is currently 

36 supported) 

37 

38 Returns: 

39 Weights sampled from a broadened DOS at values corresponding to x, 

40 in rows corresponding to DOSData entries contained in this object 

41 """ 

42 

43 if len(self) == 0: 

44 raise IndexError("No data to sample") 

45 

46 return np.asarray( 

47 [data._sample(energies, width=width, smearing=smearing) 

48 for data in self]) 

49 

50 def plot(self, 

51 npts: int = 1000, 

52 xmin: float = None, 

53 xmax: float = None, 

54 width: float = 0.1, 

55 smearing: str = 'Gauss', 

56 ax: 'matplotlib.axes.Axes' = None, 

57 show: bool = False, 

58 filename: str = None, 

59 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

60 """Simple plot of collected DOS data, resampled onto a grid 

61 

62 If the special key 'label' is present in self.info, this will be set 

63 as the label for the plotted line (unless overruled in mplargs). The 

64 label is only seen if a legend is added to the plot (i.e. by calling 

65 `ax.legend()`). 

66 

67 Args: 

68 npts, xmin, xmax: output data range, as passed to self.sample_grid 

69 width: Width of broadening kernel, passed to self.sample_grid() 

70 smearing: selection of broadening kernel for self.sample_grid() 

71 ax: existing Matplotlib axes object. If not provided, a new figure 

72 with one set of axes will be created using Pyplot 

73 show: show the figure on-screen 

74 filename: if a path is given, save the figure to this file 

75 mplargs: additional arguments to pass to matplotlib plot command 

76 (e.g. {'linewidth': 2} for a thicker line). 

77 

78 Returns: 

79 Plotting axes. If "ax" was set, this is the same object. 

80 """ 

81 return self.sample_grid(npts, 

82 xmin=xmin, xmax=xmax, 

83 width=width, smearing=smearing 

84 ).plot(npts=npts, 

85 xmin=xmin, xmax=xmax, 

86 width=width, smearing=smearing, 

87 ax=ax, show=show, filename=filename, 

88 mplargs=mplargs) 

89 

90 def sample_grid(self, 

91 npts: int, 

92 xmin: float = None, 

93 xmax: float = None, 

94 padding: float = 3, 

95 width: float = 0.1, 

96 smearing: str = 'Gauss', 

97 ) -> 'GridDOSCollection': 

98 """Sample the DOS data on an evenly-spaced energy grid 

99 

100 Args: 

101 npts: Number of sampled points 

102 xmin: Minimum sampled energy value; if unspecified, a default is 

103 chosen 

104 xmax: Maximum sampled energy value; if unspecified, a default is 

105 chosen 

106 padding: If xmin/xmax is unspecified, default value will be padded 

107 by padding * width to avoid cutting off peaks. 

108 width: Width of broadening kernel, passed to self.sample_grid() 

109 smearing: selection of broadening kernel, for self.sample_grid() 

110 

111 Returns: 

112 (energy values, sampled DOS) 

113 """ 

114 if len(self) == 0: 

115 raise IndexError("No data to sample") 

116 

117 if xmin is None: 

118 xmin = (min(min(data.get_energies()) for data in self) 

119 - (padding * width)) 

120 if xmax is None: 

121 xmax = (max(max(data.get_energies()) for data in self) 

122 + (padding * width)) 

123 

124 return GridDOSCollection( 

125 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width, 

126 smearing=smearing) 

127 for data in self]) 

128 

129 @classmethod 

130 def from_data(cls, 

131 energies: Floats, 

132 weights: Sequence[Floats], 

133 info: Sequence[Info] = None) -> 'DOSCollection': 

134 """Create a DOSCollection from data sharing a common set of energies 

135 

136 This is a convenience method to be used when all the DOS data in the 

137 collection has a common energy axis. There is no performance advantage 

138 in using this method for the generic DOSCollection, but for 

139 GridDOSCollection it is more efficient. 

140 

141 Args: 

142 energy: common set of energy values for input data 

143 weights: array of DOS weights with rows corresponding to different 

144 datasets 

145 info: sequence of info dicts corresponding to weights rows. 

146 

147 Returns: 

148 Collection of DOS data (in RawDOSData format) 

149 """ 

150 

151 info = cls._check_weights_and_info(weights, info) 

152 

153 return cls(RawDOSData(energies, row_weights, row_info) 

154 for row_weights, row_info in zip(weights, info)) 

155 

156 @staticmethod 

157 def _check_weights_and_info(weights: Sequence[Floats], 

158 info: Optional[Sequence[Info]], 

159 ) -> Sequence[Info]: 

160 if info is None: 

161 info = [{} for _ in range(len(weights))] 

162 else: 

163 if len(info) != len(weights): 

164 raise ValueError("Length of info must match number of rows in " 

165 "weights") 

166 return info 

167 

168 @overload 

169 def __getitem__(self, item: int) -> DOSData: 

170 ... 

171 

172 @overload # noqa F811 

173 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811 

174 ... 

175 

176 def __getitem__(self, item): # noqa F811 

177 if isinstance(item, int): 

178 return self._data[item] 

179 elif isinstance(item, slice): 

180 return type(self)(self._data[item]) 

181 else: 

182 raise TypeError("index in DOSCollection must be an integer or " 

183 "slice") 

184 

185 def __len__(self) -> int: 

186 return len(self._data) 

187 

188 def _almost_equals(self, other: Any) -> bool: 

189 """Compare with another DOSCollection for testing purposes""" 

190 if not isinstance(other, type(self)): 

191 return False 

192 elif not len(self) == len(other): 

193 return False 

194 else: 

195 return all(a._almost_equals(b) for a, b in zip(self, other)) 

196 

197 def total(self) -> DOSData: 

198 """Sum all the DOSData in this Collection and label it as 'Total'""" 

199 data = self.sum_all() 

200 data.info.update({'label': 'Total'}) 

201 return data 

202 

203 def sum_all(self) -> DOSData: 

204 """Sum all the DOSData contained in this Collection""" 

205 if len(self) == 0: 

206 raise IndexError("No data to sum") 

207 elif len(self) == 1: 

208 data = self[0].copy() 

209 else: 

210 data = reduce(lambda x, y: x + y, self) 

211 return data 

212 

213 D = TypeVar('D', bound=DOSData) 

214 

215 @staticmethod 

216 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes 

217 info_selection: Dict[str, str], # misses 'D' def 

218 negative: bool = False) -> List[D]: # noqa: F821 

219 query = set(info_selection.items()) 

220 

221 if negative: 

222 return [data for data in dos_collection 

223 if not query.issubset(set(data.info.items()))] 

224 else: 

225 return [data for data in dos_collection 

226 if query.issubset(set(data.info.items()))] 

227 

228 def select(self, **info_selection: str) -> 'DOSCollection': 

229 """Narrow DOSCollection to items with specified info 

230 

231 For example, if :: 

232 

233 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

234 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

235 

236 then :: 

237 

238 dc.select(b='1') 

239 

240 will return an identical object to dc, while :: 

241 

242 dc.select(a='1') 

243 

244 will return a DOSCollection with only the first item and :: 

245 

246 dc.select(a='2', b='1') 

247 

248 will return a DOSCollection with only the second item. 

249 

250 """ 

251 

252 matches = self._select_to_list(self, info_selection) 

253 return type(self)(matches) 

254 

255 def select_not(self, **info_selection: str) -> 'DOSCollection': 

256 """Narrow DOSCollection to items without specified info 

257 

258 For example, if :: 

259 

260 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

261 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

262 

263 then :: 

264 

265 dc.select_not(b='2') 

266 

267 will return an identical object to dc, while :: 

268 

269 dc.select_not(a='2') 

270 

271 will return a DOSCollection with only the first item and :: 

272 

273 dc.select_not(a='1', b='1') 

274 

275 will return a DOSCollection with only the second item. 

276 

277 """ 

278 matches = self._select_to_list(self, info_selection, negative=True) 

279 return type(self)(matches) 

280 

281 def sum_by(self, *info_keys: str) -> 'DOSCollection': 

282 """Return a DOSCollection with some data summed by common attributes 

283 

284 For example, if :: 

285 

286 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

287 DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

288 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

289 

290 then :: 

291 

292 dc.sum_by('b') 

293 

294 will return a collection equivalent to :: 

295 

296 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}) 

297 + DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

298 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

299 

300 where the resulting contained DOSData have info attributes of 

301 {'b': '1'} and {'b': '2'} respectively. 

302 

303 dc.sum_by('a', 'b') on the other hand would return the full three-entry 

304 collection, as none of the entries have common 'a' *and* 'b' info. 

305 

306 """ 

307 

308 def _matching_info_tuples(data: DOSData): 

309 """Get relevent dict entries in tuple form 

310 

311 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3} 

312 and info_keys = ('a', 'c') 

313 

314 then return (('a', 1), ('c': 3)) 

315 """ 

316 matched_keys = set(info_keys) & set(data.info) 

317 return tuple(sorted([(key, data.info[key]) 

318 for key in matched_keys])) 

319 

320 # Sorting inside info matching helps set() to remove redundant matches; 

321 # combos are then sorted() to ensure consistent output across sessions. 

322 all_combos = map(_matching_info_tuples, self) 

323 unique_combos = sorted(set(all_combos)) 

324 

325 # For each key/value combination, perform a select() to obtain all 

326 # the matching entries and sum them together. 

327 collection_data = [self.select(**dict(combo)).sum_all() 

328 for combo in unique_combos] 

329 return type(self)(collection_data) 

330 

331 def __add__(self, other: Union['DOSCollection', DOSData] 

332 ) -> 'DOSCollection': 

333 """Join entries between two DOSCollection objects of the same type 

334 

335 It is also possible to add a single DOSData object without wrapping it 

336 in a new collection: i.e. :: 

337 

338 DOSCollection([dosdata1]) + DOSCollection([dosdata2]) 

339 

340 or :: 

341 

342 DOSCollection([dosdata1]) + dosdata2 

343 

344 will return :: 

345 

346 DOSCollection([dosdata1, dosdata2]) 

347 

348 """ 

349 return _add_to_collection(other, self) 

350 

351 

352@singledispatch 

353def _add_to_collection(other: Union[DOSData, DOSCollection], 

354 collection: DOSCollection) -> DOSCollection: 

355 if isinstance(other, type(collection)): 

356 return type(collection)(list(collection) + list(other)) 

357 elif isinstance(other, DOSCollection): 

358 raise TypeError("Only DOSCollection objects of the same type may " 

359 "be joined with '+'.") 

360 else: 

361 raise TypeError("DOSCollection may only be joined to DOSData or " 

362 "DOSCollection objects with '+'.") 

363 

364 

365@_add_to_collection.register(DOSData) 

366def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection: 

367 """Return a new DOSCollection with an additional DOSData item""" 

368 return type(collection)(list(collection) + [other]) 

369 

370 

371class RawDOSCollection(DOSCollection): 

372 def __init__(self, dos_series: Iterable[RawDOSData]) -> None: 

373 super().__init__(dos_series) 

374 for dos_data in self: 

375 if not isinstance(dos_data, RawDOSData): 

376 raise TypeError("RawDOSCollection can only store " 

377 "RawDOSData objects.") 

378 

379 

380class GridDOSCollection(DOSCollection): 

381 def __init__(self, dos_series: Iterable[GridDOSData], 

382 energies: Optional[Floats] = None) -> None: 

383 dos_list = list(dos_series) 

384 if energies is None: 

385 if len(dos_list) == 0: 

386 raise ValueError("Must provide energies to create a " 

387 "GridDOSCollection without any DOS data.") 

388 self._energies = dos_list[0].get_energies() 

389 else: 

390 self._energies = np.asarray(energies) 

391 

392 self._weights = np.empty((len(dos_list), len(self._energies)), float) 

393 self._info = [] 

394 

395 for i, dos_data in enumerate(dos_list): 

396 if not isinstance(dos_data, GridDOSData): 

397 raise TypeError("GridDOSCollection can only store " 

398 "GridDOSData objects.") 

399 if (dos_data.get_energies().shape != self._energies.shape 

400 or not np.allclose(dos_data.get_energies(), 

401 self._energies)): 

402 raise ValueError("All GridDOSData objects in GridDOSCollection" 

403 " must have the same energy axis.") 

404 self._weights[i, :] = dos_data.get_weights() 

405 self._info.append(dos_data.info) 

406 

407 def get_energies(self) -> Floats: 

408 return self._energies.copy() 

409 

410 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]: 

411 return self._weights.copy() 

412 

413 def __len__(self) -> int: 

414 return self._weights.shape[0] 

415 

416 @overload # noqa F811 

417 def __getitem__(self, item: int) -> DOSData: 

418 ... 

419 

420 @overload # noqa F811 

421 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811 

422 ... 

423 

424 def __getitem__(self, item): # noqa F811 

425 if isinstance(item, int): 

426 return GridDOSData(self._energies, self._weights[item, :], 

427 info=self._info[item]) 

428 elif isinstance(item, slice): 

429 return type(self)([self[i] for i in range(len(self))[item]]) 

430 else: 

431 raise TypeError("index in DOSCollection must be an integer or " 

432 "slice") 

433 

434 @classmethod 

435 def from_data(cls, 

436 energies: Floats, 

437 weights: Sequence[Floats], 

438 info: Sequence[Info] = None) -> 'GridDOSCollection': 

439 """Create a GridDOSCollection from data with a common set of energies 

440 

441 This convenience method may also be more efficient as it limits 

442 redundant copying/checking of the data. 

443 

444 Args: 

445 energies: common set of energy values for input data 

446 weights: array of DOS weights with rows corresponding to different 

447 datasets 

448 info: sequence of info dicts corresponding to weights rows. 

449 

450 Returns: 

451 Collection of DOS data (in RawDOSData format) 

452 """ 

453 

454 weights_array = np.asarray(weights, dtype=float) 

455 if len(weights_array.shape) != 2: 

456 raise IndexError("Weights must be a 2-D array or nested sequence") 

457 if weights_array.shape[0] < 1: 

458 raise IndexError("Weights cannot be empty") 

459 if weights_array.shape[1] != len(energies): 

460 raise IndexError("Length of weights rows must equal size of x") 

461 

462 info = cls._check_weights_and_info(weights, info) 

463 

464 dos_collection = cls([GridDOSData(energies, weights_array[0])]) 

465 dos_collection._weights = weights_array 

466 dos_collection._info = list(info) 

467 

468 return dos_collection 

469 

470 def select(self, **info_selection: str) -> 'DOSCollection': 

471 """Narrow GridDOSCollection to items with specified info 

472 

473 For example, if :: 

474 

475 dc = GridDOSCollection([GridDOSData(x, y1, 

476 info={'a': '1', 'b': '1'}), 

477 GridDOSData(x, y2, 

478 info={'a': '2', 'b': '1'})]) 

479 

480 then :: 

481 

482 dc.select(b='1') 

483 

484 will return an identical object to dc, while :: 

485 

486 dc.select(a='1') 

487 

488 will return a DOSCollection with only the first item and :: 

489 

490 dc.select(a='2', b='1') 

491 

492 will return a DOSCollection with only the second item. 

493 

494 """ 

495 

496 matches = self._select_to_list(self, info_selection) 

497 if len(matches) == 0: 

498 return type(self)([], energies=self._energies) 

499 else: 

500 return type(self)(matches) 

501 

502 def select_not(self, **info_selection: str) -> 'DOSCollection': 

503 """Narrow GridDOSCollection to items without specified info 

504 

505 For example, if :: 

506 

507 dc = GridDOSCollection([GridDOSData(x, y1, 

508 info={'a': '1', 'b': '1'}), 

509 GridDOSData(x, y2, 

510 info={'a': '2', 'b': '1'})]) 

511 

512 then :: 

513 

514 dc.select_not(b='2') 

515 

516 will return an identical object to dc, while :: 

517 

518 dc.select_not(a='2') 

519 

520 will return a DOSCollection with only the first item and :: 

521 

522 dc.select_not(a='1', b='1') 

523 

524 will return a DOSCollection with only the second item. 

525 

526 """ 

527 matches = self._select_to_list(self, info_selection, negative=True) 

528 if len(matches) == 0: 

529 return type(self)([], energies=self._energies) 

530 else: 

531 return type(self)(matches) 

532 

533 def plot(self, 

534 npts: int = 0, 

535 xmin: float = None, 

536 xmax: float = None, 

537 width: float = None, 

538 smearing: str = 'Gauss', 

539 ax: 'matplotlib.axes.Axes' = None, 

540 show: bool = False, 

541 filename: str = None, 

542 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

543 """Simple plot of collected DOS data, resampled onto a grid 

544 

545 If the special key 'label' is present in self.info, this will be set 

546 as the label for the plotted line (unless overruled in mplargs). The 

547 label is only seen if a legend is added to the plot (i.e. by calling 

548 `ax.legend()`). 

549 

550 Args: 

551 npts: 

552 Number of points in resampled x-axis. If set to zero (default), 

553 no resampling is performed and the stored data is plotted 

554 directly. 

555 xmin, xmax: 

556 output data range; this limits the resampling range as well as 

557 the plotting output 

558 width: Width of broadening kernel, passed to self.sample() 

559 smearing: selection of broadening kernel, passed to self.sample() 

560 ax: existing Matplotlib axes object. If not provided, a new figure 

561 with one set of axes will be created using Pyplot 

562 show: show the figure on-screen 

563 filename: if a path is given, save the figure to this file 

564 mplargs: additional arguments to pass to matplotlib plot command 

565 (e.g. {'linewidth': 2} for a thicker line). 

566 

567 Returns: 

568 Plotting axes. If "ax" was set, this is the same object. 

569 """ 

570 

571 # Apply defaults if necessary 

572 npts, width = GridDOSData._interpret_smearing_args(npts, width) 

573 

574 if npts: 

575 assert isinstance(width, float) 

576 dos = self.sample_grid(npts, 

577 xmin=xmin, xmax=xmax, 

578 width=width, smearing=smearing) 

579 else: 

580 dos = self 

581 

582 energies, all_y = dos._energies, dos._weights 

583 

584 all_labels = [DOSData.label_from_info(data.info) for data in self] 

585 

586 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

587 self._plot_broadened(ax, energies, all_y, all_labels, mplargs) 

588 

589 return ax 

590 

591 @staticmethod 

592 def _plot_broadened(ax: 'matplotlib.axes.Axes', 

593 energies: Floats, 

594 all_y: np.ndarray, 

595 all_labels: Sequence[str], 

596 mplargs: Optional[Dict]): 

597 """Plot DOS data with labels to axes 

598 

599 This is separated into another function so that subclasses can 

600 manipulate broadening, labels etc in their plot() method.""" 

601 if mplargs is None: 

602 mplargs = {} 

603 

604 all_lines = ax.plot(energies, all_y.T, **mplargs) 

605 for line, label in zip(all_lines, all_labels): 

606 line.set_label(label) 

607 ax.legend() 

608 

609 ax.set_xlim(left=min(energies), right=max(energies)) 

610 ax.set_ylim(bottom=0)