Coverage for /builds/kinetik161/ase/ase/cli/template.py: 94.00%

200 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import string 

2 

3import numpy as np 

4 

5from ase.data import chemical_symbols 

6from ase.io import string2index 

7from ase.io.formats import parse_filename 

8 

9# default fields 

10 

11 

12def field_specs_on_conditions(calculator_outputs, rank_order): 

13 if calculator_outputs: 

14 field_specs = ['i:0', 'el', 'd', 'rd', 'df', 'rdf'] 

15 else: 

16 field_specs = ['i:0', 'el', 'dx', 'dy', 'dz', 'd', 'rd'] 

17 if rank_order is not None: 

18 field_specs[0] = 'i:1' 

19 if rank_order in field_specs: 

20 for c, i in enumerate(field_specs): 

21 if i == rank_order: 

22 field_specs[c] = i + ':0:1' 

23 else: 

24 field_specs.append(rank_order + ':0:1') 

25 else: 

26 field_specs[0] = field_specs[0] + ':1' 

27 return field_specs 

28 

29 

30def summary_functions_on_conditions(has_calc): 

31 if has_calc: 

32 return [rmsd, energy_delta] 

33 return [rmsd] 

34 

35 

36def header_alias(h): 

37 """Replace keyboard characters with Unicode symbols 

38 for pretty printing""" 

39 if h == 'i': 

40 h = 'index' 

41 elif h == 'an': 

42 h = 'atomic #' 

43 elif h == 't': 

44 h = 'tag' 

45 elif h == 'el': 

46 h = 'element' 

47 elif h[0] == 'd': 

48 h = h.replace('d', 'Δ') 

49 elif h[0] == 'r': 

50 h = 'rank ' + header_alias(h[1:]) 

51 elif h[0] == 'a': 

52 h = h.replace('a', '<') 

53 h += '>' 

54 return h 

55 

56 

57def prec_round(a, prec=2): 

58 """ 

59 To make hierarchical sorting different from non-hierarchical sorting 

60 with floats. 

61 """ 

62 if a == 0: 

63 return a 

64 else: 

65 s = 1 if a > 0 else -1 

66 m = np.log10(s * a) // 1 

67 c = np.log10(s * a) % 1 

68 return s * np.round(10**c, prec) * 10**m 

69 

70 

71prec_round = np.vectorize(prec_round) 

72 

73# end most settings 

74 

75# this will sort alphabetically by chemical symbol 

76num2sym = dict(zip(np.argsort(chemical_symbols), chemical_symbols)) 

77# to sort by atomic number, uncomment below 

78# num2sym = dict(zip(range(len(chemical_symbols)), chemical_symbols)) 

79sym2num = {v: k for k, v in num2sym.items()} 

80 

81atoms_props = [ 

82 'dx', 

83 'dy', 

84 'dz', 

85 'd', 

86 't', 

87 'an', 

88 'i', 

89 'el', 

90 'p1', 

91 'p2', 

92 'p1x', 

93 'p1y', 

94 'p1z', 

95 'p2x', 

96 'p2y', 

97 'p2z'] 

98 

99 

100def get_field_data(atoms1, atoms2, field): 

101 if field[0] == 'r': 

102 field = field[1:] 

103 rank_order = True 

104 else: 

105 rank_order = False 

106 

107 if field in atoms_props: 

108 if field == 't': 

109 data = atoms1.get_tags() 

110 elif field == 'an': 

111 data = atoms1.numbers 

112 elif field == 'el': 

113 data = np.array([sym2num[sym] for sym in atoms1.symbols]) 

114 elif field == 'i': 

115 data = np.arange(len(atoms1)) 

116 else: 

117 if field.startswith('d'): 

118 y = atoms2.positions - atoms1.positions 

119 elif field.startswith('p'): 

120 if field[1] == '1': 

121 y = atoms1.positions 

122 else: 

123 y = atoms2.positions 

124 

125 if field.endswith('x'): 

126 data = y[:, 0] 

127 elif field.endswith('y'): 

128 data = y[:, 1] 

129 elif field.endswith('z'): 

130 data = y[:, 2] 

131 else: 

132 data = np.linalg.norm(y, axis=1) 

133 else: 

134 if field[0] == 'd': 

135 y = atoms2.get_forces() - atoms1.get_forces() 

136 elif field[0] == 'a': 

137 y = (atoms2.get_forces() + atoms1.get_forces()) / 2 

138 else: 

139 if field[1] == '1': 

140 y = atoms1.get_forces() 

141 else: 

142 y = atoms2.get_forces() 

143 

144 if field.endswith('x'): 

145 data = y[:, 0] 

146 elif field.endswith('y'): 

147 data = y[:, 1] 

148 elif field.endswith('z'): 

149 data = y[:, 2] 

150 else: 

151 data = np.linalg.norm(y, axis=1) 

152 

153 if rank_order: 

154 return np.argsort(np.argsort(-data)) 

155 

156 return data 

157 

158 

159# Summary Functions 

160 

161def rmsd(atoms1, atoms2): 

162 dpositions = atoms2.positions - atoms1.positions 

163 return 'RMSD={:+.1E}'.format( 

164 np.sqrt((np.linalg.norm(dpositions, axis=1)**2).mean())) 

165 

166 

167def energy_delta(atoms1, atoms2): 

168 E1 = atoms1.get_potential_energy() 

169 E2 = atoms2.get_potential_energy() 

170 return f'E1 = {E1:+.1E}, E2 = {E2:+.1E}, dE = {E2 - E1:+1.1E}' 

171 

172 

173def parse_field_specs(field_specs): 

174 fields = [] 

175 hier = [] 

176 scent = [] 

177 for fs in field_specs: 

178 fhs = fs.split(':') 

179 if len(fhs) == 3: 

180 scent.append(int(fhs[2])) 

181 hier.append(int(fhs[1])) 

182 fields.append(fhs[0]) 

183 elif len(fhs) == 2: 

184 scent.append(-1) 

185 hier.append(int(fhs[1])) 

186 fields.append(fhs[0]) 

187 elif len(fhs) == 1: 

188 scent.append(-1) 

189 hier.append(-1) 

190 fields.append(fhs[0]) 

191 mxm = max(hier) 

192 for c in range(len(hier)): 

193 if hier[c] < 0: 

194 mxm += 1 

195 hier[c] = mxm 

196 # reversed by convention of numpy lexsort 

197 hier = np.argsort(hier)[::-1] 

198 return fields, hier, np.array(scent) 

199 

200# Class definitions 

201 

202 

203class MapFormatter(string.Formatter): 

204 """String formatting method to map string 

205 mapped to float data field 

206 used for sorting back to string.""" 

207 

208 def format_field(self, value, spec): 

209 if spec.endswith('h'): 

210 value = num2sym[int(value)] 

211 spec = spec[:-1] + 's' 

212 return super().format_field(value, spec) 

213 

214 

215class TableFormat: 

216 def __init__(self, 

217 columnwidth=9, 

218 precision=2, 

219 representation='E', 

220 toprule='=', 

221 midrule='-', 

222 bottomrule='='): 

223 

224 self.precision = precision 

225 self.representation = representation 

226 self.columnwidth = columnwidth 

227 self.formatter = MapFormatter().format 

228 self.toprule = toprule 

229 self.midrule = midrule 

230 self.bottomrule = bottomrule 

231 

232 self.fmt_class = { 

233 'signed float': "{{: ^{}.{}{}}}".format( 

234 self.columnwidth, 

235 self.precision - 1, 

236 self.representation), 

237 'unsigned float': "{{:^{}.{}{}}}".format( 

238 self.columnwidth, 

239 self.precision - 1, 

240 self.representation), 

241 'int': "{{:^{}n}}".format( 

242 self.columnwidth), 

243 'str': "{{:^{}s}}".format( 

244 self.columnwidth), 

245 'conv': "{{:^{}h}}".format( 

246 self.columnwidth)} 

247 fmt = {} 

248 signed_floats = [ 

249 'dx', 

250 'dy', 

251 'dz', 

252 'dfx', 

253 'dfy', 

254 'dfz', 

255 'afx', 

256 'afy', 

257 'afz', 

258 'p1x', 

259 'p2x', 

260 'p1y', 

261 'p2y', 

262 'p1z', 

263 'p2z', 

264 'f1x', 

265 'f2x', 

266 'f1y', 

267 'f2y', 

268 'f1z', 

269 'f2z'] 

270 for sf in signed_floats: 

271 fmt[sf] = self.fmt_class['signed float'] 

272 unsigned_floats = ['d', 'df', 'af', 'p1', 'p2', 'f1', 'f2'] 

273 for usf in unsigned_floats: 

274 fmt[usf] = self.fmt_class['unsigned float'] 

275 integers = ['i', 'an', 't'] + ['r' + sf for sf in signed_floats] + \ 

276 ['r' + usf for usf in unsigned_floats] 

277 for i in integers: 

278 fmt[i] = self.fmt_class['int'] 

279 fmt['el'] = self.fmt_class['conv'] 

280 

281 self.fmt = fmt 

282 

283 

284class Table: 

285 def __init__(self, 

286 field_specs, 

287 summary_functions=[], 

288 tableformat=None, 

289 max_lines=None, 

290 title='', 

291 tablewidth=None): 

292 

293 self.max_lines = max_lines 

294 self.summary_functions = summary_functions 

295 self.field_specs = field_specs 

296 

297 self.fields, self.hier, self.scent = parse_field_specs(self.field_specs) 

298 self.nfields = len(self.fields) 

299 

300 # formatting 

301 if tableformat is None: 

302 self.tableformat = TableFormat() 

303 else: 

304 self.tableformat = tableformat 

305 

306 if tablewidth is None: 

307 self.tablewidth = self.tableformat.columnwidth * self.nfields 

308 else: 

309 self.tablewidth = tablewidth 

310 

311 self.title = title 

312 

313 def make(self, atoms1, atoms2, csv=False): 

314 header = self.make_header(csv=csv) 

315 body = self.make_body(atoms1, atoms2, csv=csv) 

316 if self.max_lines is not None: 

317 body = body[:self.max_lines] 

318 summary = self.make_summary(atoms1, atoms2) 

319 

320 return '\n'.join([self.title, 

321 self.tableformat.toprule * self.tablewidth, 

322 header, 

323 self.tableformat.midrule * self.tablewidth, 

324 body, 

325 self.tableformat.bottomrule * self.tablewidth, 

326 summary]) 

327 

328 def make_header(self, csv=False): 

329 if csv: 

330 return ','.join([header_alias(field) for field in self.fields]) 

331 

332 fields = self.tableformat.fmt_class['str'] * self.nfields 

333 headers = [header_alias(field) for field in self.fields] 

334 

335 return self.tableformat.formatter(fields, *headers) 

336 

337 def make_summary(self, atoms1, atoms2): 

338 return '\n'.join([summary_function(atoms1, atoms2) 

339 for summary_function in self.summary_functions]) 

340 

341 def make_body(self, atoms1, atoms2, csv=False): 

342 field_data = np.array([get_field_data(atoms1, atoms2, field) 

343 for field in self.fields]) 

344 

345 sorting_array = field_data * self.scent[:, np.newaxis] 

346 sorting_array = sorting_array[self.hier] 

347 sorting_array = prec_round(sorting_array, self.tableformat.precision) 

348 

349 field_data = field_data[:, np.lexsort(sorting_array)].transpose() 

350 

351 if csv: 

352 rowformat = ','.join( 

353 ['{:h}' if field == 'el' else '{{:.{}E}}'.format( 

354 self.tableformat.precision) for field in self.fields]) 

355 else: 

356 rowformat = ''.join([self.tableformat.fmt[field] 

357 for field in self.fields]) 

358 body = [ 

359 self.tableformat.formatter( 

360 rowformat, 

361 *row) for row in field_data] 

362 return '\n'.join(body) 

363 

364 

365default_index = string2index(':') 

366 

367 

368def slice_split(filename): 

369 if '@' in filename: 

370 filename, index = parse_filename(filename, None) 

371 else: 

372 filename, index = parse_filename(filename, default_index) 

373 return filename, index