Coverage for /builds/kinetik161/ase/ase/io/cif.py: 87.80%
492 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""Module to read and write atoms in cif file format.
3See http://www.iucr.org/resources/cif/spec/version1.1/cifsyntax for a
4description of the file format. STAR extensions as save frames,
5global blocks, nested loops and multi-data values are not supported.
6The "latin-1" encoding is required by the IUCR specification.
7"""
9import collections.abc
10import io
11import re
12import shlex
13import warnings
14from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
16import numpy as np
18from ase import Atoms
19from ase.cell import Cell
20from ase.io.cif_unicode import format_unicode, handle_subscripts
21from ase.spacegroup import crystal
22from ase.spacegroup.spacegroup import Spacegroup, spacegroup_from_data
23from ase.utils import iofunction
25rhombohedral_spacegroups = {146, 148, 155, 160, 161, 166, 167}
28old_spacegroup_names = {'Abm2': 'Aem2',
29 'Aba2': 'Aea2',
30 'Cmca': 'Cmce',
31 'Cmma': 'Cmme',
32 'Ccca': 'Ccc1'}
34# CIF maps names to either single values or to multiple values via loops.
35CIFDataValue = Union[str, int, float]
36CIFData = Union[CIFDataValue, List[CIFDataValue]]
39def convert_value(value: str) -> CIFDataValue:
40 """Convert CIF value string to corresponding python type."""
41 value = value.strip()
42 if re.match('(".*")|(\'.*\')$', value):
43 return handle_subscripts(value[1:-1])
44 elif re.match(r'[+-]?\d+$', value):
45 return int(value)
46 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?$', value):
47 return float(value)
48 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+\)$',
49 value):
50 return float(value[:value.index('(')]) # strip off uncertainties
51 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+$',
52 value):
53 warnings.warn(f'Badly formed number: "{value}"')
54 return float(value[:value.index('(')]) # strip off uncertainties
55 else:
56 return handle_subscripts(value)
59def parse_multiline_string(lines: List[str], line: str) -> str:
60 """Parse semicolon-enclosed multiline string and return it."""
61 assert line[0] == ';'
62 strings = [line[1:].lstrip()]
63 while True:
64 line = lines.pop().strip()
65 if line[:1] == ';':
66 break
67 strings.append(line)
68 return '\n'.join(strings).strip()
71def parse_singletag(lines: List[str], line: str) -> Tuple[str, CIFDataValue]:
72 """Parse a CIF tag (entries starting with underscore). Returns
73 a key-value pair."""
74 kv = line.split(None, 1)
75 if len(kv) == 1:
76 key = line
77 line = lines.pop().strip()
78 while not line or line[0] == '#':
79 line = lines.pop().strip()
80 if line[0] == ';':
81 value = parse_multiline_string(lines, line)
82 else:
83 value = line
84 else:
85 key, value = kv
86 return key, convert_value(value)
89def parse_cif_loop_headers(lines: List[str]) -> Iterator[str]:
90 while lines:
91 line = lines.pop()
92 tokens = line.split()
94 if len(tokens) == 1 and tokens[0].startswith('_'):
95 header = tokens[0].lower()
96 yield header
97 else:
98 lines.append(line) # 'undo' pop
99 return
102def parse_cif_loop_data(lines: List[str],
103 ncolumns: int) -> List[List[CIFDataValue]]:
104 columns: List[List[CIFDataValue]] = [[] for _ in range(ncolumns)]
106 tokens = []
107 while lines:
108 line = lines.pop().strip()
109 lowerline = line.lower()
110 if (not line or
111 line.startswith('_') or
112 lowerline.startswith('data_') or
113 lowerline.startswith('loop_')):
114 lines.append(line)
115 break
117 if line.startswith('#'):
118 continue
120 line = line.split(' #')[0]
122 if line.startswith(';'):
123 moretokens = [parse_multiline_string(lines, line)]
124 else:
125 if ncolumns == 1:
126 moretokens = [line]
127 else:
128 moretokens = shlex.split(line, posix=False)
130 tokens += moretokens
131 if len(tokens) < ncolumns:
132 continue
133 if len(tokens) == ncolumns:
134 for i, token in enumerate(tokens):
135 columns[i].append(convert_value(token))
136 else:
137 warnings.warn(f'Wrong number {len(tokens)} of tokens, '
138 f'expected {ncolumns}: {tokens}')
140 # (Due to continue statements we cannot move this to start of loop)
141 tokens = []
143 if tokens:
144 assert len(tokens) < ncolumns
145 raise RuntimeError('CIF loop ended unexpectedly with incomplete row: '
146 f'{tokens}, expected {ncolumns} tokens')
148 return columns
151def parse_loop(lines: List[str]) -> Dict[str, List[CIFDataValue]]:
152 """Parse a CIF loop. Returns a dict with column tag names as keys
153 and a lists of the column content as values."""
155 headers = list(parse_cif_loop_headers(lines))
156 # Dict would be better. But there can be repeated headers.
158 columns = parse_cif_loop_data(lines, len(headers))
160 columns_dict = {}
161 for i, header in enumerate(headers):
162 if header in columns_dict:
163 warnings.warn(f'Duplicated loop tags: {header}')
164 else:
165 columns_dict[header] = columns[i]
166 return columns_dict
169def parse_items(lines: List[str], line: str) -> Dict[str, CIFData]:
170 """Parse a CIF data items and return a dict with all tags."""
171 tags: Dict[str, CIFData] = {}
173 while True:
174 if not lines:
175 break
176 line = lines.pop().strip()
177 if not line:
178 continue
179 lowerline = line.lower()
180 if not line or line.startswith('#'):
181 continue
182 elif line.startswith('_'):
183 key, value = parse_singletag(lines, line)
184 tags[key.lower()] = value
185 elif lowerline.startswith('loop_'):
186 tags.update(parse_loop(lines))
187 elif lowerline.startswith('data_'):
188 if line:
189 lines.append(line)
190 break
191 elif line.startswith(';'):
192 parse_multiline_string(lines, line)
193 else:
194 raise ValueError(f'Unexpected CIF file entry: "{line}"')
195 return tags
198class NoStructureData(RuntimeError):
199 pass
202class CIFBlock(collections.abc.Mapping):
203 """A block (i.e., a single system) in a crystallographic information file.
205 Use this object to query CIF tags or import information as ASE objects."""
207 cell_tags = ['_cell_length_a', '_cell_length_b', '_cell_length_c',
208 '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma']
210 def __init__(self, name: str, tags: Dict[str, CIFData]):
211 self.name = name
212 self._tags = tags
214 def __repr__(self) -> str:
215 tags = set(self._tags)
216 return f'CIFBlock({self.name}, tags={tags})'
218 def __getitem__(self, key: str) -> CIFData:
219 return self._tags[key]
221 def __iter__(self) -> Iterator[str]:
222 return iter(self._tags)
224 def __len__(self) -> int:
225 return len(self._tags)
227 def get(self, key, default=None):
228 return self._tags.get(key, default)
230 def get_cellpar(self) -> Optional[List]:
231 try:
232 return [self[tag] for tag in self.cell_tags]
233 except KeyError:
234 return None
236 def get_cell(self) -> Cell:
237 cellpar = self.get_cellpar()
238 if cellpar is None:
239 return Cell.new([0, 0, 0])
240 return Cell.new(cellpar)
242 def _raw_scaled_positions(self) -> Optional[np.ndarray]:
243 coords = [self.get(name) for name in ['_atom_site_fract_x',
244 '_atom_site_fract_y',
245 '_atom_site_fract_z']]
246 # XXX Shall we try to handle mixed coordinates?
247 # (Some scaled vs others fractional)
248 if None in coords:
249 return None
250 return np.array(coords).T
252 def _raw_positions(self) -> Optional[np.ndarray]:
253 coords = [self.get('_atom_site_cartn_x'),
254 self.get('_atom_site_cartn_y'),
255 self.get('_atom_site_cartn_z')]
256 if None in coords:
257 return None
258 return np.array(coords).T
260 def _get_site_coordinates(self):
261 scaled = self._raw_scaled_positions()
263 if scaled is not None:
264 return 'scaled', scaled
266 cartesian = self._raw_positions()
268 if cartesian is None:
269 raise NoStructureData('No positions found in structure')
271 return 'cartesian', cartesian
273 def _get_symbols_with_deuterium(self):
274 labels = self._get_any(['_atom_site_type_symbol',
275 '_atom_site_label'])
276 if labels is None:
277 raise NoStructureData('No symbols')
279 symbols = []
280 for label in labels:
281 if label == '.' or label == '?':
282 raise NoStructureData('Symbols are undetermined')
283 # Strip off additional labeling on chemical symbols
284 match = re.search(r'([A-Z][a-z]?)', label)
285 symbol = match.group(0)
286 symbols.append(symbol)
287 return symbols
289 def get_symbols(self) -> List[str]:
290 symbols = self._get_symbols_with_deuterium()
291 return [symbol if symbol != 'D' else 'H' for symbol in symbols]
293 def _where_deuterium(self):
294 return np.array([symbol == 'D' for symbol
295 in self._get_symbols_with_deuterium()], bool)
297 def _get_masses(self) -> Optional[np.ndarray]:
298 mask = self._where_deuterium()
299 if not any(mask):
300 return None
302 symbols = self.get_symbols()
303 masses = Atoms(symbols).get_masses()
304 masses[mask] = 2.01355
305 return masses
307 def _get_any(self, names):
308 for name in names:
309 if name in self:
310 return self[name]
311 return None
313 def _get_spacegroup_number(self):
314 # Symmetry specification, see
315 # http://www.iucr.org/resources/cif/dictionaries/cif_sym for a
316 # complete list of official keys. In addition we also try to
317 # support some commonly used depricated notations
318 return self._get_any(['_space_group.it_number',
319 '_space_group_it_number',
320 '_symmetry_int_tables_number'])
322 def _get_spacegroup_name(self):
323 hm_symbol = self._get_any(['_space_group_name_h-m_alt',
324 '_symmetry_space_group_name_h-m',
325 '_space_group.Patterson_name_h-m',
326 '_space_group.patterson_name_h-m'])
328 hm_symbol = old_spacegroup_names.get(hm_symbol, hm_symbol)
329 return hm_symbol
331 def _get_sitesym(self):
332 sitesym = self._get_any(['_space_group_symop_operation_xyz',
333 '_space_group_symop.operation_xyz',
334 '_symmetry_equiv_pos_as_xyz'])
335 if isinstance(sitesym, str):
336 sitesym = [sitesym]
337 return sitesym
339 def _get_fractional_occupancies(self):
340 return self.get('_atom_site_occupancy')
342 def _get_setting(self) -> Optional[int]:
343 setting_str = self.get('_symmetry_space_group_setting')
344 if setting_str is None:
345 return None
347 setting = int(setting_str)
348 if setting not in [1, 2]:
349 raise ValueError(
350 f'Spacegroup setting must be 1 or 2, not {setting}')
351 return setting
353 def get_spacegroup(self, subtrans_included) -> Spacegroup:
354 # XXX The logic in this method needs serious cleaning up!
355 # The setting needs to be passed as either 1 or two, not None (default)
356 no = self._get_spacegroup_number()
357 hm_symbol = self._get_spacegroup_name()
358 sitesym = self._get_sitesym()
360 setting = 1
361 spacegroup = 1
362 if sitesym:
363 # Special cases: sitesym can be None or an empty list.
364 # The empty list could be replaced with just the identity
365 # function, but it seems more correct to try to get the
366 # spacegroup number and derive the symmetries for that.
367 subtrans = [(0.0, 0.0, 0.0)] if subtrans_included else None
368 spacegroup = spacegroup_from_data(
369 no=no, symbol=hm_symbol, sitesym=sitesym, subtrans=subtrans,
370 setting=setting)
371 elif no is not None:
372 spacegroup = no
373 elif hm_symbol is not None:
374 spacegroup = hm_symbol
375 else:
376 spacegroup = 1
378 setting_std = self._get_setting()
380 setting_name = None
381 if '_symmetry_space_group_setting' in self:
382 assert setting_std is not None
383 setting = setting_std
384 elif '_space_group_crystal_system' in self:
385 setting_name = self['_space_group_crystal_system']
386 elif '_symmetry_cell_setting' in self:
387 setting_name = self['_symmetry_cell_setting']
389 if setting_name:
390 no = Spacegroup(spacegroup).no
391 if no in rhombohedral_spacegroups:
392 if setting_name == 'hexagonal':
393 setting = 1
394 elif setting_name in ('trigonal', 'rhombohedral'):
395 setting = 2
396 else:
397 warnings.warn(
398 f'unexpected crystal system {repr(setting_name)} '
399 f'for space group {repr(spacegroup)}')
400 # FIXME - check for more crystal systems...
401 else:
402 warnings.warn(
403 f'crystal system {repr(setting_name)} is not '
404 f'interpreted for space group {repr(spacegroup)}. '
405 'This may result in wrong setting!')
407 spg = Spacegroup(spacegroup, setting)
408 if no is not None:
409 assert int(spg) == no, (int(spg), no)
410 return spg
412 def get_unsymmetrized_structure(self) -> Atoms:
413 """Return Atoms without symmetrizing coordinates.
415 This returns a (normally) unphysical Atoms object
416 corresponding only to those coordinates included
417 in the CIF file, useful for e.g. debugging.
419 This method may change behaviour in the future."""
420 symbols = self.get_symbols()
421 coordtype, coords = self._get_site_coordinates()
423 atoms = Atoms(symbols=symbols,
424 cell=self.get_cell(),
425 masses=self._get_masses())
427 if coordtype == 'scaled':
428 atoms.set_scaled_positions(coords)
429 else:
430 assert coordtype == 'cartesian'
431 atoms.positions[:] = coords
433 return atoms
435 def has_structure(self):
436 """Whether this CIF block has an atomic configuration."""
437 try:
438 self.get_symbols()
439 self._get_site_coordinates()
440 except NoStructureData:
441 return False
442 else:
443 return True
445 def get_atoms(self, store_tags=False, primitive_cell=False,
446 subtrans_included=True, fractional_occupancies=True) -> Atoms:
447 """Returns an Atoms object from a cif tags dictionary. See read_cif()
448 for a description of the arguments."""
449 if primitive_cell and subtrans_included:
450 raise RuntimeError(
451 'Primitive cell cannot be determined when sublattice '
452 'translations are included in the symmetry operations listed '
453 'in the CIF file, i.e. when `subtrans_included` is True.')
455 cell = self.get_cell()
456 assert cell.rank in [0, 3]
458 kwargs: Dict[str, Any] = {}
459 if store_tags:
460 kwargs['info'] = self._tags.copy()
462 if fractional_occupancies:
463 occupancies = self._get_fractional_occupancies()
464 else:
465 occupancies = None
467 if occupancies is not None:
468 # no warnings in this case
469 kwargs['onduplicates'] = 'keep'
471 # The unsymmetrized_structure is not the asymmetric unit
472 # because the asymmetric unit should have (in general) a smaller cell,
473 # whereas we have the full cell.
474 unsymmetrized_structure = self.get_unsymmetrized_structure()
476 if cell.rank == 3:
477 spacegroup = self.get_spacegroup(subtrans_included)
478 atoms = crystal(unsymmetrized_structure,
479 spacegroup=spacegroup,
480 setting=spacegroup.setting,
481 occupancies=occupancies,
482 primitive_cell=primitive_cell,
483 **kwargs)
484 else:
485 atoms = unsymmetrized_structure
486 if kwargs.get('info') is not None:
487 atoms.info.update(kwargs['info'])
488 if occupancies is not None:
489 # Compile an occupancies dictionary
490 occ_dict = {}
491 for i, sym in enumerate(atoms.symbols):
492 occ_dict[str(i)] = {sym: occupancies[i]}
493 atoms.info['occupancy'] = occ_dict
495 return atoms
498def parse_block(lines: List[str], line: str) -> CIFBlock:
499 assert line.lower().startswith('data_')
500 blockname = line.split('_', 1)[1].rstrip()
501 tags = parse_items(lines, line)
502 return CIFBlock(blockname, tags)
505def parse_cif(fileobj, reader='ase') -> Iterator[CIFBlock]:
506 if reader == 'ase':
507 return parse_cif_ase(fileobj)
508 elif reader == 'pycodcif':
509 return parse_cif_pycodcif(fileobj)
510 else:
511 raise ValueError(f'No such reader: {reader}')
514def parse_cif_ase(fileobj) -> Iterator[CIFBlock]:
515 """Parse a CIF file using ase CIF parser."""
517 if isinstance(fileobj, str):
518 with open(fileobj, 'rb') as fileobj:
519 data = fileobj.read()
520 else:
521 data = fileobj.read()
523 if isinstance(data, bytes):
524 data = data.decode('latin1')
525 data = format_unicode(data)
526 lines = [e for e in data.split('\n') if len(e) > 0]
527 if len(lines) > 0 and lines[0].rstrip() == '#\\#CIF_2.0':
528 warnings.warn('CIF v2.0 file format detected; `ase` CIF reader might '
529 'incorrectly interpret some syntax constructions, use '
530 '`pycodcif` reader instead')
531 lines = [''] + lines[::-1] # all lines (reversed)
533 while lines:
534 line = lines.pop().strip()
535 if not line or line.startswith('#'):
536 continue
538 yield parse_block(lines, line)
541def parse_cif_pycodcif(fileobj) -> Iterator[CIFBlock]:
542 """Parse a CIF file using pycodcif CIF parser."""
543 if not isinstance(fileobj, str):
544 fileobj = fileobj.name
546 try:
547 from pycodcif import parse
548 except ImportError:
549 raise ImportError(
550 'parse_cif_pycodcif requires pycodcif ' +
551 '(http://wiki.crystallography.net/cod-tools/pycodcif/)')
553 data, _, _ = parse(fileobj)
555 for datablock in data:
556 tags = datablock['values']
557 for tag in tags.keys():
558 values = [convert_value(x) for x in tags[tag]]
559 if len(values) == 1:
560 tags[tag] = values[0]
561 else:
562 tags[tag] = values
563 yield CIFBlock(datablock['name'], tags)
566def iread_cif(
567 fileobj,
568 index=-1,
569 store_tags: bool = False,
570 primitive_cell: bool = False,
571 subtrans_included: bool = True,
572 fractional_occupancies: bool = True,
573 reader: str = 'ase',
574) -> Iterator[Atoms]:
575 # Find all CIF blocks with valid crystal data
576 # TODO: return Atoms of the block name ``index`` if it is a string.
577 images = []
578 for block in parse_cif(fileobj, reader):
579 if not block.has_structure():
580 continue
582 atoms = block.get_atoms(
583 store_tags, primitive_cell,
584 subtrans_included,
585 fractional_occupancies=fractional_occupancies)
586 images.append(atoms)
588 if index is None or index == ':':
589 index = slice(None, None, None)
591 if not isinstance(index, (slice, str)):
592 index = slice(index, (index + 1) or None)
594 for atoms in images[index]:
595 yield atoms
598def read_cif(
599 fileobj,
600 index=-1,
601 *,
602 store_tags: bool = False,
603 primitive_cell: bool = False,
604 subtrans_included: bool = True,
605 fractional_occupancies: bool = True,
606 reader: str = 'ase',
607) -> Union[Atoms, List[Atoms]]:
608 """Read Atoms object from CIF file.
610 Parameters
611 ----------
612 store_tags : bool
613 If true, the *info* attribute of the returned Atoms object will be
614 populated with all tags in the corresponding cif data block.
615 primitive_cell : bool
616 If true, the primitive cell is built instead of the conventional cell.
617 subtrans_included : bool
618 If true, sublattice translations are assumed to be included among the
619 symmetry operations listed in the CIF file (seems to be the common
620 behaviour of CIF files).
621 Otherwise the sublattice translations are determined from setting 1 of
622 the extracted space group. A result of setting this flag to true, is
623 that it will not be possible to determine the primitive cell.
624 fractional_occupancies : bool
625 If true, the resulting atoms object will be tagged equipped with a
626 dictionary `occupancy`. The keys of this dictionary will be integers
627 converted to strings. The conversion to string is done in order to
628 avoid troubles with JSON encoding/decoding of the dictionaries with
629 non-string keys.
630 Also, in case of mixed occupancies, the atom's chemical symbol will be
631 that of the most dominant species.
632 reader : str
633 Select CIF reader.
635 * ``ase`` : built-in CIF reader (default)
636 * ``pycodcif`` : CIF reader based on ``pycodcif`` package
638 Notes
639 -----
640 Only blocks with valid crystal data will be included.
641 """
642 g = iread_cif(
643 fileobj,
644 index,
645 store_tags,
646 primitive_cell,
647 subtrans_included,
648 fractional_occupancies,
649 reader,
650 )
651 if isinstance(index, (slice, str)):
652 # Return list of atoms
653 return list(g)
654 else:
655 # Return single atoms object
656 return next(g)
659def format_cell(cell: Cell) -> str:
660 assert cell.rank == 3
661 lines = []
662 for name, value in zip(CIFBlock.cell_tags, cell.cellpar()):
663 line = f'{name:20} {value}\n'
664 lines.append(line)
665 assert len(lines) == 6
666 return ''.join(lines)
669def format_generic_spacegroup_info() -> str:
670 # We assume no symmetry whatsoever
671 return '\n'.join([
672 '_space_group_name_H-M_alt "P 1"',
673 '_space_group_IT_number 1',
674 '',
675 'loop_',
676 ' _space_group_symop_operation_xyz',
677 " 'x, y, z'",
678 '',
679 ])
682class CIFLoop:
683 def __init__(self):
684 self.names = []
685 self.formats = []
686 self.arrays = []
688 def add(self, name, array, fmt):
689 assert name.startswith('_')
690 self.names.append(name)
691 self.formats.append(fmt)
692 self.arrays.append(array)
693 if len(self.arrays[0]) != len(self.arrays[-1]):
694 raise ValueError(f'Loop data "{name}" has {len(array)} '
695 'elements, expected {len(self.arrays[0])}')
697 def tostring(self):
698 lines = []
699 append = lines.append
700 append('loop_')
701 for name in self.names:
702 append(f' {name}')
704 template = ' ' + ' '.join(self.formats)
706 ncolumns = len(self.arrays)
707 nrows = len(self.arrays[0]) if ncolumns > 0 else 0
708 for row in range(nrows):
709 arraydata = [array[row] for array in self.arrays]
710 line = template.format(*arraydata)
711 append(line)
712 append('')
713 return '\n'.join(lines)
716@iofunction('wb')
717def write_cif(fd, images, cif_format=None,
718 wrap=True, labels=None, loop_keys=None) -> None:
719 """Write *images* to CIF file.
721 wrap: bool
722 Wrap atoms into unit cell.
724 labels: list
725 Use this list (shaped list[i_frame][i_atom] = string) for the
726 '_atom_site_label' section instead of automatically generating
727 it from the element symbol.
729 loop_keys: dict
730 Add the information from this dictionary to the `loop_`
731 section. Keys are printed to the `loop_` section preceeded by
732 ' _'. dict[key] should contain the data printed for each atom,
733 so it needs to have the setup `dict[key][i_frame][i_atom] =
734 string`. The strings are printed as they are, so take care of
735 formating. Information can be re-read using the `store_tags`
736 option of the cif reader.
738 """
740 if cif_format is not None:
741 warnings.warn('The cif_format argument is deprecated and may be '
742 'removed in the future. Use loop_keys to customize '
743 'data written in loop.', FutureWarning)
745 if loop_keys is None:
746 loop_keys = {}
748 if hasattr(images, 'get_positions'):
749 images = [images]
751 fd = io.TextIOWrapper(fd, encoding='latin-1')
752 try:
753 for i, atoms in enumerate(images):
754 blockname = f'data_image{i}\n'
755 image_loop_keys = {key: loop_keys[key][i] for key in loop_keys}
757 write_cif_image(blockname, atoms, fd,
758 wrap=wrap,
759 labels=None if labels is None else labels[i],
760 loop_keys=image_loop_keys)
762 finally:
763 # Using the TextIOWrapper somehow causes the file to close
764 # when this function returns.
765 # Detach in order to circumvent this highly illogical problem:
766 fd.detach()
769def autolabel(symbols: Sequence[str]) -> List[str]:
770 no: Dict[str, int] = {}
771 labels = []
772 for symbol in symbols:
773 if symbol in no:
774 no[symbol] += 1
775 else:
776 no[symbol] = 1
777 labels.append('%s%d' % (symbol, no[symbol]))
778 return labels
781def chemical_formula_header(atoms):
782 counts = atoms.symbols.formula.count()
783 formula_sum = ' '.join(f'{sym}{count}' for sym, count
784 in counts.items())
785 return (f'_chemical_formula_structural {atoms.symbols}\n'
786 f'_chemical_formula_sum "{formula_sum}"\n')
789class BadOccupancies(ValueError):
790 pass
793def expand_kinds(atoms, coords):
794 # try to fetch occupancies // spacegroup_kinds - occupancy mapping
795 symbols = list(atoms.symbols)
796 coords = list(coords)
797 occupancies = [1] * len(symbols)
798 occ_info = atoms.info.get('occupancy')
799 kinds = atoms.arrays.get('spacegroup_kinds')
800 if occ_info is not None and kinds is not None:
801 for i, kind in enumerate(kinds):
802 occ_info_kind = occ_info[str(kind)]
803 symbol = symbols[i]
804 if symbol not in occ_info_kind:
805 raise BadOccupancies('Occupancies present but no occupancy '
806 'info for "{symbol}"')
807 occupancies[i] = occ_info_kind[symbol]
808 # extend the positions array in case of mixed occupancy
809 for sym, occ in occ_info[str(kind)].items():
810 if sym != symbols[i]:
811 symbols.append(sym)
812 coords.append(coords[i])
813 occupancies.append(occ)
814 return symbols, coords, occupancies
817def atoms_to_loop_data(atoms, wrap, labels, loop_keys):
818 if atoms.cell.rank == 3:
819 coord_type = 'fract'
820 coords = atoms.get_scaled_positions(wrap).tolist()
821 else:
822 coord_type = 'Cartn'
823 coords = atoms.get_positions(wrap).tolist()
825 try:
826 symbols, coords, occupancies = expand_kinds(atoms, coords)
827 except BadOccupancies as err:
828 warnings.warn(str(err))
829 occupancies = [1] * len(atoms)
830 symbols = list(atoms.symbols)
832 if labels is None:
833 labels = autolabel(symbols)
835 coord_headers = [f'_atom_site_{coord_type}_{axisname}'
836 for axisname in 'xyz']
838 loopdata = {}
839 loopdata['_atom_site_label'] = (labels, '{:<8s}')
840 loopdata['_atom_site_occupancy'] = (occupancies, '{:6.4f}')
842 _coords = np.array(coords)
843 for i, key in enumerate(coord_headers):
844 loopdata[key] = (_coords[:, i], '{}')
846 loopdata['_atom_site_type_symbol'] = (symbols, '{:<2s}')
847 loopdata['_atom_site_symmetry_multiplicity'] = (
848 [1.0] * len(symbols), '{}')
850 for key in loop_keys:
851 # Should expand the loop_keys like we expand the occupancy stuff.
852 # Otherwise user will never figure out how to do this.
853 values = [loop_keys[key][i] for i in range(len(symbols))]
854 loopdata['_' + key] = (values, '{}')
856 return loopdata, coord_headers
859def write_cif_image(blockname, atoms, fd, *, wrap,
860 labels, loop_keys):
861 fd.write(blockname)
862 fd.write(chemical_formula_header(atoms))
864 rank = atoms.cell.rank
865 if rank == 3:
866 fd.write(format_cell(atoms.cell))
867 fd.write('\n')
868 fd.write(format_generic_spacegroup_info())
869 fd.write('\n')
870 elif rank != 0:
871 raise ValueError('CIF format can only represent systems with '
872 f'0 or 3 lattice vectors. Got {rank}.')
874 loopdata, coord_headers = atoms_to_loop_data(atoms, wrap, labels,
875 loop_keys)
877 headers = [
878 '_atom_site_type_symbol',
879 '_atom_site_label',
880 '_atom_site_symmetry_multiplicity',
881 *coord_headers,
882 '_atom_site_occupancy',
883 ]
885 headers += ['_' + key for key in loop_keys]
887 loop = CIFLoop()
888 for header in headers:
889 array, fmt = loopdata[header]
890 loop.add(header, array, fmt)
892 fd.write(loop.tostring())