Coverage for /builds/kinetik161/ase/ase/db/cli.py: 62.30%

244 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import json 

2import sys 

3from collections import defaultdict 

4from contextlib import contextmanager 

5from pathlib import Path 

6from typing import Iterable, Iterator 

7 

8import ase.io 

9from ase.db import connect 

10from ase.db.core import convert_str_to_int_float_or_str 

11from ase.db.row import row2dct 

12from ase.db.table import Table, all_columns 

13from ase.utils import plural 

14 

15 

16def count_keys(db, query): 

17 keys = defaultdict(int) 

18 for row in db.select(query): 

19 for key in row._keys: 

20 keys[key] += 1 

21 

22 n = max(len(key) for key in keys) + 1 

23 for key, number in keys.items(): 

24 print('{:{}} {}'.format(key + ':', n, number)) 

25 return 

26 

27 

28def main(args): 

29 verbosity = 1 - args.quiet + args.verbose 

30 query = ','.join(args.query) 

31 

32 if args.sort.endswith('-'): 

33 # Allow using "key-" instead of "-key" for reverse sorting 

34 args.sort = '-' + args.sort[:-1] 

35 

36 if query.isdigit(): 

37 query = int(query) 

38 

39 add_key_value_pairs = {} 

40 if args.add_key_value_pairs: 

41 for pair in args.add_key_value_pairs.split(','): 

42 key, value = pair.split('=') 

43 add_key_value_pairs[key] = convert_str_to_int_float_or_str(value) 

44 

45 if args.delete_keys: 

46 delete_keys = args.delete_keys.split(',') 

47 else: 

48 delete_keys = [] 

49 

50 db = connect(args.database, use_lock_file=not args.no_lock_file) 

51 

52 def out(*args): 

53 if verbosity > 0: 

54 print(*args) 

55 

56 if args.analyse: 

57 db.analyse() 

58 return 

59 

60 if args.show_keys: 

61 count_keys(db, query) 

62 return 

63 

64 if args.show_values: 

65 keys = args.show_values.split(',') 

66 values = {key: defaultdict(int) for key in keys} 

67 numbers = set() 

68 for row in db.select(query): 

69 kvp = row.key_value_pairs 

70 for key in keys: 

71 value = kvp.get(key) 

72 if value is not None: 

73 values[key][value] += 1 

74 if not isinstance(value, str): 

75 numbers.add(key) 

76 

77 n = max(len(key) for key in keys) + 1 

78 for key in keys: 

79 vals = values[key] 

80 if key in numbers: 

81 print('{:{}} [{}..{}]' 

82 .format(key + ':', n, min(vals), max(vals))) 

83 else: 

84 print('{:{}} {}' 

85 .format(key + ':', n, 

86 ', '.join(f'{v}({n})' 

87 for v, n in vals.items()))) 

88 return 

89 

90 if args.add_from_file: 

91 filename = args.add_from_file 

92 configs = ase.io.read(filename) 

93 if not isinstance(configs, list): 

94 configs = [configs] 

95 for atoms in configs: 

96 db.write(atoms, key_value_pairs=add_key_value_pairs) 

97 out('Added ' + plural(len(configs), 'row')) 

98 return 

99 

100 if args.count: 

101 n = db.count(query) 

102 print(f'{plural(n, "row")}') 

103 return 

104 

105 if args.insert_into: 

106 if args.limit == -1: 

107 args.limit = 0 

108 

109 progressbar = no_progressbar 

110 length = None 

111 

112 if args.progress_bar: 

113 # Try to import the one from click. 

114 # People using ase.db will most likely have flask installed 

115 # and therfore also click. 

116 try: 

117 from click import progressbar 

118 except ImportError: 

119 pass 

120 else: 

121 length = db.count(query) 

122 

123 nkvp = 0 

124 nrows = 0 

125 with connect(args.insert_into, 

126 use_lock_file=not args.no_lock_file) as db2: 

127 with progressbar(db.select(query, 

128 sort=args.sort, 

129 limit=args.limit, 

130 offset=args.offset), 

131 length=length) as rows: 

132 for row in rows: 

133 kvp = row.get('key_value_pairs', {}) 

134 nkvp -= len(kvp) 

135 kvp.update(add_key_value_pairs) 

136 nkvp += len(kvp) 

137 if args.strip_data: 

138 db2.write(row.toatoms(), **kvp) 

139 else: 

140 db2.write(row, data=row.get('data'), **kvp) 

141 nrows += 1 

142 

143 out('Added %s (%s updated)' % 

144 (plural(nkvp, 'key-value pair'), 

145 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair'))) 

146 out(f'Inserted {plural(nrows, "row")}') 

147 return 

148 

149 if args.limit == -1: 

150 args.limit = 20 

151 

152 if args.explain: 

153 for row in db.select(query, explain=True, 

154 verbosity=verbosity, 

155 limit=args.limit, offset=args.offset): 

156 print(row['explain']) 

157 return 

158 

159 if args.show_metadata: 

160 print(json.dumps(db.metadata, sort_keys=True, indent=4)) 

161 return 

162 

163 if args.set_metadata: 

164 with open(args.set_metadata) as fd: 

165 db.metadata = json.load(fd) 

166 return 

167 

168 if add_key_value_pairs or delete_keys: 

169 ids = [row['id'] for row in db.select(query)] 

170 M = 0 

171 N = 0 

172 with db: 

173 for id in ids: 

174 m, n = db.update(id, delete_keys=delete_keys, 

175 **add_key_value_pairs) 

176 M += m 

177 N += n 

178 out('Added %s (%s updated)' % 

179 (plural(M, 'key-value pair'), 

180 plural(len(add_key_value_pairs) * len(ids) - M, 'pair'))) 

181 out('Removed', plural(N, 'key-value pair')) 

182 

183 return 

184 

185 if args.delete: 

186 ids = [row['id'] for row in db.select(query, include_data=False)] 

187 if ids and not args.yes: 

188 msg = f'Delete {plural(len(ids), "row")}? (yes/No): ' 

189 if input(msg).lower() != 'yes': 

190 return 

191 db.delete(ids) 

192 out(f'Deleted {plural(len(ids), "row")}') 

193 return 

194 

195 if args.plot: 

196 if ':' in args.plot: 

197 tags, keys = args.plot.split(':') 

198 tags = tags.split(',') 

199 else: 

200 tags = [] 

201 keys = args.plot 

202 keys = keys.split(',') 

203 plots = defaultdict(list) 

204 X = {} 

205 labels = [] 

206 for row in db.select(query, sort=args.sort, include_data=False): 

207 name = ','.join(str(row[tag]) for tag in tags) 

208 x = row.get(keys[0]) 

209 if x is not None: 

210 if isinstance(x, str): 

211 if x not in X: 

212 X[x] = len(X) 

213 labels.append(x) 

214 x = X[x] 

215 plots[name].append([x] + [row.get(key) for key in keys[1:]]) 

216 import matplotlib.pyplot as plt 

217 for name, plot in plots.items(): 

218 xyy = list(zip(*plot)) 

219 x = xyy[0] 

220 for y, key in zip(xyy[1:], keys[1:]): 

221 plt.plot(x, y, label=name + ':' + key) 

222 if X: 

223 plt.xticks(range(len(labels)), labels, rotation=90) 

224 plt.legend() 

225 plt.show() 

226 return 

227 

228 if args.json: 

229 row = db.get(query) 

230 db2 = connect(sys.stdout, 'json', use_lock_file=False) 

231 kvp = row.get('key_value_pairs', {}) 

232 db2.write(row, data=row.get('data'), **kvp) 

233 return 

234 

235 if args.long: 

236 row = db.get(query) 

237 print(row2str(row)) 

238 return 

239 

240 if args.open_web_browser: 

241 try: 

242 import flask # noqa 

243 except ImportError: 

244 print('Please install Flask: python3 -m pip install flask') 

245 return 

246 check_jsmol() 

247 import ase.db.app as app 

248 app.DBApp().run_db(db) 

249 return 

250 

251 columns = list(all_columns) 

252 c = args.columns 

253 if c and c.startswith('++'): 

254 keys = set() 

255 for row in db.select(query, 

256 limit=args.limit, offset=args.offset, 

257 include_data=False): 

258 keys.update(row._keys) 

259 columns.extend(keys) 

260 if c[2:3] == ',': 

261 c = c[3:] 

262 else: 

263 c = '' 

264 if c: 

265 if c[0] == '+': 

266 c = c[1:] 

267 elif c[0] != '-': 

268 columns = [] 

269 for col in c.split(','): 

270 if col[0] == '-': 

271 columns.remove(col[1:]) 

272 else: 

273 columns.append(col.lstrip('+')) 

274 

275 table = Table(db, verbosity=verbosity, cut=args.cut) 

276 table.select(query, columns, args.sort, args.limit, args.offset) 

277 if args.csv: 

278 table.write_csv() 

279 else: 

280 table.write(query) 

281 

282 

283def row2str(row) -> str: 

284 t = row2dct(row, key_descriptions={}) 

285 S = [t['formula'] + ':', 

286 'Unit cell in Ang:', 

287 'axis|periodic| x| y| z|' + 

288 ' length| angle'] 

289 c = 1 

290 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' + 

291 '{3:>10}|{4:>10}') 

292 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']): 

293 S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A)) 

294 c += 1 

295 S.append('') 

296 

297 if 'stress' in t: 

298 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:', 

299 ' {}\n'.format(t['stress'])] 

300 

301 if 'dipole' in t: 

302 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole'])) 

303 

304 if 'constraints' in t: 

305 S.append('Constraints: {}\n'.format(t['constraints'])) 

306 

307 if 'data' in t: 

308 S.append('Data: {}\n'.format(t['data'])) 

309 

310 width0 = max(max(len(row[0]) for row in t['table']), 3) 

311 width1 = max(max(len(row[1]) for row in t['table']), 11) 

312 S.append('{:{}} | {:{}} | Value' 

313 .format('Key', width0, 'Description', width1)) 

314 for key, desc, value in t['table']: 

315 S.append('{:{}} | {:{}} | {}' 

316 .format(key, width0, desc, width1, value)) 

317 return '\n'.join(S) 

318 

319 

320@contextmanager 

321def no_progressbar(iterable: Iterable, 

322 length: int = None) -> Iterator[Iterable]: 

323 """A do-nothing implementation.""" 

324 yield iterable 

325 

326 

327def check_jsmol(): 

328 static = Path(__file__).parent / 'static' 

329 if not (static / 'jsmol/JSmol.min.js').is_file(): 

330 print(f""" 

331 WARNING: 

332 You don't have jsmol on your system. 

333 

334 Download Jmol-*-binary.tar.gz from 

335 https://sourceforge.net/projects/jmol/files/Jmol/, 

336 extract jsmol.zip, unzip it and create a soft-link: 

337 

338 $ tar -xf Jmol-*-binary.tar.gz 

339 $ unzip jmol-*/jsmol.zip 

340 $ ln -s $PWD/jsmol {static}/jsmol 

341 """, 

342 file=sys.stderr)