Coverage for /builds/kinetik161/ase/ase/db/jsondb.py: 91.43%

175 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import os 

2import sys 

3from contextlib import ExitStack 

4 

5import numpy as np 

6 

7from ase.db.core import Database, lock, now, ops 

8from ase.db.row import AtomsRow 

9from ase.io.jsonio import decode, encode 

10from ase.parallel import parallel_function, world 

11 

12 

13class JSONDatabase(Database): 

14 def __enter__(self): 

15 return self 

16 

17 def __exit__(self, exc_type, exc_value, tb): 

18 pass 

19 

20 def _write(self, atoms, key_value_pairs, data, id): 

21 Database._write(self, atoms, key_value_pairs, data) 

22 

23 bigdct = {} 

24 ids = [] 

25 nextid = 1 

26 

27 if (isinstance(self.filename, str) and 

28 os.path.isfile(self.filename)): 

29 try: 

30 bigdct, ids, nextid = self._read_json() 

31 except (SyntaxError, ValueError): 

32 pass 

33 

34 mtime = now() 

35 

36 if isinstance(atoms, AtomsRow): 

37 row = atoms 

38 else: 

39 row = AtomsRow(atoms) 

40 row.ctime = mtime 

41 row.user = os.getenv('USER') 

42 

43 dct = {} 

44 for key in row.__dict__: 

45 if key[0] == '_' or key in row._keys or key == 'id': 

46 continue 

47 dct[key] = row[key] 

48 

49 dct['mtime'] = mtime 

50 

51 if key_value_pairs: 

52 dct['key_value_pairs'] = key_value_pairs 

53 

54 if data: 

55 dct['data'] = data 

56 

57 constraints = row.get('constraints') 

58 if constraints: 

59 dct['constraints'] = constraints 

60 

61 if id is None: 

62 id = nextid 

63 ids.append(id) 

64 nextid += 1 

65 else: 

66 assert id in bigdct 

67 

68 bigdct[id] = dct 

69 self._write_json(bigdct, ids, nextid) 

70 return id 

71 

72 def _read_json(self): 

73 if isinstance(self.filename, str): 

74 with open(self.filename) as fd: 

75 bigdct = decode(fd.read()) 

76 else: 

77 bigdct = decode(self.filename.read()) 

78 if self.filename is not sys.stdin: 

79 self.filename.seek(0) 

80 

81 if not isinstance(bigdct, dict) or not ('ids' in bigdct 

82 or 1 in bigdct): 

83 from ase.io.formats import UnknownFileTypeError 

84 raise UnknownFileTypeError('Does not resemble ASE JSON database') 

85 

86 ids = bigdct.get('ids') 

87 if ids is None: 

88 # Allow for missing "ids" and "nextid": 

89 assert 1 in bigdct 

90 return bigdct, [1], 2 

91 if not isinstance(ids, list): 

92 ids = ids.tolist() 

93 return bigdct, ids, bigdct['nextid'] 

94 

95 def _write_json(self, bigdct, ids, nextid): 

96 if world.rank > 0: 

97 return 

98 

99 with ExitStack() as stack: 

100 if isinstance(self.filename, str): 

101 fd = stack.enter_context(open(self.filename, 'w')) 

102 else: 

103 fd = self.filename 

104 print('{', end='', file=fd) 

105 for id in ids: 

106 dct = bigdct[id] 

107 txt = ',\n '.join(f'"{key}": {encode(dct[key])}' 

108 for key in sorted(dct.keys())) 

109 print(f'"{id}": {{\n {txt}}},', file=fd) 

110 if self._metadata is not None: 

111 print(f'"metadata": {encode(self.metadata)},', file=fd) 

112 print(f'"ids": {ids},', file=fd) 

113 print(f'"nextid": {nextid}}}', file=fd) 

114 

115 @parallel_function 

116 @lock 

117 def delete(self, ids): 

118 bigdct, myids, nextid = self._read_json() 

119 for id in ids: 

120 del bigdct[id] 

121 myids.remove(id) 

122 self._write_json(bigdct, myids, nextid) 

123 

124 def _get_row(self, id): 

125 bigdct, ids, nextid = self._read_json() 

126 if id is None: 

127 assert len(ids) == 1 

128 id = ids[0] 

129 dct = bigdct[id] 

130 dct['id'] = id 

131 return AtomsRow(dct) 

132 

133 def _select(self, keys, cmps, explain=False, verbosity=0, 

134 limit=None, offset=0, sort=None, include_data=True, 

135 columns='all'): 

136 if explain: 

137 yield {'explain': (0, 0, 0, 'scan table')} 

138 return 

139 

140 if sort: 

141 if sort[0] == '-': 

142 reverse = True 

143 sort = sort[1:] 

144 else: 

145 reverse = False 

146 

147 def f(row): 

148 return row.get(sort, missing) 

149 

150 rows = [] 

151 missing = [] 

152 for row in self._select(keys, cmps): 

153 key = row.get(sort) 

154 if key is None: 

155 missing.append((0, row)) 

156 else: 

157 rows.append((key, row)) 

158 

159 rows.sort(reverse=reverse, key=lambda x: x[0]) 

160 rows += missing 

161 

162 if limit: 

163 rows = rows[offset:offset + limit] 

164 for key, row in rows: 

165 yield row 

166 return 

167 

168 try: 

169 bigdct, ids, nextid = self._read_json() 

170 except OSError: 

171 return 

172 

173 if not limit: 

174 limit = -offset - 1 

175 

176 cmps = [(key, ops[op], val) for key, op, val in cmps] 

177 n = 0 

178 for id in ids: 

179 if n - offset == limit: 

180 return 

181 dct = bigdct[id] 

182 if not include_data: 

183 dct.pop('data', None) 

184 row = AtomsRow(dct) 

185 row.id = id 

186 for key in keys: 

187 if key not in row: 

188 break 

189 else: 

190 for key, op, val in cmps: 

191 if isinstance(key, int): 

192 value = np.equal(row.numbers, key).sum() 

193 else: 

194 value = row.get(key) 

195 if key == 'pbc': 

196 assert op in [ops['='], ops['!=']] 

197 value = ''.join('FT'[x] for x in value) 

198 if value is None or not op(value, val): 

199 break 

200 else: 

201 if n >= offset: 

202 yield row 

203 n += 1 

204 

205 @property 

206 def metadata(self): 

207 if self._metadata is None: 

208 bigdct, myids, nextid = self._read_json() 

209 self._metadata = bigdct.get('metadata', {}) 

210 return self._metadata.copy() 

211 

212 @metadata.setter 

213 def metadata(self, dct): 

214 bigdct, ids, nextid = self._read_json() 

215 self._metadata = dct 

216 self._write_json(bigdct, ids, nextid) 

217 

218 def get_all_key_names(self): 

219 keys = set() 

220 bigdct, ids, nextid = self._read_json() 

221 for id in ids: 

222 dct = bigdct[id] 

223 kvp = dct.get('key_value_pairs') 

224 if kvp: 

225 keys.update(kvp) 

226 return keys