Coverage for /builds/kinetik161/ase/ase/db/cli.py: 62.30%
244 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import json
2import sys
3from collections import defaultdict
4from contextlib import contextmanager
5from pathlib import Path
6from typing import Iterable, Iterator
8import ase.io
9from ase.db import connect
10from ase.db.core import convert_str_to_int_float_or_str
11from ase.db.row import row2dct
12from ase.db.table import Table, all_columns
13from ase.utils import plural
16def count_keys(db, query):
17 keys = defaultdict(int)
18 for row in db.select(query):
19 for key in row._keys:
20 keys[key] += 1
22 n = max(len(key) for key in keys) + 1
23 for key, number in keys.items():
24 print('{:{}} {}'.format(key + ':', n, number))
25 return
28def main(args):
29 verbosity = 1 - args.quiet + args.verbose
30 query = ','.join(args.query)
32 if args.sort.endswith('-'):
33 # Allow using "key-" instead of "-key" for reverse sorting
34 args.sort = '-' + args.sort[:-1]
36 if query.isdigit():
37 query = int(query)
39 add_key_value_pairs = {}
40 if args.add_key_value_pairs:
41 for pair in args.add_key_value_pairs.split(','):
42 key, value = pair.split('=')
43 add_key_value_pairs[key] = convert_str_to_int_float_or_str(value)
45 if args.delete_keys:
46 delete_keys = args.delete_keys.split(',')
47 else:
48 delete_keys = []
50 db = connect(args.database, use_lock_file=not args.no_lock_file)
52 def out(*args):
53 if verbosity > 0:
54 print(*args)
56 if args.analyse:
57 db.analyse()
58 return
60 if args.show_keys:
61 count_keys(db, query)
62 return
64 if args.show_values:
65 keys = args.show_values.split(',')
66 values = {key: defaultdict(int) for key in keys}
67 numbers = set()
68 for row in db.select(query):
69 kvp = row.key_value_pairs
70 for key in keys:
71 value = kvp.get(key)
72 if value is not None:
73 values[key][value] += 1
74 if not isinstance(value, str):
75 numbers.add(key)
77 n = max(len(key) for key in keys) + 1
78 for key in keys:
79 vals = values[key]
80 if key in numbers:
81 print('{:{}} [{}..{}]'
82 .format(key + ':', n, min(vals), max(vals)))
83 else:
84 print('{:{}} {}'
85 .format(key + ':', n,
86 ', '.join(f'{v}({n})'
87 for v, n in vals.items())))
88 return
90 if args.add_from_file:
91 filename = args.add_from_file
92 configs = ase.io.read(filename)
93 if not isinstance(configs, list):
94 configs = [configs]
95 for atoms in configs:
96 db.write(atoms, key_value_pairs=add_key_value_pairs)
97 out('Added ' + plural(len(configs), 'row'))
98 return
100 if args.count:
101 n = db.count(query)
102 print(f'{plural(n, "row")}')
103 return
105 if args.insert_into:
106 if args.limit == -1:
107 args.limit = 0
109 progressbar = no_progressbar
110 length = None
112 if args.progress_bar:
113 # Try to import the one from click.
114 # People using ase.db will most likely have flask installed
115 # and therfore also click.
116 try:
117 from click import progressbar
118 except ImportError:
119 pass
120 else:
121 length = db.count(query)
123 nkvp = 0
124 nrows = 0
125 with connect(args.insert_into,
126 use_lock_file=not args.no_lock_file) as db2:
127 with progressbar(db.select(query,
128 sort=args.sort,
129 limit=args.limit,
130 offset=args.offset),
131 length=length) as rows:
132 for row in rows:
133 kvp = row.get('key_value_pairs', {})
134 nkvp -= len(kvp)
135 kvp.update(add_key_value_pairs)
136 nkvp += len(kvp)
137 if args.strip_data:
138 db2.write(row.toatoms(), **kvp)
139 else:
140 db2.write(row, data=row.get('data'), **kvp)
141 nrows += 1
143 out('Added %s (%s updated)' %
144 (plural(nkvp, 'key-value pair'),
145 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
146 out(f'Inserted {plural(nrows, "row")}')
147 return
149 if args.limit == -1:
150 args.limit = 20
152 if args.explain:
153 for row in db.select(query, explain=True,
154 verbosity=verbosity,
155 limit=args.limit, offset=args.offset):
156 print(row['explain'])
157 return
159 if args.show_metadata:
160 print(json.dumps(db.metadata, sort_keys=True, indent=4))
161 return
163 if args.set_metadata:
164 with open(args.set_metadata) as fd:
165 db.metadata = json.load(fd)
166 return
168 if add_key_value_pairs or delete_keys:
169 ids = [row['id'] for row in db.select(query)]
170 M = 0
171 N = 0
172 with db:
173 for id in ids:
174 m, n = db.update(id, delete_keys=delete_keys,
175 **add_key_value_pairs)
176 M += m
177 N += n
178 out('Added %s (%s updated)' %
179 (plural(M, 'key-value pair'),
180 plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
181 out('Removed', plural(N, 'key-value pair'))
183 return
185 if args.delete:
186 ids = [row['id'] for row in db.select(query, include_data=False)]
187 if ids and not args.yes:
188 msg = f'Delete {plural(len(ids), "row")}? (yes/No): '
189 if input(msg).lower() != 'yes':
190 return
191 db.delete(ids)
192 out(f'Deleted {plural(len(ids), "row")}')
193 return
195 if args.plot:
196 if ':' in args.plot:
197 tags, keys = args.plot.split(':')
198 tags = tags.split(',')
199 else:
200 tags = []
201 keys = args.plot
202 keys = keys.split(',')
203 plots = defaultdict(list)
204 X = {}
205 labels = []
206 for row in db.select(query, sort=args.sort, include_data=False):
207 name = ','.join(str(row[tag]) for tag in tags)
208 x = row.get(keys[0])
209 if x is not None:
210 if isinstance(x, str):
211 if x not in X:
212 X[x] = len(X)
213 labels.append(x)
214 x = X[x]
215 plots[name].append([x] + [row.get(key) for key in keys[1:]])
216 import matplotlib.pyplot as plt
217 for name, plot in plots.items():
218 xyy = list(zip(*plot))
219 x = xyy[0]
220 for y, key in zip(xyy[1:], keys[1:]):
221 plt.plot(x, y, label=name + ':' + key)
222 if X:
223 plt.xticks(range(len(labels)), labels, rotation=90)
224 plt.legend()
225 plt.show()
226 return
228 if args.json:
229 row = db.get(query)
230 db2 = connect(sys.stdout, 'json', use_lock_file=False)
231 kvp = row.get('key_value_pairs', {})
232 db2.write(row, data=row.get('data'), **kvp)
233 return
235 if args.long:
236 row = db.get(query)
237 print(row2str(row))
238 return
240 if args.open_web_browser:
241 try:
242 import flask # noqa
243 except ImportError:
244 print('Please install Flask: python3 -m pip install flask')
245 return
246 check_jsmol()
247 import ase.db.app as app
248 app.DBApp().run_db(db)
249 return
251 columns = list(all_columns)
252 c = args.columns
253 if c and c.startswith('++'):
254 keys = set()
255 for row in db.select(query,
256 limit=args.limit, offset=args.offset,
257 include_data=False):
258 keys.update(row._keys)
259 columns.extend(keys)
260 if c[2:3] == ',':
261 c = c[3:]
262 else:
263 c = ''
264 if c:
265 if c[0] == '+':
266 c = c[1:]
267 elif c[0] != '-':
268 columns = []
269 for col in c.split(','):
270 if col[0] == '-':
271 columns.remove(col[1:])
272 else:
273 columns.append(col.lstrip('+'))
275 table = Table(db, verbosity=verbosity, cut=args.cut)
276 table.select(query, columns, args.sort, args.limit, args.offset)
277 if args.csv:
278 table.write_csv()
279 else:
280 table.write(query)
283def row2str(row) -> str:
284 t = row2dct(row, key_descriptions={})
285 S = [t['formula'] + ':',
286 'Unit cell in Ang:',
287 'axis|periodic| x| y| z|' +
288 ' length| angle']
289 c = 1
290 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' +
291 '{3:>10}|{4:>10}')
292 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']):
293 S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A))
294 c += 1
295 S.append('')
297 if 'stress' in t:
298 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:',
299 ' {}\n'.format(t['stress'])]
301 if 'dipole' in t:
302 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole']))
304 if 'constraints' in t:
305 S.append('Constraints: {}\n'.format(t['constraints']))
307 if 'data' in t:
308 S.append('Data: {}\n'.format(t['data']))
310 width0 = max(max(len(row[0]) for row in t['table']), 3)
311 width1 = max(max(len(row[1]) for row in t['table']), 11)
312 S.append('{:{}} | {:{}} | Value'
313 .format('Key', width0, 'Description', width1))
314 for key, desc, value in t['table']:
315 S.append('{:{}} | {:{}} | {}'
316 .format(key, width0, desc, width1, value))
317 return '\n'.join(S)
320@contextmanager
321def no_progressbar(iterable: Iterable,
322 length: int = None) -> Iterator[Iterable]:
323 """A do-nothing implementation."""
324 yield iterable
327def check_jsmol():
328 static = Path(__file__).parent / 'static'
329 if not (static / 'jsmol/JSmol.min.js').is_file():
330 print(f"""
331 WARNING:
332 You don't have jsmol on your system.
334 Download Jmol-*-binary.tar.gz from
335 https://sourceforge.net/projects/jmol/files/Jmol/,
336 extract jsmol.zip, unzip it and create a soft-link:
338 $ tar -xf Jmol-*-binary.tar.gz
339 $ unzip jmol-*/jsmol.zip
340 $ ln -s $PWD/jsmol {static}/jsmol
341 """,
342 file=sys.stderr)