Coverage for /builds/kinetik161/ase/ase/utils/__init__.py: 83.91%

379 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import errno 

2import functools 

3import io 

4import os 

5import pickle 

6import re 

7import string 

8import sys 

9import time 

10import warnings 

11from contextlib import ExitStack, contextmanager 

12from importlib import import_module 

13from math import atan2, cos, degrees, gcd, radians, sin 

14from pathlib import Path, PurePath 

15from typing import Callable, Dict, List, Type, Union 

16 

17import numpy as np 

18 

19from ase.formula import formula_hill, formula_metal 

20 

21__all__ = ['basestring', 'import_module', 'seterr', 'plural', 

22 'devnull', 'gcd', 'convert_string_to_fd', 'Lock', 

23 'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens', 

24 'hsv2rgb', 'hsv', 'pickleload', 'reader', 

25 'formula_hill', 'formula_metal', 'PurePath', 'xwopen', 

26 'tokenize_version', 'get_python_package_path_description'] 

27 

28 

29def tokenize_version(version_string: str): 

30 """Parse version string into a tuple for version comparisons. 

31 

32 Usage: tokenize_version('3.8') < tokenize_version('3.8.1'). 

33 """ 

34 tokens = [] 

35 for component in version_string.split('.'): 

36 match = re.match(r'(\d*)(.*)', component) 

37 assert match is not None, f'Cannot parse component {component}' 

38 number_str, tail = match.group(1, 2) 

39 try: 

40 number = int(number_str) 

41 except ValueError: 

42 number = -1 

43 tokens += [number, tail] 

44 return tuple(tokens) 

45 

46 

47# Python 2+3 compatibility stuff (let's try to remove these things): 

48basestring = str 

49pickleload = functools.partial(pickle.load, encoding='bytes') 

50 

51 

52def deprecated( 

53 message: Union[str, Warning], 

54 category: Type[Warning] = FutureWarning, 

55 callback: Callable[[List, Dict], bool] = lambda args, kwargs: True 

56): 

57 """Return a decorator deprecating a function. 

58 

59 Parameters 

60 ---------- 

61 message : str or Warning 

62 The message to be emitted. If ``message`` is a Warning, then 

63 ``category`` is ignored and ``message.__class__`` will be used. 

64 category : Type[Warning], default=FutureWarning 

65 The type of warning to be emitted. If ``message`` is a ``Warning`` 

66 instance, then ``category`` will be ignored and ``message.__class__`` 

67 will be used. 

68 callback : Callable[[List, Dict], bool], default=lambda args, kwargs: True 

69 A callable that determines if the warning should be emitted and handles 

70 any processing prior to calling the deprecated function. The callable 

71 will receive two arguments, a list and a dictionary. The list will 

72 contain the positional arguments that the deprecated function was 

73 called with at runtime while the dictionary will contain the keyword 

74 arguments. The callable *must* return ``True`` if the warning is to be 

75 emitted and ``False`` otherwise. The list and dictionary will be 

76 unpacked into the positional and keyword arguments, respectively, used 

77 to call the deprecated function. 

78 

79 Returns 

80 ------- 

81 deprecated_decorator : Callable 

82 A decorator for deprecated functions that can be used to conditionally 

83 emit deprecation warnings and/or pre-process the arguments of a 

84 deprecated function. 

85 

86 Example 

87 ------- 

88 >>> # Inspect & replace a keyword parameter passed to a deprecated function 

89 >>> from typing import Any, Callable, Dict, List 

90 >>> import warnings 

91 >>> from ase.utils import deprecated 

92 

93 >>> def alias_callback_factory(kwarg: str, alias: str) -> Callable: 

94 ... def _replace_arg(_: List, kwargs: Dict[str, Any]) -> bool: 

95 ... kwargs[kwarg] = kwargs[alias] 

96 ... del kwargs[alias] 

97 ... return True 

98 ... return _replace_arg 

99 

100 >>> MESSAGE = ("Calling this function with `atoms` is deprecated. " 

101 ... "Use `optimizable` instead.") 

102 >>> @deprecated( 

103 ... MESSAGE, 

104 ... category=DeprecationWarning, 

105 ... callback=alias_callback_factory("optimizable", "atoms") 

106 ... ) 

107 ... def function(*, atoms=None, optimizable=None): 

108 ... ''' 

109 ... .. deprecated:: 3.23.0 

110 ... Calling this function with ``atoms`` is deprecated. 

111 ... Use ``optimizable`` instead. 

112 ... ''' 

113 ... print(f"atoms: {atoms}") 

114 ... print(f"optimizable: {optimizable}") 

115 

116 >>> with warnings.catch_warnings(record=True) as w: 

117 ... warnings.simplefilter("always") 

118 ... function(atoms="atoms") 

119 atoms: None 

120 optimizable: atoms 

121 

122 >>> w[-1].category == DeprecationWarning 

123 True 

124 """ 

125 

126 def deprecated_decorator(func): 

127 @functools.wraps(func) 

128 def deprecated_function(*args, **kwargs): 

129 _args = list(args) 

130 if callback(_args, kwargs): 

131 warnings.warn(message, category=category, stacklevel=2) 

132 

133 return func(*_args, **kwargs) 

134 

135 return deprecated_function 

136 

137 return deprecated_decorator 

138 

139 

140@contextmanager 

141def seterr(**kwargs): 

142 """Set how floating-point errors are handled. 

143 

144 See np.seterr() for more details. 

145 """ 

146 old = np.seterr(**kwargs) 

147 try: 

148 yield 

149 finally: 

150 np.seterr(**old) 

151 

152 

153def plural(n, word): 

154 """Use plural for n!=1. 

155 

156 >>> from ase.utils import plural 

157 

158 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg') 

159 ('0 eggs', '1 egg', '2 eggs') 

160 """ 

161 if n == 1: 

162 return '1 ' + word 

163 return '%d %ss' % (n, word) 

164 

165 

166class DevNull: 

167 encoding = 'UTF-8' 

168 closed = False 

169 

170 _use_os_devnull = deprecated('use open(os.devnull) instead', 

171 DeprecationWarning) 

172 # Deprecated for ase-3.21.0. Change to futurewarning later on. 

173 

174 @_use_os_devnull 

175 def write(self, string): 

176 pass 

177 

178 @_use_os_devnull 

179 def flush(self): 

180 pass 

181 

182 @_use_os_devnull 

183 def seek(self, offset, whence=0): 

184 return 0 

185 

186 @_use_os_devnull 

187 def tell(self): 

188 return 0 

189 

190 @_use_os_devnull 

191 def close(self): 

192 pass 

193 

194 @_use_os_devnull 

195 def isatty(self): 

196 return False 

197 

198 @_use_os_devnull 

199 def read(self, n=-1): 

200 return '' 

201 

202 

203devnull = DevNull() 

204 

205 

206@deprecated('convert_string_to_fd does not facilitate proper resource ' 

207 'management. ' 

208 'Please use e.g. ase.utils.IOContext class instead.') 

209def convert_string_to_fd(name, world=None): 

210 """Create a file-descriptor for text output. 

211 

212 Will open a file for writing with given name. Use None for no output and 

213 '-' for sys.stdout. 

214 

215 .. deprecated:: 3.22.1 

216 Please use e.g. :class:`ase.utils.IOContext` class instead. 

217 """ 

218 if world is None: 

219 from ase.parallel import world 

220 if name is None or world.rank != 0: 

221 return open(os.devnull, 'w') 

222 if name == '-': 

223 return sys.stdout 

224 if isinstance(name, (str, PurePath)): 

225 return open(str(name), 'w') # str for py3.5 pathlib 

226 return name # we assume name is already a file-descriptor 

227 

228 

229# Only Windows has O_BINARY: 

230CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0) 

231 

232 

233@contextmanager 

234def xwopen(filename, world=None): 

235 """Create and open filename exclusively for writing. 

236 

237 If master cpu gets exclusive write access to filename, a file 

238 descriptor is returned (a dummy file descriptor is returned on the 

239 slaves). If the master cpu does not get write access, None is 

240 returned on all processors.""" 

241 

242 fd = opencew(filename, world) 

243 try: 

244 yield fd 

245 finally: 

246 if fd is not None: 

247 fd.close() 

248 

249 

250# @deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak') 

251def opencew(filename, world=None): 

252 return _opencew(filename, world) 

253 

254 

255def _opencew(filename, world=None): 

256 import ase.parallel as parallel 

257 if world is None: 

258 world = parallel.world 

259 

260 closelater = [] 

261 

262 def opener(file, flags): 

263 return os.open(file, flags | CEW_FLAGS) 

264 

265 try: 

266 error = 0 

267 if world.rank == 0: 

268 try: 

269 fd = open(filename, 'wb', opener=opener) 

270 except OSError as ex: 

271 error = ex.errno 

272 else: 

273 closelater.append(fd) 

274 else: 

275 fd = open(os.devnull, 'wb') 

276 closelater.append(fd) 

277 

278 # Synchronize: 

279 error = world.sum_scalar(error) 

280 if error == errno.EEXIST: 

281 return None 

282 if error: 

283 raise OSError(error, 'Error', filename) 

284 

285 return fd 

286 except BaseException: 

287 for fd in closelater: 

288 fd.close() 

289 raise 

290 

291 

292def opencew_text(*args, **kwargs): 

293 fd = opencew(*args, **kwargs) 

294 if fd is None: 

295 return None 

296 return io.TextIOWrapper(fd) 

297 

298 

299class Lock: 

300 def __init__(self, name='lock', world=None, timeout=float('inf')): 

301 self.name = str(name) 

302 self.timeout = timeout 

303 if world is None: 

304 from ase.parallel import world 

305 self.world = world 

306 

307 def acquire(self): 

308 dt = 0.2 

309 t1 = time.time() 

310 while True: 

311 fd = opencew(self.name, self.world) 

312 if fd is not None: 

313 self.fd = fd 

314 break 

315 time_left = self.timeout - (time.time() - t1) 

316 if time_left <= 0: 

317 raise TimeoutError 

318 time.sleep(min(dt, time_left)) 

319 dt *= 2 

320 

321 def release(self): 

322 self.world.barrier() 

323 # Important to close fd before deleting file on windows 

324 # as a WinError would otherwise be raised. 

325 self.fd.close() 

326 if self.world.rank == 0: 

327 os.remove(self.name) 

328 self.world.barrier() 

329 

330 def __enter__(self): 

331 self.acquire() 

332 

333 def __exit__(self, type, value, tb): 

334 self.release() 

335 

336 

337class OpenLock: 

338 def acquire(self): 

339 pass 

340 

341 def release(self): 

342 pass 

343 

344 def __enter__(self): 

345 pass 

346 

347 def __exit__(self, type, value, tb): 

348 pass 

349 

350 

351def search_current_git_hash(arg, world=None): 

352 """Search for .git directory and current git commit hash. 

353 

354 Parameters: 

355 

356 arg: str (directory path) or python module 

357 .git directory is searched from the parent directory of 

358 the given directory or module. 

359 """ 

360 if world is None: 

361 from ase.parallel import world 

362 if world.rank != 0: 

363 return None 

364 

365 # Check argument 

366 if isinstance(arg, str): 

367 # Directory path 

368 dpath = arg 

369 else: 

370 # Assume arg is module 

371 dpath = os.path.dirname(arg.__file__) 

372 # dpath = os.path.abspath(dpath) 

373 # in case this is just symlinked into $PYTHONPATH 

374 dpath = os.path.realpath(dpath) 

375 dpath = os.path.dirname(dpath) # Go to the parent directory 

376 git_dpath = os.path.join(dpath, '.git') 

377 if not os.path.isdir(git_dpath): 

378 # Replace this 'if' with a loop if you want to check 

379 # further parent directories 

380 return None 

381 HEAD_file = os.path.join(git_dpath, 'HEAD') 

382 if not os.path.isfile(HEAD_file): 

383 return None 

384 with open(HEAD_file) as fd: 

385 line = fd.readline().strip() 

386 if line.startswith('ref: '): 

387 ref = line[5:] 

388 ref_file = os.path.join(git_dpath, ref) 

389 else: 

390 # Assuming detached HEAD state 

391 ref_file = HEAD_file 

392 if not os.path.isfile(ref_file): 

393 return None 

394 with open(ref_file) as fd: 

395 line = fd.readline().strip() 

396 if all(c in string.hexdigits for c in line): 

397 return line 

398 return None 

399 

400 

401def rotate(rotations, rotation=np.identity(3)): 

402 """Convert string of format '50x,-10y,120z' to a rotation matrix. 

403 

404 Note that the order of rotation matters, i.e. '50x,40z' is different 

405 from '40z,50x'. 

406 """ 

407 

408 if rotations == '': 

409 return rotation.copy() 

410 

411 for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1]))) 

412 for s in rotations.split(',')]: 

413 s = sin(a) 

414 c = cos(a) 

415 if i == 0: 

416 rotation = np.dot(rotation, [(1, 0, 0), 

417 (0, c, s), 

418 (0, -s, c)]) 

419 elif i == 1: 

420 rotation = np.dot(rotation, [(c, 0, -s), 

421 (0, 1, 0), 

422 (s, 0, c)]) 

423 else: 

424 rotation = np.dot(rotation, [(c, s, 0), 

425 (-s, c, 0), 

426 (0, 0, 1)]) 

427 return rotation 

428 

429 

430def givens(a, b): 

431 """Solve the equation system:: 

432 

433 [ c s] [a] [r] 

434 [ ] . [ ] = [ ] 

435 [-s c] [b] [0] 

436 """ 

437 sgn = np.sign 

438 if b == 0: 

439 c = sgn(a) 

440 s = 0 

441 r = abs(a) 

442 elif abs(b) >= abs(a): 

443 cot = a / b 

444 u = sgn(b) * (1 + cot**2)**0.5 

445 s = 1. / u 

446 c = s * cot 

447 r = b * u 

448 else: 

449 tan = b / a 

450 u = sgn(a) * (1 + tan**2)**0.5 

451 c = 1. / u 

452 s = c * tan 

453 r = a * u 

454 return c, s, r 

455 

456 

457def irotate(rotation, initial=np.identity(3)): 

458 """Determine x, y, z rotation angles from rotation matrix.""" 

459 a = np.dot(initial, rotation) 

460 cx, sx, rx = givens(a[2, 2], a[1, 2]) 

461 cy, sy, ry = givens(rx, a[0, 2]) 

462 cz, sz, rz = givens(cx * a[1, 1] - sx * a[2, 1], 

463 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1])) 

464 x = degrees(atan2(sx, cx)) 

465 y = degrees(atan2(-sy, cy)) 

466 z = degrees(atan2(sz, cz)) 

467 return x, y, z 

468 

469 

470def pbc2pbc(pbc): 

471 newpbc = np.empty(3, bool) 

472 newpbc[:] = pbc 

473 return newpbc 

474 

475 

476def hsv2rgb(h, s, v): 

477 """http://en.wikipedia.org/wiki/HSL_and_HSV 

478 

479 h (hue) in [0, 360[ 

480 s (saturation) in [0, 1] 

481 v (value) in [0, 1] 

482 

483 return rgb in range [0, 1] 

484 """ 

485 if v == 0: 

486 return 0, 0, 0 

487 if s == 0: 

488 return v, v, v 

489 

490 i, f = divmod(h / 60., 1) 

491 p = v * (1 - s) 

492 q = v * (1 - s * f) 

493 t = v * (1 - s * (1 - f)) 

494 

495 if i == 0: 

496 return v, t, p 

497 elif i == 1: 

498 return q, v, p 

499 elif i == 2: 

500 return p, v, t 

501 elif i == 3: 

502 return p, q, v 

503 elif i == 4: 

504 return t, p, v 

505 elif i == 5: 

506 return v, p, q 

507 else: 

508 raise RuntimeError('h must be in [0, 360]') 

509 

510 

511def hsv(array, s=.9, v=.9): 

512 array = (array + array.min()) * 359. / (array.max() - array.min()) 

513 result = np.empty((len(array.flat), 3)) 

514 for rgb, h in zip(result, array.flat): 

515 rgb[:] = hsv2rgb(h, s, v) 

516 return np.reshape(result, array.shape + (3,)) 

517 

518 

519# This code does the same, but requires pylab 

520# def cmap(array, name='hsv'): 

521# import pylab 

522# a = (array + array.min()) / array.ptp() 

523# rgba = getattr(pylab.cm, name)(a) 

524# return rgba[:-1] # return rgb only (not alpha) 

525 

526 

527def longsum(x): 

528 """128-bit floating point sum.""" 

529 return float(np.asarray(x, dtype=np.longdouble).sum()) 

530 

531 

532@contextmanager 

533def workdir(path, mkdir=False): 

534 """Temporarily change, and optionally create, working directory.""" 

535 path = Path(path) 

536 if mkdir: 

537 path.mkdir(parents=True, exist_ok=True) 

538 

539 olddir = os.getcwd() 

540 os.chdir(path) 

541 try: 

542 yield # Yield the Path or dirname maybe? 

543 finally: 

544 os.chdir(olddir) 

545 

546 

547class iofunction: 

548 """Decorate func so it accepts either str or file. 

549 

550 (Won't work on functions that return a generator.)""" 

551 

552 def __init__(self, mode): 

553 self.mode = mode 

554 

555 def __call__(self, func): 

556 @functools.wraps(func) 

557 def iofunc(file, *args, **kwargs): 

558 openandclose = isinstance(file, (str, PurePath)) 

559 fd = None 

560 try: 

561 if openandclose: 

562 fd = open(str(file), self.mode) 

563 else: 

564 fd = file 

565 obj = func(fd, *args, **kwargs) 

566 return obj 

567 finally: 

568 if openandclose and fd is not None: 

569 # fd may be None if open() failed 

570 fd.close() 

571 return iofunc 

572 

573 

574def writer(func): 

575 return iofunction('w')(func) 

576 

577 

578def reader(func): 

579 return iofunction('r')(func) 

580 

581 

582# The next two functions are for hotplugging into a JSONable class 

583# using the jsonable decorator. We are supposed to have this kind of stuff 

584# in ase.io.jsonio, but we'd rather import them from a 'basic' module 

585# like ase/utils than one which triggers a lot of extra (cyclic) imports. 

586 

587def write_json(self, fd): 

588 """Write to JSON file.""" 

589 from ase.io.jsonio import write_json as _write_json 

590 _write_json(fd, self) 

591 

592 

593@classmethod # type: ignore[misc] 

594def read_json(cls, fd): 

595 """Read new instance from JSON file.""" 

596 from ase.io.jsonio import read_json as _read_json 

597 obj = _read_json(fd) 

598 assert isinstance(obj, cls) 

599 return obj 

600 

601 

602def jsonable(name): 

603 """Decorator for facilitating JSON I/O with a class. 

604 

605 Pokes JSON-based read and write functions into the class. 

606 

607 In order to write an object to JSON, it needs to be a known simple type 

608 (such as ndarray, float, ...) or implement todict(). If the class 

609 defines a string called ase_objtype, the decoder will want to convert 

610 the object back into its original type when reading.""" 

611 def jsonableclass(cls): 

612 cls.ase_objtype = name 

613 if not hasattr(cls, 'todict'): 

614 raise TypeError('Class must implement todict()') 

615 

616 # We may want the write and read to be optional. 

617 # E.g. a calculator might want to be JSONable, but not 

618 # that .write() produces a JSON file. 

619 # 

620 # This is mostly for 'lightweight' object IO. 

621 cls.write = write_json 

622 cls.read = read_json 

623 return cls 

624 return jsonableclass 

625 

626 

627class ExperimentalFeatureWarning(Warning): 

628 pass 

629 

630 

631def experimental(func): 

632 """Decorator for functions not ready for production use.""" 

633 @functools.wraps(func) 

634 def expfunc(*args, **kwargs): 

635 warnings.warn('This function may change or misbehave: {}()' 

636 .format(func.__qualname__), 

637 ExperimentalFeatureWarning) 

638 return func(*args, **kwargs) 

639 return expfunc 

640 

641 

642def lazymethod(meth): 

643 """Decorator for lazy evaluation and caching of data. 

644 

645 Example:: 

646 

647 class MyClass: 

648 

649 @lazymethod 

650 def thing(self): 

651 return expensive_calculation() 

652 

653 The method body is only executed first time thing() is called, and 

654 its return value is stored. Subsequent calls return the cached 

655 value.""" 

656 name = meth.__name__ 

657 

658 @functools.wraps(meth) 

659 def getter(self): 

660 try: 

661 cache = self._lazy_cache 

662 except AttributeError: 

663 cache = self._lazy_cache = {} 

664 

665 if name not in cache: 

666 cache[name] = meth(self) 

667 return cache[name] 

668 return getter 

669 

670 

671def atoms_to_spglib_cell(atoms): 

672 """Convert atoms into data suitable for calling spglib.""" 

673 return (atoms.get_cell(), 

674 atoms.get_scaled_positions(), 

675 atoms.get_atomic_numbers()) 

676 

677 

678def warn_legacy(feature_name): 

679 warnings.warn( 

680 f'The {feature_name} feature is untested and ASE developers do not ' 

681 'know whether it works or how to use it. Please rehabilitate it ' 

682 '(by writing unittests) or it may be removed.', 

683 FutureWarning) 

684 

685 

686def lazyproperty(meth): 

687 """Decorator like lazymethod, but making item available as a property.""" 

688 return property(lazymethod(meth)) 

689 

690 

691class IOContext: 

692 @lazyproperty 

693 def _exitstack(self): 

694 return ExitStack() 

695 

696 def __enter__(self): 

697 return self 

698 

699 def __exit__(self, *args): 

700 self.close() 

701 

702 def closelater(self, fd): 

703 return self._exitstack.enter_context(fd) 

704 

705 def close(self): 

706 self._exitstack.close() 

707 

708 def openfile(self, file, comm=None, mode='w'): 

709 from ase.parallel import world 

710 if comm is None: 

711 comm = world 

712 

713 if hasattr(file, 'close'): 

714 return file # File already opened, not for us to close. 

715 

716 encoding = None if mode.endswith('b') else 'utf-8' 

717 

718 if file is None or comm.rank != 0: 

719 return self.closelater(open(os.devnull, mode=mode, 

720 encoding=encoding)) 

721 

722 if file == '-': 

723 return sys.stdout 

724 

725 return self.closelater(open(file, mode=mode, encoding=encoding)) 

726 

727 

728def get_python_package_path_description( 

729 package, default='module has no path') -> str: 

730 """Helper to get path description of a python package/module 

731 

732 If path has multiple elements, the first one is returned. 

733 If it is empty, the default is returned. 

734 Exceptions are returned as strings default+(exception). 

735 Always returns a string. 

736 """ 

737 try: 

738 p = list(package.__path__) 

739 if p: 

740 return str(p[0]) 

741 else: 

742 return default 

743 except Exception as ex: 

744 return f"{default} ({ex})"