Coverage for /builds/kinetik161/ase/ase/io/pickletrajectory.py: 76.87%
268 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import collections
2import errno
3import os
4import pickle
5import sys
6import warnings
8import numpy as np
10from ase.atoms import Atoms
11from ase.calculators.calculator import PropertyNotImplementedError
12from ase.calculators.singlepoint import SinglePointCalculator
13from ase.constraints import FixAtoms
14from ase.parallel import barrier, world
17class PickleTrajectory:
18 """Reads/writes Atoms objects into a .traj file."""
19 # Per default, write these quantities
20 write_energy = True
21 write_forces = True
22 write_stress = True
23 write_charges = True
24 write_magmoms = True
25 write_momenta = True
26 write_info = True
28 def __init__(self, filename, mode='r', atoms=None, master=None,
29 backup=True, _warn=True):
30 """A PickleTrajectory can be created in read, write or append mode.
32 Parameters:
34 filename:
35 The name of the parameter file. Should end in .traj.
37 mode='r':
38 The mode.
40 'r' is read mode, the file should already exist, and
41 no atoms argument should be specified.
43 'w' is write mode. If the file already exists, it is
44 renamed by appending .bak to the file name. The atoms
45 argument specifies the Atoms object to be written to the
46 file, if not given it must instead be given as an argument
47 to the write() method.
49 'a' is append mode. It acts a write mode, except that
50 data is appended to a preexisting file.
52 atoms=None:
53 The Atoms object to be written in write or append mode.
55 master=None:
56 Controls which process does the actual writing. The
57 default is that process number 0 does this. If this
58 argument is given, processes where it is True will write.
60 backup=True:
61 Use backup=False to disable renaming of an existing file.
62 """
64 if _warn:
65 msg = 'Please stop using old trajectory files!'
66 if mode == 'r':
67 msg += ('\nConvert to the new future-proof format like this:\n'
68 '\n $ python3 -m ase.io.trajectory ' +
69 filename + '\n')
71 raise RuntimeError(msg)
73 self.numbers = None
74 self.pbc = None
75 self.sanitycheck = True
76 self.pre_observers = [] # Callback functions before write
77 self.post_observers = [] # Callback functions after write
79 # Counter used to determine when callbacks are called:
80 self.write_counter = 0
82 self.offsets = []
83 if master is None:
84 master = (world.rank == 0)
85 self.master = master
86 self.backup = backup
87 self.set_atoms(atoms)
88 self.open(filename, mode)
90 def open(self, filename, mode):
91 """Opens the file.
93 For internal use only.
94 """
95 self.fd = filename
96 if mode == 'r':
97 if isinstance(filename, str):
98 self.fd = open(filename, 'rb')
99 self.read_header()
100 elif mode == 'a':
101 exists = True
102 if isinstance(filename, str):
103 exists = os.path.isfile(filename)
104 if exists:
105 exists = os.path.getsize(filename) > 0
106 if exists:
107 self.fd = open(filename, 'rb')
108 self.read_header()
109 self.fd.close()
110 barrier()
111 if self.master:
112 self.fd = open(filename, 'ab+')
113 else:
114 self.fd = open(os.devnull, 'ab+')
115 elif mode == 'w':
116 if self.master:
117 if isinstance(filename, str):
118 if self.backup and os.path.isfile(filename):
119 try:
120 os.rename(filename, filename + '.bak')
121 except OSError as e:
122 # this must run on Win only! Not atomic!
123 if e.errno != errno.EEXIST:
124 raise
125 os.unlink(filename + '.bak')
126 os.rename(filename, filename + '.bak')
127 self.fd = open(filename, 'wb')
128 else:
129 self.fd = open(os.devnull, 'wb')
130 else:
131 raise ValueError('mode must be "r", "w" or "a".')
133 def set_atoms(self, atoms=None):
134 """Associate an Atoms object with the trajectory.
136 Mostly for internal use.
137 """
138 if atoms is not None and not hasattr(atoms, 'get_positions'):
139 raise TypeError('"atoms" argument is not an Atoms object.')
140 self.atoms = atoms
142 def read_header(self):
143 if hasattr(self.fd, 'name'):
144 if os.path.isfile(self.fd.name):
145 if os.path.getsize(self.fd.name) == 0:
146 return
147 self.fd.seek(0)
148 try:
149 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory':
150 raise OSError('This is not a trajectory file!')
151 d = pickle.load(self.fd)
152 except EOFError:
153 raise EOFError('Bad trajectory file.')
155 self.pbc = d['pbc']
156 self.numbers = d['numbers']
157 self.tags = d.get('tags')
158 self.masses = d.get('masses')
159 self.constraints = dict2constraints(d)
160 self.offsets.append(self.fd.tell())
162 def write(self, atoms=None):
163 if atoms is None:
164 atoms = self.atoms
166 for image in atoms.iterimages():
167 self._write_atoms(image)
169 def _write_atoms(self, atoms):
170 """Write the atoms to the file.
172 If the atoms argument is not given, the atoms object specified
173 when creating the trajectory object is used.
174 """
175 self._call_observers(self.pre_observers)
177 if len(self.offsets) == 0:
178 self.write_header(atoms)
179 else:
180 if (atoms.pbc != self.pbc).any():
181 raise ValueError('Bad periodic boundary conditions!')
182 elif self.sanitycheck and len(atoms) != len(self.numbers):
183 raise ValueError('Bad number of atoms!')
184 elif self.sanitycheck and (atoms.numbers != self.numbers).any():
185 raise ValueError('Bad atomic numbers!')
187 if atoms.has('momenta'):
188 momenta = atoms.get_momenta()
189 else:
190 momenta = None
192 d = {'positions': atoms.get_positions(),
193 'cell': atoms.get_cell(),
194 'momenta': momenta}
196 if atoms.calc is not None:
197 if self.write_energy:
198 d['energy'] = atoms.get_potential_energy()
199 if self.write_forces:
200 assert self.write_energy
201 try:
202 d['forces'] = atoms.get_forces(apply_constraint=False)
203 except PropertyNotImplementedError:
204 pass
205 if self.write_stress:
206 assert self.write_energy
207 try:
208 d['stress'] = atoms.get_stress()
209 except PropertyNotImplementedError:
210 pass
211 if self.write_charges:
212 try:
213 d['charges'] = atoms.get_charges()
214 except PropertyNotImplementedError:
215 pass
216 if self.write_magmoms:
217 try:
218 magmoms = atoms.get_magnetic_moments()
219 if any(np.asarray(magmoms).flat):
220 d['magmoms'] = magmoms
221 except (PropertyNotImplementedError, AttributeError):
222 pass
224 if 'magmoms' not in d and atoms.has('initial_magmoms'):
225 d['magmoms'] = atoms.get_initial_magnetic_moments()
226 if 'charges' not in d and atoms.has('initial_charges'):
227 charges = atoms.get_initial_charges()
228 if (charges != 0).any():
229 d['charges'] = charges
231 if self.write_info:
232 d['info'] = stringnify_info(atoms.info)
234 if self.master:
235 pickle.dump(d, self.fd, protocol=2)
236 self.fd.flush()
237 self.offsets.append(self.fd.tell())
238 self._call_observers(self.post_observers)
239 self.write_counter += 1
241 def write_header(self, atoms):
242 self.fd.write(b'PickleTrajectory')
243 if atoms.has('tags'):
244 tags = atoms.get_tags()
245 else:
246 tags = None
247 if atoms.has('masses'):
248 masses = atoms.get_masses()
249 else:
250 masses = None
251 d = {'version': 3,
252 'pbc': atoms.get_pbc(),
253 'numbers': atoms.get_atomic_numbers(),
254 'tags': tags,
255 'masses': masses,
256 'constraints': [], # backwards compatibility
257 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)}
258 pickle.dump(d, self.fd, protocol=2)
259 self.header_written = True
260 self.offsets.append(self.fd.tell())
262 # Atomic numbers and periodic boundary conditions are only
263 # written once - in the header. Store them here so that we can
264 # check that they are the same for all images:
265 self.numbers = atoms.get_atomic_numbers()
266 self.pbc = atoms.get_pbc()
268 def close(self):
269 """Close the trajectory file."""
270 self.fd.close()
272 def __getitem__(self, i=-1):
273 if isinstance(i, slice):
274 return [self[j] for j in range(*i.indices(len(self)))]
276 N = len(self.offsets)
277 if 0 <= i < N:
278 self.fd.seek(self.offsets[i])
279 try:
280 d = pickle.load(self.fd, encoding='bytes')
281 d = {k.decode() if isinstance(k, bytes) else k: v
282 for k, v in d.items()}
283 except EOFError:
284 raise IndexError
285 if i == N - 1:
286 self.offsets.append(self.fd.tell())
287 charges = d.get('charges')
288 magmoms = d.get('magmoms')
289 try:
290 constraints = [c.copy() for c in self.constraints]
291 except Exception:
292 constraints = []
293 warnings.warn('Constraints did not unpickle correctly.')
294 atoms = Atoms(positions=d['positions'],
295 numbers=self.numbers,
296 cell=d['cell'],
297 momenta=d['momenta'],
298 magmoms=magmoms,
299 charges=charges,
300 tags=self.tags,
301 masses=self.masses,
302 pbc=self.pbc,
303 info=unstringnify_info(d.get('info', {})),
304 constraint=constraints)
305 if 'energy' in d:
306 calc = SinglePointCalculator(
307 atoms,
308 energy=d.get('energy', None),
309 forces=d.get('forces', None),
310 stress=d.get('stress', None),
311 magmoms=magmoms)
312 atoms.calc = calc
313 return atoms
315 if i >= N:
316 for j in range(N - 1, i + 1):
317 atoms = self[j]
318 return atoms
320 i = len(self) + i
321 if i < 0:
322 raise IndexError('Trajectory index out of range.')
323 return self[i]
325 def __len__(self):
326 if len(self.offsets) == 0:
327 return 0
328 N = len(self.offsets) - 1
329 while True:
330 self.fd.seek(self.offsets[N])
331 try:
332 pickle.load(self.fd)
333 except EOFError:
334 return N
335 self.offsets.append(self.fd.tell())
336 N += 1
338 def pre_write_attach(self, function, interval=1, *args, **kwargs):
339 """Attach a function to be called before writing begins.
341 function: The function or callable object to be called.
343 interval: How often the function is called. Default: every time (1).
345 All other arguments are stored, and passed to the function.
346 """
347 if not isinstance(function, collections.abc.Callable):
348 raise ValueError('Callback object must be callable.')
349 self.pre_observers.append((function, interval, args, kwargs))
351 def post_write_attach(self, function, interval=1, *args, **kwargs):
352 """Attach a function to be called after writing ends.
354 function: The function or callable object to be called.
356 interval: How often the function is called. Default: every time (1).
358 All other arguments are stored, and passed to the function.
359 """
360 if not isinstance(function, collections.abc.Callable):
361 raise ValueError('Callback object must be callable.')
362 self.post_observers.append((function, interval, args, kwargs))
364 def _call_observers(self, obs):
365 """Call pre/post write observers."""
366 for function, interval, args, kwargs in obs:
367 if self.write_counter % interval == 0:
368 function(*args, **kwargs)
370 def __enter__(self):
371 return self
373 def __exit__(self, *args):
374 self.close()
377def stringnify_info(info):
378 """Return a stringnified version of the dict *info* that is
379 ensured to be picklable. Items with non-string keys or
380 unpicklable values are dropped and a warning is issued."""
381 stringnified = {}
382 for k, v in info.items():
383 if not isinstance(k, str):
384 warnings.warn('Non-string info-dict key is not stored in ' +
385 'trajectory: ' + repr(k), UserWarning)
386 continue
387 try:
388 # Should highest protocol be used here for efficiency?
389 # Protocol 2 seems not to raise an exception when one
390 # tries to pickle a file object, so by using that, we
391 # might end up with file objects in inconsistent states.
392 s = pickle.dumps(v, protocol=0)
393 except pickle.PicklingError:
394 warnings.warn('Skipping not picklable info-dict item: ' +
395 f'"{k}" ({sys.exc_info()[1]})', UserWarning)
396 else:
397 stringnified[k] = s
398 return stringnified
401def unstringnify_info(stringnified):
402 """Convert the dict *stringnified* to a dict with unstringnified
403 objects and return it. Objects that cannot be unpickled will be
404 skipped and a warning will be issued."""
405 info = {}
406 for k, s in stringnified.items():
407 try:
408 v = pickle.loads(s)
409 except pickle.UnpicklingError:
410 warnings.warn('Skipping not unpicklable info-dict item: ' +
411 f'"{k}" ({sys.exc_info()[1]})', UserWarning)
412 else:
413 info[k] = v
414 return info
417def dict2constraints(d):
418 """Convert dict unpickled from trajectory file to list of constraints."""
420 version = d.get('version', 1)
422 if version == 1:
423 return d['constraints']
424 elif version in (2, 3):
425 try:
426 constraints = pickle.loads(d['constraints_string'])
427 for c in constraints:
428 if isinstance(c, FixAtoms) and c.index.dtype == bool:
429 # Special handling of old pickles:
430 c.index = np.arange(len(c.index))[c.index]
431 return constraints
432 except (AttributeError, KeyError, EOFError, ImportError, TypeError):
433 warnings.warn('Could not unpickle constraints!')
434 return []
435 else:
436 return []