Coverage for /builds/kinetik161/ase/ase/io/zmatrix.py: 96.12%

129 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import re 

2from collections import namedtuple 

3from numbers import Real 

4from string import digits 

5from typing import Dict, List, Optional, Tuple, Union 

6 

7import numpy as np 

8 

9from ase import Atoms 

10from ase.units import Angstrom, Bohr, nm 

11 

12# split on newlines or semicolons 

13_re_linesplit = re.compile(r'\n|;') 

14# split definitions on whitespace or on "=" (possibly also with whitespace) 

15_re_defs = re.compile(r'\s*=\s*|\s+') 

16 

17 

18_ZMatrixRow = namedtuple( 

19 '_ZMatrixRow', 'ind1 dist ind2 a_bend ind3 a_dihedral', 

20) 

21 

22 

23ThreeFloats = Union[Tuple[float, float, float], np.ndarray] 

24 

25 

26class _ZMatrixToAtoms: 

27 known_units = dict( 

28 distance={'angstrom': Angstrom, 'bohr': Bohr, 'au': Bohr, 'nm': nm}, 

29 angle={'radians': 1., 'degrees': np.pi / 180}, 

30 ) 

31 

32 def __init__(self, dconv: Union[str, Real], aconv: Union[str, Real], 

33 defs: Optional[Union[Dict[str, float], 

34 str, List[str]]] = None) -> None: 

35 self.dconv = self.get_units('distance', dconv) # type: float 

36 self.aconv = self.get_units('angle', aconv) # type: float 

37 self.set_defs(defs) 

38 self.name_to_index: Optional[Dict[str, int]] = {} 

39 self.symbols: List[str] = [] 

40 self.positions: List[ThreeFloats] = [] 

41 

42 @property 

43 def nrows(self): 

44 return len(self.symbols) 

45 

46 def get_units(self, kind: str, value: Union[str, Real]) -> float: 

47 if isinstance(value, Real): 

48 return float(value) 

49 out = self.known_units[kind].get(value.lower()) 

50 if out is None: 

51 raise ValueError("Unknown {} units: {}" 

52 .format(kind, value)) 

53 return out 

54 

55 def set_defs(self, defs: Union[Dict[str, float], str, 

56 List[str], None]) -> None: 

57 self.defs = {} # type: Dict[str, float] 

58 if defs is None: 

59 return 

60 

61 if isinstance(defs, dict): 

62 self.defs.update(**defs) 

63 return 

64 

65 if isinstance(defs, str): 

66 defs = _re_linesplit.split(defs.strip()) 

67 

68 for row in defs: 

69 key, val = _re_defs.split(row) 

70 self.defs[key] = self.get_var(val) 

71 

72 def get_var(self, val: str) -> float: 

73 try: 

74 return float(val) 

75 except ValueError as e: 

76 val_out = self.defs.get(val.lstrip('+-')) 

77 if val_out is None: 

78 raise ValueError('Invalid value encountered in Z-matrix: {}' 

79 .format(val)) from e 

80 return val_out * (-1 if val.startswith('-') else 1) 

81 

82 def get_index(self, name: str) -> int: 

83 """Find index for a given atom name""" 

84 try: 

85 return int(name) - 1 

86 except ValueError as e: 

87 if self.name_to_index is None or name not in self.name_to_index: 

88 raise ValueError('Failed to determine index for name "{}"' 

89 .format(name)) from e 

90 return self.name_to_index[name] 

91 

92 def set_index(self, name: str) -> None: 

93 """Assign index to a given atom name for name -> index lookup""" 

94 if self.name_to_index is None: 

95 return 

96 

97 if name in self.name_to_index: 

98 # "name" has been encountered before, so name_to_index is no 

99 # longer meaningful. Destroy the map. 

100 self.name_to_index = None 

101 return 

102 

103 self.name_to_index[name] = self.nrows 

104 

105 def validate_indices(self, *indices: int) -> None: 

106 """Raises an error if indices in a Z-matrix row are invalid.""" 

107 if any(np.array(indices) >= self.nrows): 

108 raise ValueError('An invalid Z-matrix was provided! Row {} refers ' 

109 'to atom indices {}, at least one of which ' 

110 "hasn't been defined yet!" 

111 .format(self.nrows, indices)) 

112 

113 if len(indices) != len(set(indices)): 

114 raise ValueError('An atom index has been used more than once a ' 

115 'row of the Z-matrix! Row numbers {}, ' 

116 'referred indices: {}' 

117 .format(self.nrows, indices)) 

118 

119 def parse_row(self, row: str) -> Tuple[ 

120 str, Union[_ZMatrixRow, ThreeFloats], 

121 ]: 

122 tokens = row.split() 

123 name = tokens[0] 

124 self.set_index(name) 

125 if len(tokens) == 1: 

126 assert self.nrows == 0 

127 return name, np.zeros(3, dtype=float) 

128 

129 ind1 = self.get_index(tokens[1]) 

130 if ind1 == -1: 

131 assert len(tokens) == 5 

132 return name, np.array(list(map(self.get_var, tokens[2:])), 

133 dtype=float) 

134 

135 dist = self.dconv * self.get_var(tokens[2]) 

136 

137 if len(tokens) == 3: 

138 assert self.nrows == 1 

139 self.validate_indices(ind1) 

140 return name, np.array([dist, 0, 0], dtype=float) 

141 

142 ind2 = self.get_index(tokens[3]) 

143 a_bend = self.aconv * self.get_var(tokens[4]) 

144 

145 if len(tokens) == 5: 

146 assert self.nrows == 2 

147 self.validate_indices(ind1, ind2) 

148 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, None, None) 

149 

150 ind3 = self.get_index(tokens[5]) 

151 a_dihedral = self.aconv * self.get_var(tokens[6]) 

152 self.validate_indices(ind1, ind2, ind3) 

153 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, ind3, 

154 a_dihedral) 

155 

156 def add_atom(self, name: str, pos: ThreeFloats) -> None: 

157 """Sets the symbol and position of an atom.""" 

158 self.symbols.append( 

159 ''.join([c for c in name if c not in digits]).capitalize() 

160 ) 

161 self.positions.append(pos) 

162 

163 def add_row(self, row: str) -> None: 

164 name, zrow = self.parse_row(row) 

165 

166 if not isinstance(zrow, _ZMatrixRow): 

167 self.add_atom(name, zrow) 

168 return 

169 

170 if zrow.ind3 is None: 

171 # This is the third atom, so only a bond distance and an angle 

172 # have been provided. 

173 pos = self.positions[zrow.ind1].copy() 

174 pos[0] += zrow.dist * np.cos(zrow.a_bend) * (zrow.ind2 - zrow.ind1) 

175 pos[1] += zrow.dist * np.sin(zrow.a_bend) 

176 self.add_atom(name, pos) 

177 return 

178 

179 # ax1 is the dihedral axis, which is defined by the bond vector 

180 # between the two inner atoms in the dihedral, ind1 and ind2 

181 ax1 = self.positions[zrow.ind2] - self.positions[zrow.ind1] 

182 ax1 /= np.linalg.norm(ax1) 

183 

184 # ax2 lies within the 1-2-3 plane, and it is perpendicular 

185 # to the dihedral axis 

186 ax2 = self.positions[zrow.ind2] - self.positions[zrow.ind3] 

187 ax2 -= ax1 * (ax2 @ ax1) 

188 ax2 /= np.linalg.norm(ax2) 

189 

190 # ax3 is a vector that forms the appropriate dihedral angle, though 

191 # the bending angle is 90 degrees, rather than a_bend. It is formed 

192 # from a linear combination of ax2 and (ax2 x ax1) 

193 ax3 = (ax2 * np.cos(zrow.a_dihedral) 

194 + np.cross(ax2, ax1) * np.sin(zrow.a_dihedral)) 

195 

196 # The final position vector is a linear combination of ax1 and ax3. 

197 pos = ax1 * np.cos(zrow.a_bend) - ax3 * np.sin(zrow.a_bend) 

198 pos *= zrow.dist / np.linalg.norm(pos) 

199 pos += self.positions[zrow.ind1] 

200 self.add_atom(name, pos) 

201 

202 def to_atoms(self) -> Atoms: 

203 return Atoms(self.symbols, self.positions) 

204 

205 

206def parse_zmatrix(zmat: Union[str, List[str]], 

207 distance_units: Union[str, Real] = 'angstrom', 

208 angle_units: Union[str, Real] = 'degrees', 

209 defs: Optional[Union[Dict[str, float], str, 

210 List[str]]] = None) -> Atoms: 

211 """Converts a Z-matrix into an Atoms object. 

212 

213 Parameters: 

214 

215 zmat: Iterable or str 

216 The Z-matrix to be parsed. Iteration over `zmat` should yield the rows 

217 of the Z-matrix. If `zmat` is a str, it will be automatically split 

218 into a list at newlines. 

219 distance_units: str or float, optional 

220 The units of distance in the provided Z-matrix. 

221 Defaults to Angstrom. 

222 angle_units: str or float, optional 

223 The units for angles in the provided Z-matrix. 

224 Defaults to degrees. 

225 defs: dict or str, optional 

226 If `zmat` contains symbols for bond distances, bending angles, and/or 

227 dihedral angles instead of numeric values, then the definition of 

228 those symbols should be passed to this function using this keyword 

229 argument. 

230 Note: The symbol definitions are typically printed adjacent to the 

231 Z-matrix itself, but this function will not automatically separate 

232 the symbol definitions from the Z-matrix. 

233 

234 Returns: 

235 

236 atoms: Atoms object 

237 """ 

238 zmatrix = _ZMatrixToAtoms(distance_units, angle_units, defs=defs) 

239 

240 # zmat should be a list containing the rows of the z-matrix. 

241 # for convenience, allow block strings and split at newlines. 

242 if isinstance(zmat, str): 

243 zmat = _re_linesplit.split(zmat.strip()) 

244 

245 for row in zmat: 

246 zmatrix.add_row(row) 

247 

248 return zmatrix.to_atoms()