Coverage for /builds/kinetik161/ase/ase/utils/filecache.py: 95.92%
196 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import json
2from collections.abc import Mapping, MutableMapping
3from contextlib import contextmanager
4from pathlib import Path
6from ase.io.jsonio import encode as encode_json
7from ase.io.jsonio import read_json, write_json
8from ase.io.ulm import InvalidULMFileError, NDArrayReader, Writer, ulmopen
9from ase.parallel import world
10from ase.utils import opencew
13def missing(key):
14 raise KeyError(key)
17class Locked(Exception):
18 pass
21# Note:
22#
23# The communicator handling is a complete hack.
24# We should entirely remove communicators from these objects.
25# (Actually: opencew() should not know about communicators.)
26# Then the caller is responsible for handling parallelism,
27# which makes life simpler for both the caller and us!
28#
29# Also, things like clean()/__del__ are not correctly implemented
30# in parallel. The reason why it currently "works" is that
31# we don't call those functions from Vibrations etc., or they do so
32# only for rank==0.
35class JSONBackend:
36 extension = '.json'
37 DecodeError = json.decoder.JSONDecodeError
39 @staticmethod
40 def open_for_writing(path, comm):
41 return opencew(path, world=comm)
43 @staticmethod
44 def read(fname):
45 return read_json(fname, always_array=False)
47 @staticmethod
48 def open_and_write(target, data, comm):
49 if comm.rank == 0:
50 write_json(target, data)
52 @staticmethod
53 def write(fd, value):
54 fd.write(encode_json(value).encode('utf-8'))
56 @classmethod
57 def dump_cache(cls, path, dct, comm):
58 return CombinedJSONCache.dump_cache(path, dct, comm)
60 @classmethod
61 def create_multifile_cache(cls, directory, comm):
62 return MultiFileJSONCache(directory, comm=comm)
65class ULMBackend:
66 extension = '.ulm'
67 DecodeError = InvalidULMFileError
69 @staticmethod
70 def open_for_writing(path, comm):
71 fd = opencew(path, world=comm)
72 if fd is not None:
73 return Writer(fd, 'w', '')
75 @staticmethod
76 def read(fname):
77 with ulmopen(fname, 'r') as r:
78 data = r._data['cache']
79 if isinstance(data, NDArrayReader):
80 return data.read()
81 return data
83 @staticmethod
84 def open_and_write(target, data, comm):
85 if comm.rank == 0:
86 with ulmopen(target, 'w') as w:
87 w.write('cache', data)
89 @staticmethod
90 def write(fd, value):
91 fd.write('cache', value)
93 @classmethod
94 def dump_cache(cls, path, dct, comm):
95 return CombinedULMCache.dump_cache(path, dct, comm)
97 @classmethod
98 def create_multifile_cache(cls, directory, comm):
99 return MultiFileULMCache(directory, comm=comm)
102class CacheLock:
103 def __init__(self, fd, key, backend):
104 self.fd = fd
105 self.key = key
106 self.backend = backend
108 def save(self, value):
109 try:
110 self.backend.write(self.fd, value)
111 except Exception as ex:
112 raise RuntimeError(f'Failed to save {value} to cache') from ex
113 finally:
114 self.fd.close()
117class _MultiFileCacheTemplate(MutableMapping):
118 writable = True
120 def __init__(self, directory, comm=world):
121 self.directory = Path(directory)
122 self.comm = comm
124 def _filename(self, key):
125 return self.directory / (f'cache.{key}' + self.backend.extension)
127 def _glob(self):
128 return self.directory.glob('cache.*' + self.backend.extension)
130 def __iter__(self):
131 for path in self._glob():
132 cache, key = path.stem.split('.', 1)
133 if cache != 'cache':
134 continue
135 yield key
137 def __len__(self):
138 # Very inefficient this, but not a big usecase.
139 return len(list(self._glob()))
141 @contextmanager
142 def lock(self, key):
143 if self.comm.rank == 0:
144 self.directory.mkdir(exist_ok=True, parents=True)
145 path = self._filename(key)
146 fd = self.backend.open_for_writing(path, self.comm)
147 try:
148 if fd is None:
149 yield None
150 else:
151 yield CacheLock(fd, key, self.backend)
152 finally:
153 if fd is not None:
154 fd.close()
156 def __setitem__(self, key, value):
157 with self.lock(key) as handle:
158 if handle is None:
159 raise Locked(key)
160 handle.save(value)
162 def __getitem__(self, key):
163 path = self._filename(key)
164 try:
165 return self.backend.read(path)
166 except FileNotFoundError:
167 missing(key)
168 except self.backend.DecodeError:
169 # May be partially written, which typically means empty
170 # because the file was locked with exclusive-write-open.
171 #
172 # Since we decide what keys we have based on which files exist,
173 # we are obligated to return a value for this case too.
174 # So we return None.
175 return None
177 def __delitem__(self, key):
178 try:
179 self._filename(key).unlink()
180 except FileNotFoundError:
181 missing(key)
183 def combine(self):
184 cache = self.backend.dump_cache(self.directory, dict(self),
185 comm=self.comm)
186 assert set(cache) == set(self)
187 self.clear()
188 assert len(self) == 0
189 return cache
191 def split(self):
192 return self
194 def filecount(self):
195 return len(self)
197 def strip_empties(self):
198 empties = [key for key, value in self.items() if value is None]
199 for key in empties:
200 del self[key]
201 return len(empties)
204class _CombinedCacheTemplate(Mapping):
205 writable = False
207 def __init__(self, directory, dct, comm=world):
208 self.directory = Path(directory)
209 self._dct = dict(dct)
210 self.comm = comm
212 def filecount(self):
213 return int(self._filename.is_file())
215 @property
216 def _filename(self):
217 return self.directory / ('combined' + self.backend.extension)
219 def __len__(self):
220 return len(self._dct)
222 def __iter__(self):
223 return iter(self._dct)
225 def __getitem__(self, index):
226 return self._dct[index]
228 def _dump(self):
229 target = self._filename
230 if target.exists():
231 raise RuntimeError(f'Already exists: {target}')
232 self.directory.mkdir(exist_ok=True, parents=True)
233 self.backend.open_and_write(target, self._dct, comm=self.comm)
235 @classmethod
236 def dump_cache(cls, path, dct, comm=world):
237 cache = cls(path, dct, comm=comm)
238 cache._dump()
239 return cache
241 @classmethod
242 def load(cls, path, comm):
243 # XXX Very hacky this one
244 cache = cls(path, {}, comm=comm)
245 dct = cls.backend.read(cache._filename)
246 cache._dct.update(dct)
247 return cache
249 def clear(self):
250 self._filename.unlink()
251 self._dct.clear()
253 def combine(self):
254 return self
256 def split(self):
257 cache = self.backend.create_multifile_cache(self.directory,
258 comm=self.comm)
259 assert len(cache) == 0
260 cache.update(self)
261 assert set(cache) == set(self)
262 self.clear()
263 return cache
266class MultiFileJSONCache(_MultiFileCacheTemplate):
267 backend = JSONBackend()
270class MultiFileULMCache(_MultiFileCacheTemplate):
271 backend = ULMBackend()
274class CombinedJSONCache(_CombinedCacheTemplate):
275 backend = JSONBackend()
278class CombinedULMCache(_CombinedCacheTemplate):
279 backend = ULMBackend()
282def get_json_cache(directory, comm=world):
283 try:
284 return CombinedJSONCache.load(directory, comm=comm)
285 except FileNotFoundError:
286 return MultiFileJSONCache(directory, comm=comm)