Coverage for /builds/kinetik161/ase/ase/spacegroup/utils.py: 90.16%

61 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1from typing import List 

2 

3import numpy as np 

4 

5from ase import Atoms 

6 

7from .spacegroup import _SPACEGROUP, Spacegroup 

8 

9__all__ = ('get_basis', ) 

10 

11 

12def _has_spglib() -> bool: 

13 """Check if spglib is available""" 

14 try: 

15 import spglib 

16 assert spglib # silence flakes 

17 except ImportError: 

18 return False 

19 return True 

20 

21 

22def _get_basis_ase(atoms: Atoms, 

23 spacegroup: _SPACEGROUP, 

24 tol: float = 1e-5) -> np.ndarray: 

25 """Recursively get a reduced basis, by removing equivalent sites. 

26 Uses the first index as a basis, then removes all equivalent sites, 

27 uses the next index which hasn't been placed into a basis, etc. 

28 

29 :param atoms: Atoms object to get basis from. 

30 :param spacegroup: ``int``, ``str``, or 

31 :class:`ase.spacegroup.Spacegroup` object. 

32 :param tol: ``float``, numeric tolerance for positional comparisons 

33 Default: ``1e-5`` 

34 """ 

35 scaled_positions = atoms.get_scaled_positions() 

36 spacegroup = Spacegroup(spacegroup) 

37 

38 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray): 

39 """Check if a scaled position is in a site""" 

40 for site in sites: 

41 if np.allclose(site, scaled_pos, atol=tol): 

42 return True 

43 return False 

44 

45 def _get_basis(scaled_positions: np.ndarray, 

46 spacegroup: Spacegroup, 

47 all_basis=None) -> np.ndarray: 

48 """Main recursive function to be executed""" 

49 if all_basis is None: 

50 # Initialization, first iteration 

51 all_basis = [] 

52 if len(scaled_positions) == 0: 

53 # End termination 

54 return np.array(all_basis) 

55 

56 basis = scaled_positions[0] 

57 all_basis.append(basis.tolist()) # Add the site as a basis 

58 

59 # Get equivalent sites 

60 sites, _ = spacegroup.equivalent_sites(basis) 

61 

62 # Remove equivalent 

63 new_scaled = np.array( 

64 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)]) 

65 # We should always have at least popped off the site itself 

66 assert len(new_scaled) < len(scaled_positions) 

67 

68 return _get_basis(new_scaled, spacegroup, all_basis=all_basis) 

69 

70 return _get_basis(scaled_positions, spacegroup) 

71 

72 

73def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray: 

74 """Get a reduced basis using spglib. This requires having the 

75 spglib package installed. 

76 

77 :param atoms: Atoms, atoms object to get basis from 

78 :param tol: ``float``, numeric tolerance for positional comparisons 

79 Default: ``1e-5`` 

80 """ 

81 if not _has_spglib(): 

82 # Give a reasonable alternative solution to this function. 

83 raise ImportError( 

84 'This function requires spglib. Use "get_basis" and specify ' 

85 'the spacegroup instead, or install spglib.') 

86 

87 scaled_positions = atoms.get_scaled_positions() 

88 reduced_indices = _get_reduced_indices(atoms, tol=tol) 

89 return scaled_positions[reduced_indices] 

90 

91 

92def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool: 

93 """Helper dispatch function, for deciding if the spglib implementation 

94 can be used""" 

95 if not _has_spglib(): 

96 # Spglib not installed 

97 return False 

98 if spacegroup is not None: 

99 # Currently, passing an explicit space group is not supported 

100 # in spglib implementation 

101 return False 

102 return True 

103 

104 

105# Dispatcher function for chosing get_basis implementation. 

106def get_basis(atoms: Atoms, 

107 spacegroup: _SPACEGROUP = None, 

108 method: str = 'auto', 

109 tol: float = 1e-5) -> np.ndarray: 

110 """Function for determining a reduced basis of an atoms object. 

111 Can use either an ASE native algorithm or an spglib based one. 

112 The native ASE version requires specifying a space group, 

113 while the (current) spglib version cannot. 

114 The default behavior is to automatically determine which implementation 

115 to use, based on the the ``spacegroup`` parameter, 

116 and whether spglib is installed. 

117 

118 :param atoms: ase Atoms object to get basis from 

119 :param spacegroup: Optional, ``int``, ``str`` 

120 or :class:`ase.spacegroup.Spacegroup` object. 

121 If unspecified, the spacegroup can be inferred using spglib, 

122 if spglib is installed, and ``method`` is set to either 

123 ``'spglib'`` or ``'auto'``. 

124 Inferring the spacegroup requires spglib. 

125 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``. 

126 Selection of which implementation to use. 

127 It is recommended to use ``'auto'``, which is also the default. 

128 :param tol: ``float``, numeric tolerance for positional comparisons 

129 Default: ``1e-5`` 

130 """ 

131 ALLOWED_METHODS = ('auto', 'ase', 'spglib') 

132 

133 if method not in ALLOWED_METHODS: 

134 raise ValueError('Expected one of {} methods, got {}'.format( 

135 ALLOWED_METHODS, method)) 

136 

137 if method == 'auto': 

138 # Figure out which implementation we want to use automatically 

139 # Essentially figure out if we can use the spglib version or not 

140 use_spglib = _can_use_spglib(spacegroup=spacegroup) 

141 else: 

142 # User told us which implementation they wanted 

143 use_spglib = method == 'spglib' 

144 

145 if use_spglib: 

146 # Use the spglib implementation 

147 # Note, we do not pass the spacegroup, as the function cannot handle 

148 # an explicit space group right now. This may change in the future. 

149 return _get_basis_spglib(atoms, tol=tol) 

150 else: 

151 # Use the ASE native non-spglib version, since a specific 

152 # space group is requested 

153 if spacegroup is None: 

154 # We have reached this point either because spglib is not installed, 

155 # or ASE was explicitly required 

156 raise ValueError( 

157 'A space group must be specified for the native ASE ' 

158 'implementation. Try using the spglib version instead, ' 

159 'or explicitly specifying a space group.') 

160 return _get_basis_ase(atoms, spacegroup, tol=tol) 

161 

162 

163def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]: 

164 """Get a list of the reduced atomic indices using spglib. 

165 Note: Does no checks to see if spglib is installed. 

166 

167 :param atoms: ase Atoms object to reduce 

168 :param tol: ``float``, numeric tolerance for positional comparisons 

169 """ 

170 import spglib 

171 

172 # Create input for spglib 

173 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(), 

174 atoms.numbers) 

175 symmetry_data = spglib.get_symmetry_dataset(spglib_cell, symprec=tol) 

176 return list(set(symmetry_data['equivalent_atoms']))