Coverage for /builds/kinetik161/ase/ase/utils/__init__.py: 83.91%
379 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import errno
2import functools
3import io
4import os
5import pickle
6import re
7import string
8import sys
9import time
10import warnings
11from contextlib import ExitStack, contextmanager
12from importlib import import_module
13from math import atan2, cos, degrees, gcd, radians, sin
14from pathlib import Path, PurePath
15from typing import Callable, Dict, List, Type, Union
17import numpy as np
19from ase.formula import formula_hill, formula_metal
21__all__ = ['basestring', 'import_module', 'seterr', 'plural',
22 'devnull', 'gcd', 'convert_string_to_fd', 'Lock',
23 'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens',
24 'hsv2rgb', 'hsv', 'pickleload', 'reader',
25 'formula_hill', 'formula_metal', 'PurePath', 'xwopen',
26 'tokenize_version', 'get_python_package_path_description']
29def tokenize_version(version_string: str):
30 """Parse version string into a tuple for version comparisons.
32 Usage: tokenize_version('3.8') < tokenize_version('3.8.1').
33 """
34 tokens = []
35 for component in version_string.split('.'):
36 match = re.match(r'(\d*)(.*)', component)
37 assert match is not None, f'Cannot parse component {component}'
38 number_str, tail = match.group(1, 2)
39 try:
40 number = int(number_str)
41 except ValueError:
42 number = -1
43 tokens += [number, tail]
44 return tuple(tokens)
47# Python 2+3 compatibility stuff (let's try to remove these things):
48basestring = str
49pickleload = functools.partial(pickle.load, encoding='bytes')
52def deprecated(
53 message: Union[str, Warning],
54 category: Type[Warning] = FutureWarning,
55 callback: Callable[[List, Dict], bool] = lambda args, kwargs: True
56):
57 """Return a decorator deprecating a function.
59 Parameters
60 ----------
61 message : str or Warning
62 The message to be emitted. If ``message`` is a Warning, then
63 ``category`` is ignored and ``message.__class__`` will be used.
64 category : Type[Warning], default=FutureWarning
65 The type of warning to be emitted. If ``message`` is a ``Warning``
66 instance, then ``category`` will be ignored and ``message.__class__``
67 will be used.
68 callback : Callable[[List, Dict], bool], default=lambda args, kwargs: True
69 A callable that determines if the warning should be emitted and handles
70 any processing prior to calling the deprecated function. The callable
71 will receive two arguments, a list and a dictionary. The list will
72 contain the positional arguments that the deprecated function was
73 called with at runtime while the dictionary will contain the keyword
74 arguments. The callable *must* return ``True`` if the warning is to be
75 emitted and ``False`` otherwise. The list and dictionary will be
76 unpacked into the positional and keyword arguments, respectively, used
77 to call the deprecated function.
79 Returns
80 -------
81 deprecated_decorator : Callable
82 A decorator for deprecated functions that can be used to conditionally
83 emit deprecation warnings and/or pre-process the arguments of a
84 deprecated function.
86 Example
87 -------
88 >>> # Inspect & replace a keyword parameter passed to a deprecated function
89 >>> from typing import Any, Callable, Dict, List
90 >>> import warnings
91 >>> from ase.utils import deprecated
93 >>> def alias_callback_factory(kwarg: str, alias: str) -> Callable:
94 ... def _replace_arg(_: List, kwargs: Dict[str, Any]) -> bool:
95 ... kwargs[kwarg] = kwargs[alias]
96 ... del kwargs[alias]
97 ... return True
98 ... return _replace_arg
100 >>> MESSAGE = ("Calling this function with `atoms` is deprecated. "
101 ... "Use `optimizable` instead.")
102 >>> @deprecated(
103 ... MESSAGE,
104 ... category=DeprecationWarning,
105 ... callback=alias_callback_factory("optimizable", "atoms")
106 ... )
107 ... def function(*, atoms=None, optimizable=None):
108 ... '''
109 ... .. deprecated:: 3.23.0
110 ... Calling this function with ``atoms`` is deprecated.
111 ... Use ``optimizable`` instead.
112 ... '''
113 ... print(f"atoms: {atoms}")
114 ... print(f"optimizable: {optimizable}")
116 >>> with warnings.catch_warnings(record=True) as w:
117 ... warnings.simplefilter("always")
118 ... function(atoms="atoms")
119 atoms: None
120 optimizable: atoms
122 >>> w[-1].category == DeprecationWarning
123 True
124 """
126 def deprecated_decorator(func):
127 @functools.wraps(func)
128 def deprecated_function(*args, **kwargs):
129 _args = list(args)
130 if callback(_args, kwargs):
131 warnings.warn(message, category=category, stacklevel=2)
133 return func(*_args, **kwargs)
135 return deprecated_function
137 return deprecated_decorator
140@contextmanager
141def seterr(**kwargs):
142 """Set how floating-point errors are handled.
144 See np.seterr() for more details.
145 """
146 old = np.seterr(**kwargs)
147 try:
148 yield
149 finally:
150 np.seterr(**old)
153def plural(n, word):
154 """Use plural for n!=1.
156 >>> from ase.utils import plural
158 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg')
159 ('0 eggs', '1 egg', '2 eggs')
160 """
161 if n == 1:
162 return '1 ' + word
163 return '%d %ss' % (n, word)
166class DevNull:
167 encoding = 'UTF-8'
168 closed = False
170 _use_os_devnull = deprecated('use open(os.devnull) instead',
171 DeprecationWarning)
172 # Deprecated for ase-3.21.0. Change to futurewarning later on.
174 @_use_os_devnull
175 def write(self, string):
176 pass
178 @_use_os_devnull
179 def flush(self):
180 pass
182 @_use_os_devnull
183 def seek(self, offset, whence=0):
184 return 0
186 @_use_os_devnull
187 def tell(self):
188 return 0
190 @_use_os_devnull
191 def close(self):
192 pass
194 @_use_os_devnull
195 def isatty(self):
196 return False
198 @_use_os_devnull
199 def read(self, n=-1):
200 return ''
203devnull = DevNull()
206@deprecated('convert_string_to_fd does not facilitate proper resource '
207 'management. '
208 'Please use e.g. ase.utils.IOContext class instead.')
209def convert_string_to_fd(name, world=None):
210 """Create a file-descriptor for text output.
212 Will open a file for writing with given name. Use None for no output and
213 '-' for sys.stdout.
215 .. deprecated:: 3.22.1
216 Please use e.g. :class:`ase.utils.IOContext` class instead.
217 """
218 if world is None:
219 from ase.parallel import world
220 if name is None or world.rank != 0:
221 return open(os.devnull, 'w')
222 if name == '-':
223 return sys.stdout
224 if isinstance(name, (str, PurePath)):
225 return open(str(name), 'w') # str for py3.5 pathlib
226 return name # we assume name is already a file-descriptor
229# Only Windows has O_BINARY:
230CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0)
233@contextmanager
234def xwopen(filename, world=None):
235 """Create and open filename exclusively for writing.
237 If master cpu gets exclusive write access to filename, a file
238 descriptor is returned (a dummy file descriptor is returned on the
239 slaves). If the master cpu does not get write access, None is
240 returned on all processors."""
242 fd = opencew(filename, world)
243 try:
244 yield fd
245 finally:
246 if fd is not None:
247 fd.close()
250# @deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak')
251def opencew(filename, world=None):
252 return _opencew(filename, world)
255def _opencew(filename, world=None):
256 import ase.parallel as parallel
257 if world is None:
258 world = parallel.world
260 closelater = []
262 def opener(file, flags):
263 return os.open(file, flags | CEW_FLAGS)
265 try:
266 error = 0
267 if world.rank == 0:
268 try:
269 fd = open(filename, 'wb', opener=opener)
270 except OSError as ex:
271 error = ex.errno
272 else:
273 closelater.append(fd)
274 else:
275 fd = open(os.devnull, 'wb')
276 closelater.append(fd)
278 # Synchronize:
279 error = world.sum_scalar(error)
280 if error == errno.EEXIST:
281 return None
282 if error:
283 raise OSError(error, 'Error', filename)
285 return fd
286 except BaseException:
287 for fd in closelater:
288 fd.close()
289 raise
292def opencew_text(*args, **kwargs):
293 fd = opencew(*args, **kwargs)
294 if fd is None:
295 return None
296 return io.TextIOWrapper(fd)
299class Lock:
300 def __init__(self, name='lock', world=None, timeout=float('inf')):
301 self.name = str(name)
302 self.timeout = timeout
303 if world is None:
304 from ase.parallel import world
305 self.world = world
307 def acquire(self):
308 dt = 0.2
309 t1 = time.time()
310 while True:
311 fd = opencew(self.name, self.world)
312 if fd is not None:
313 self.fd = fd
314 break
315 time_left = self.timeout - (time.time() - t1)
316 if time_left <= 0:
317 raise TimeoutError
318 time.sleep(min(dt, time_left))
319 dt *= 2
321 def release(self):
322 self.world.barrier()
323 # Important to close fd before deleting file on windows
324 # as a WinError would otherwise be raised.
325 self.fd.close()
326 if self.world.rank == 0:
327 os.remove(self.name)
328 self.world.barrier()
330 def __enter__(self):
331 self.acquire()
333 def __exit__(self, type, value, tb):
334 self.release()
337class OpenLock:
338 def acquire(self):
339 pass
341 def release(self):
342 pass
344 def __enter__(self):
345 pass
347 def __exit__(self, type, value, tb):
348 pass
351def search_current_git_hash(arg, world=None):
352 """Search for .git directory and current git commit hash.
354 Parameters:
356 arg: str (directory path) or python module
357 .git directory is searched from the parent directory of
358 the given directory or module.
359 """
360 if world is None:
361 from ase.parallel import world
362 if world.rank != 0:
363 return None
365 # Check argument
366 if isinstance(arg, str):
367 # Directory path
368 dpath = arg
369 else:
370 # Assume arg is module
371 dpath = os.path.dirname(arg.__file__)
372 # dpath = os.path.abspath(dpath)
373 # in case this is just symlinked into $PYTHONPATH
374 dpath = os.path.realpath(dpath)
375 dpath = os.path.dirname(dpath) # Go to the parent directory
376 git_dpath = os.path.join(dpath, '.git')
377 if not os.path.isdir(git_dpath):
378 # Replace this 'if' with a loop if you want to check
379 # further parent directories
380 return None
381 HEAD_file = os.path.join(git_dpath, 'HEAD')
382 if not os.path.isfile(HEAD_file):
383 return None
384 with open(HEAD_file) as fd:
385 line = fd.readline().strip()
386 if line.startswith('ref: '):
387 ref = line[5:]
388 ref_file = os.path.join(git_dpath, ref)
389 else:
390 # Assuming detached HEAD state
391 ref_file = HEAD_file
392 if not os.path.isfile(ref_file):
393 return None
394 with open(ref_file) as fd:
395 line = fd.readline().strip()
396 if all(c in string.hexdigits for c in line):
397 return line
398 return None
401def rotate(rotations, rotation=np.identity(3)):
402 """Convert string of format '50x,-10y,120z' to a rotation matrix.
404 Note that the order of rotation matters, i.e. '50x,40z' is different
405 from '40z,50x'.
406 """
408 if rotations == '':
409 return rotation.copy()
411 for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1])))
412 for s in rotations.split(',')]:
413 s = sin(a)
414 c = cos(a)
415 if i == 0:
416 rotation = np.dot(rotation, [(1, 0, 0),
417 (0, c, s),
418 (0, -s, c)])
419 elif i == 1:
420 rotation = np.dot(rotation, [(c, 0, -s),
421 (0, 1, 0),
422 (s, 0, c)])
423 else:
424 rotation = np.dot(rotation, [(c, s, 0),
425 (-s, c, 0),
426 (0, 0, 1)])
427 return rotation
430def givens(a, b):
431 """Solve the equation system::
433 [ c s] [a] [r]
434 [ ] . [ ] = [ ]
435 [-s c] [b] [0]
436 """
437 sgn = np.sign
438 if b == 0:
439 c = sgn(a)
440 s = 0
441 r = abs(a)
442 elif abs(b) >= abs(a):
443 cot = a / b
444 u = sgn(b) * (1 + cot**2)**0.5
445 s = 1. / u
446 c = s * cot
447 r = b * u
448 else:
449 tan = b / a
450 u = sgn(a) * (1 + tan**2)**0.5
451 c = 1. / u
452 s = c * tan
453 r = a * u
454 return c, s, r
457def irotate(rotation, initial=np.identity(3)):
458 """Determine x, y, z rotation angles from rotation matrix."""
459 a = np.dot(initial, rotation)
460 cx, sx, rx = givens(a[2, 2], a[1, 2])
461 cy, sy, ry = givens(rx, a[0, 2])
462 cz, sz, rz = givens(cx * a[1, 1] - sx * a[2, 1],
463 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1]))
464 x = degrees(atan2(sx, cx))
465 y = degrees(atan2(-sy, cy))
466 z = degrees(atan2(sz, cz))
467 return x, y, z
470def pbc2pbc(pbc):
471 newpbc = np.empty(3, bool)
472 newpbc[:] = pbc
473 return newpbc
476def hsv2rgb(h, s, v):
477 """http://en.wikipedia.org/wiki/HSL_and_HSV
479 h (hue) in [0, 360[
480 s (saturation) in [0, 1]
481 v (value) in [0, 1]
483 return rgb in range [0, 1]
484 """
485 if v == 0:
486 return 0, 0, 0
487 if s == 0:
488 return v, v, v
490 i, f = divmod(h / 60., 1)
491 p = v * (1 - s)
492 q = v * (1 - s * f)
493 t = v * (1 - s * (1 - f))
495 if i == 0:
496 return v, t, p
497 elif i == 1:
498 return q, v, p
499 elif i == 2:
500 return p, v, t
501 elif i == 3:
502 return p, q, v
503 elif i == 4:
504 return t, p, v
505 elif i == 5:
506 return v, p, q
507 else:
508 raise RuntimeError('h must be in [0, 360]')
511def hsv(array, s=.9, v=.9):
512 array = (array + array.min()) * 359. / (array.max() - array.min())
513 result = np.empty((len(array.flat), 3))
514 for rgb, h in zip(result, array.flat):
515 rgb[:] = hsv2rgb(h, s, v)
516 return np.reshape(result, array.shape + (3,))
519# This code does the same, but requires pylab
520# def cmap(array, name='hsv'):
521# import pylab
522# a = (array + array.min()) / array.ptp()
523# rgba = getattr(pylab.cm, name)(a)
524# return rgba[:-1] # return rgb only (not alpha)
527def longsum(x):
528 """128-bit floating point sum."""
529 return float(np.asarray(x, dtype=np.longdouble).sum())
532@contextmanager
533def workdir(path, mkdir=False):
534 """Temporarily change, and optionally create, working directory."""
535 path = Path(path)
536 if mkdir:
537 path.mkdir(parents=True, exist_ok=True)
539 olddir = os.getcwd()
540 os.chdir(path)
541 try:
542 yield # Yield the Path or dirname maybe?
543 finally:
544 os.chdir(olddir)
547class iofunction:
548 """Decorate func so it accepts either str or file.
550 (Won't work on functions that return a generator.)"""
552 def __init__(self, mode):
553 self.mode = mode
555 def __call__(self, func):
556 @functools.wraps(func)
557 def iofunc(file, *args, **kwargs):
558 openandclose = isinstance(file, (str, PurePath))
559 fd = None
560 try:
561 if openandclose:
562 fd = open(str(file), self.mode)
563 else:
564 fd = file
565 obj = func(fd, *args, **kwargs)
566 return obj
567 finally:
568 if openandclose and fd is not None:
569 # fd may be None if open() failed
570 fd.close()
571 return iofunc
574def writer(func):
575 return iofunction('w')(func)
578def reader(func):
579 return iofunction('r')(func)
582# The next two functions are for hotplugging into a JSONable class
583# using the jsonable decorator. We are supposed to have this kind of stuff
584# in ase.io.jsonio, but we'd rather import them from a 'basic' module
585# like ase/utils than one which triggers a lot of extra (cyclic) imports.
587def write_json(self, fd):
588 """Write to JSON file."""
589 from ase.io.jsonio import write_json as _write_json
590 _write_json(fd, self)
593@classmethod # type: ignore[misc]
594def read_json(cls, fd):
595 """Read new instance from JSON file."""
596 from ase.io.jsonio import read_json as _read_json
597 obj = _read_json(fd)
598 assert isinstance(obj, cls)
599 return obj
602def jsonable(name):
603 """Decorator for facilitating JSON I/O with a class.
605 Pokes JSON-based read and write functions into the class.
607 In order to write an object to JSON, it needs to be a known simple type
608 (such as ndarray, float, ...) or implement todict(). If the class
609 defines a string called ase_objtype, the decoder will want to convert
610 the object back into its original type when reading."""
611 def jsonableclass(cls):
612 cls.ase_objtype = name
613 if not hasattr(cls, 'todict'):
614 raise TypeError('Class must implement todict()')
616 # We may want the write and read to be optional.
617 # E.g. a calculator might want to be JSONable, but not
618 # that .write() produces a JSON file.
619 #
620 # This is mostly for 'lightweight' object IO.
621 cls.write = write_json
622 cls.read = read_json
623 return cls
624 return jsonableclass
627class ExperimentalFeatureWarning(Warning):
628 pass
631def experimental(func):
632 """Decorator for functions not ready for production use."""
633 @functools.wraps(func)
634 def expfunc(*args, **kwargs):
635 warnings.warn('This function may change or misbehave: {}()'
636 .format(func.__qualname__),
637 ExperimentalFeatureWarning)
638 return func(*args, **kwargs)
639 return expfunc
642def lazymethod(meth):
643 """Decorator for lazy evaluation and caching of data.
645 Example::
647 class MyClass:
649 @lazymethod
650 def thing(self):
651 return expensive_calculation()
653 The method body is only executed first time thing() is called, and
654 its return value is stored. Subsequent calls return the cached
655 value."""
656 name = meth.__name__
658 @functools.wraps(meth)
659 def getter(self):
660 try:
661 cache = self._lazy_cache
662 except AttributeError:
663 cache = self._lazy_cache = {}
665 if name not in cache:
666 cache[name] = meth(self)
667 return cache[name]
668 return getter
671def atoms_to_spglib_cell(atoms):
672 """Convert atoms into data suitable for calling spglib."""
673 return (atoms.get_cell(),
674 atoms.get_scaled_positions(),
675 atoms.get_atomic_numbers())
678def warn_legacy(feature_name):
679 warnings.warn(
680 f'The {feature_name} feature is untested and ASE developers do not '
681 'know whether it works or how to use it. Please rehabilitate it '
682 '(by writing unittests) or it may be removed.',
683 FutureWarning)
686def lazyproperty(meth):
687 """Decorator like lazymethod, but making item available as a property."""
688 return property(lazymethod(meth))
691class IOContext:
692 @lazyproperty
693 def _exitstack(self):
694 return ExitStack()
696 def __enter__(self):
697 return self
699 def __exit__(self, *args):
700 self.close()
702 def closelater(self, fd):
703 return self._exitstack.enter_context(fd)
705 def close(self):
706 self._exitstack.close()
708 def openfile(self, file, comm=None, mode='w'):
709 from ase.parallel import world
710 if comm is None:
711 comm = world
713 if hasattr(file, 'close'):
714 return file # File already opened, not for us to close.
716 encoding = None if mode.endswith('b') else 'utf-8'
718 if file is None or comm.rank != 0:
719 return self.closelater(open(os.devnull, mode=mode,
720 encoding=encoding))
722 if file == '-':
723 return sys.stdout
725 return self.closelater(open(file, mode=mode, encoding=encoding))
728def get_python_package_path_description(
729 package, default='module has no path') -> str:
730 """Helper to get path description of a python package/module
732 If path has multiple elements, the first one is returned.
733 If it is empty, the default is returned.
734 Exceptions are returned as strings default+(exception).
735 Always returns a string.
736 """
737 try:
738 p = list(package.__path__)
739 if p:
740 return str(p[0])
741 else:
742 return default
743 except Exception as ex:
744 return f"{default} ({ex})"