Coverage for /builds/kinetik161/ase/ase/calculators/checkpoint.py: 86.09%

151 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1"""Checkpointing and restart functionality for scripts using ASE Atoms objects. 

2 

3Initialize checkpoint object: 

4 

5CP = Checkpoint('checkpoints.db') 

6 

7Checkpointed code block in try ... except notation: 

8 

9try: 

10 a, C, C_err = CP.load() 

11except NoCheckpoint: 

12 C, C_err = fit_elastic_constants(a) 

13 CP.save(a, C, C_err) 

14 

15Checkpoint code block, shorthand notation: 

16 

17C, C_err = CP(fit_elastic_constants)(a) 

18 

19Example for checkpointing within an iterative loop, e.g. for searching crack 

20tip position: 

21 

22try: 

23 a, converged, tip_x, tip_y = CP.load() 

24except NoCheckpoint: 

25 converged = False 

26 tip_x = tip_x0 

27 tip_y = tip_y0 

28while not converged: 

29 ... do something to find better crack tip position ... 

30 converged = ... 

31 CP.flush(a, converged, tip_x, tip_y) 

32 

33The simplest way to use checkpointing is through the CheckpointCalculator. It 

34wraps any calculator object and does a checkpoint whenever a calculation 

35is performed: 

36 

37 calc = ... 

38 cp_calc = CheckpointCalculator(calc) 

39 atoms.calc = cp_calc 

40 e = atoms.get_potential_energy() # 1st time, does calc, writes to checkfile 

41 # subsequent runs, reads from checkpoint 

42""" 

43 

44from typing import Any, Dict 

45 

46import numpy as np 

47 

48import ase 

49from ase.calculators.calculator import Calculator 

50from ase.db import connect 

51 

52 

53class NoCheckpoint(Exception): 

54 pass 

55 

56 

57class DevNull: 

58 def write(str, *args): 

59 pass 

60 

61 

62class Checkpoint: 

63 _value_prefix = '_values_' 

64 

65 def __init__(self, db='checkpoints.db', logfile=None): 

66 self.db = db 

67 if logfile is None: 

68 logfile = DevNull() 

69 self.logfile = logfile 

70 

71 self.checkpoint_id = [0] 

72 self.in_checkpointed_region = False 

73 

74 def __call__(self, func, *args, **kwargs): 

75 checkpoint_func_name = str(func) 

76 

77 def decorated_func(*args, **kwargs): 

78 # Get the first ase.Atoms object. 

79 atoms = None 

80 for a in args: 

81 if atoms is None and isinstance(a, ase.Atoms): 

82 atoms = a 

83 

84 try: 

85 retvals = self.load(atoms=atoms) 

86 except NoCheckpoint: 

87 retvals = func(*args, **kwargs) 

88 if isinstance(retvals, tuple): 

89 self.save(*retvals, atoms=atoms, 

90 checkpoint_func_name=checkpoint_func_name) 

91 else: 

92 self.save(retvals, atoms=atoms, 

93 checkpoint_func_name=checkpoint_func_name) 

94 return retvals 

95 return decorated_func 

96 

97 def _increase_checkpoint_id(self): 

98 if self.in_checkpointed_region: 

99 self.checkpoint_id += [1] 

100 else: 

101 self.checkpoint_id[-1] += 1 

102 self.logfile.write('Entered checkpoint region ' 

103 '{}.\n'.format(self.checkpoint_id)) 

104 

105 self.in_checkpointed_region = True 

106 

107 def _decrease_checkpoint_id(self): 

108 self.logfile.write('Leaving checkpoint region ' 

109 '{}.\n'.format(self.checkpoint_id)) 

110 if not self.in_checkpointed_region: 

111 self.checkpoint_id = self.checkpoint_id[:-1] 

112 assert len(self.checkpoint_id) >= 1 

113 self.in_checkpointed_region = False 

114 assert self.checkpoint_id[-1] >= 1 

115 

116 def _mangled_checkpoint_id(self): 

117 """ 

118 Returns a mangled checkpoint id string: 

119 check_c_1:c_2:c_3:... 

120 E.g. if checkpoint is nested and id is [3,2,6] it returns: 

121 'check3:2:6' 

122 """ 

123 return 'check' + ':'.join(str(id) for id in self.checkpoint_id) 

124 

125 def load(self, atoms=None): 

126 """ 

127 Retrieve checkpoint data from file. If atoms object is specified, then 

128 the calculator connected to that object is copied to all returning 

129 atoms object. 

130 

131 Returns tuple of values as passed to flush or save during checkpoint 

132 write. 

133 """ 

134 self._increase_checkpoint_id() 

135 

136 retvals = [] 

137 with connect(self.db) as db: 

138 try: 

139 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id()) 

140 except KeyError: 

141 raise NoCheckpoint 

142 

143 data = dbentry.data 

144 atomsi = data['checkpoint_atoms_args_index'] 

145 i = 0 

146 while (i == atomsi or 

147 f'{self._value_prefix}{i}' in data): 

148 if i == atomsi: 

149 newatoms = dbentry.toatoms() 

150 if atoms is not None: 

151 # Assign calculator 

152 newatoms.calc = atoms.calc 

153 retvals += [newatoms] 

154 else: 

155 retvals += [data[f'{self._value_prefix}{i}']] 

156 i += 1 

157 

158 self.logfile.write('Successfully restored checkpoint ' 

159 '{}.\n'.format(self.checkpoint_id)) 

160 self._decrease_checkpoint_id() 

161 if len(retvals) == 1: 

162 return retvals[0] 

163 else: 

164 return tuple(retvals) 

165 

166 def _flush(self, *args, **kwargs): 

167 data = {f'{self._value_prefix}{i}': v 

168 for i, v in enumerate(args)} 

169 

170 try: 

171 atomsi = [isinstance(v, ase.Atoms) for v in args].index(True) 

172 atoms = args[atomsi] 

173 del data[f'{self._value_prefix}{atomsi}'] 

174 except ValueError: 

175 atomsi = -1 

176 try: 

177 atoms = kwargs['atoms'] 

178 except KeyError: 

179 raise RuntimeError('No atoms object provided in arguments.') 

180 

181 try: 

182 del kwargs['atoms'] 

183 except KeyError: 

184 pass 

185 

186 data['checkpoint_atoms_args_index'] = atomsi 

187 data.update(kwargs) 

188 

189 with connect(self.db) as db: 

190 try: 

191 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id()) 

192 del db[dbentry.id] 

193 except KeyError: 

194 pass 

195 db.write(atoms, checkpoint_id=self._mangled_checkpoint_id(), 

196 data=data) 

197 

198 self.logfile.write('Successfully stored checkpoint ' 

199 '{}.\n'.format(self.checkpoint_id)) 

200 

201 def flush(self, *args, **kwargs): 

202 """ 

203 Store data to a checkpoint without increasing the checkpoint id. This 

204 is useful to continuously update the checkpoint state in an iterative 

205 loop. 

206 """ 

207 # If we are flushing from a successfully restored checkpoint, then 

208 # in_checkpointed_region will be set to False. We need to reset to True 

209 # because a call to flush indicates that this checkpoint is still 

210 # active. 

211 self.in_checkpointed_region = False 

212 self._flush(*args, **kwargs) 

213 

214 def save(self, *args, **kwargs): 

215 """ 

216 Store data to a checkpoint and increase the checkpoint id. This closes 

217 the checkpoint. 

218 """ 

219 self._decrease_checkpoint_id() 

220 self._flush(*args, **kwargs) 

221 

222 

223def atoms_almost_equal(a, b, tol=1e-9): 

224 return (np.abs(a.positions - b.positions).max() < tol and 

225 (a.numbers == b.numbers).all() and 

226 np.abs(a.cell - b.cell).max() < tol and 

227 (a.pbc == b.pbc).all()) 

228 

229 

230class CheckpointCalculator(Calculator): 

231 """ 

232 This wraps any calculator object to checkpoint whenever a calculation 

233 is performed. 

234 

235 This is particularly useful for expensive calculators, e.g. DFT and 

236 allows usage of complex workflows. 

237 

238 Example usage: 

239 

240 calc = ... 

241 cp_calc = CheckpointCalculator(calc) 

242 atoms.calc = cp_calc 

243 e = atoms.get_potential_energy() 

244 # 1st time, does calc, writes to checkfile 

245 # subsequent runs, reads from checkpoint file 

246 """ 

247 implemented_properties = ase.calculators.calculator.all_properties 

248 default_parameters: Dict[str, Any] = {} 

249 name = 'CheckpointCalculator' 

250 

251 property_to_method_name = { 

252 'energy': 'get_potential_energy', 

253 'energies': 'get_potential_energies', 

254 'forces': 'get_forces', 

255 'stress': 'get_stress', 

256 'stresses': 'get_stresses'} 

257 

258 def __init__(self, calculator, db='checkpoints.db', logfile=None): 

259 Calculator.__init__(self) 

260 self.calculator = calculator 

261 if logfile is None: 

262 logfile = DevNull() 

263 self.checkpoint = Checkpoint(db, logfile) 

264 self.logfile = logfile 

265 

266 def calculate(self, atoms, properties, system_changes): 

267 Calculator.calculate(self, atoms, properties, system_changes) 

268 try: 

269 results = self.checkpoint.load(atoms) 

270 prev_atoms, results = results[0], results[1:] 

271 try: 

272 assert atoms_almost_equal(atoms, prev_atoms) 

273 except AssertionError: 

274 raise AssertionError('mismatch between current atoms and ' 

275 'those read from checkpoint file') 

276 self.logfile.write('retrieved results for {} from checkpoint\n' 

277 .format(properties)) 

278 # save results in calculator for next time 

279 if isinstance(self.calculator, Calculator): 

280 if not hasattr(self.calculator, 'results'): 

281 self.calculator.results = {} 

282 self.calculator.results.update(dict(zip(properties, results))) 

283 except NoCheckpoint: 

284 if isinstance(self.calculator, Calculator): 

285 self.logfile.write('doing calculation of {} with new-style ' 

286 'calculator interface\n'.format(properties)) 

287 self.calculator.calculate(atoms, properties, system_changes) 

288 results = [self.calculator.results[prop] 

289 for prop in properties] 

290 else: 

291 self.logfile.write('doing calculation of {} with old-style ' 

292 'calculator interface\n'.format(properties)) 

293 results = [] 

294 for prop in properties: 

295 method_name = self.property_to_method_name[prop] 

296 method = getattr(self.calculator, method_name) 

297 results.append(method(atoms)) 

298 _calculator = atoms.calc 

299 try: 

300 atoms.calc = self.calculator 

301 self.checkpoint.save(atoms, *results) 

302 finally: 

303 atoms.calc = _calculator 

304 

305 self.results = dict(zip(properties, results))