Coverage for /builds/kinetik161/ase/ase/io/pickletrajectory.py: 76.87%

268 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import collections 

2import errno 

3import os 

4import pickle 

5import sys 

6import warnings 

7 

8import numpy as np 

9 

10from ase.atoms import Atoms 

11from ase.calculators.calculator import PropertyNotImplementedError 

12from ase.calculators.singlepoint import SinglePointCalculator 

13from ase.constraints import FixAtoms 

14from ase.parallel import barrier, world 

15 

16 

17class PickleTrajectory: 

18 """Reads/writes Atoms objects into a .traj file.""" 

19 # Per default, write these quantities 

20 write_energy = True 

21 write_forces = True 

22 write_stress = True 

23 write_charges = True 

24 write_magmoms = True 

25 write_momenta = True 

26 write_info = True 

27 

28 def __init__(self, filename, mode='r', atoms=None, master=None, 

29 backup=True, _warn=True): 

30 """A PickleTrajectory can be created in read, write or append mode. 

31 

32 Parameters: 

33 

34 filename: 

35 The name of the parameter file. Should end in .traj. 

36 

37 mode='r': 

38 The mode. 

39 

40 'r' is read mode, the file should already exist, and 

41 no atoms argument should be specified. 

42 

43 'w' is write mode. If the file already exists, it is 

44 renamed by appending .bak to the file name. The atoms 

45 argument specifies the Atoms object to be written to the 

46 file, if not given it must instead be given as an argument 

47 to the write() method. 

48 

49 'a' is append mode. It acts a write mode, except that 

50 data is appended to a preexisting file. 

51 

52 atoms=None: 

53 The Atoms object to be written in write or append mode. 

54 

55 master=None: 

56 Controls which process does the actual writing. The 

57 default is that process number 0 does this. If this 

58 argument is given, processes where it is True will write. 

59 

60 backup=True: 

61 Use backup=False to disable renaming of an existing file. 

62 """ 

63 

64 if _warn: 

65 msg = 'Please stop using old trajectory files!' 

66 if mode == 'r': 

67 msg += ('\nConvert to the new future-proof format like this:\n' 

68 '\n $ python3 -m ase.io.trajectory ' + 

69 filename + '\n') 

70 

71 raise RuntimeError(msg) 

72 

73 self.numbers = None 

74 self.pbc = None 

75 self.sanitycheck = True 

76 self.pre_observers = [] # Callback functions before write 

77 self.post_observers = [] # Callback functions after write 

78 

79 # Counter used to determine when callbacks are called: 

80 self.write_counter = 0 

81 

82 self.offsets = [] 

83 if master is None: 

84 master = (world.rank == 0) 

85 self.master = master 

86 self.backup = backup 

87 self.set_atoms(atoms) 

88 self.open(filename, mode) 

89 

90 def open(self, filename, mode): 

91 """Opens the file. 

92 

93 For internal use only. 

94 """ 

95 self.fd = filename 

96 if mode == 'r': 

97 if isinstance(filename, str): 

98 self.fd = open(filename, 'rb') 

99 self.read_header() 

100 elif mode == 'a': 

101 exists = True 

102 if isinstance(filename, str): 

103 exists = os.path.isfile(filename) 

104 if exists: 

105 exists = os.path.getsize(filename) > 0 

106 if exists: 

107 self.fd = open(filename, 'rb') 

108 self.read_header() 

109 self.fd.close() 

110 barrier() 

111 if self.master: 

112 self.fd = open(filename, 'ab+') 

113 else: 

114 self.fd = open(os.devnull, 'ab+') 

115 elif mode == 'w': 

116 if self.master: 

117 if isinstance(filename, str): 

118 if self.backup and os.path.isfile(filename): 

119 try: 

120 os.rename(filename, filename + '.bak') 

121 except OSError as e: 

122 # this must run on Win only! Not atomic! 

123 if e.errno != errno.EEXIST: 

124 raise 

125 os.unlink(filename + '.bak') 

126 os.rename(filename, filename + '.bak') 

127 self.fd = open(filename, 'wb') 

128 else: 

129 self.fd = open(os.devnull, 'wb') 

130 else: 

131 raise ValueError('mode must be "r", "w" or "a".') 

132 

133 def set_atoms(self, atoms=None): 

134 """Associate an Atoms object with the trajectory. 

135 

136 Mostly for internal use. 

137 """ 

138 if atoms is not None and not hasattr(atoms, 'get_positions'): 

139 raise TypeError('"atoms" argument is not an Atoms object.') 

140 self.atoms = atoms 

141 

142 def read_header(self): 

143 if hasattr(self.fd, 'name'): 

144 if os.path.isfile(self.fd.name): 

145 if os.path.getsize(self.fd.name) == 0: 

146 return 

147 self.fd.seek(0) 

148 try: 

149 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory': 

150 raise OSError('This is not a trajectory file!') 

151 d = pickle.load(self.fd) 

152 except EOFError: 

153 raise EOFError('Bad trajectory file.') 

154 

155 self.pbc = d['pbc'] 

156 self.numbers = d['numbers'] 

157 self.tags = d.get('tags') 

158 self.masses = d.get('masses') 

159 self.constraints = dict2constraints(d) 

160 self.offsets.append(self.fd.tell()) 

161 

162 def write(self, atoms=None): 

163 if atoms is None: 

164 atoms = self.atoms 

165 

166 for image in atoms.iterimages(): 

167 self._write_atoms(image) 

168 

169 def _write_atoms(self, atoms): 

170 """Write the atoms to the file. 

171 

172 If the atoms argument is not given, the atoms object specified 

173 when creating the trajectory object is used. 

174 """ 

175 self._call_observers(self.pre_observers) 

176 

177 if len(self.offsets) == 0: 

178 self.write_header(atoms) 

179 else: 

180 if (atoms.pbc != self.pbc).any(): 

181 raise ValueError('Bad periodic boundary conditions!') 

182 elif self.sanitycheck and len(atoms) != len(self.numbers): 

183 raise ValueError('Bad number of atoms!') 

184 elif self.sanitycheck and (atoms.numbers != self.numbers).any(): 

185 raise ValueError('Bad atomic numbers!') 

186 

187 if atoms.has('momenta'): 

188 momenta = atoms.get_momenta() 

189 else: 

190 momenta = None 

191 

192 d = {'positions': atoms.get_positions(), 

193 'cell': atoms.get_cell(), 

194 'momenta': momenta} 

195 

196 if atoms.calc is not None: 

197 if self.write_energy: 

198 d['energy'] = atoms.get_potential_energy() 

199 if self.write_forces: 

200 assert self.write_energy 

201 try: 

202 d['forces'] = atoms.get_forces(apply_constraint=False) 

203 except PropertyNotImplementedError: 

204 pass 

205 if self.write_stress: 

206 assert self.write_energy 

207 try: 

208 d['stress'] = atoms.get_stress() 

209 except PropertyNotImplementedError: 

210 pass 

211 if self.write_charges: 

212 try: 

213 d['charges'] = atoms.get_charges() 

214 except PropertyNotImplementedError: 

215 pass 

216 if self.write_magmoms: 

217 try: 

218 magmoms = atoms.get_magnetic_moments() 

219 if any(np.asarray(magmoms).flat): 

220 d['magmoms'] = magmoms 

221 except (PropertyNotImplementedError, AttributeError): 

222 pass 

223 

224 if 'magmoms' not in d and atoms.has('initial_magmoms'): 

225 d['magmoms'] = atoms.get_initial_magnetic_moments() 

226 if 'charges' not in d and atoms.has('initial_charges'): 

227 charges = atoms.get_initial_charges() 

228 if (charges != 0).any(): 

229 d['charges'] = charges 

230 

231 if self.write_info: 

232 d['info'] = stringnify_info(atoms.info) 

233 

234 if self.master: 

235 pickle.dump(d, self.fd, protocol=2) 

236 self.fd.flush() 

237 self.offsets.append(self.fd.tell()) 

238 self._call_observers(self.post_observers) 

239 self.write_counter += 1 

240 

241 def write_header(self, atoms): 

242 self.fd.write(b'PickleTrajectory') 

243 if atoms.has('tags'): 

244 tags = atoms.get_tags() 

245 else: 

246 tags = None 

247 if atoms.has('masses'): 

248 masses = atoms.get_masses() 

249 else: 

250 masses = None 

251 d = {'version': 3, 

252 'pbc': atoms.get_pbc(), 

253 'numbers': atoms.get_atomic_numbers(), 

254 'tags': tags, 

255 'masses': masses, 

256 'constraints': [], # backwards compatibility 

257 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)} 

258 pickle.dump(d, self.fd, protocol=2) 

259 self.header_written = True 

260 self.offsets.append(self.fd.tell()) 

261 

262 # Atomic numbers and periodic boundary conditions are only 

263 # written once - in the header. Store them here so that we can 

264 # check that they are the same for all images: 

265 self.numbers = atoms.get_atomic_numbers() 

266 self.pbc = atoms.get_pbc() 

267 

268 def close(self): 

269 """Close the trajectory file.""" 

270 self.fd.close() 

271 

272 def __getitem__(self, i=-1): 

273 if isinstance(i, slice): 

274 return [self[j] for j in range(*i.indices(len(self)))] 

275 

276 N = len(self.offsets) 

277 if 0 <= i < N: 

278 self.fd.seek(self.offsets[i]) 

279 try: 

280 d = pickle.load(self.fd, encoding='bytes') 

281 d = {k.decode() if isinstance(k, bytes) else k: v 

282 for k, v in d.items()} 

283 except EOFError: 

284 raise IndexError 

285 if i == N - 1: 

286 self.offsets.append(self.fd.tell()) 

287 charges = d.get('charges') 

288 magmoms = d.get('magmoms') 

289 try: 

290 constraints = [c.copy() for c in self.constraints] 

291 except Exception: 

292 constraints = [] 

293 warnings.warn('Constraints did not unpickle correctly.') 

294 atoms = Atoms(positions=d['positions'], 

295 numbers=self.numbers, 

296 cell=d['cell'], 

297 momenta=d['momenta'], 

298 magmoms=magmoms, 

299 charges=charges, 

300 tags=self.tags, 

301 masses=self.masses, 

302 pbc=self.pbc, 

303 info=unstringnify_info(d.get('info', {})), 

304 constraint=constraints) 

305 if 'energy' in d: 

306 calc = SinglePointCalculator( 

307 atoms, 

308 energy=d.get('energy', None), 

309 forces=d.get('forces', None), 

310 stress=d.get('stress', None), 

311 magmoms=magmoms) 

312 atoms.calc = calc 

313 return atoms 

314 

315 if i >= N: 

316 for j in range(N - 1, i + 1): 

317 atoms = self[j] 

318 return atoms 

319 

320 i = len(self) + i 

321 if i < 0: 

322 raise IndexError('Trajectory index out of range.') 

323 return self[i] 

324 

325 def __len__(self): 

326 if len(self.offsets) == 0: 

327 return 0 

328 N = len(self.offsets) - 1 

329 while True: 

330 self.fd.seek(self.offsets[N]) 

331 try: 

332 pickle.load(self.fd) 

333 except EOFError: 

334 return N 

335 self.offsets.append(self.fd.tell()) 

336 N += 1 

337 

338 def pre_write_attach(self, function, interval=1, *args, **kwargs): 

339 """Attach a function to be called before writing begins. 

340 

341 function: The function or callable object to be called. 

342 

343 interval: How often the function is called. Default: every time (1). 

344 

345 All other arguments are stored, and passed to the function. 

346 """ 

347 if not isinstance(function, collections.abc.Callable): 

348 raise ValueError('Callback object must be callable.') 

349 self.pre_observers.append((function, interval, args, kwargs)) 

350 

351 def post_write_attach(self, function, interval=1, *args, **kwargs): 

352 """Attach a function to be called after writing ends. 

353 

354 function: The function or callable object to be called. 

355 

356 interval: How often the function is called. Default: every time (1). 

357 

358 All other arguments are stored, and passed to the function. 

359 """ 

360 if not isinstance(function, collections.abc.Callable): 

361 raise ValueError('Callback object must be callable.') 

362 self.post_observers.append((function, interval, args, kwargs)) 

363 

364 def _call_observers(self, obs): 

365 """Call pre/post write observers.""" 

366 for function, interval, args, kwargs in obs: 

367 if self.write_counter % interval == 0: 

368 function(*args, **kwargs) 

369 

370 def __enter__(self): 

371 return self 

372 

373 def __exit__(self, *args): 

374 self.close() 

375 

376 

377def stringnify_info(info): 

378 """Return a stringnified version of the dict *info* that is 

379 ensured to be picklable. Items with non-string keys or 

380 unpicklable values are dropped and a warning is issued.""" 

381 stringnified = {} 

382 for k, v in info.items(): 

383 if not isinstance(k, str): 

384 warnings.warn('Non-string info-dict key is not stored in ' + 

385 'trajectory: ' + repr(k), UserWarning) 

386 continue 

387 try: 

388 # Should highest protocol be used here for efficiency? 

389 # Protocol 2 seems not to raise an exception when one 

390 # tries to pickle a file object, so by using that, we 

391 # might end up with file objects in inconsistent states. 

392 s = pickle.dumps(v, protocol=0) 

393 except pickle.PicklingError: 

394 warnings.warn('Skipping not picklable info-dict item: ' + 

395 f'"{k}" ({sys.exc_info()[1]})', UserWarning) 

396 else: 

397 stringnified[k] = s 

398 return stringnified 

399 

400 

401def unstringnify_info(stringnified): 

402 """Convert the dict *stringnified* to a dict with unstringnified 

403 objects and return it. Objects that cannot be unpickled will be 

404 skipped and a warning will be issued.""" 

405 info = {} 

406 for k, s in stringnified.items(): 

407 try: 

408 v = pickle.loads(s) 

409 except pickle.UnpicklingError: 

410 warnings.warn('Skipping not unpicklable info-dict item: ' + 

411 f'"{k}" ({sys.exc_info()[1]})', UserWarning) 

412 else: 

413 info[k] = v 

414 return info 

415 

416 

417def dict2constraints(d): 

418 """Convert dict unpickled from trajectory file to list of constraints.""" 

419 

420 version = d.get('version', 1) 

421 

422 if version == 1: 

423 return d['constraints'] 

424 elif version in (2, 3): 

425 try: 

426 constraints = pickle.loads(d['constraints_string']) 

427 for c in constraints: 

428 if isinstance(c, FixAtoms) and c.index.dtype == bool: 

429 # Special handling of old pickles: 

430 c.index = np.arange(len(c.index))[c.index] 

431 return constraints 

432 except (AttributeError, KeyError, EOFError, ImportError, TypeError): 

433 warnings.warn('Could not unpickle constraints!') 

434 return [] 

435 else: 

436 return []