Coverage for /builds/kinetik161/ase/ase/cli/template.py: 94.00%
200 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import string
3import numpy as np
5from ase.data import chemical_symbols
6from ase.io import string2index
7from ase.io.formats import parse_filename
9# default fields
12def field_specs_on_conditions(calculator_outputs, rank_order):
13 if calculator_outputs:
14 field_specs = ['i:0', 'el', 'd', 'rd', 'df', 'rdf']
15 else:
16 field_specs = ['i:0', 'el', 'dx', 'dy', 'dz', 'd', 'rd']
17 if rank_order is not None:
18 field_specs[0] = 'i:1'
19 if rank_order in field_specs:
20 for c, i in enumerate(field_specs):
21 if i == rank_order:
22 field_specs[c] = i + ':0:1'
23 else:
24 field_specs.append(rank_order + ':0:1')
25 else:
26 field_specs[0] = field_specs[0] + ':1'
27 return field_specs
30def summary_functions_on_conditions(has_calc):
31 if has_calc:
32 return [rmsd, energy_delta]
33 return [rmsd]
36def header_alias(h):
37 """Replace keyboard characters with Unicode symbols
38 for pretty printing"""
39 if h == 'i':
40 h = 'index'
41 elif h == 'an':
42 h = 'atomic #'
43 elif h == 't':
44 h = 'tag'
45 elif h == 'el':
46 h = 'element'
47 elif h[0] == 'd':
48 h = h.replace('d', 'Δ')
49 elif h[0] == 'r':
50 h = 'rank ' + header_alias(h[1:])
51 elif h[0] == 'a':
52 h = h.replace('a', '<')
53 h += '>'
54 return h
57def prec_round(a, prec=2):
58 """
59 To make hierarchical sorting different from non-hierarchical sorting
60 with floats.
61 """
62 if a == 0:
63 return a
64 else:
65 s = 1 if a > 0 else -1
66 m = np.log10(s * a) // 1
67 c = np.log10(s * a) % 1
68 return s * np.round(10**c, prec) * 10**m
71prec_round = np.vectorize(prec_round)
73# end most settings
75# this will sort alphabetically by chemical symbol
76num2sym = dict(zip(np.argsort(chemical_symbols), chemical_symbols))
77# to sort by atomic number, uncomment below
78# num2sym = dict(zip(range(len(chemical_symbols)), chemical_symbols))
79sym2num = {v: k for k, v in num2sym.items()}
81atoms_props = [
82 'dx',
83 'dy',
84 'dz',
85 'd',
86 't',
87 'an',
88 'i',
89 'el',
90 'p1',
91 'p2',
92 'p1x',
93 'p1y',
94 'p1z',
95 'p2x',
96 'p2y',
97 'p2z']
100def get_field_data(atoms1, atoms2, field):
101 if field[0] == 'r':
102 field = field[1:]
103 rank_order = True
104 else:
105 rank_order = False
107 if field in atoms_props:
108 if field == 't':
109 data = atoms1.get_tags()
110 elif field == 'an':
111 data = atoms1.numbers
112 elif field == 'el':
113 data = np.array([sym2num[sym] for sym in atoms1.symbols])
114 elif field == 'i':
115 data = np.arange(len(atoms1))
116 else:
117 if field.startswith('d'):
118 y = atoms2.positions - atoms1.positions
119 elif field.startswith('p'):
120 if field[1] == '1':
121 y = atoms1.positions
122 else:
123 y = atoms2.positions
125 if field.endswith('x'):
126 data = y[:, 0]
127 elif field.endswith('y'):
128 data = y[:, 1]
129 elif field.endswith('z'):
130 data = y[:, 2]
131 else:
132 data = np.linalg.norm(y, axis=1)
133 else:
134 if field[0] == 'd':
135 y = atoms2.get_forces() - atoms1.get_forces()
136 elif field[0] == 'a':
137 y = (atoms2.get_forces() + atoms1.get_forces()) / 2
138 else:
139 if field[1] == '1':
140 y = atoms1.get_forces()
141 else:
142 y = atoms2.get_forces()
144 if field.endswith('x'):
145 data = y[:, 0]
146 elif field.endswith('y'):
147 data = y[:, 1]
148 elif field.endswith('z'):
149 data = y[:, 2]
150 else:
151 data = np.linalg.norm(y, axis=1)
153 if rank_order:
154 return np.argsort(np.argsort(-data))
156 return data
159# Summary Functions
161def rmsd(atoms1, atoms2):
162 dpositions = atoms2.positions - atoms1.positions
163 return 'RMSD={:+.1E}'.format(
164 np.sqrt((np.linalg.norm(dpositions, axis=1)**2).mean()))
167def energy_delta(atoms1, atoms2):
168 E1 = atoms1.get_potential_energy()
169 E2 = atoms2.get_potential_energy()
170 return f'E1 = {E1:+.1E}, E2 = {E2:+.1E}, dE = {E2 - E1:+1.1E}'
173def parse_field_specs(field_specs):
174 fields = []
175 hier = []
176 scent = []
177 for fs in field_specs:
178 fhs = fs.split(':')
179 if len(fhs) == 3:
180 scent.append(int(fhs[2]))
181 hier.append(int(fhs[1]))
182 fields.append(fhs[0])
183 elif len(fhs) == 2:
184 scent.append(-1)
185 hier.append(int(fhs[1]))
186 fields.append(fhs[0])
187 elif len(fhs) == 1:
188 scent.append(-1)
189 hier.append(-1)
190 fields.append(fhs[0])
191 mxm = max(hier)
192 for c in range(len(hier)):
193 if hier[c] < 0:
194 mxm += 1
195 hier[c] = mxm
196 # reversed by convention of numpy lexsort
197 hier = np.argsort(hier)[::-1]
198 return fields, hier, np.array(scent)
200# Class definitions
203class MapFormatter(string.Formatter):
204 """String formatting method to map string
205 mapped to float data field
206 used for sorting back to string."""
208 def format_field(self, value, spec):
209 if spec.endswith('h'):
210 value = num2sym[int(value)]
211 spec = spec[:-1] + 's'
212 return super().format_field(value, spec)
215class TableFormat:
216 def __init__(self,
217 columnwidth=9,
218 precision=2,
219 representation='E',
220 toprule='=',
221 midrule='-',
222 bottomrule='='):
224 self.precision = precision
225 self.representation = representation
226 self.columnwidth = columnwidth
227 self.formatter = MapFormatter().format
228 self.toprule = toprule
229 self.midrule = midrule
230 self.bottomrule = bottomrule
232 self.fmt_class = {
233 'signed float': "{{: ^{}.{}{}}}".format(
234 self.columnwidth,
235 self.precision - 1,
236 self.representation),
237 'unsigned float': "{{:^{}.{}{}}}".format(
238 self.columnwidth,
239 self.precision - 1,
240 self.representation),
241 'int': "{{:^{}n}}".format(
242 self.columnwidth),
243 'str': "{{:^{}s}}".format(
244 self.columnwidth),
245 'conv': "{{:^{}h}}".format(
246 self.columnwidth)}
247 fmt = {}
248 signed_floats = [
249 'dx',
250 'dy',
251 'dz',
252 'dfx',
253 'dfy',
254 'dfz',
255 'afx',
256 'afy',
257 'afz',
258 'p1x',
259 'p2x',
260 'p1y',
261 'p2y',
262 'p1z',
263 'p2z',
264 'f1x',
265 'f2x',
266 'f1y',
267 'f2y',
268 'f1z',
269 'f2z']
270 for sf in signed_floats:
271 fmt[sf] = self.fmt_class['signed float']
272 unsigned_floats = ['d', 'df', 'af', 'p1', 'p2', 'f1', 'f2']
273 for usf in unsigned_floats:
274 fmt[usf] = self.fmt_class['unsigned float']
275 integers = ['i', 'an', 't'] + ['r' + sf for sf in signed_floats] + \
276 ['r' + usf for usf in unsigned_floats]
277 for i in integers:
278 fmt[i] = self.fmt_class['int']
279 fmt['el'] = self.fmt_class['conv']
281 self.fmt = fmt
284class Table:
285 def __init__(self,
286 field_specs,
287 summary_functions=[],
288 tableformat=None,
289 max_lines=None,
290 title='',
291 tablewidth=None):
293 self.max_lines = max_lines
294 self.summary_functions = summary_functions
295 self.field_specs = field_specs
297 self.fields, self.hier, self.scent = parse_field_specs(self.field_specs)
298 self.nfields = len(self.fields)
300 # formatting
301 if tableformat is None:
302 self.tableformat = TableFormat()
303 else:
304 self.tableformat = tableformat
306 if tablewidth is None:
307 self.tablewidth = self.tableformat.columnwidth * self.nfields
308 else:
309 self.tablewidth = tablewidth
311 self.title = title
313 def make(self, atoms1, atoms2, csv=False):
314 header = self.make_header(csv=csv)
315 body = self.make_body(atoms1, atoms2, csv=csv)
316 if self.max_lines is not None:
317 body = body[:self.max_lines]
318 summary = self.make_summary(atoms1, atoms2)
320 return '\n'.join([self.title,
321 self.tableformat.toprule * self.tablewidth,
322 header,
323 self.tableformat.midrule * self.tablewidth,
324 body,
325 self.tableformat.bottomrule * self.tablewidth,
326 summary])
328 def make_header(self, csv=False):
329 if csv:
330 return ','.join([header_alias(field) for field in self.fields])
332 fields = self.tableformat.fmt_class['str'] * self.nfields
333 headers = [header_alias(field) for field in self.fields]
335 return self.tableformat.formatter(fields, *headers)
337 def make_summary(self, atoms1, atoms2):
338 return '\n'.join([summary_function(atoms1, atoms2)
339 for summary_function in self.summary_functions])
341 def make_body(self, atoms1, atoms2, csv=False):
342 field_data = np.array([get_field_data(atoms1, atoms2, field)
343 for field in self.fields])
345 sorting_array = field_data * self.scent[:, np.newaxis]
346 sorting_array = sorting_array[self.hier]
347 sorting_array = prec_round(sorting_array, self.tableformat.precision)
349 field_data = field_data[:, np.lexsort(sorting_array)].transpose()
351 if csv:
352 rowformat = ','.join(
353 ['{:h}' if field == 'el' else '{{:.{}E}}'.format(
354 self.tableformat.precision) for field in self.fields])
355 else:
356 rowformat = ''.join([self.tableformat.fmt[field]
357 for field in self.fields])
358 body = [
359 self.tableformat.formatter(
360 rowformat,
361 *row) for row in field_data]
362 return '\n'.join(body)
365default_index = string2index(':')
368def slice_split(filename):
369 if '@' in filename:
370 filename, index = parse_filename(filename, None)
371 else:
372 filename, index = parse_filename(filename, default_index)
373 return filename, index