Coverage for /builds/kinetik161/ase/ase/geometry/dimensionality/rank_determination.py: 99.17%

120 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1""" 

2Implements the Rank Determination Algorithm (RDA) 

3 

4Method is described in: 

5Definition of a scoring parameter to identify low-dimensional materials 

6components 

7P.M. Larsen, M. Pandey, M. Strange, and K. W. Jacobsen 

8Phys. Rev. Materials 3 034003, 2019 

9https://doi.org/10.1103/PhysRevMaterials.3.034003 

10""" 

11from collections import defaultdict 

12 

13import numpy as np 

14 

15from ase.geometry.dimensionality.disjoint_set import DisjointSet 

16 

17# Numpy has a large overhead for lots of small vectors. The cross product is 

18# particularly bad. Pure python is a lot faster. 

19 

20 

21def dot_product(A, B): 

22 return sum([a * b for a, b in zip(A, B)]) 

23 

24 

25def cross_product(a, b): 

26 return [a[i] * b[j] - a[j] * b[i] for i, j in [(1, 2), (2, 0), (0, 1)]] 

27 

28 

29def subtract(A, B): 

30 return [a - b for a, b in zip(A, B)] 

31 

32 

33def rank_increase(a, b): 

34 if len(a) == 0: 

35 return True 

36 elif len(a) == 1: 

37 return a[0] != b 

38 elif len(a) == 4: 

39 return False 

40 

41 L = a + [b] 

42 w = cross_product(subtract(L[1], L[0]), subtract(L[2], L[0])) 

43 if len(a) == 2: 

44 return any(w) 

45 elif len(a) == 3: 

46 return dot_product(w, subtract(L[3], L[0])) != 0 

47 else: 

48 raise Exception("This shouldn't be possible.") 

49 

50 

51def bfs(adjacency, start): 

52 """Traverse the component graph using BFS. 

53 

54 The graph is traversed until the matrix rank of the subspace spanned by 

55 the visited components no longer increases. 

56 """ 

57 visited = set() 

58 cvisited = defaultdict(list) 

59 queue = [(start, (0, 0, 0))] 

60 while queue: 

61 vertex = queue.pop(0) 

62 if vertex in visited: 

63 continue 

64 

65 visited.add(vertex) 

66 c, p = vertex 

67 if not rank_increase(cvisited[c], p): 

68 continue 

69 

70 cvisited[c].append(p) 

71 

72 for nc, offset in adjacency[c]: 

73 

74 nbrpos = (p[0] + offset[0], p[1] + offset[1], p[2] + offset[2]) 

75 nbrnode = (nc, nbrpos) 

76 if nbrnode in visited: 

77 continue 

78 

79 if rank_increase(cvisited[nc], nbrpos): 

80 queue.append(nbrnode) 

81 

82 return visited, len(cvisited[start]) - 1 

83 

84 

85def traverse_component_graphs(adjacency): 

86 vertices = adjacency.keys() 

87 all_visited = {} 

88 ranks = {} 

89 for v in vertices: 

90 visited, rank = bfs(adjacency, v) 

91 all_visited[v] = visited 

92 ranks[v] = rank 

93 

94 return all_visited, ranks 

95 

96 

97def build_adjacency_list(parents, bonds): 

98 graph = np.unique(parents) 

99 adjacency = {e: set() for e in graph} 

100 for (i, j, offset) in bonds: 

101 component_a = parents[i] 

102 component_b = parents[j] 

103 adjacency[component_a].add((component_b, offset)) 

104 return adjacency 

105 

106 

107def get_dimensionality_histogram(ranks, roots): 

108 h = [0, 0, 0, 0] 

109 for e in roots: 

110 h[ranks[e]] += 1 

111 return tuple(h) 

112 

113 

114def merge_mutual_visits(all_visited, ranks, graph): 

115 """Find components with mutual visits and merge them.""" 

116 merged = False 

117 common = defaultdict(list) 

118 for b, visited in all_visited.items(): 

119 for offset in visited: 

120 for a in common[offset]: 

121 assert ranks[a] == ranks[b] 

122 merged |= graph.union(a, b) 

123 common[offset].append(b) 

124 

125 if not merged: 

126 return merged, all_visited, ranks 

127 

128 merged_visits = defaultdict(set) 

129 merged_ranks = {} 

130 parents = graph.find_all() 

131 for k, v in all_visited.items(): 

132 key = parents[k] 

133 merged_visits[key].update(v) 

134 merged_ranks[key] = ranks[key] 

135 return merged, merged_visits, merged_ranks 

136 

137 

138class RDA: 

139 

140 def __init__(self, num_atoms): 

141 """ 

142 Initializes the RDA class. 

143 

144 A disjoint set is used to maintain the component graph. 

145 

146 Parameters: 

147 

148 num_atoms: int The number of atoms in the unit cell. 

149 """ 

150 self.bonds = [] 

151 self.graph = DisjointSet(num_atoms) 

152 self.adjacency = None 

153 self.hcached = None 

154 self.components_cached = None 

155 self.cdim_cached = None 

156 

157 def insert_bond(self, i, j, offset): 

158 """ 

159 Adds a bond to the list of graph edges. 

160 

161 Graph components are merged if the bond does not cross a cell boundary. 

162 Bonds which cross cell boundaries can inappropriately connect 

163 components which are not connected in the infinite crystal. This is 

164 tested during graph traversal. 

165 

166 Parameters: 

167 

168 i: int The index of the first atom. 

169 n: int The index of the second atom. 

170 offset: tuple The cell offset of the second atom. 

171 """ 

172 roffset = tuple(-np.array(offset)) 

173 

174 if offset == (0, 0, 0): # only want bonds in aperiodic unit cell 

175 self.graph.union(i, j) 

176 else: 

177 self.bonds += [(i, j, offset)] 

178 self.bonds += [(j, i, roffset)] 

179 

180 def check(self): 

181 """ 

182 Determines the dimensionality histogram. 

183 

184 The component graph is traversed (using BFS) until the matrix rank 

185 of the subspace spanned by the visited components no longer increases. 

186 

187 Returns: 

188 hist : tuple Dimensionality histogram. 

189 """ 

190 adjacency = build_adjacency_list(self.graph.find_all(), 

191 self.bonds) 

192 if adjacency == self.adjacency: 

193 return self.hcached 

194 

195 self.adjacency = adjacency 

196 self.all_visited, self.ranks = traverse_component_graphs(adjacency) 

197 res = merge_mutual_visits(self.all_visited, self.ranks, self.graph) 

198 _, self.all_visited, self.ranks = res 

199 

200 self.roots = np.unique(self.graph.find_all()) 

201 h = get_dimensionality_histogram(self.ranks, self.roots) 

202 self.hcached = h 

203 return h 

204 

205 def get_components(self): 

206 """ 

207 Determines the dimensionality and constituent atoms of each component. 

208 

209 Returns: 

210 components: array The component ID of every atom 

211 """ 

212 component_dim = {e: self.ranks[e] for e in self.roots} 

213 relabelled_components = self.graph.find_all(relabel=True) 

214 relabelled_dim = {} 

215 for k, v in component_dim.items(): 

216 relabelled_dim[relabelled_components[k]] = v 

217 self.cdim_cached = relabelled_dim 

218 self.components_cached = relabelled_components 

219 

220 return relabelled_components, relabelled_dim