Coverage for /builds/kinetik161/ase/ase/geometry/dimensionality/rank_determination.py: 99.17%
120 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""
2Implements the Rank Determination Algorithm (RDA)
4Method is described in:
5Definition of a scoring parameter to identify low-dimensional materials
6components
7P.M. Larsen, M. Pandey, M. Strange, and K. W. Jacobsen
8Phys. Rev. Materials 3 034003, 2019
9https://doi.org/10.1103/PhysRevMaterials.3.034003
10"""
11from collections import defaultdict
13import numpy as np
15from ase.geometry.dimensionality.disjoint_set import DisjointSet
17# Numpy has a large overhead for lots of small vectors. The cross product is
18# particularly bad. Pure python is a lot faster.
21def dot_product(A, B):
22 return sum([a * b for a, b in zip(A, B)])
25def cross_product(a, b):
26 return [a[i] * b[j] - a[j] * b[i] for i, j in [(1, 2), (2, 0), (0, 1)]]
29def subtract(A, B):
30 return [a - b for a, b in zip(A, B)]
33def rank_increase(a, b):
34 if len(a) == 0:
35 return True
36 elif len(a) == 1:
37 return a[0] != b
38 elif len(a) == 4:
39 return False
41 L = a + [b]
42 w = cross_product(subtract(L[1], L[0]), subtract(L[2], L[0]))
43 if len(a) == 2:
44 return any(w)
45 elif len(a) == 3:
46 return dot_product(w, subtract(L[3], L[0])) != 0
47 else:
48 raise Exception("This shouldn't be possible.")
51def bfs(adjacency, start):
52 """Traverse the component graph using BFS.
54 The graph is traversed until the matrix rank of the subspace spanned by
55 the visited components no longer increases.
56 """
57 visited = set()
58 cvisited = defaultdict(list)
59 queue = [(start, (0, 0, 0))]
60 while queue:
61 vertex = queue.pop(0)
62 if vertex in visited:
63 continue
65 visited.add(vertex)
66 c, p = vertex
67 if not rank_increase(cvisited[c], p):
68 continue
70 cvisited[c].append(p)
72 for nc, offset in adjacency[c]:
74 nbrpos = (p[0] + offset[0], p[1] + offset[1], p[2] + offset[2])
75 nbrnode = (nc, nbrpos)
76 if nbrnode in visited:
77 continue
79 if rank_increase(cvisited[nc], nbrpos):
80 queue.append(nbrnode)
82 return visited, len(cvisited[start]) - 1
85def traverse_component_graphs(adjacency):
86 vertices = adjacency.keys()
87 all_visited = {}
88 ranks = {}
89 for v in vertices:
90 visited, rank = bfs(adjacency, v)
91 all_visited[v] = visited
92 ranks[v] = rank
94 return all_visited, ranks
97def build_adjacency_list(parents, bonds):
98 graph = np.unique(parents)
99 adjacency = {e: set() for e in graph}
100 for (i, j, offset) in bonds:
101 component_a = parents[i]
102 component_b = parents[j]
103 adjacency[component_a].add((component_b, offset))
104 return adjacency
107def get_dimensionality_histogram(ranks, roots):
108 h = [0, 0, 0, 0]
109 for e in roots:
110 h[ranks[e]] += 1
111 return tuple(h)
114def merge_mutual_visits(all_visited, ranks, graph):
115 """Find components with mutual visits and merge them."""
116 merged = False
117 common = defaultdict(list)
118 for b, visited in all_visited.items():
119 for offset in visited:
120 for a in common[offset]:
121 assert ranks[a] == ranks[b]
122 merged |= graph.union(a, b)
123 common[offset].append(b)
125 if not merged:
126 return merged, all_visited, ranks
128 merged_visits = defaultdict(set)
129 merged_ranks = {}
130 parents = graph.find_all()
131 for k, v in all_visited.items():
132 key = parents[k]
133 merged_visits[key].update(v)
134 merged_ranks[key] = ranks[key]
135 return merged, merged_visits, merged_ranks
138class RDA:
140 def __init__(self, num_atoms):
141 """
142 Initializes the RDA class.
144 A disjoint set is used to maintain the component graph.
146 Parameters:
148 num_atoms: int The number of atoms in the unit cell.
149 """
150 self.bonds = []
151 self.graph = DisjointSet(num_atoms)
152 self.adjacency = None
153 self.hcached = None
154 self.components_cached = None
155 self.cdim_cached = None
157 def insert_bond(self, i, j, offset):
158 """
159 Adds a bond to the list of graph edges.
161 Graph components are merged if the bond does not cross a cell boundary.
162 Bonds which cross cell boundaries can inappropriately connect
163 components which are not connected in the infinite crystal. This is
164 tested during graph traversal.
166 Parameters:
168 i: int The index of the first atom.
169 n: int The index of the second atom.
170 offset: tuple The cell offset of the second atom.
171 """
172 roffset = tuple(-np.array(offset))
174 if offset == (0, 0, 0): # only want bonds in aperiodic unit cell
175 self.graph.union(i, j)
176 else:
177 self.bonds += [(i, j, offset)]
178 self.bonds += [(j, i, roffset)]
180 def check(self):
181 """
182 Determines the dimensionality histogram.
184 The component graph is traversed (using BFS) until the matrix rank
185 of the subspace spanned by the visited components no longer increases.
187 Returns:
188 hist : tuple Dimensionality histogram.
189 """
190 adjacency = build_adjacency_list(self.graph.find_all(),
191 self.bonds)
192 if adjacency == self.adjacency:
193 return self.hcached
195 self.adjacency = adjacency
196 self.all_visited, self.ranks = traverse_component_graphs(adjacency)
197 res = merge_mutual_visits(self.all_visited, self.ranks, self.graph)
198 _, self.all_visited, self.ranks = res
200 self.roots = np.unique(self.graph.find_all())
201 h = get_dimensionality_histogram(self.ranks, self.roots)
202 self.hcached = h
203 return h
205 def get_components(self):
206 """
207 Determines the dimensionality and constituent atoms of each component.
209 Returns:
210 components: array The component ID of every atom
211 """
212 component_dim = {e: self.ranks[e] for e in self.roots}
213 relabelled_components = self.graph.find_all(relabel=True)
214 relabelled_dim = {}
215 for k, v in component_dim.items():
216 relabelled_dim[relabelled_components[k]] = v
217 self.cdim_cached = relabelled_dim
218 self.components_cached = relabelled_components
220 return relabelled_components, relabelled_dim