Coverage for /builds/kinetik161/ase/ase/calculators/checkpoint.py: 86.09%
151 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""Checkpointing and restart functionality for scripts using ASE Atoms objects.
3Initialize checkpoint object:
5CP = Checkpoint('checkpoints.db')
7Checkpointed code block in try ... except notation:
9try:
10 a, C, C_err = CP.load()
11except NoCheckpoint:
12 C, C_err = fit_elastic_constants(a)
13 CP.save(a, C, C_err)
15Checkpoint code block, shorthand notation:
17C, C_err = CP(fit_elastic_constants)(a)
19Example for checkpointing within an iterative loop, e.g. for searching crack
20tip position:
22try:
23 a, converged, tip_x, tip_y = CP.load()
24except NoCheckpoint:
25 converged = False
26 tip_x = tip_x0
27 tip_y = tip_y0
28while not converged:
29 ... do something to find better crack tip position ...
30 converged = ...
31 CP.flush(a, converged, tip_x, tip_y)
33The simplest way to use checkpointing is through the CheckpointCalculator. It
34wraps any calculator object and does a checkpoint whenever a calculation
35is performed:
37 calc = ...
38 cp_calc = CheckpointCalculator(calc)
39 atoms.calc = cp_calc
40 e = atoms.get_potential_energy() # 1st time, does calc, writes to checkfile
41 # subsequent runs, reads from checkpoint
42"""
44from typing import Any, Dict
46import numpy as np
48import ase
49from ase.calculators.calculator import Calculator
50from ase.db import connect
53class NoCheckpoint(Exception):
54 pass
57class DevNull:
58 def write(str, *args):
59 pass
62class Checkpoint:
63 _value_prefix = '_values_'
65 def __init__(self, db='checkpoints.db', logfile=None):
66 self.db = db
67 if logfile is None:
68 logfile = DevNull()
69 self.logfile = logfile
71 self.checkpoint_id = [0]
72 self.in_checkpointed_region = False
74 def __call__(self, func, *args, **kwargs):
75 checkpoint_func_name = str(func)
77 def decorated_func(*args, **kwargs):
78 # Get the first ase.Atoms object.
79 atoms = None
80 for a in args:
81 if atoms is None and isinstance(a, ase.Atoms):
82 atoms = a
84 try:
85 retvals = self.load(atoms=atoms)
86 except NoCheckpoint:
87 retvals = func(*args, **kwargs)
88 if isinstance(retvals, tuple):
89 self.save(*retvals, atoms=atoms,
90 checkpoint_func_name=checkpoint_func_name)
91 else:
92 self.save(retvals, atoms=atoms,
93 checkpoint_func_name=checkpoint_func_name)
94 return retvals
95 return decorated_func
97 def _increase_checkpoint_id(self):
98 if self.in_checkpointed_region:
99 self.checkpoint_id += [1]
100 else:
101 self.checkpoint_id[-1] += 1
102 self.logfile.write('Entered checkpoint region '
103 '{}.\n'.format(self.checkpoint_id))
105 self.in_checkpointed_region = True
107 def _decrease_checkpoint_id(self):
108 self.logfile.write('Leaving checkpoint region '
109 '{}.\n'.format(self.checkpoint_id))
110 if not self.in_checkpointed_region:
111 self.checkpoint_id = self.checkpoint_id[:-1]
112 assert len(self.checkpoint_id) >= 1
113 self.in_checkpointed_region = False
114 assert self.checkpoint_id[-1] >= 1
116 def _mangled_checkpoint_id(self):
117 """
118 Returns a mangled checkpoint id string:
119 check_c_1:c_2:c_3:...
120 E.g. if checkpoint is nested and id is [3,2,6] it returns:
121 'check3:2:6'
122 """
123 return 'check' + ':'.join(str(id) for id in self.checkpoint_id)
125 def load(self, atoms=None):
126 """
127 Retrieve checkpoint data from file. If atoms object is specified, then
128 the calculator connected to that object is copied to all returning
129 atoms object.
131 Returns tuple of values as passed to flush or save during checkpoint
132 write.
133 """
134 self._increase_checkpoint_id()
136 retvals = []
137 with connect(self.db) as db:
138 try:
139 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id())
140 except KeyError:
141 raise NoCheckpoint
143 data = dbentry.data
144 atomsi = data['checkpoint_atoms_args_index']
145 i = 0
146 while (i == atomsi or
147 f'{self._value_prefix}{i}' in data):
148 if i == atomsi:
149 newatoms = dbentry.toatoms()
150 if atoms is not None:
151 # Assign calculator
152 newatoms.calc = atoms.calc
153 retvals += [newatoms]
154 else:
155 retvals += [data[f'{self._value_prefix}{i}']]
156 i += 1
158 self.logfile.write('Successfully restored checkpoint '
159 '{}.\n'.format(self.checkpoint_id))
160 self._decrease_checkpoint_id()
161 if len(retvals) == 1:
162 return retvals[0]
163 else:
164 return tuple(retvals)
166 def _flush(self, *args, **kwargs):
167 data = {f'{self._value_prefix}{i}': v
168 for i, v in enumerate(args)}
170 try:
171 atomsi = [isinstance(v, ase.Atoms) for v in args].index(True)
172 atoms = args[atomsi]
173 del data[f'{self._value_prefix}{atomsi}']
174 except ValueError:
175 atomsi = -1
176 try:
177 atoms = kwargs['atoms']
178 except KeyError:
179 raise RuntimeError('No atoms object provided in arguments.')
181 try:
182 del kwargs['atoms']
183 except KeyError:
184 pass
186 data['checkpoint_atoms_args_index'] = atomsi
187 data.update(kwargs)
189 with connect(self.db) as db:
190 try:
191 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id())
192 del db[dbentry.id]
193 except KeyError:
194 pass
195 db.write(atoms, checkpoint_id=self._mangled_checkpoint_id(),
196 data=data)
198 self.logfile.write('Successfully stored checkpoint '
199 '{}.\n'.format(self.checkpoint_id))
201 def flush(self, *args, **kwargs):
202 """
203 Store data to a checkpoint without increasing the checkpoint id. This
204 is useful to continuously update the checkpoint state in an iterative
205 loop.
206 """
207 # If we are flushing from a successfully restored checkpoint, then
208 # in_checkpointed_region will be set to False. We need to reset to True
209 # because a call to flush indicates that this checkpoint is still
210 # active.
211 self.in_checkpointed_region = False
212 self._flush(*args, **kwargs)
214 def save(self, *args, **kwargs):
215 """
216 Store data to a checkpoint and increase the checkpoint id. This closes
217 the checkpoint.
218 """
219 self._decrease_checkpoint_id()
220 self._flush(*args, **kwargs)
223def atoms_almost_equal(a, b, tol=1e-9):
224 return (np.abs(a.positions - b.positions).max() < tol and
225 (a.numbers == b.numbers).all() and
226 np.abs(a.cell - b.cell).max() < tol and
227 (a.pbc == b.pbc).all())
230class CheckpointCalculator(Calculator):
231 """
232 This wraps any calculator object to checkpoint whenever a calculation
233 is performed.
235 This is particularly useful for expensive calculators, e.g. DFT and
236 allows usage of complex workflows.
238 Example usage:
240 calc = ...
241 cp_calc = CheckpointCalculator(calc)
242 atoms.calc = cp_calc
243 e = atoms.get_potential_energy()
244 # 1st time, does calc, writes to checkfile
245 # subsequent runs, reads from checkpoint file
246 """
247 implemented_properties = ase.calculators.calculator.all_properties
248 default_parameters: Dict[str, Any] = {}
249 name = 'CheckpointCalculator'
251 property_to_method_name = {
252 'energy': 'get_potential_energy',
253 'energies': 'get_potential_energies',
254 'forces': 'get_forces',
255 'stress': 'get_stress',
256 'stresses': 'get_stresses'}
258 def __init__(self, calculator, db='checkpoints.db', logfile=None):
259 Calculator.__init__(self)
260 self.calculator = calculator
261 if logfile is None:
262 logfile = DevNull()
263 self.checkpoint = Checkpoint(db, logfile)
264 self.logfile = logfile
266 def calculate(self, atoms, properties, system_changes):
267 Calculator.calculate(self, atoms, properties, system_changes)
268 try:
269 results = self.checkpoint.load(atoms)
270 prev_atoms, results = results[0], results[1:]
271 try:
272 assert atoms_almost_equal(atoms, prev_atoms)
273 except AssertionError:
274 raise AssertionError('mismatch between current atoms and '
275 'those read from checkpoint file')
276 self.logfile.write('retrieved results for {} from checkpoint\n'
277 .format(properties))
278 # save results in calculator for next time
279 if isinstance(self.calculator, Calculator):
280 if not hasattr(self.calculator, 'results'):
281 self.calculator.results = {}
282 self.calculator.results.update(dict(zip(properties, results)))
283 except NoCheckpoint:
284 if isinstance(self.calculator, Calculator):
285 self.logfile.write('doing calculation of {} with new-style '
286 'calculator interface\n'.format(properties))
287 self.calculator.calculate(atoms, properties, system_changes)
288 results = [self.calculator.results[prop]
289 for prop in properties]
290 else:
291 self.logfile.write('doing calculation of {} with old-style '
292 'calculator interface\n'.format(properties))
293 results = []
294 for prop in properties:
295 method_name = self.property_to_method_name[prop]
296 method = getattr(self.calculator, method_name)
297 results.append(method(atoms))
298 _calculator = atoms.calc
299 try:
300 atoms.calc = self.calculator
301 self.checkpoint.save(atoms, *results)
302 finally:
303 atoms.calc = _calculator
305 self.results = dict(zip(properties, results))