Coverage for /builds/kinetik161/ase/ase/io/zmatrix.py: 96.12%
129 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import re
2from collections import namedtuple
3from numbers import Real
4from string import digits
5from typing import Dict, List, Optional, Tuple, Union
7import numpy as np
9from ase import Atoms
10from ase.units import Angstrom, Bohr, nm
12# split on newlines or semicolons
13_re_linesplit = re.compile(r'\n|;')
14# split definitions on whitespace or on "=" (possibly also with whitespace)
15_re_defs = re.compile(r'\s*=\s*|\s+')
18_ZMatrixRow = namedtuple(
19 '_ZMatrixRow', 'ind1 dist ind2 a_bend ind3 a_dihedral',
20)
23ThreeFloats = Union[Tuple[float, float, float], np.ndarray]
26class _ZMatrixToAtoms:
27 known_units = dict(
28 distance={'angstrom': Angstrom, 'bohr': Bohr, 'au': Bohr, 'nm': nm},
29 angle={'radians': 1., 'degrees': np.pi / 180},
30 )
32 def __init__(self, dconv: Union[str, Real], aconv: Union[str, Real],
33 defs: Optional[Union[Dict[str, float],
34 str, List[str]]] = None) -> None:
35 self.dconv = self.get_units('distance', dconv) # type: float
36 self.aconv = self.get_units('angle', aconv) # type: float
37 self.set_defs(defs)
38 self.name_to_index: Optional[Dict[str, int]] = {}
39 self.symbols: List[str] = []
40 self.positions: List[ThreeFloats] = []
42 @property
43 def nrows(self):
44 return len(self.symbols)
46 def get_units(self, kind: str, value: Union[str, Real]) -> float:
47 if isinstance(value, Real):
48 return float(value)
49 out = self.known_units[kind].get(value.lower())
50 if out is None:
51 raise ValueError("Unknown {} units: {}"
52 .format(kind, value))
53 return out
55 def set_defs(self, defs: Union[Dict[str, float], str,
56 List[str], None]) -> None:
57 self.defs = {} # type: Dict[str, float]
58 if defs is None:
59 return
61 if isinstance(defs, dict):
62 self.defs.update(**defs)
63 return
65 if isinstance(defs, str):
66 defs = _re_linesplit.split(defs.strip())
68 for row in defs:
69 key, val = _re_defs.split(row)
70 self.defs[key] = self.get_var(val)
72 def get_var(self, val: str) -> float:
73 try:
74 return float(val)
75 except ValueError as e:
76 val_out = self.defs.get(val.lstrip('+-'))
77 if val_out is None:
78 raise ValueError('Invalid value encountered in Z-matrix: {}'
79 .format(val)) from e
80 return val_out * (-1 if val.startswith('-') else 1)
82 def get_index(self, name: str) -> int:
83 """Find index for a given atom name"""
84 try:
85 return int(name) - 1
86 except ValueError as e:
87 if self.name_to_index is None or name not in self.name_to_index:
88 raise ValueError('Failed to determine index for name "{}"'
89 .format(name)) from e
90 return self.name_to_index[name]
92 def set_index(self, name: str) -> None:
93 """Assign index to a given atom name for name -> index lookup"""
94 if self.name_to_index is None:
95 return
97 if name in self.name_to_index:
98 # "name" has been encountered before, so name_to_index is no
99 # longer meaningful. Destroy the map.
100 self.name_to_index = None
101 return
103 self.name_to_index[name] = self.nrows
105 def validate_indices(self, *indices: int) -> None:
106 """Raises an error if indices in a Z-matrix row are invalid."""
107 if any(np.array(indices) >= self.nrows):
108 raise ValueError('An invalid Z-matrix was provided! Row {} refers '
109 'to atom indices {}, at least one of which '
110 "hasn't been defined yet!"
111 .format(self.nrows, indices))
113 if len(indices) != len(set(indices)):
114 raise ValueError('An atom index has been used more than once a '
115 'row of the Z-matrix! Row numbers {}, '
116 'referred indices: {}'
117 .format(self.nrows, indices))
119 def parse_row(self, row: str) -> Tuple[
120 str, Union[_ZMatrixRow, ThreeFloats],
121 ]:
122 tokens = row.split()
123 name = tokens[0]
124 self.set_index(name)
125 if len(tokens) == 1:
126 assert self.nrows == 0
127 return name, np.zeros(3, dtype=float)
129 ind1 = self.get_index(tokens[1])
130 if ind1 == -1:
131 assert len(tokens) == 5
132 return name, np.array(list(map(self.get_var, tokens[2:])),
133 dtype=float)
135 dist = self.dconv * self.get_var(tokens[2])
137 if len(tokens) == 3:
138 assert self.nrows == 1
139 self.validate_indices(ind1)
140 return name, np.array([dist, 0, 0], dtype=float)
142 ind2 = self.get_index(tokens[3])
143 a_bend = self.aconv * self.get_var(tokens[4])
145 if len(tokens) == 5:
146 assert self.nrows == 2
147 self.validate_indices(ind1, ind2)
148 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, None, None)
150 ind3 = self.get_index(tokens[5])
151 a_dihedral = self.aconv * self.get_var(tokens[6])
152 self.validate_indices(ind1, ind2, ind3)
153 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, ind3,
154 a_dihedral)
156 def add_atom(self, name: str, pos: ThreeFloats) -> None:
157 """Sets the symbol and position of an atom."""
158 self.symbols.append(
159 ''.join([c for c in name if c not in digits]).capitalize()
160 )
161 self.positions.append(pos)
163 def add_row(self, row: str) -> None:
164 name, zrow = self.parse_row(row)
166 if not isinstance(zrow, _ZMatrixRow):
167 self.add_atom(name, zrow)
168 return
170 if zrow.ind3 is None:
171 # This is the third atom, so only a bond distance and an angle
172 # have been provided.
173 pos = self.positions[zrow.ind1].copy()
174 pos[0] += zrow.dist * np.cos(zrow.a_bend) * (zrow.ind2 - zrow.ind1)
175 pos[1] += zrow.dist * np.sin(zrow.a_bend)
176 self.add_atom(name, pos)
177 return
179 # ax1 is the dihedral axis, which is defined by the bond vector
180 # between the two inner atoms in the dihedral, ind1 and ind2
181 ax1 = self.positions[zrow.ind2] - self.positions[zrow.ind1]
182 ax1 /= np.linalg.norm(ax1)
184 # ax2 lies within the 1-2-3 plane, and it is perpendicular
185 # to the dihedral axis
186 ax2 = self.positions[zrow.ind2] - self.positions[zrow.ind3]
187 ax2 -= ax1 * (ax2 @ ax1)
188 ax2 /= np.linalg.norm(ax2)
190 # ax3 is a vector that forms the appropriate dihedral angle, though
191 # the bending angle is 90 degrees, rather than a_bend. It is formed
192 # from a linear combination of ax2 and (ax2 x ax1)
193 ax3 = (ax2 * np.cos(zrow.a_dihedral)
194 + np.cross(ax2, ax1) * np.sin(zrow.a_dihedral))
196 # The final position vector is a linear combination of ax1 and ax3.
197 pos = ax1 * np.cos(zrow.a_bend) - ax3 * np.sin(zrow.a_bend)
198 pos *= zrow.dist / np.linalg.norm(pos)
199 pos += self.positions[zrow.ind1]
200 self.add_atom(name, pos)
202 def to_atoms(self) -> Atoms:
203 return Atoms(self.symbols, self.positions)
206def parse_zmatrix(zmat: Union[str, List[str]],
207 distance_units: Union[str, Real] = 'angstrom',
208 angle_units: Union[str, Real] = 'degrees',
209 defs: Optional[Union[Dict[str, float], str,
210 List[str]]] = None) -> Atoms:
211 """Converts a Z-matrix into an Atoms object.
213 Parameters:
215 zmat: Iterable or str
216 The Z-matrix to be parsed. Iteration over `zmat` should yield the rows
217 of the Z-matrix. If `zmat` is a str, it will be automatically split
218 into a list at newlines.
219 distance_units: str or float, optional
220 The units of distance in the provided Z-matrix.
221 Defaults to Angstrom.
222 angle_units: str or float, optional
223 The units for angles in the provided Z-matrix.
224 Defaults to degrees.
225 defs: dict or str, optional
226 If `zmat` contains symbols for bond distances, bending angles, and/or
227 dihedral angles instead of numeric values, then the definition of
228 those symbols should be passed to this function using this keyword
229 argument.
230 Note: The symbol definitions are typically printed adjacent to the
231 Z-matrix itself, but this function will not automatically separate
232 the symbol definitions from the Z-matrix.
234 Returns:
236 atoms: Atoms object
237 """
238 zmatrix = _ZMatrixToAtoms(distance_units, angle_units, defs=defs)
240 # zmat should be a list containing the rows of the z-matrix.
241 # for convenience, allow block strings and split at newlines.
242 if isinstance(zmat, str):
243 zmat = _re_linesplit.split(zmat.strip())
245 for row in zmat:
246 zmatrix.add_row(row)
248 return zmatrix.to_atoms()