Coverage for /builds/kinetik161/ase/ase/io/netcdftrajectory.py: 82.83%
361 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""
2netcdftrajectory - I/O trajectory files in the AMBER NetCDF convention
4More information on the AMBER NetCDF conventions can be found at
5http://ambermd.org/netcdf/. This module supports extensions to
6these conventions, such as writing of additional fields and writing to
7HDF5 (NetCDF-4) files.
9A netCDF4-python is required by this module:
11 netCDF4-python - https://github.com/Unidata/netcdf4-python
13NetCDF files can be directly visualized using the libAtoms flavor of
14AtomEye (http://www.libatoms.org/),
15VMD (http://www.ks.uiuc.edu/Research/vmd/)
16or Ovito (http://www.ovito.org/, starting with version 2.3).
17"""
20import collections
21import os
22import warnings
23from functools import reduce
25import numpy as np
27import ase
28from ase.data import atomic_masses
29from ase.geometry import cellpar_to_cell
32class NetCDFTrajectory:
33 """
34 Reads/writes Atoms objects into an AMBER-style .nc trajectory file.
35 """
37 # Default dimension names
38 _frame_dim = 'frame'
39 _spatial_dim = 'spatial'
40 _atom_dim = 'atom'
41 _cell_spatial_dim = 'cell_spatial'
42 _cell_angular_dim = 'cell_angular'
43 _label_dim = 'label'
44 _Voigt_dim = 'Voigt' # For stress/strain tensors
46 # Default field names. If it is a list, check for any of these names upon
47 # opening. Upon writing, use the first name.
48 _spatial_var = 'spatial'
49 _cell_spatial_var = 'cell_spatial'
50 _cell_angular_var = 'cell_angular'
51 _time_var = 'time'
52 _numbers_var = ['atom_types', 'type', 'Z']
53 _positions_var = 'coordinates'
54 _velocities_var = 'velocities'
55 _cell_origin_var = 'cell_origin'
56 _cell_lengths_var = 'cell_lengths'
57 _cell_angles_var = 'cell_angles'
59 _default_vars = reduce(lambda x, y: x + y,
60 [_numbers_var, [_positions_var], [_velocities_var],
61 [_cell_origin_var], [_cell_lengths_var],
62 [_cell_angles_var]])
64 def __init__(self, filename, mode='r', atoms=None, types_to_numbers=None,
65 double=True, netcdf_format='NETCDF3_CLASSIC', keep_open=True,
66 index_var='id', chunk_size=1000000):
67 """
68 A NetCDFTrajectory can be created in read, write or append mode.
70 Parameters:
72 filename:
73 The name of the parameter file. Should end in .nc.
75 mode='r':
76 The mode.
78 'r' is read mode, the file should already exist, and no atoms
79 argument should be specified.
81 'w' is write mode. The atoms argument specifies the Atoms object
82 to be written to the file, if not given it must instead be given
83 as an argument to the write() method.
85 'a' is append mode. It acts a write mode, except that data is
86 appended to a preexisting file.
88 atoms=None:
89 The Atoms object to be written in write or append mode.
91 types_to_numbers=None:
92 Dictionary or list for conversion of atom types to atomic numbers
93 when reading a trajectory file.
95 double=True:
96 Create new variable in double precision.
98 netcdf_format='NETCDF3_CLASSIC':
99 Format string for the underlying NetCDF file format. Only relevant
100 if a new file is created. More information can be found at
101 https://www.unidata.ucar.edu/software/netcdf/docs/netcdf/File-Format.html
103 'NETCDF3_CLASSIC' is the original binary format.
105 'NETCDF3_64BIT' can be used to write larger files.
107 'NETCDF4_CLASSIC' is HDF5 with some NetCDF limitations.
109 'NETCDF4' is HDF5.
111 keep_open=True:
112 Keep the file open during consecutive read/write operations.
113 Set to false if you experience data corruption. This will close the
114 file after each read/write operation by comes with serious
115 performance penalty.
117 index_var='id':
118 Name of variable containing the atom indices. Atoms are reordered
119 by this index upon reading if this variable is present. Default
120 value is for LAMMPS output. None switches atom indices off.
122 chunk_size=1000000:
123 Maximum size of consecutive number of records (along the 'atom')
124 dimension read when reading from a NetCDF file. This is used to
125 reduce the memory footprint of a read operation on very large files.
126 """
127 self.nc = None
128 self.chunk_size = chunk_size
130 self.numbers = None
131 self.pre_observers = [] # Callback functions before write
132 self.post_observers = [] # Callback functions after write are called
134 self.has_header = False
135 self._set_atoms(atoms)
137 self.types_to_numbers = None
138 if isinstance(types_to_numbers, list):
139 types_to_numbers = {x: y for x, y in enumerate(types_to_numbers)}
140 if types_to_numbers is not None:
141 self.types_to_numbers = types_to_numbers
143 self.index_var = index_var
145 if self.index_var is not None:
146 self._default_vars += [self.index_var]
148 # 'l' should be a valid type according to the netcdf4-python
149 # documentation, but does not appear to work.
150 self.dtype_conv = {'l': 'i'}
151 if not double:
152 self.dtype_conv.update(dict(d='f'))
154 self.extra_per_frame_vars = []
155 self.extra_per_file_vars = []
156 # per frame atts are global quantities, not quantities stored for each
157 # atom
158 self.extra_per_frame_atts = []
160 self.mode = mode
161 self.netcdf_format = netcdf_format
163 if atoms:
164 self.n_atoms = len(atoms)
165 else:
166 self.n_atoms = None
168 self.filename = filename
169 if keep_open is None:
170 # Only netCDF4-python supports append to files
171 self.keep_open = self.mode == 'r'
172 else:
173 self.keep_open = keep_open
175 def __del__(self):
176 self.close()
178 def _open(self):
179 """
180 Opens the file.
182 For internal use only.
183 """
184 import netCDF4
185 if self.nc is not None:
186 return
187 if self.mode == 'a' and not os.path.exists(self.filename):
188 self.mode = 'w'
189 self.nc = netCDF4.Dataset(self.filename, self.mode,
190 format=self.netcdf_format)
192 self.frame = 0
193 if self.mode == 'r' or self.mode == 'a':
194 self._read_header()
195 self.frame = self._len()
197 def _set_atoms(self, atoms=None):
198 """
199 Associate an Atoms object with the trajectory.
201 For internal use only.
202 """
203 if atoms is not None and not hasattr(atoms, 'get_positions'):
204 raise TypeError('"atoms" argument is not an Atoms object.')
205 self.atoms = atoms
207 def _read_header(self):
208 if not self.n_atoms:
209 self.n_atoms = len(self.nc.dimensions[self._atom_dim])
211 for name, var in self.nc.variables.items():
212 # This can be unicode which confuses ASE
213 name = str(name)
214 # _default_vars is taken care of already
215 if name not in self._default_vars:
216 if len(var.dimensions) >= 2:
217 if var.dimensions[0] == self._frame_dim:
218 if var.dimensions[1] == self._atom_dim:
219 self.extra_per_frame_vars += [name]
220 else:
221 self.extra_per_frame_atts += [name]
223 elif len(var.dimensions) == 1:
224 if var.dimensions[0] == self._atom_dim:
225 self.extra_per_file_vars += [name]
226 elif var.dimensions[0] == self._frame_dim:
227 self.extra_per_frame_atts += [name]
229 self.has_header = True
231 def write(self, atoms=None, frame=None, arrays=None, time=None):
232 """
233 Write the atoms to the file.
235 If the atoms argument is not given, the atoms object specified
236 when creating the trajectory object is used.
237 """
238 self._open()
239 self._call_observers(self.pre_observers)
240 if atoms is None:
241 atoms = self.atoms
243 if hasattr(atoms, 'interpolate'):
244 # seems to be a NEB
245 neb = atoms
246 assert not neb.parallel
247 try:
248 neb.get_energies_and_forces(all=True)
249 except AttributeError:
250 pass
251 for image in neb.images:
252 self.write(image)
253 return
255 if not self.has_header:
256 self._define_file_structure(atoms)
257 else:
258 if len(atoms) != self.n_atoms:
259 raise ValueError('Bad number of atoms!')
261 if frame is None:
262 i = self.frame
263 else:
264 i = frame
266 # Number can be per file variable
267 numbers = self._get_variable(self._numbers_var)
268 if numbers.dimensions[0] == self._frame_dim:
269 numbers[i] = atoms.get_atomic_numbers()
270 else:
271 if np.any(numbers != atoms.get_atomic_numbers()):
272 raise ValueError('Atomic numbers do not match!')
273 self._get_variable(self._positions_var)[i] = atoms.get_positions()
274 if atoms.has('momenta'):
275 self._add_velocities()
276 self._get_variable(self._velocities_var)[i] = \
277 atoms.get_momenta() / atoms.get_masses().reshape(-1, 1)
278 a, b, c, alpha, beta, gamma = atoms.cell.cellpar()
279 if np.any(np.logical_not(atoms.pbc)):
280 warnings.warn('Atoms have nonperiodic directions. Cell lengths in '
281 'these directions are lost and will be '
282 'shrink-wrapped when reading the NetCDF file.')
283 cell_lengths = np.array([a, b, c]) * atoms.pbc
284 self._get_variable(self._cell_lengths_var)[i] = cell_lengths
285 self._get_variable(self._cell_angles_var)[i] = [alpha, beta, gamma]
286 self._get_variable(self._cell_origin_var)[i] = \
287 atoms.get_celldisp().reshape(3)
288 if arrays is not None:
289 for array in arrays:
290 data = atoms.get_array(array)
291 if array in self.extra_per_file_vars:
292 # This field exists but is per file data. Check that the
293 # data remains consistent.
294 if np.any(self._get_variable(array) != data):
295 raise ValueError('Trying to write Atoms object with '
296 'incompatible data for the {} '
297 'array.'.format(array))
298 else:
299 self._add_array(atoms, array, data.dtype, data.shape)
300 self._get_variable(array)[i] = data
301 if time is not None:
302 self._add_time()
303 self._get_variable(self._time_var)[i] = time
305 self.sync()
307 self._call_observers(self.post_observers)
308 self.frame += 1
309 self._close()
311 def write_arrays(self, atoms, frame, arrays):
312 self._open()
313 self._call_observers(self.pre_observers)
314 for array in arrays:
315 data = atoms.get_array(array)
316 if array in self.extra_per_file_vars:
317 # This field exists but is per file data. Check that the
318 # data remains consistent.
319 if np.any(self._get_variable(array) != data):
320 raise ValueError('Trying to write Atoms object with '
321 'incompatible data for the {} '
322 'array.'.format(array))
323 else:
324 self._add_array(atoms, array, data.dtype, data.shape)
325 self._get_variable(array)[frame] = data
326 self._call_observers(self.post_observers)
327 self._close()
329 def _define_file_structure(self, atoms):
330 self.nc.Conventions = 'AMBER'
331 self.nc.ConventionVersion = '1.0'
332 self.nc.program = 'ASE'
333 self.nc.programVersion = ase.__version__
334 self.nc.title = "MOL"
336 if self._frame_dim not in self.nc.dimensions:
337 self.nc.createDimension(self._frame_dim, None)
338 if self._spatial_dim not in self.nc.dimensions:
339 self.nc.createDimension(self._spatial_dim, 3)
340 if self._atom_dim not in self.nc.dimensions:
341 self.nc.createDimension(self._atom_dim, len(atoms))
342 if self._cell_spatial_dim not in self.nc.dimensions:
343 self.nc.createDimension(self._cell_spatial_dim, 3)
344 if self._cell_angular_dim not in self.nc.dimensions:
345 self.nc.createDimension(self._cell_angular_dim, 3)
346 if self._label_dim not in self.nc.dimensions:
347 self.nc.createDimension(self._label_dim, 5)
349 # Self-describing variables from AMBER convention
350 if not self._has_variable(self._spatial_var):
351 self.nc.createVariable(self._spatial_var, 'S1',
352 (self._spatial_dim,))
353 self.nc.variables[self._spatial_var][:] = ['x', 'y', 'z']
354 if not self._has_variable(self._cell_spatial_var):
355 self.nc.createVariable(self._cell_spatial_dim, 'S1',
356 (self._cell_spatial_dim,))
357 self.nc.variables[self._cell_spatial_var][:] = ['a', 'b', 'c']
358 if not self._has_variable(self._cell_angular_var):
359 self.nc.createVariable(self._cell_angular_var, 'S1',
360 (self._cell_angular_dim, self._label_dim,))
361 self.nc.variables[self._cell_angular_var][0] = [x for x in 'alpha']
362 self.nc.variables[self._cell_angular_var][1] = [x for x in 'beta ']
363 self.nc.variables[self._cell_angular_var][2] = [x for x in 'gamma']
365 if not self._has_variable(self._numbers_var):
366 self.nc.createVariable(self._numbers_var[0], 'i',
367 (self._frame_dim, self._atom_dim,))
368 if not self._has_variable(self._positions_var):
369 self.nc.createVariable(self._positions_var, 'f4',
370 (self._frame_dim, self._atom_dim,
371 self._spatial_dim))
372 self.nc.variables[self._positions_var].units = 'Angstrom'
373 self.nc.variables[self._positions_var].scale_factor = 1.
374 if not self._has_variable(self._cell_lengths_var):
375 self.nc.createVariable(self._cell_lengths_var, 'd',
376 (self._frame_dim, self._cell_spatial_dim))
377 self.nc.variables[self._cell_lengths_var].units = 'Angstrom'
378 self.nc.variables[self._cell_lengths_var].scale_factor = 1.
379 if not self._has_variable(self._cell_angles_var):
380 self.nc.createVariable(self._cell_angles_var, 'd',
381 (self._frame_dim, self._cell_angular_dim))
382 self.nc.variables[self._cell_angles_var].units = 'degree'
383 if not self._has_variable(self._cell_origin_var):
384 self.nc.createVariable(self._cell_origin_var, 'd',
385 (self._frame_dim, self._cell_spatial_dim))
386 self.nc.variables[self._cell_origin_var].units = 'Angstrom'
387 self.nc.variables[self._cell_origin_var].scale_factor = 1.
389 def _add_time(self):
390 if not self._has_variable(self._time_var):
391 self.nc.createVariable(self._time_var, 'f8', (self._frame_dim,))
393 def _add_velocities(self):
394 if not self._has_variable(self._velocities_var):
395 self.nc.createVariable(self._velocities_var, 'f4',
396 (self._frame_dim, self._atom_dim,
397 self._spatial_dim))
398 self.nc.variables[self._positions_var].units = \
399 'Angstrom/Femtosecond'
400 self.nc.variables[self._positions_var].scale_factor = 1.
402 def _add_array(self, atoms, array_name, type, shape):
403 if not self._has_variable(array_name):
404 dims = [self._frame_dim]
405 for i in shape:
406 if i == len(atoms):
407 dims += [self._atom_dim]
408 elif i == 3:
409 dims += [self._spatial_dim]
410 elif i == 6:
411 # This can only be stress/strain tensor in Voigt notation
412 if self._Voigt_dim not in self.nc.dimensions:
413 self.nc.createDimension(self._Voigt_dim, 6)
414 dims += [self._Voigt_dim]
415 else:
416 raise TypeError("Don't know how to dump array of shape {}"
417 " into NetCDF trajectory.".format(shape))
418 if hasattr(type, 'char'):
419 t = self.dtype_conv.get(type.char, type)
420 else:
421 t = type
422 self.nc.createVariable(array_name, t, dims)
424 def _get_variable(self, name, exc=True):
425 if isinstance(name, list):
426 for n in name:
427 if n in self.nc.variables:
428 return self.nc.variables[n]
429 if exc:
430 raise RuntimeError(
431 'None of the variables {} was found in the '
432 'NetCDF trajectory.'.format(', '.join(name)))
433 else:
434 if name in self.nc.variables:
435 return self.nc.variables[name]
436 if exc:
437 raise RuntimeError('Variables {} was found in the NetCDF '
438 'trajectory.'.format(name))
439 return None
441 def _has_variable(self, name):
442 if isinstance(name, list):
443 for n in name:
444 if n in self.nc.variables:
445 return True
446 return False
447 else:
448 return name in self.nc.variables
450 def _get_data(self, name, frame, index, exc=True):
451 var = self._get_variable(name, exc=exc)
452 if var is None:
453 return None
454 if var.dimensions[0] == self._frame_dim:
455 data = np.zeros(var.shape[1:], dtype=var.dtype)
456 s = var.shape[1]
457 if s < self.chunk_size:
458 data[index] = var[frame]
459 else:
460 # If this is a large data set, only read chunks from it to
461 # reduce memory footprint of the NetCDFTrajectory reader.
462 for i in range((s - 1) // self.chunk_size + 1):
463 sl = slice(i * self.chunk_size,
464 min((i + 1) * self.chunk_size, s))
465 data[index[sl]] = var[frame, sl]
466 else:
467 data = np.zeros(var.shape, dtype=var.dtype)
468 s = var.shape[0]
469 if s < self.chunk_size:
470 data[index] = var[...]
471 else:
472 # If this is a large data set, only read chunks from it to
473 # reduce memory footprint of the NetCDFTrajectory reader.
474 for i in range((s - 1) // self.chunk_size + 1):
475 sl = slice(i * self.chunk_size,
476 min((i + 1) * self.chunk_size, s))
477 data[index[sl]] = var[sl]
478 return data
480 def __enter__(self):
481 return self
483 def __exit__(self, *args):
484 self.close()
486 def close(self):
487 """Close the trajectory file."""
488 if self.nc is not None:
489 self.nc.close()
490 self.nc = None
492 def _close(self):
493 if not self.keep_open:
494 self.close()
495 if self.mode == 'w':
496 self.mode = 'a'
498 def sync(self):
499 self.nc.sync()
501 def __getitem__(self, i=-1):
502 self._open()
504 if isinstance(i, slice):
505 return [self[j] for j in range(*i.indices(self._len()))]
507 N = self._len()
508 if 0 <= i < N:
509 # Non-periodic boundaries have cell_length == 0.0
510 cell_lengths = \
511 np.array(self.nc.variables[self._cell_lengths_var][i][:])
512 pbc = np.abs(cell_lengths > 1e-6)
514 # Do we have a cell origin?
515 if self._has_variable(self._cell_origin_var):
516 origin = np.array(
517 self.nc.variables[self._cell_origin_var][i][:])
518 else:
519 origin = np.zeros([3], dtype=float)
521 # Do we have an index variable?
522 if (self.index_var is not None and
523 self._has_variable(self.index_var)):
524 index = np.array(self.nc.variables[self.index_var][i][:])
525 # The index variable can be non-consecutive, we here construct
526 # a consecutive one.
527 consecutive_index = np.zeros_like(index)
528 consecutive_index[np.argsort(index)] = np.arange(self.n_atoms)
529 else:
530 consecutive_index = np.arange(self.n_atoms)
532 # Read element numbers
533 self.numbers = self._get_data(self._numbers_var, i,
534 consecutive_index, exc=False)
535 if self.numbers is None:
536 self.numbers = np.ones(self.n_atoms, dtype=int)
537 if self.types_to_numbers is not None:
538 d = set(self.numbers).difference(self.types_to_numbers.keys())
539 if len(d) > 0:
540 self.types_to_numbers.update({num: num for num in d})
541 func = np.vectorize(self.types_to_numbers.get)
542 self.numbers = func(self.numbers)
543 self.masses = atomic_masses[self.numbers]
545 # Read positions
546 positions = self._get_data(self._positions_var, i,
547 consecutive_index)
549 # Determine cell size for non-periodic directions from shrink
550 # wrapped cell.
551 for dim in np.arange(3)[np.logical_not(pbc)]:
552 origin[dim] = positions[:, dim].min()
553 cell_lengths[dim] = positions[:, dim].max() - origin[dim]
555 # Construct cell shape from cell lengths and angles
556 cell = cellpar_to_cell(
557 list(cell_lengths) +
558 list(self.nc.variables[self._cell_angles_var][i])
559 )
561 # Compute momenta from velocities (if present)
562 momenta = self._get_data(self._velocities_var, i,
563 consecutive_index, exc=False)
564 if momenta is not None:
565 momenta *= self.masses.reshape(-1, 1)
567 # Fill info dict with additional data found in the NetCDF file
568 info = {}
569 for name in self.extra_per_frame_atts:
570 info[name] = np.array(self.nc.variables[name][i])
572 # Create atoms object
573 atoms = ase.Atoms(
574 positions=positions,
575 numbers=self.numbers,
576 cell=cell,
577 celldisp=origin,
578 momenta=momenta,
579 masses=self.masses,
580 pbc=pbc,
581 info=info
582 )
584 # Attach additional arrays found in the NetCDF file
585 for name in self.extra_per_frame_vars:
586 atoms.set_array(name, self._get_data(name, i,
587 consecutive_index))
588 for name in self.extra_per_file_vars:
589 atoms.set_array(name, self._get_data(name, i,
590 consecutive_index))
591 self._close()
592 return atoms
594 i = N + i
595 if i < 0 or i >= N:
596 self._close()
597 raise IndexError('Trajectory index out of range.')
598 return self[i]
600 def _len(self):
601 if self._frame_dim in self.nc.dimensions:
602 return int(self._get_variable(self._positions_var).shape[0])
603 else:
604 return 0
606 def __len__(self):
607 self._open()
608 n_frames = self._len()
609 self._close()
610 return n_frames
612 def pre_write_attach(self, function, interval=1, *args, **kwargs):
613 """
614 Attach a function to be called before writing begins.
616 function: The function or callable object to be called.
618 interval: How often the function is called. Default: every time (1).
620 All other arguments are stored, and passed to the function.
621 """
622 if not isinstance(function, collections.abc.Callable):
623 raise ValueError('Callback object must be callable.')
624 self.pre_observers.append((function, interval, args, kwargs))
626 def post_write_attach(self, function, interval=1, *args, **kwargs):
627 """
628 Attach a function to be called after writing ends.
630 function: The function or callable object to be called.
632 interval: How often the function is called. Default: every time (1).
634 All other arguments are stored, and passed to the function.
635 """
636 if not isinstance(function, collections.abc.Callable):
637 raise ValueError('Callback object must be callable.')
638 self.post_observers.append((function, interval, args, kwargs))
640 def _call_observers(self, obs):
641 """Call pre/post write observers."""
642 for function, interval, args, kwargs in obs:
643 if self.write_counter % interval == 0:
644 function(*args, **kwargs)
647def read_netcdftrajectory(filename, index=-1):
648 with NetCDFTrajectory(filename, mode='r') as traj:
649 return traj[index]
652def write_netcdftrajectory(filename, images):
653 if hasattr(images, 'get_positions'):
654 images = [images]
656 with NetCDFTrajectory(filename, mode='w') as traj:
657 for atoms in images:
658 traj.write(atoms)