Coverage for /builds/kinetik161/ase/ase/db/jsondb.py: 91.43%
175 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import os
2import sys
3from contextlib import ExitStack
5import numpy as np
7from ase.db.core import Database, lock, now, ops
8from ase.db.row import AtomsRow
9from ase.io.jsonio import decode, encode
10from ase.parallel import parallel_function, world
13class JSONDatabase(Database):
14 def __enter__(self):
15 return self
17 def __exit__(self, exc_type, exc_value, tb):
18 pass
20 def _write(self, atoms, key_value_pairs, data, id):
21 Database._write(self, atoms, key_value_pairs, data)
23 bigdct = {}
24 ids = []
25 nextid = 1
27 if (isinstance(self.filename, str) and
28 os.path.isfile(self.filename)):
29 try:
30 bigdct, ids, nextid = self._read_json()
31 except (SyntaxError, ValueError):
32 pass
34 mtime = now()
36 if isinstance(atoms, AtomsRow):
37 row = atoms
38 else:
39 row = AtomsRow(atoms)
40 row.ctime = mtime
41 row.user = os.getenv('USER')
43 dct = {}
44 for key in row.__dict__:
45 if key[0] == '_' or key in row._keys or key == 'id':
46 continue
47 dct[key] = row[key]
49 dct['mtime'] = mtime
51 if key_value_pairs:
52 dct['key_value_pairs'] = key_value_pairs
54 if data:
55 dct['data'] = data
57 constraints = row.get('constraints')
58 if constraints:
59 dct['constraints'] = constraints
61 if id is None:
62 id = nextid
63 ids.append(id)
64 nextid += 1
65 else:
66 assert id in bigdct
68 bigdct[id] = dct
69 self._write_json(bigdct, ids, nextid)
70 return id
72 def _read_json(self):
73 if isinstance(self.filename, str):
74 with open(self.filename) as fd:
75 bigdct = decode(fd.read())
76 else:
77 bigdct = decode(self.filename.read())
78 if self.filename is not sys.stdin:
79 self.filename.seek(0)
81 if not isinstance(bigdct, dict) or not ('ids' in bigdct
82 or 1 in bigdct):
83 from ase.io.formats import UnknownFileTypeError
84 raise UnknownFileTypeError('Does not resemble ASE JSON database')
86 ids = bigdct.get('ids')
87 if ids is None:
88 # Allow for missing "ids" and "nextid":
89 assert 1 in bigdct
90 return bigdct, [1], 2
91 if not isinstance(ids, list):
92 ids = ids.tolist()
93 return bigdct, ids, bigdct['nextid']
95 def _write_json(self, bigdct, ids, nextid):
96 if world.rank > 0:
97 return
99 with ExitStack() as stack:
100 if isinstance(self.filename, str):
101 fd = stack.enter_context(open(self.filename, 'w'))
102 else:
103 fd = self.filename
104 print('{', end='', file=fd)
105 for id in ids:
106 dct = bigdct[id]
107 txt = ',\n '.join(f'"{key}": {encode(dct[key])}'
108 for key in sorted(dct.keys()))
109 print(f'"{id}": {{\n {txt}}},', file=fd)
110 if self._metadata is not None:
111 print(f'"metadata": {encode(self.metadata)},', file=fd)
112 print(f'"ids": {ids},', file=fd)
113 print(f'"nextid": {nextid}}}', file=fd)
115 @parallel_function
116 @lock
117 def delete(self, ids):
118 bigdct, myids, nextid = self._read_json()
119 for id in ids:
120 del bigdct[id]
121 myids.remove(id)
122 self._write_json(bigdct, myids, nextid)
124 def _get_row(self, id):
125 bigdct, ids, nextid = self._read_json()
126 if id is None:
127 assert len(ids) == 1
128 id = ids[0]
129 dct = bigdct[id]
130 dct['id'] = id
131 return AtomsRow(dct)
133 def _select(self, keys, cmps, explain=False, verbosity=0,
134 limit=None, offset=0, sort=None, include_data=True,
135 columns='all'):
136 if explain:
137 yield {'explain': (0, 0, 0, 'scan table')}
138 return
140 if sort:
141 if sort[0] == '-':
142 reverse = True
143 sort = sort[1:]
144 else:
145 reverse = False
147 def f(row):
148 return row.get(sort, missing)
150 rows = []
151 missing = []
152 for row in self._select(keys, cmps):
153 key = row.get(sort)
154 if key is None:
155 missing.append((0, row))
156 else:
157 rows.append((key, row))
159 rows.sort(reverse=reverse, key=lambda x: x[0])
160 rows += missing
162 if limit:
163 rows = rows[offset:offset + limit]
164 for key, row in rows:
165 yield row
166 return
168 try:
169 bigdct, ids, nextid = self._read_json()
170 except OSError:
171 return
173 if not limit:
174 limit = -offset - 1
176 cmps = [(key, ops[op], val) for key, op, val in cmps]
177 n = 0
178 for id in ids:
179 if n - offset == limit:
180 return
181 dct = bigdct[id]
182 if not include_data:
183 dct.pop('data', None)
184 row = AtomsRow(dct)
185 row.id = id
186 for key in keys:
187 if key not in row:
188 break
189 else:
190 for key, op, val in cmps:
191 if isinstance(key, int):
192 value = np.equal(row.numbers, key).sum()
193 else:
194 value = row.get(key)
195 if key == 'pbc':
196 assert op in [ops['='], ops['!=']]
197 value = ''.join('FT'[x] for x in value)
198 if value is None or not op(value, val):
199 break
200 else:
201 if n >= offset:
202 yield row
203 n += 1
205 @property
206 def metadata(self):
207 if self._metadata is None:
208 bigdct, myids, nextid = self._read_json()
209 self._metadata = bigdct.get('metadata', {})
210 return self._metadata.copy()
212 @metadata.setter
213 def metadata(self, dct):
214 bigdct, ids, nextid = self._read_json()
215 self._metadata = dct
216 self._write_json(bigdct, ids, nextid)
218 def get_all_key_names(self):
219 keys = set()
220 bigdct, ids, nextid = self._read_json()
221 for id in ids:
222 dct = bigdct[id]
223 kvp = dct.get('key_value_pairs')
224 if kvp:
225 keys.update(kvp)
226 return keys