Coverage for /builds/kinetik161/ase/ase/spectrum/dosdata.py: 100.00%

152 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1# Refactor of DOS-like data objects 

2# towards replacing ase.dft.dos and ase.dft.pdos 

3import warnings 

4from abc import ABCMeta, abstractmethod 

5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union 

6 

7import numpy as np 

8 

9from ase.utils.plotting import SimplePlottingAxes 

10 

11# This import is for the benefit of type-checking / mypy 

12if False: 

13 import matplotlib.axes 

14 

15# For now we will be strict about Info and say it has to be str->str. Perhaps 

16# later we will allow other types that have reliable comparison operations. 

17Info = Dict[str, str] 

18 

19# Still no good solution to type checking with arrays. 

20Floats = Union[Sequence[float], np.ndarray] 

21 

22 

23class DOSData(metaclass=ABCMeta): 

24 """Abstract base class for a single series of DOS-like data 

25 

26 Only the 'info' is a mutable attribute; DOS data is set at init""" 

27 

28 def __init__(self, 

29 info: Info = None) -> None: 

30 if info is None: 

31 self.info = {} 

32 elif isinstance(info, dict): 

33 self.info = info 

34 else: 

35 raise TypeError("Info must be a dict or None") 

36 

37 @abstractmethod 

38 def get_energies(self) -> Floats: 

39 """Get energy data stored in this object""" 

40 

41 @abstractmethod 

42 def get_weights(self) -> Floats: 

43 """Get DOS weights stored in this object""" 

44 

45 @abstractmethod 

46 def copy(self) -> 'DOSData': 

47 """Returns a copy in which info dict can be safely mutated""" 

48 

49 def _sample(self, 

50 energies: Floats, 

51 width: float = 0.1, 

52 smearing: str = 'Gauss') -> np.ndarray: 

53 """Sample the DOS data at chosen points, with broadening 

54 

55 Note that no correction is made here for the sampling bin width; total 

56 intensity will vary with sampling density. 

57 

58 Args: 

59 energies: energy values for sampling 

60 width: Width of broadening kernel 

61 smearing: selection of broadening kernel (only "Gauss" is currently 

62 supported) 

63 

64 Returns: 

65 Weights sampled from a broadened DOS at values corresponding to x 

66 """ 

67 

68 self._check_positive_width(width) 

69 weights_grid = np.zeros(len(energies), float) 

70 weights = self.get_weights() 

71 energies = np.asarray(energies, float) 

72 

73 for i, raw_energy in enumerate(self.get_energies()): 

74 delta = self._delta(energies, raw_energy, width, smearing=smearing) 

75 weights_grid += weights[i] * delta 

76 return weights_grid 

77 

78 def _almost_equals(self, other: Any) -> bool: 

79 """Compare with another DOSData for testing purposes""" 

80 if not isinstance(other, type(self)): 

81 return False 

82 if self.info != other.info: 

83 return False 

84 if not np.allclose(self.get_weights(), other.get_weights()): 

85 return False 

86 return np.allclose(self.get_energies(), other.get_energies()) 

87 

88 @staticmethod 

89 def _delta(x: np.ndarray, 

90 x0: float, 

91 width: float, 

92 smearing: str = 'Gauss') -> np.ndarray: 

93 """Return a delta-function centered at 'x0'. 

94 

95 This function is used with numpy broadcasting; if x is a row and x0 is 

96 a column vector, the returned data will be a 2D array with each row 

97 corresponding to a different delta center. 

98 """ 

99 if smearing.lower() == 'gauss': 

100 x1 = -0.5 * ((x - x0) / width)**2 

101 return np.exp(x1) / (np.sqrt(2 * np.pi) * width) 

102 else: 

103 msg = 'Requested smearing type not recognized. Got {}'.format( 

104 smearing) 

105 raise ValueError(msg) 

106 

107 @staticmethod 

108 def _check_positive_width(width): 

109 if width <= 0.0: 

110 msg = 'Cannot add 0 or negative width smearing' 

111 raise ValueError(msg) 

112 

113 def sample_grid(self, 

114 npts: int, 

115 xmin: float = None, 

116 xmax: float = None, 

117 padding: float = 3, 

118 width: float = 0.1, 

119 smearing: str = 'Gauss', 

120 ) -> 'GridDOSData': 

121 """Sample the DOS data on an evenly-spaced energy grid 

122 

123 Args: 

124 npts: Number of sampled points 

125 xmin: Minimum sampled x value; if unspecified, a default is chosen 

126 xmax: Maximum sampled x value; if unspecified, a default is chosen 

127 padding: If xmin/xmax is unspecified, default value will be padded 

128 by padding * width to avoid cutting off peaks. 

129 width: Width of broadening kernel 

130 smearing: selection of broadening kernel (only 'Gauss' is 

131 implemented) 

132 

133 Returns: 

134 (energy values, sampled DOS) 

135 """ 

136 

137 if xmin is None: 

138 xmin = min(self.get_energies()) - (padding * width) 

139 if xmax is None: 

140 xmax = max(self.get_energies()) + (padding * width) 

141 energies_grid = np.linspace(xmin, xmax, npts) 

142 weights_grid = self._sample(energies_grid, width=width, 

143 smearing=smearing) 

144 

145 return GridDOSData(energies_grid, weights_grid, info=self.info.copy()) 

146 

147 def plot(self, 

148 npts: int = 1000, 

149 xmin: float = None, 

150 xmax: float = None, 

151 width: float = 0.1, 

152 smearing: str = 'Gauss', 

153 ax: 'matplotlib.axes.Axes' = None, 

154 show: bool = False, 

155 filename: str = None, 

156 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

157 """Simple 1-D plot of DOS data, resampled onto a grid 

158 

159 If the special key 'label' is present in self.info, this will be set 

160 as the label for the plotted line (unless overruled in mplargs). The 

161 label is only seen if a legend is added to the plot (i.e. by calling 

162 ``ax.legend()``). 

163 

164 Args: 

165 npts, xmin, xmax: output data range, as passed to self.sample_grid 

166 width: Width of broadening kernel for self.sample_grid() 

167 smearing: selection of broadening kernel for self.sample_grid() 

168 ax: existing Matplotlib axes object. If not provided, a new figure 

169 with one set of axes will be created using Pyplot 

170 show: show the figure on-screen 

171 filename: if a path is given, save the figure to this file 

172 mplargs: additional arguments to pass to matplotlib plot command 

173 (e.g. {'linewidth': 2} for a thicker line). 

174 

175 

176 Returns: 

177 Plotting axes. If "ax" was set, this is the same object. 

178 """ 

179 

180 if mplargs is None: 

181 mplargs = {} 

182 if 'label' not in mplargs: 

183 mplargs.update({'label': self.label_from_info(self.info)}) 

184 

185 return self.sample_grid(npts, xmin=xmin, xmax=xmax, 

186 width=width, 

187 smearing=smearing 

188 ).plot(ax=ax, xmin=xmin, xmax=xmax, 

189 show=show, filename=filename, 

190 mplargs=mplargs) 

191 

192 @staticmethod 

193 def label_from_info(info: Dict[str, str]): 

194 """Generate an automatic legend label from info dict""" 

195 if 'label' in info: 

196 return info['label'] 

197 else: 

198 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}', 

199 info.items())) 

200 

201 

202class GeneralDOSData(DOSData): 

203 """Base class for a single series of DOS-like data 

204 

205 Only the 'info' is a mutable attribute; DOS data is set at init 

206 

207 This is the base class for DOSData objects that accept/set seperate 

208 "energies" and "weights" sequences of equal length at init. 

209 

210 """ 

211 

212 def __init__(self, 

213 energies: Floats, 

214 weights: Floats, 

215 info: Info = None) -> None: 

216 super().__init__(info=info) 

217 

218 n_entries = len(energies) 

219 if len(weights) != n_entries: 

220 raise ValueError("Energies and weights must be the same length") 

221 

222 # Internally store the data as a np array with two rows; energy, weight 

223 self._data = np.empty((2, n_entries), dtype=float, order='C') 

224 self._data[0, :] = energies 

225 self._data[1, :] = weights 

226 

227 def get_energies(self) -> np.ndarray: 

228 return self._data[0, :].copy() 

229 

230 def get_weights(self) -> np.ndarray: 

231 return self._data[1, :].copy() 

232 

233 D = TypeVar('D', bound='GeneralDOSData') 

234 

235 def copy(self: D) -> D: # noqa F821 

236 return type(self)(self.get_energies(), self.get_weights(), 

237 info=self.info.copy()) 

238 

239 

240class RawDOSData(GeneralDOSData): 

241 """A collection of weighted delta functions which sum to form a DOS 

242 

243 This is an appropriate data container for density-of-states (DOS) or 

244 spectral data where the energy data values not form a known regular 

245 grid. The data may be plotted or resampled for further analysis using the 

246 sample_grid() and plot() methods. Multiple weights at the same 

247 energy value will *only* be combined in output data, and data stored in 

248 RawDOSData is never resampled. A plot_deltas() function is also provided 

249 which plots the raw data. 

250 

251 Metadata may be stored in the info dict, in which keys and values must be 

252 strings. This data is used for selecting and combining multiple DOSData 

253 objects in a DOSCollection object. 

254 

255 When RawDOSData objects are combined with the addition operator:: 

256 

257 big_dos = raw_dos_1 + raw_dos_2 

258 

259 the energy and weights data is *concatenated* (i.e. combined without 

260 sorting or replacement) and the new info dictionary consists of the 

261 *intersection* of the inputs: only key-value pairs that were common to both 

262 of the input objects will be retained in the new combined object. For 

263 example:: 

264 

265 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'}) 

266 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'})) 

267 

268 will yield the equivalent of:: 

269 

270 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'}) 

271 

272 """ 

273 

274 def __add__(self, other: 'RawDOSData') -> 'RawDOSData': 

275 if not isinstance(other, RawDOSData): 

276 raise TypeError("RawDOSData can only be combined with other " 

277 "RawDOSData objects") 

278 

279 # Take intersection of metadata (i.e. only common entries are retained) 

280 new_info = dict(set(self.info.items()) & set(other.info.items())) 

281 

282 # Concatenate the energy/weight data 

283 new_data = np.concatenate((self._data, other._data), axis=1) 

284 

285 new_object = RawDOSData([], [], info=new_info) 

286 new_object._data = new_data 

287 

288 return new_object 

289 

290 def plot_deltas(self, 

291 ax: 'matplotlib.axes.Axes' = None, 

292 show: bool = False, 

293 filename: str = None, 

294 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

295 """Simple plot of sparse DOS data as a set of delta functions 

296 

297 Items at the same x-value can overlap and will not be summed together 

298 

299 Args: 

300 ax: existing Matplotlib axes object. If not provided, a new figure 

301 with one set of axes will be created using Pyplot 

302 show: show the figure on-screen 

303 filename: if a path is given, save the figure to this file 

304 mplargs: additional arguments to pass to matplotlib Axes.vlines 

305 command (e.g. {'linewidth': 2} for a thicker line). 

306 

307 Returns: 

308 Plotting axes. If "ax" was set, this is the same object. 

309 """ 

310 

311 if mplargs is None: 

312 mplargs = {} 

313 

314 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

315 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs) 

316 

317 return ax 

318 

319 

320class GridDOSData(GeneralDOSData): 

321 """A collection of regularly-sampled data which represents a DOS 

322 

323 This is an appropriate data container for density-of-states (DOS) or 

324 spectral data where the intensity values form a regular grid. This 

325 is generally the result of sampling or integrating into discrete 

326 bins, rather than a collection of unique states. The data may be 

327 plotted or resampled for further analysis using the sample_grid() 

328 and plot() methods. 

329 

330 Metadata may be stored in the info dict, in which keys and values must be 

331 strings. This data is used for selecting and combining multiple DOSData 

332 objects in a DOSCollection object. 

333 

334 When RawDOSData objects are combined with the addition operator:: 

335 

336 big_dos = raw_dos_1 + raw_dos_2 

337 

338 the weights data is *summed* (requiring a consistent energy grid) and the 

339 new info dictionary consists of the *intersection* of the inputs: only 

340 key-value pairs that were common to both of the input objects will be 

341 retained in the new combined object. For example:: 

342 

343 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3], 

344 info={'symbol': 'O', 'index': '1'}) 

345 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6], 

346 info={'symbol': 'O', 'index': '2'})) 

347 

348 will yield the equivalent of:: 

349 

350 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'}) 

351 

352 """ 

353 

354 def __init__(self, 

355 energies: Floats, 

356 weights: Floats, 

357 info: Info = None) -> None: 

358 n_entries = len(energies) 

359 if not np.allclose(energies, 

360 np.linspace(energies[0], energies[-1], n_entries)): 

361 raise ValueError("Energies must be an evenly-spaced 1-D grid") 

362 

363 if len(weights) != n_entries: 

364 raise ValueError("Energies and weights must be the same length") 

365 

366 super().__init__(energies, weights, info=info) 

367 self.sigma_cutoff = 3 

368 

369 def _check_spacing(self, width) -> float: 

370 current_spacing = self._data[0, 1] - self._data[0, 0] 

371 if width < (2 * current_spacing): 

372 warnings.warn( 

373 "The broadening width is small compared to the original " 

374 "sampling density. The results are unlikely to be smooth.") 

375 return current_spacing 

376 

377 def _sample(self, 

378 energies: Floats, 

379 width: float = 0.1, 

380 smearing: str = 'Gauss') -> np.ndarray: 

381 current_spacing = self._check_spacing(width) 

382 return super()._sample(energies=energies, 

383 width=width, smearing=smearing 

384 ) * current_spacing 

385 

386 def __add__(self, other: 'GridDOSData') -> 'GridDOSData': 

387 # This method uses direct access to the mutable energy and weights data 

388 # (self._data) to avoid redundant copying operations. The __init__ 

389 # method of GridDOSData will write this to a new array, so on this 

390 # occasion it is safe to pass references to the mutable data. 

391 

392 if not isinstance(other, GridDOSData): 

393 raise TypeError("GridDOSData can only be combined with other " 

394 "GridDOSData objects") 

395 if len(self._data[0, :]) != len(other.get_energies()): 

396 raise ValueError("Cannot add GridDOSData objects with different-" 

397 "length energy grids.") 

398 

399 if not np.allclose(self._data[0, :], other.get_energies()): 

400 raise ValueError("Cannot add GridDOSData objects with different " 

401 "energy grids.") 

402 

403 # Take intersection of metadata (i.e. only common entries are retained) 

404 new_info = dict(set(self.info.items()) & set(other.info.items())) 

405 

406 # Sum the energy/weight data 

407 new_weights = self._data[1, :] + other.get_weights() 

408 

409 new_object = GridDOSData(self._data[0, :], new_weights, 

410 info=new_info) 

411 return new_object 

412 

413 @staticmethod 

414 def _interpret_smearing_args(npts: int, 

415 width: float = None, 

416 default_npts: int = 1000, 

417 default_width: float = 0.1 

418 ) -> Tuple[int, Union[float, None]]: 

419 """Figure out what the user intended: resample if width provided""" 

420 if width is not None: 

421 if npts: 

422 return (npts, float(width)) 

423 else: 

424 return (default_npts, float(width)) 

425 else: 

426 if npts: 

427 return (npts, default_width) 

428 else: 

429 return (0, None) 

430 

431 def plot(self, 

432 npts: int = 0, 

433 xmin: float = None, 

434 xmax: float = None, 

435 width: float = None, 

436 smearing: str = 'Gauss', 

437 ax: 'matplotlib.axes.Axes' = None, 

438 show: bool = False, 

439 filename: str = None, 

440 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

441 """Simple 1-D plot of DOS data 

442 

443 Data will be resampled onto a grid with `npts` points unless `npts` is 

444 set to zero, in which case: 

445 

446 - no resampling takes place 

447 - `width` and `smearing` are ignored 

448 - `xmin` and `xmax` affect the axis limits of the plot, not the 

449 underlying data. 

450 

451 If the special key 'label' is present in self.info, this will be set 

452 as the label for the plotted line (unless overruled in mplargs). The 

453 label is only seen if a legend is added to the plot (i.e. by calling 

454 ``ax.legend()``). 

455 

456 Args: 

457 npts, xmin, xmax: output data range, as passed to self.sample_grid 

458 width: Width of broadening kernel, passed to self.sample_grid(). 

459 If no npts was set but width is set, npts will be set to 1000. 

460 smearing: selection of broadening kernel for self.sample_grid() 

461 ax: existing Matplotlib axes object. If not provided, a new figure 

462 with one set of axes will be created using Pyplot 

463 show: show the figure on-screen 

464 filename: if a path is given, save the figure to this file 

465 mplargs: additional arguments to pass to matplotlib plot command 

466 (e.g. {'linewidth': 2} for a thicker line). 

467 

468 Returns: 

469 Plotting axes. If "ax" was set, this is the same object. 

470 """ 

471 

472 npts, width = self._interpret_smearing_args(npts, width) 

473 

474 if mplargs is None: 

475 mplargs = {} 

476 if 'label' not in mplargs: 

477 mplargs.update({'label': self.label_from_info(self.info)}) 

478 

479 if npts: 

480 assert isinstance(width, float) 

481 dos = self.sample_grid(npts, xmin=xmin, 

482 xmax=xmax, width=width, 

483 smearing=smearing) 

484 else: 

485 dos = self 

486 

487 energies, intensity = dos.get_energies(), dos.get_weights() 

488 

489 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

490 ax.plot(energies, intensity, **mplargs) 

491 ax.set_xlim(left=xmin, right=xmax) 

492 

493 return ax