Coverage for /builds/kinetik161/ase/ase/utils/filecache.py: 95.92%

196 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import json 

2from collections.abc import Mapping, MutableMapping 

3from contextlib import contextmanager 

4from pathlib import Path 

5 

6from ase.io.jsonio import encode as encode_json 

7from ase.io.jsonio import read_json, write_json 

8from ase.io.ulm import InvalidULMFileError, NDArrayReader, Writer, ulmopen 

9from ase.parallel import world 

10from ase.utils import opencew 

11 

12 

13def missing(key): 

14 raise KeyError(key) 

15 

16 

17class Locked(Exception): 

18 pass 

19 

20 

21# Note: 

22# 

23# The communicator handling is a complete hack. 

24# We should entirely remove communicators from these objects. 

25# (Actually: opencew() should not know about communicators.) 

26# Then the caller is responsible for handling parallelism, 

27# which makes life simpler for both the caller and us! 

28# 

29# Also, things like clean()/__del__ are not correctly implemented 

30# in parallel. The reason why it currently "works" is that 

31# we don't call those functions from Vibrations etc., or they do so 

32# only for rank==0. 

33 

34 

35class JSONBackend: 

36 extension = '.json' 

37 DecodeError = json.decoder.JSONDecodeError 

38 

39 @staticmethod 

40 def open_for_writing(path, comm): 

41 return opencew(path, world=comm) 

42 

43 @staticmethod 

44 def read(fname): 

45 return read_json(fname, always_array=False) 

46 

47 @staticmethod 

48 def open_and_write(target, data, comm): 

49 if comm.rank == 0: 

50 write_json(target, data) 

51 

52 @staticmethod 

53 def write(fd, value): 

54 fd.write(encode_json(value).encode('utf-8')) 

55 

56 @classmethod 

57 def dump_cache(cls, path, dct, comm): 

58 return CombinedJSONCache.dump_cache(path, dct, comm) 

59 

60 @classmethod 

61 def create_multifile_cache(cls, directory, comm): 

62 return MultiFileJSONCache(directory, comm=comm) 

63 

64 

65class ULMBackend: 

66 extension = '.ulm' 

67 DecodeError = InvalidULMFileError 

68 

69 @staticmethod 

70 def open_for_writing(path, comm): 

71 fd = opencew(path, world=comm) 

72 if fd is not None: 

73 return Writer(fd, 'w', '') 

74 

75 @staticmethod 

76 def read(fname): 

77 with ulmopen(fname, 'r') as r: 

78 data = r._data['cache'] 

79 if isinstance(data, NDArrayReader): 

80 return data.read() 

81 return data 

82 

83 @staticmethod 

84 def open_and_write(target, data, comm): 

85 if comm.rank == 0: 

86 with ulmopen(target, 'w') as w: 

87 w.write('cache', data) 

88 

89 @staticmethod 

90 def write(fd, value): 

91 fd.write('cache', value) 

92 

93 @classmethod 

94 def dump_cache(cls, path, dct, comm): 

95 return CombinedULMCache.dump_cache(path, dct, comm) 

96 

97 @classmethod 

98 def create_multifile_cache(cls, directory, comm): 

99 return MultiFileULMCache(directory, comm=comm) 

100 

101 

102class CacheLock: 

103 def __init__(self, fd, key, backend): 

104 self.fd = fd 

105 self.key = key 

106 self.backend = backend 

107 

108 def save(self, value): 

109 try: 

110 self.backend.write(self.fd, value) 

111 except Exception as ex: 

112 raise RuntimeError(f'Failed to save {value} to cache') from ex 

113 finally: 

114 self.fd.close() 

115 

116 

117class _MultiFileCacheTemplate(MutableMapping): 

118 writable = True 

119 

120 def __init__(self, directory, comm=world): 

121 self.directory = Path(directory) 

122 self.comm = comm 

123 

124 def _filename(self, key): 

125 return self.directory / (f'cache.{key}' + self.backend.extension) 

126 

127 def _glob(self): 

128 return self.directory.glob('cache.*' + self.backend.extension) 

129 

130 def __iter__(self): 

131 for path in self._glob(): 

132 cache, key = path.stem.split('.', 1) 

133 if cache != 'cache': 

134 continue 

135 yield key 

136 

137 def __len__(self): 

138 # Very inefficient this, but not a big usecase. 

139 return len(list(self._glob())) 

140 

141 @contextmanager 

142 def lock(self, key): 

143 if self.comm.rank == 0: 

144 self.directory.mkdir(exist_ok=True, parents=True) 

145 path = self._filename(key) 

146 fd = self.backend.open_for_writing(path, self.comm) 

147 try: 

148 if fd is None: 

149 yield None 

150 else: 

151 yield CacheLock(fd, key, self.backend) 

152 finally: 

153 if fd is not None: 

154 fd.close() 

155 

156 def __setitem__(self, key, value): 

157 with self.lock(key) as handle: 

158 if handle is None: 

159 raise Locked(key) 

160 handle.save(value) 

161 

162 def __getitem__(self, key): 

163 path = self._filename(key) 

164 try: 

165 return self.backend.read(path) 

166 except FileNotFoundError: 

167 missing(key) 

168 except self.backend.DecodeError: 

169 # May be partially written, which typically means empty 

170 # because the file was locked with exclusive-write-open. 

171 # 

172 # Since we decide what keys we have based on which files exist, 

173 # we are obligated to return a value for this case too. 

174 # So we return None. 

175 return None 

176 

177 def __delitem__(self, key): 

178 try: 

179 self._filename(key).unlink() 

180 except FileNotFoundError: 

181 missing(key) 

182 

183 def combine(self): 

184 cache = self.backend.dump_cache(self.directory, dict(self), 

185 comm=self.comm) 

186 assert set(cache) == set(self) 

187 self.clear() 

188 assert len(self) == 0 

189 return cache 

190 

191 def split(self): 

192 return self 

193 

194 def filecount(self): 

195 return len(self) 

196 

197 def strip_empties(self): 

198 empties = [key for key, value in self.items() if value is None] 

199 for key in empties: 

200 del self[key] 

201 return len(empties) 

202 

203 

204class _CombinedCacheTemplate(Mapping): 

205 writable = False 

206 

207 def __init__(self, directory, dct, comm=world): 

208 self.directory = Path(directory) 

209 self._dct = dict(dct) 

210 self.comm = comm 

211 

212 def filecount(self): 

213 return int(self._filename.is_file()) 

214 

215 @property 

216 def _filename(self): 

217 return self.directory / ('combined' + self.backend.extension) 

218 

219 def __len__(self): 

220 return len(self._dct) 

221 

222 def __iter__(self): 

223 return iter(self._dct) 

224 

225 def __getitem__(self, index): 

226 return self._dct[index] 

227 

228 def _dump(self): 

229 target = self._filename 

230 if target.exists(): 

231 raise RuntimeError(f'Already exists: {target}') 

232 self.directory.mkdir(exist_ok=True, parents=True) 

233 self.backend.open_and_write(target, self._dct, comm=self.comm) 

234 

235 @classmethod 

236 def dump_cache(cls, path, dct, comm=world): 

237 cache = cls(path, dct, comm=comm) 

238 cache._dump() 

239 return cache 

240 

241 @classmethod 

242 def load(cls, path, comm): 

243 # XXX Very hacky this one 

244 cache = cls(path, {}, comm=comm) 

245 dct = cls.backend.read(cache._filename) 

246 cache._dct.update(dct) 

247 return cache 

248 

249 def clear(self): 

250 self._filename.unlink() 

251 self._dct.clear() 

252 

253 def combine(self): 

254 return self 

255 

256 def split(self): 

257 cache = self.backend.create_multifile_cache(self.directory, 

258 comm=self.comm) 

259 assert len(cache) == 0 

260 cache.update(self) 

261 assert set(cache) == set(self) 

262 self.clear() 

263 return cache 

264 

265 

266class MultiFileJSONCache(_MultiFileCacheTemplate): 

267 backend = JSONBackend() 

268 

269 

270class MultiFileULMCache(_MultiFileCacheTemplate): 

271 backend = ULMBackend() 

272 

273 

274class CombinedJSONCache(_CombinedCacheTemplate): 

275 backend = JSONBackend() 

276 

277 

278class CombinedULMCache(_CombinedCacheTemplate): 

279 backend = ULMBackend() 

280 

281 

282def get_json_cache(directory, comm=world): 

283 try: 

284 return CombinedJSONCache.load(directory, comm=comm) 

285 except FileNotFoundError: 

286 return MultiFileJSONCache(directory, comm=comm)