Coverage for /builds/kinetik161/ase/ase/utils/structure_comparator.py: 95.22%
293 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""Determine symmetry equivalence of two structures.
2Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012)."""
3from collections import Counter
4from itertools import combinations, filterfalse, product
6import numpy as np
7from scipy.spatial import cKDTree as KDTree
9from ase import Atom, Atoms
10from ase.build.tools import niggli_reduce
13def normalize(cell):
14 for i in range(3):
15 cell[i] /= np.linalg.norm(cell[i])
18class SpgLibNotFoundError(Exception):
19 """Raised if SPG lib is not found when needed."""
21 def __init__(self, msg):
22 super().__init__(msg)
25class SymmetryEquivalenceCheck:
26 """Compare two structures to determine if they are symmetry equivalent.
28 Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012).
30 Parameters:
32 angle_tol: float
33 angle tolerance for the lattice vectors in degrees
35 ltol: float
36 relative tolerance for the length of the lattice vectors (per atom)
38 stol: float
39 position tolerance for the site comparison in units of
40 (V/N)^(1/3) (average length between atoms)
42 vol_tol: float
43 volume tolerance in angstrom cubed to compare the volumes of
44 the two structures
46 scale_volume: bool
47 if True the volumes of the two structures are scaled to be equal
49 to_primitive: bool
50 if True the structures are reduced to their primitive cells
51 note that this feature requires spglib to installed
53 Examples:
55 >>> from ase.build import bulk
56 >>> from ase.utils.structure_comparator import SymmetryEquivalenceCheck
57 >>> comp = SymmetryEquivalenceCheck()
59 Compare a cell with a rotated version
61 >>> a = bulk('Al', orthorhombic=True)
62 >>> b = a.copy()
63 >>> b.rotate(60, 'x', rotate_cell=True)
64 >>> comp.compare(a, b)
65 True
67 Transform to the primitive cell and then compare
69 >>> pa = bulk('Al')
70 >>> comp.compare(a, pa)
71 False
72 >>> comp = SymmetryEquivalenceCheck(to_primitive=True)
73 >>> comp.compare(a, pa)
74 True
76 Compare one structure with a list of other structures
78 >>> import numpy as np
79 >>> from ase import Atoms
80 >>> s1 = Atoms('H3', positions=[[0.5, 0.5, 0],
81 ... [0.5, 1.5, 0],
82 ... [1.5, 1.5, 0]],
83 ... cell=[2, 2, 2], pbc=True)
84 >>> comp = SymmetryEquivalenceCheck(stol=0.068)
85 >>> s2_list = []
86 >>> for d in np.linspace(0.1, 1.0, 5):
87 ... s2 = s1.copy()
88 ... s2.positions[0] += [d, 0, 0]
89 ... s2_list.append(s2)
90 >>> comp.compare(s1, s2_list[:-1])
91 False
92 >>> comp.compare(s1, s2_list)
93 True
95 """
97 def __init__(self, angle_tol=1.0, ltol=0.05, stol=0.05, vol_tol=0.1,
98 scale_volume=False, to_primitive=False):
99 self.angle_tol = angle_tol * np.pi / 180.0 # convert to radians
100 self.scale_volume = scale_volume
101 self.stol = stol
102 self.ltol = ltol
103 self.vol_tol = vol_tol
104 self.position_tolerance = 0.0
105 self.to_primitive = to_primitive
107 # Variables to be used in the compare function
108 self.s1 = None
109 self.s2 = None
110 self.expanded_s1 = None
111 self.expanded_s2 = None
112 self.least_freq_element = None
114 def _niggli_reduce(self, atoms):
115 """Reduce to niggli cells.
117 Reduce the atoms to niggli cells, then rotates the niggli cells to
118 the so called "standard" orientation with one lattice vector along the
119 x-axis and a second vector in the xy plane.
120 """
121 niggli_reduce(atoms)
122 self._standarize_cell(atoms)
124 def _standarize_cell(self, atoms):
125 """Rotate the first vector such that it points along the x-axis.
126 Then rotate around the first vector so the second vector is in the
127 xy plane.
128 """
129 # Rotate first vector to x axis
130 cell = atoms.get_cell().T
131 total_rot_mat = np.eye(3)
132 v1 = cell[:, 0]
133 l1 = np.sqrt(v1[0]**2 + v1[2]**2)
134 angle = np.abs(np.arcsin(v1[2] / l1))
135 if (v1[0] < 0.0 and v1[2] > 0.0):
136 angle = np.pi - angle
137 elif (v1[0] < 0.0 and v1[2] < 0.0):
138 angle = np.pi + angle
139 elif (v1[0] > 0.0 and v1[2] < 0.0):
140 angle = -angle
141 ca = np.cos(angle)
142 sa = np.sin(angle)
143 rotmat = np.array([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]])
144 total_rot_mat = rotmat.dot(total_rot_mat)
145 cell = rotmat.dot(cell)
147 v1 = cell[:, 0]
148 l1 = np.sqrt(v1[0]**2 + v1[1]**2)
149 angle = np.abs(np.arcsin(v1[1] / l1))
150 if (v1[0] < 0.0 and v1[1] > 0.0):
151 angle = np.pi - angle
152 elif (v1[0] < 0.0 and v1[1] < 0.0):
153 angle = np.pi + angle
154 elif (v1[0] > 0.0 and v1[1] < 0.0):
155 angle = -angle
156 ca = np.cos(angle)
157 sa = np.sin(angle)
158 rotmat = np.array([[ca, sa, 0.0], [-sa, ca, 0.0], [0.0, 0.0, 1.0]])
159 total_rot_mat = rotmat.dot(total_rot_mat)
160 cell = rotmat.dot(cell)
162 # Rotate around x axis such that the second vector is in the xy plane
163 v2 = cell[:, 1]
164 l2 = np.sqrt(v2[1]**2 + v2[2]**2)
165 angle = np.abs(np.arcsin(v2[2] / l2))
166 if (v2[1] < 0.0 and v2[2] > 0.0):
167 angle = np.pi - angle
168 elif (v2[1] < 0.0 and v2[2] < 0.0):
169 angle = np.pi + angle
170 elif (v2[1] > 0.0 and v2[2] < 0.0):
171 angle = -angle
172 ca = np.cos(angle)
173 sa = np.sin(angle)
174 rotmat = np.array([[1.0, 0.0, 0.0], [0.0, ca, sa], [0.0, -sa, ca]])
175 total_rot_mat = rotmat.dot(total_rot_mat)
176 cell = rotmat.dot(cell)
178 atoms.set_cell(cell.T)
179 atoms.set_positions(total_rot_mat.dot(atoms.get_positions().T).T)
180 atoms.wrap(pbc=[1, 1, 1])
181 return atoms
183 def _get_element_count(self, struct):
184 """Count the number of elements in each of the structures."""
185 return Counter(struct.numbers)
187 def _get_angles(self, cell):
188 """Get the internal angles of the unit cell."""
189 cell = cell.copy()
191 normalize(cell)
193 dot = cell.dot(cell.T)
195 # Extract only the relevant dot products
196 dot = [dot[0, 1], dot[0, 2], dot[1, 2]]
198 # Return angles
199 return np.arccos(dot)
201 def _has_same_elements(self):
202 """Check if two structures have same elements."""
203 elem1 = self._get_element_count(self.s1)
204 return elem1 == self._get_element_count(self.s2)
206 def _has_same_angles(self):
207 """Check that the Niggli unit vectors has the same internal angles."""
208 ang1 = np.sort(self._get_angles(self.s1.get_cell()))
209 ang2 = np.sort(self._get_angles(self.s2.get_cell()))
211 return np.allclose(ang1, ang2, rtol=0, atol=self.angle_tol)
213 def _has_same_volume(self):
214 vol1 = self.s1.get_volume()
215 vol2 = self.s2.get_volume()
216 return np.abs(vol1 - vol2) < self.vol_tol
218 def _scale_volumes(self):
219 """Scale the cell of s2 to have the same volume as s1."""
220 cell2 = self.s2.get_cell()
221 # Get the volumes
222 v2 = np.linalg.det(cell2)
223 v1 = np.linalg.det(self.s1.get_cell())
225 # Scale the cells
226 coordinate_scaling = (v1 / v2)**(1.0 / 3.0)
227 cell2 *= coordinate_scaling
228 self.s2.set_cell(cell2, scale_atoms=True)
230 def compare(self, s1, s2):
231 """Compare the two structures.
233 Return *True* if the two structures are equivalent, *False* otherwise.
235 Parameters:
237 s1: Atoms object.
238 Transformation matrices are calculated based on this structure.
240 s2: Atoms or list
241 s1 can be compared to one structure or many structures supplied in
242 a list. If s2 is a list it returns True if any structure in s2
243 matches s1, False otherwise.
244 """
245 if self.to_primitive:
246 s1 = self._reduce_to_primitive(s1)
247 self._set_least_frequent_element(s1)
248 self._least_frequent_element_to_origin(s1)
249 self.s1 = s1.copy()
250 vol = self.s1.get_volume()
251 self.expanded_s1 = None
252 s1_niggli_reduced = False
254 if isinstance(s2, Atoms):
255 # Just make it a list of length 1
256 s2 = [s2]
258 matrices = None
259 translations = None
260 transposed_matrices = None
261 for struct in s2:
262 self.s2 = struct.copy()
263 self.expanded_s2 = None
265 if self.to_primitive:
266 self.s2 = self._reduce_to_primitive(self.s2)
268 # Compare number of elements in structures
269 if len(self.s1) != len(self.s2):
270 continue
272 # Compare chemical formulae
273 if not self._has_same_elements():
274 continue
276 # Compare angles
277 if not s1_niggli_reduced:
278 self._niggli_reduce(self.s1)
279 self._niggli_reduce(self.s2)
280 if not self._has_same_angles():
281 continue
283 # Compare volumes
284 if self.scale_volume:
285 self._scale_volumes()
286 if not self._has_same_volume():
287 continue
289 if matrices is None:
290 matrices = self._get_rotation_reflection_matrices()
291 if matrices is None:
292 continue
294 if translations is None:
295 translations = self._get_least_frequent_positions(self.s1)
297 # After the candidate translation based on s1 has been computed
298 # we need potentially to swap s1 and s2 for robust comparison
299 self._least_frequent_element_to_origin(self.s2)
300 switch = self._switch_reference_struct()
301 if switch:
302 # Remember the matrices and translations used before
303 old_matrices = matrices
304 old_translations = translations
306 # If a s1 and s2 has been switched we need to use the
307 # transposed version of the matrices to map atoms the
308 # other way
309 if transposed_matrices is None:
310 transposed_matrices = np.transpose(matrices,
311 axes=[0, 2, 1])
312 matrices = transposed_matrices
313 translations = self._get_least_frequent_positions(self.s1)
315 # Calculate tolerance on positions
316 self.position_tolerance = \
317 self.stol * (vol / len(self.s2))**(1.0 / 3.0)
319 if self._positions_match(matrices, translations):
320 return True
322 # Set the reference structure back to its original
323 self.s1 = s1.copy()
324 if switch:
325 self.expanded_s1 = self.expanded_s2
326 matrices = old_matrices
327 translations = old_translations
328 return False
330 def _set_least_frequent_element(self, atoms):
331 """Save the atomic number of the least frequent element."""
332 elem1 = self._get_element_count(atoms)
333 self.least_freq_element = elem1.most_common()[-1][0]
335 def _get_least_frequent_positions(self, atoms):
336 """Get the positions of the least frequent element in atoms."""
337 pos = atoms.get_positions(wrap=True)
338 return pos[atoms.numbers == self.least_freq_element]
340 def _get_only_least_frequent_of(self, struct):
341 """Get the atoms object with all other elements than the least frequent
342 one removed. Wrap the positions to get everything in the cell."""
343 pos = struct.get_positions(wrap=True)
345 indices = struct.numbers == self.least_freq_element
346 least_freq_struct = struct[indices]
347 least_freq_struct.set_positions(pos[indices])
349 return least_freq_struct
351 def _switch_reference_struct(self):
352 """There is an intrinsic assymetry in the system because
353 one of the atoms are being expanded, while the other is not.
354 This can cause the algorithm to return different result
355 depending on which structure is passed first.
356 We adopt the convention of using the atoms object
357 having the fewest atoms in its expanded cell as the
358 reference object.
359 We return True if a switch of structures has been performed."""
361 # First expand the cells
362 if self.expanded_s1 is None:
363 self.expanded_s1 = self._expand(self.s1)
364 if self.expanded_s2 is None:
365 self.expanded_s2 = self._expand(self.s2)
367 exp1 = self.expanded_s1
368 exp2 = self.expanded_s2
369 if len(exp1) < len(exp2):
370 # s1 should be the reference structure
371 # We have to swap s1 and s2
372 s1_temp = self.s1.copy()
373 self.s1 = self.s2
374 self.s2 = s1_temp
375 exp1_temp = self.expanded_s1.copy()
376 self.expanded_s1 = self.expanded_s2
377 self.expanded_s2 = exp1_temp
378 return True
379 return False
381 def _positions_match(self, rotation_reflection_matrices, translations):
382 """Check if the position and elements match.
384 Note that this function changes self.s1 and self.s2 to the rotation and
385 translation that matches best. Hence, it is crucial that this function
386 calls the element comparison, not the other way around.
387 """
388 pos1_ref = self.s1.get_positions(wrap=True)
390 # Get the expanded reference object
391 exp2 = self.expanded_s2
392 # Build a KD tree to enable fast look-up of nearest neighbours
393 tree = KDTree(exp2.get_positions())
394 for i in range(translations.shape[0]):
395 # Translate
396 pos1_trans = pos1_ref - translations[i]
397 for matrix in rotation_reflection_matrices:
398 # Rotate
399 pos1 = matrix.dot(pos1_trans.T).T
401 # Update the atoms positions
402 self.s1.set_positions(pos1)
403 self.s1.wrap(pbc=[1, 1, 1])
404 if self._elements_match(self.s1, exp2, tree):
405 return True
406 return False
408 def _expand(self, ref_atoms, tol=0.0001):
409 """If an atom is closer to a boundary than tol it is repeated at the
410 opposite boundaries.
412 This ensures that atoms having crossed the cell boundaries due to
413 numerical noise are properly detected.
415 The distance between a position and cell boundary is calculated as:
416 dot(position, (b_vec x c_vec) / (|b_vec| |c_vec|) ), where x is the
417 cross product.
418 """
419 syms = ref_atoms.get_chemical_symbols()
420 cell = ref_atoms.get_cell()
421 positions = ref_atoms.get_positions(wrap=True)
422 expanded_atoms = ref_atoms.copy()
424 # Calculate normal vectors to the unit cell faces
425 normal_vectors = np.array([np.cross(cell[1, :], cell[2, :]),
426 np.cross(cell[0, :], cell[2, :]),
427 np.cross(cell[0, :], cell[1, :])])
428 normalize(normal_vectors)
430 # Get the distance to the unit cell faces from each atomic position
431 pos2faces = np.abs(positions.dot(normal_vectors.T))
433 # And the opposite faces
434 pos2oppofaces = np.abs(np.dot(positions - np.sum(cell, axis=0),
435 normal_vectors.T))
437 for i, i2face in enumerate(pos2faces):
438 # Append indices for positions close to the other faces
439 # and convert to boolean array signifying if the position at
440 # index i is close to the faces bordering origo (0, 1, 2) or
441 # the opposite faces (3, 4, 5)
442 i_close2face = np.append(i2face, pos2oppofaces[i]) < tol
443 # For each position i.e. row it holds that
444 # 1 x True -> close to face -> 1 extra atom at opposite face
445 # 2 x True -> close to edge -> 3 extra atoms at opposite edges
446 # 3 x True -> close to corner -> 7 extra atoms opposite corners
447 # E.g. to add atoms at all corners we need to use the cell
448 # vectors: (a, b, c, a + b, a + c, b + c, a + b + c), we use
449 # itertools.combinations to get them all
450 for j in range(sum(i_close2face)):
451 for c in combinations(np.nonzero(i_close2face)[0], j + 1):
452 # Get the displacement vectors by adding the corresponding
453 # cell vectors, if the atom is close to an opposite face
454 # i.e. k > 2 subtract the cell vector
455 disp_vec = np.zeros(3)
456 for k in c:
457 disp_vec += cell[k % 3] * (int(k < 3) * 2 - 1)
458 pos = positions[i] + disp_vec
459 expanded_atoms.append(Atom(syms[i], position=pos))
460 return expanded_atoms
462 def _equal_elements_in_array(self, arr):
463 s = np.sort(arr)
464 return np.any(s[1:] == s[:-1])
466 def _elements_match(self, s1, s2, kdtree):
467 """Check if all the elements in s1 match corresponding position in s2
469 NOTE: The unit cells may be in different octants
470 Hence, try all cyclic permutations of x,y and z
471 """
472 pos1 = s1.get_positions()
473 for order in range(1): # Is the order still needed?
474 pos_order = [order, (order + 1) % 3, (order + 2) % 3]
475 pos = pos1[:, np.argsort(pos_order)]
476 dists, closest_in_s2 = kdtree.query(pos)
478 # Check if the elements are the same
479 if not np.all(s2.numbers[closest_in_s2] == s1.numbers):
480 return False
482 # Check if any distance is too large
483 if np.any(dists > self.position_tolerance):
484 return False
486 # Check for duplicates in what atom is closest
487 if self._equal_elements_in_array(closest_in_s2):
488 return False
490 return True
492 def _least_frequent_element_to_origin(self, atoms):
493 """Put one of the least frequent elements at the origin."""
494 least_freq_pos = self._get_least_frequent_positions(atoms)
495 cell_diag = np.sum(atoms.get_cell(), axis=0)
496 d = least_freq_pos[0] - 1e-6 * cell_diag
497 atoms.positions -= d
498 atoms.wrap(pbc=[1, 1, 1])
500 def _get_rotation_reflection_matrices(self):
501 """Compute candidates for the transformation matrix."""
502 atoms1_ref = self._get_only_least_frequent_of(self.s1)
503 cell = self.s1.get_cell().T
504 cell_diag = np.sum(cell, axis=1)
505 angle_tol = self.angle_tol
507 # Additional vector that is added to make sure that
508 # there always is an atom at the origin
509 delta_vec = 1E-6 * cell_diag
511 # Store three reference vectors and their lengths
512 ref_vec = self.s2.get_cell()
513 ref_vec_lengths = np.linalg.norm(ref_vec, axis=1)
515 # Compute ref vec angles
516 # ref_angles are arranged as [angle12, angle13, angle23]
517 ref_angles = np.array(self._get_angles(ref_vec))
518 large_angles = ref_angles > np.pi / 2.0
519 ref_angles[large_angles] = np.pi - ref_angles[large_angles]
521 # Translate by one cell diagonal so that a central cell is
522 # surrounded by cells in all directions
523 sc_atom_search = atoms1_ref * (3, 3, 3)
524 new_sc_pos = sc_atom_search.get_positions()
525 new_sc_pos -= new_sc_pos[0] + cell_diag - delta_vec
527 lengths = np.linalg.norm(new_sc_pos, axis=1)
529 candidate_indices = []
530 rtol = self.ltol / len(self.s1)
531 for k in range(3):
532 correct_lengths_mask = np.isclose(lengths,
533 ref_vec_lengths[k],
534 rtol=rtol, atol=0)
535 # The first vector is not interesting
536 correct_lengths_mask[0] = False
538 # If no trial vectors can be found (for any direction)
539 # then the candidates are different and we return None
540 if not np.any(correct_lengths_mask):
541 return None
543 candidate_indices.append(np.nonzero(correct_lengths_mask)[0])
545 # Now we calculate all relevant angles in one step. The relevant angles
546 # are the ones made by the current candidates. We will have to keep
547 # track of the indices in the angles matrix and the indices in the
548 # position and length arrays.
550 # Get all candidate indices (aci), only unique values
551 aci = np.sort(list(set().union(*candidate_indices)))
553 # Make a dictionary from original positions and lengths index to
554 # index in angle matrix
555 i2ang = dict(zip(aci, range(len(aci))))
557 # Calculate the dot product divided by the lengths:
558 # cos(angle) = dot(vec1, vec2) / |vec1| |vec2|
559 cosa = np.inner(new_sc_pos[aci],
560 new_sc_pos[aci]) / np.outer(lengths[aci],
561 lengths[aci])
562 # Make sure the inverse cosine will work
563 cosa[cosa > 1] = 1
564 cosa[cosa < -1] = -1
565 angles = np.arccos(cosa)
566 # Do trick for enantiomorphic structures
567 angles[angles > np.pi / 2] = np.pi - angles[angles > np.pi / 2]
569 # Check which angles match the reference angles
570 # Test for all combinations on candidates. filterfalse makes sure
571 # that there are no duplicate candidates. product is the same as
572 # nested for loops.
573 refined_candidate_list = []
574 for p in filterfalse(self._equal_elements_in_array,
575 product(*candidate_indices)):
576 a = np.array([angles[i2ang[p[0]], i2ang[p[1]]],
577 angles[i2ang[p[0]], i2ang[p[2]]],
578 angles[i2ang[p[1]], i2ang[p[2]]]])
580 if np.allclose(a, ref_angles, atol=angle_tol, rtol=0):
581 refined_candidate_list.append(new_sc_pos[np.array(p)].T)
583 # Get the rotation/reflection matrix [R] by:
584 # [R] = [V][T]^-1, where [V] is the reference vectors and
585 # [T] is the trial vectors
586 # XXX What do we know about the length/shape of refined_candidate_list?
587 if len(refined_candidate_list) == 0:
588 return None
589 else:
590 inverted_trial = np.linalg.inv(refined_candidate_list)
592 # Equivalent to np.matmul(ref_vec.T, inverted_trial)
593 candidate_trans_mat = np.dot(ref_vec.T, inverted_trial.T).T
594 return candidate_trans_mat
596 def _reduce_to_primitive(self, structure):
597 """Reduce the two structure to their primitive type"""
598 try:
599 import spglib
600 except ImportError:
601 raise SpgLibNotFoundError(
602 "SpgLib is required if to_primitive=True")
603 cell = (structure.get_cell()).tolist()
604 pos = structure.get_scaled_positions().tolist()
605 numbers = structure.get_atomic_numbers()
607 cell, scaled_pos, numbers = spglib.standardize_cell(
608 (cell, pos, numbers), to_primitive=True)
610 atoms = Atoms(
611 scaled_positions=scaled_pos,
612 numbers=numbers,
613 cell=cell,
614 pbc=True)
615 return atoms