Coverage for /builds/kinetik161/ase/ase/io/utils.py: 92.67%

191 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1from itertools import islice 

2from math import sqrt 

3from typing import IO 

4 

5import numpy as np 

6 

7from ase.data import atomic_numbers, covalent_radii 

8from ase.data.colors import jmol_colors 

9from ase.io.formats import string2index 

10from ase.utils import rotate 

11 

12 

13class PlottingVariables: 

14 # removed writer - self 

15 def __init__(self, atoms, rotation='', show_unit_cell=2, 

16 radii=None, bbox=None, colors=None, scale=20, 

17 maxwidth=500, extra_offset=(0., 0.)): 

18 self.numbers = atoms.get_atomic_numbers() 

19 self.colors = colors 

20 if colors is None: 

21 ncolors = len(jmol_colors) 

22 self.colors = jmol_colors[self.numbers.clip(max=ncolors - 1)] 

23 

24 if radii is None: 

25 radii = covalent_radii[self.numbers] 

26 elif isinstance(radii, float): 

27 radii = covalent_radii[self.numbers] * radii 

28 else: 

29 radii = np.array(radii) 

30 

31 natoms = len(atoms) 

32 

33 if isinstance(rotation, str): 

34 rotation = rotate(rotation) 

35 

36 cell = atoms.get_cell() 

37 disp = atoms.get_celldisp().flatten() 

38 

39 if show_unit_cell > 0: 

40 L, T, D = cell_to_lines(self, cell) 

41 cell_vertices = np.empty((2, 2, 2, 3)) 

42 for c1 in range(2): 

43 for c2 in range(2): 

44 for c3 in range(2): 

45 cell_vertices[c1, c2, c3] = np.dot([c1, c2, c3], 

46 cell) + disp 

47 cell_vertices.shape = (8, 3) 

48 cell_vertices = np.dot(cell_vertices, rotation) 

49 else: 

50 L = np.empty((0, 3)) 

51 T = None 

52 D = None 

53 cell_vertices = None 

54 

55 nlines = len(L) 

56 

57 positions = np.empty((natoms + nlines, 3)) 

58 R = atoms.get_positions() 

59 positions[:natoms] = R 

60 positions[natoms:] = L 

61 

62 r2 = radii**2 

63 for n in range(nlines): 

64 d = D[T[n]] 

65 if ((((R - L[n] - d)**2).sum(1) < r2) & 

66 (((R - L[n] + d)**2).sum(1) < r2)).any(): 

67 T[n] = -1 

68 

69 positions = np.dot(positions, rotation) 

70 R = positions[:natoms] 

71 

72 if bbox is None: 

73 X1 = (R - radii[:, None]).min(0) 

74 X2 = (R + radii[:, None]).max(0) 

75 if show_unit_cell == 2: 

76 X1 = np.minimum(X1, cell_vertices.min(0)) 

77 X2 = np.maximum(X2, cell_vertices.max(0)) 

78 M = (X1 + X2) / 2 

79 S = 1.05 * (X2 - X1) 

80 w = scale * S[0] 

81 if w > maxwidth: 

82 w = maxwidth 

83 scale = w / S[0] 

84 h = scale * S[1] 

85 offset = np.array([scale * M[0] - w / 2, scale * M[1] - h / 2, 0]) 

86 else: 

87 w = (bbox[2] - bbox[0]) * scale 

88 h = (bbox[3] - bbox[1]) * scale 

89 offset = np.array([bbox[0], bbox[1], 0]) * scale 

90 

91 offset[0] = offset[0] - extra_offset[0] 

92 offset[1] = offset[1] - extra_offset[1] 

93 self.w = w + extra_offset[0] 

94 self.h = h + extra_offset[1] 

95 

96 positions *= scale 

97 positions -= offset 

98 

99 if nlines > 0: 

100 D = np.dot(D, rotation)[:, :2] * scale 

101 

102 if cell_vertices is not None: 

103 cell_vertices *= scale 

104 cell_vertices -= offset 

105 

106 cell = np.dot(cell, rotation) 

107 cell *= scale 

108 

109 self.cell = cell 

110 self.positions = positions 

111 self.D = D 

112 self.T = T 

113 self.cell_vertices = cell_vertices 

114 self.natoms = natoms 

115 self.d = 2 * scale * radii 

116 self.constraints = atoms.constraints 

117 

118 # extension for partial occupancies 

119 self.frac_occ = False 

120 self.tags = None 

121 self.occs = None 

122 

123 try: 

124 self.occs = atoms.info['occupancy'] 

125 self.tags = atoms.get_tags() 

126 self.frac_occ = True 

127 except KeyError: 

128 pass 

129 

130 

131def cell_to_lines(writer, cell): 

132 # XXX this needs to be updated for cell vectors that are zero. 

133 # Cannot read the code though! (What are T and D? nn?) 

134 nlines = 0 

135 nsegments = [] 

136 for c in range(3): 

137 d = sqrt((cell[c]**2).sum()) 

138 n = max(2, int(d / 0.3)) 

139 nsegments.append(n) 

140 nlines += 4 * n 

141 

142 positions = np.empty((nlines, 3)) 

143 T = np.empty(nlines, int) 

144 D = np.zeros((3, 3)) 

145 

146 n1 = 0 

147 for c in range(3): 

148 n = nsegments[c] 

149 dd = cell[c] / (4 * n - 2) 

150 D[c] = dd 

151 P = np.arange(1, 4 * n + 1, 4)[:, None] * dd 

152 T[n1:] = c 

153 for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]: 

154 n2 = n1 + n 

155 positions[n1:n2] = P + i * cell[c - 2] + j * cell[c - 1] 

156 n1 = n2 

157 

158 return positions, T, D 

159 

160 

161def make_patch_list(writer): 

162 from matplotlib.patches import Circle, PathPatch, Wedge 

163 from matplotlib.path import Path 

164 

165 indices = writer.positions[:, 2].argsort() 

166 patch_list = [] 

167 for a in indices: 

168 xy = writer.positions[a, :2] 

169 if a < writer.natoms: 

170 r = writer.d[a] / 2 

171 if writer.frac_occ: 

172 site_occ = writer.occs[str(writer.tags[a])] 

173 # first an empty circle if a site is not fully occupied 

174 if (np.sum([v for v in site_occ.values()])) < 1.0: 

175 # fill with white 

176 fill = '#ffffff' 

177 patch = Circle(xy, r, facecolor=fill, 

178 edgecolor='black') 

179 patch_list.append(patch) 

180 

181 start = 0 

182 # start with the dominant species 

183 for sym, occ in sorted(site_occ.items(), 

184 key=lambda x: x[1], 

185 reverse=True): 

186 if np.round(occ, decimals=4) == 1.0: 

187 patch = Circle(xy, r, facecolor=writer.colors[a], 

188 edgecolor='black') 

189 patch_list.append(patch) 

190 else: 

191 # jmol colors for the moment 

192 extent = 360. * occ 

193 patch = Wedge( 

194 xy, r, start, start + extent, 

195 facecolor=jmol_colors[atomic_numbers[sym]], 

196 edgecolor='black') 

197 patch_list.append(patch) 

198 start += extent 

199 

200 else: 

201 if ((xy[1] + r > 0) and (xy[1] - r < writer.h) and 

202 (xy[0] + r > 0) and (xy[0] - r < writer.w)): 

203 patch = Circle(xy, r, facecolor=writer.colors[a], 

204 edgecolor='black') 

205 patch_list.append(patch) 

206 else: 

207 a -= writer.natoms 

208 c = writer.T[a] 

209 if c != -1: 

210 hxy = writer.D[c] 

211 patch = PathPatch(Path((xy + hxy, xy - hxy))) 

212 patch_list.append(patch) 

213 return patch_list 

214 

215 

216class ImageChunk: 

217 """Base Class for a file chunk which contains enough information to 

218 reconstruct an atoms object.""" 

219 

220 def build(self, **kwargs): 

221 """Construct the atoms object from the stored information, 

222 and return it""" 

223 

224 

225class ImageIterator: 

226 """Iterate over chunks, to return the corresponding Atoms objects. 

227 Will only build the atoms objects which corresponds to the requested 

228 indices when called. 

229 Assumes ``ichunks`` is in iterator, which returns ``ImageChunk`` 

230 type objects. See extxyz.py:iread_xyz as an example. 

231 """ 

232 

233 def __init__(self, ichunks): 

234 self.ichunks = ichunks 

235 

236 def __call__(self, fd: IO, index=None, **kwargs): 

237 if isinstance(index, str): 

238 index = string2index(index) 

239 

240 if index is None or index == ':': 

241 index = slice(None, None, None) 

242 

243 if not isinstance(index, (slice, str)): 

244 index = slice(index, (index + 1) or None) 

245 

246 for chunk in self._getslice(fd, index): 

247 yield chunk.build(**kwargs) 

248 

249 def _getslice(self, fd: IO, indices: slice): 

250 try: 

251 iterator = islice(self.ichunks(fd), 

252 indices.start, indices.stop, 

253 indices.step) 

254 except ValueError: 

255 # Negative indices. Go through the whole thing to get the length, 

256 # which allows us to evaluate the slice, and then read it again 

257 if not hasattr(fd, 'seekable') or not fd.seekable(): 

258 raise ValueError('Negative indices only supported for ' 

259 'seekable streams') 

260 

261 startpos = fd.tell() 

262 nchunks = 0 

263 for _ in self.ichunks(fd): 

264 nchunks += 1 

265 fd.seek(startpos) 

266 indices_tuple = indices.indices(nchunks) 

267 iterator = islice(self.ichunks(fd), *indices_tuple) 

268 return iterator 

269 

270 

271def verify_cell_for_export(cell, check_orthorhombric=True): 

272 """Function to verify if the cell size is defined and if the cell is 

273 

274 Parameters: 

275 

276 cell: cell object 

277 cell to be checked. 

278 

279 check_orthorhombric: bool 

280 If True, check if the cell is orthorhombric, raise an ``ValueError`` if 

281 the cell is orthorhombric. If False, doesn't check if the cell is 

282 orthorhombric. 

283 

284 Raise a ``ValueError`` if the cell if not suitable for export to mustem xtl 

285 file or prismatic/computem xyz format: 

286 - if cell is not orthorhombic (only when check_orthorhombric=True) 

287 - if cell size is not defined 

288 """ 

289 

290 if check_orthorhombric and not cell.orthorhombic: 

291 raise ValueError('To export to this format, the cell needs to be ' 

292 'orthorhombic.') 

293 if cell.rank < 3: 

294 raise ValueError('To export to this format, the cell size needs ' 

295 'to be set: current cell is {}.'.format(cell)) 

296 

297 

298def verify_dictionary(atoms, dictionary, dictionary_name): 

299 """ 

300 Verify a dictionary have a key for each symbol present in the atoms object. 

301 

302 Parameters: 

303 

304 dictionary: dict 

305 Dictionary to be checked. 

306 

307 

308 dictionary_name: dict 

309 Name of the dictionary to be displayed in the error message. 

310 

311 cell: cell object 

312 cell to be checked. 

313 

314 

315 Raise a ``ValueError`` if the key doesn't match the atoms present in the 

316 cell. 

317 """ 

318 # Check if we have enough key 

319 for key in set(atoms.symbols): 

320 if key not in dictionary: 

321 raise ValueError('Missing the {} key in the `{}` dictionary.' 

322 ''.format(key, dictionary_name))