Coverage for /builds/kinetik161/ase/ase/spacegroup/utils.py: 90.16%
61 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1from typing import List
3import numpy as np
5from ase import Atoms
7from .spacegroup import _SPACEGROUP, Spacegroup
9__all__ = ('get_basis', )
12def _has_spglib() -> bool:
13 """Check if spglib is available"""
14 try:
15 import spglib
16 assert spglib # silence flakes
17 except ImportError:
18 return False
19 return True
22def _get_basis_ase(atoms: Atoms,
23 spacegroup: _SPACEGROUP,
24 tol: float = 1e-5) -> np.ndarray:
25 """Recursively get a reduced basis, by removing equivalent sites.
26 Uses the first index as a basis, then removes all equivalent sites,
27 uses the next index which hasn't been placed into a basis, etc.
29 :param atoms: Atoms object to get basis from.
30 :param spacegroup: ``int``, ``str``, or
31 :class:`ase.spacegroup.Spacegroup` object.
32 :param tol: ``float``, numeric tolerance for positional comparisons
33 Default: ``1e-5``
34 """
35 scaled_positions = atoms.get_scaled_positions()
36 spacegroup = Spacegroup(spacegroup)
38 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray):
39 """Check if a scaled position is in a site"""
40 for site in sites:
41 if np.allclose(site, scaled_pos, atol=tol):
42 return True
43 return False
45 def _get_basis(scaled_positions: np.ndarray,
46 spacegroup: Spacegroup,
47 all_basis=None) -> np.ndarray:
48 """Main recursive function to be executed"""
49 if all_basis is None:
50 # Initialization, first iteration
51 all_basis = []
52 if len(scaled_positions) == 0:
53 # End termination
54 return np.array(all_basis)
56 basis = scaled_positions[0]
57 all_basis.append(basis.tolist()) # Add the site as a basis
59 # Get equivalent sites
60 sites, _ = spacegroup.equivalent_sites(basis)
62 # Remove equivalent
63 new_scaled = np.array(
64 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)])
65 # We should always have at least popped off the site itself
66 assert len(new_scaled) < len(scaled_positions)
68 return _get_basis(new_scaled, spacegroup, all_basis=all_basis)
70 return _get_basis(scaled_positions, spacegroup)
73def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray:
74 """Get a reduced basis using spglib. This requires having the
75 spglib package installed.
77 :param atoms: Atoms, atoms object to get basis from
78 :param tol: ``float``, numeric tolerance for positional comparisons
79 Default: ``1e-5``
80 """
81 if not _has_spglib():
82 # Give a reasonable alternative solution to this function.
83 raise ImportError(
84 'This function requires spglib. Use "get_basis" and specify '
85 'the spacegroup instead, or install spglib.')
87 scaled_positions = atoms.get_scaled_positions()
88 reduced_indices = _get_reduced_indices(atoms, tol=tol)
89 return scaled_positions[reduced_indices]
92def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool:
93 """Helper dispatch function, for deciding if the spglib implementation
94 can be used"""
95 if not _has_spglib():
96 # Spglib not installed
97 return False
98 if spacegroup is not None:
99 # Currently, passing an explicit space group is not supported
100 # in spglib implementation
101 return False
102 return True
105# Dispatcher function for chosing get_basis implementation.
106def get_basis(atoms: Atoms,
107 spacegroup: _SPACEGROUP = None,
108 method: str = 'auto',
109 tol: float = 1e-5) -> np.ndarray:
110 """Function for determining a reduced basis of an atoms object.
111 Can use either an ASE native algorithm or an spglib based one.
112 The native ASE version requires specifying a space group,
113 while the (current) spglib version cannot.
114 The default behavior is to automatically determine which implementation
115 to use, based on the the ``spacegroup`` parameter,
116 and whether spglib is installed.
118 :param atoms: ase Atoms object to get basis from
119 :param spacegroup: Optional, ``int``, ``str``
120 or :class:`ase.spacegroup.Spacegroup` object.
121 If unspecified, the spacegroup can be inferred using spglib,
122 if spglib is installed, and ``method`` is set to either
123 ``'spglib'`` or ``'auto'``.
124 Inferring the spacegroup requires spglib.
125 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``.
126 Selection of which implementation to use.
127 It is recommended to use ``'auto'``, which is also the default.
128 :param tol: ``float``, numeric tolerance for positional comparisons
129 Default: ``1e-5``
130 """
131 ALLOWED_METHODS = ('auto', 'ase', 'spglib')
133 if method not in ALLOWED_METHODS:
134 raise ValueError('Expected one of {} methods, got {}'.format(
135 ALLOWED_METHODS, method))
137 if method == 'auto':
138 # Figure out which implementation we want to use automatically
139 # Essentially figure out if we can use the spglib version or not
140 use_spglib = _can_use_spglib(spacegroup=spacegroup)
141 else:
142 # User told us which implementation they wanted
143 use_spglib = method == 'spglib'
145 if use_spglib:
146 # Use the spglib implementation
147 # Note, we do not pass the spacegroup, as the function cannot handle
148 # an explicit space group right now. This may change in the future.
149 return _get_basis_spglib(atoms, tol=tol)
150 else:
151 # Use the ASE native non-spglib version, since a specific
152 # space group is requested
153 if spacegroup is None:
154 # We have reached this point either because spglib is not installed,
155 # or ASE was explicitly required
156 raise ValueError(
157 'A space group must be specified for the native ASE '
158 'implementation. Try using the spglib version instead, '
159 'or explicitly specifying a space group.')
160 return _get_basis_ase(atoms, spacegroup, tol=tol)
163def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]:
164 """Get a list of the reduced atomic indices using spglib.
165 Note: Does no checks to see if spglib is installed.
167 :param atoms: ase Atoms object to reduce
168 :param tol: ``float``, numeric tolerance for positional comparisons
169 """
170 import spglib
172 # Create input for spglib
173 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(),
174 atoms.numbers)
175 symmetry_data = spglib.get_symmetry_dataset(spglib_cell, symprec=tol)
176 return list(set(symmetry_data['equivalent_atoms']))