Coverage for /builds/kinetik161/ase/ase/io/cif.py: 87.80%

492 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1"""Module to read and write atoms in cif file format. 

2 

3See http://www.iucr.org/resources/cif/spec/version1.1/cifsyntax for a 

4description of the file format. STAR extensions as save frames, 

5global blocks, nested loops and multi-data values are not supported. 

6The "latin-1" encoding is required by the IUCR specification. 

7""" 

8 

9import collections.abc 

10import io 

11import re 

12import shlex 

13import warnings 

14from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union 

15 

16import numpy as np 

17 

18from ase import Atoms 

19from ase.cell import Cell 

20from ase.io.cif_unicode import format_unicode, handle_subscripts 

21from ase.spacegroup import crystal 

22from ase.spacegroup.spacegroup import Spacegroup, spacegroup_from_data 

23from ase.utils import iofunction 

24 

25rhombohedral_spacegroups = {146, 148, 155, 160, 161, 166, 167} 

26 

27 

28old_spacegroup_names = {'Abm2': 'Aem2', 

29 'Aba2': 'Aea2', 

30 'Cmca': 'Cmce', 

31 'Cmma': 'Cmme', 

32 'Ccca': 'Ccc1'} 

33 

34# CIF maps names to either single values or to multiple values via loops. 

35CIFDataValue = Union[str, int, float] 

36CIFData = Union[CIFDataValue, List[CIFDataValue]] 

37 

38 

39def convert_value(value: str) -> CIFDataValue: 

40 """Convert CIF value string to corresponding python type.""" 

41 value = value.strip() 

42 if re.match('(".*")|(\'.*\')$', value): 

43 return handle_subscripts(value[1:-1]) 

44 elif re.match(r'[+-]?\d+$', value): 

45 return int(value) 

46 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?$', value): 

47 return float(value) 

48 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+\)$', 

49 value): 

50 return float(value[:value.index('(')]) # strip off uncertainties 

51 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+$', 

52 value): 

53 warnings.warn(f'Badly formed number: "{value}"') 

54 return float(value[:value.index('(')]) # strip off uncertainties 

55 else: 

56 return handle_subscripts(value) 

57 

58 

59def parse_multiline_string(lines: List[str], line: str) -> str: 

60 """Parse semicolon-enclosed multiline string and return it.""" 

61 assert line[0] == ';' 

62 strings = [line[1:].lstrip()] 

63 while True: 

64 line = lines.pop().strip() 

65 if line[:1] == ';': 

66 break 

67 strings.append(line) 

68 return '\n'.join(strings).strip() 

69 

70 

71def parse_singletag(lines: List[str], line: str) -> Tuple[str, CIFDataValue]: 

72 """Parse a CIF tag (entries starting with underscore). Returns 

73 a key-value pair.""" 

74 kv = line.split(None, 1) 

75 if len(kv) == 1: 

76 key = line 

77 line = lines.pop().strip() 

78 while not line or line[0] == '#': 

79 line = lines.pop().strip() 

80 if line[0] == ';': 

81 value = parse_multiline_string(lines, line) 

82 else: 

83 value = line 

84 else: 

85 key, value = kv 

86 return key, convert_value(value) 

87 

88 

89def parse_cif_loop_headers(lines: List[str]) -> Iterator[str]: 

90 while lines: 

91 line = lines.pop() 

92 tokens = line.split() 

93 

94 if len(tokens) == 1 and tokens[0].startswith('_'): 

95 header = tokens[0].lower() 

96 yield header 

97 else: 

98 lines.append(line) # 'undo' pop 

99 return 

100 

101 

102def parse_cif_loop_data(lines: List[str], 

103 ncolumns: int) -> List[List[CIFDataValue]]: 

104 columns: List[List[CIFDataValue]] = [[] for _ in range(ncolumns)] 

105 

106 tokens = [] 

107 while lines: 

108 line = lines.pop().strip() 

109 lowerline = line.lower() 

110 if (not line or 

111 line.startswith('_') or 

112 lowerline.startswith('data_') or 

113 lowerline.startswith('loop_')): 

114 lines.append(line) 

115 break 

116 

117 if line.startswith('#'): 

118 continue 

119 

120 line = line.split(' #')[0] 

121 

122 if line.startswith(';'): 

123 moretokens = [parse_multiline_string(lines, line)] 

124 else: 

125 if ncolumns == 1: 

126 moretokens = [line] 

127 else: 

128 moretokens = shlex.split(line, posix=False) 

129 

130 tokens += moretokens 

131 if len(tokens) < ncolumns: 

132 continue 

133 if len(tokens) == ncolumns: 

134 for i, token in enumerate(tokens): 

135 columns[i].append(convert_value(token)) 

136 else: 

137 warnings.warn(f'Wrong number {len(tokens)} of tokens, ' 

138 f'expected {ncolumns}: {tokens}') 

139 

140 # (Due to continue statements we cannot move this to start of loop) 

141 tokens = [] 

142 

143 if tokens: 

144 assert len(tokens) < ncolumns 

145 raise RuntimeError('CIF loop ended unexpectedly with incomplete row: ' 

146 f'{tokens}, expected {ncolumns} tokens') 

147 

148 return columns 

149 

150 

151def parse_loop(lines: List[str]) -> Dict[str, List[CIFDataValue]]: 

152 """Parse a CIF loop. Returns a dict with column tag names as keys 

153 and a lists of the column content as values.""" 

154 

155 headers = list(parse_cif_loop_headers(lines)) 

156 # Dict would be better. But there can be repeated headers. 

157 

158 columns = parse_cif_loop_data(lines, len(headers)) 

159 

160 columns_dict = {} 

161 for i, header in enumerate(headers): 

162 if header in columns_dict: 

163 warnings.warn(f'Duplicated loop tags: {header}') 

164 else: 

165 columns_dict[header] = columns[i] 

166 return columns_dict 

167 

168 

169def parse_items(lines: List[str], line: str) -> Dict[str, CIFData]: 

170 """Parse a CIF data items and return a dict with all tags.""" 

171 tags: Dict[str, CIFData] = {} 

172 

173 while True: 

174 if not lines: 

175 break 

176 line = lines.pop().strip() 

177 if not line: 

178 continue 

179 lowerline = line.lower() 

180 if not line or line.startswith('#'): 

181 continue 

182 elif line.startswith('_'): 

183 key, value = parse_singletag(lines, line) 

184 tags[key.lower()] = value 

185 elif lowerline.startswith('loop_'): 

186 tags.update(parse_loop(lines)) 

187 elif lowerline.startswith('data_'): 

188 if line: 

189 lines.append(line) 

190 break 

191 elif line.startswith(';'): 

192 parse_multiline_string(lines, line) 

193 else: 

194 raise ValueError(f'Unexpected CIF file entry: "{line}"') 

195 return tags 

196 

197 

198class NoStructureData(RuntimeError): 

199 pass 

200 

201 

202class CIFBlock(collections.abc.Mapping): 

203 """A block (i.e., a single system) in a crystallographic information file. 

204 

205 Use this object to query CIF tags or import information as ASE objects.""" 

206 

207 cell_tags = ['_cell_length_a', '_cell_length_b', '_cell_length_c', 

208 '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma'] 

209 

210 def __init__(self, name: str, tags: Dict[str, CIFData]): 

211 self.name = name 

212 self._tags = tags 

213 

214 def __repr__(self) -> str: 

215 tags = set(self._tags) 

216 return f'CIFBlock({self.name}, tags={tags})' 

217 

218 def __getitem__(self, key: str) -> CIFData: 

219 return self._tags[key] 

220 

221 def __iter__(self) -> Iterator[str]: 

222 return iter(self._tags) 

223 

224 def __len__(self) -> int: 

225 return len(self._tags) 

226 

227 def get(self, key, default=None): 

228 return self._tags.get(key, default) 

229 

230 def get_cellpar(self) -> Optional[List]: 

231 try: 

232 return [self[tag] for tag in self.cell_tags] 

233 except KeyError: 

234 return None 

235 

236 def get_cell(self) -> Cell: 

237 cellpar = self.get_cellpar() 

238 if cellpar is None: 

239 return Cell.new([0, 0, 0]) 

240 return Cell.new(cellpar) 

241 

242 def _raw_scaled_positions(self) -> Optional[np.ndarray]: 

243 coords = [self.get(name) for name in ['_atom_site_fract_x', 

244 '_atom_site_fract_y', 

245 '_atom_site_fract_z']] 

246 # XXX Shall we try to handle mixed coordinates? 

247 # (Some scaled vs others fractional) 

248 if None in coords: 

249 return None 

250 return np.array(coords).T 

251 

252 def _raw_positions(self) -> Optional[np.ndarray]: 

253 coords = [self.get('_atom_site_cartn_x'), 

254 self.get('_atom_site_cartn_y'), 

255 self.get('_atom_site_cartn_z')] 

256 if None in coords: 

257 return None 

258 return np.array(coords).T 

259 

260 def _get_site_coordinates(self): 

261 scaled = self._raw_scaled_positions() 

262 

263 if scaled is not None: 

264 return 'scaled', scaled 

265 

266 cartesian = self._raw_positions() 

267 

268 if cartesian is None: 

269 raise NoStructureData('No positions found in structure') 

270 

271 return 'cartesian', cartesian 

272 

273 def _get_symbols_with_deuterium(self): 

274 labels = self._get_any(['_atom_site_type_symbol', 

275 '_atom_site_label']) 

276 if labels is None: 

277 raise NoStructureData('No symbols') 

278 

279 symbols = [] 

280 for label in labels: 

281 if label == '.' or label == '?': 

282 raise NoStructureData('Symbols are undetermined') 

283 # Strip off additional labeling on chemical symbols 

284 match = re.search(r'([A-Z][a-z]?)', label) 

285 symbol = match.group(0) 

286 symbols.append(symbol) 

287 return symbols 

288 

289 def get_symbols(self) -> List[str]: 

290 symbols = self._get_symbols_with_deuterium() 

291 return [symbol if symbol != 'D' else 'H' for symbol in symbols] 

292 

293 def _where_deuterium(self): 

294 return np.array([symbol == 'D' for symbol 

295 in self._get_symbols_with_deuterium()], bool) 

296 

297 def _get_masses(self) -> Optional[np.ndarray]: 

298 mask = self._where_deuterium() 

299 if not any(mask): 

300 return None 

301 

302 symbols = self.get_symbols() 

303 masses = Atoms(symbols).get_masses() 

304 masses[mask] = 2.01355 

305 return masses 

306 

307 def _get_any(self, names): 

308 for name in names: 

309 if name in self: 

310 return self[name] 

311 return None 

312 

313 def _get_spacegroup_number(self): 

314 # Symmetry specification, see 

315 # http://www.iucr.org/resources/cif/dictionaries/cif_sym for a 

316 # complete list of official keys. In addition we also try to 

317 # support some commonly used depricated notations 

318 return self._get_any(['_space_group.it_number', 

319 '_space_group_it_number', 

320 '_symmetry_int_tables_number']) 

321 

322 def _get_spacegroup_name(self): 

323 hm_symbol = self._get_any(['_space_group_name_h-m_alt', 

324 '_symmetry_space_group_name_h-m', 

325 '_space_group.Patterson_name_h-m', 

326 '_space_group.patterson_name_h-m']) 

327 

328 hm_symbol = old_spacegroup_names.get(hm_symbol, hm_symbol) 

329 return hm_symbol 

330 

331 def _get_sitesym(self): 

332 sitesym = self._get_any(['_space_group_symop_operation_xyz', 

333 '_space_group_symop.operation_xyz', 

334 '_symmetry_equiv_pos_as_xyz']) 

335 if isinstance(sitesym, str): 

336 sitesym = [sitesym] 

337 return sitesym 

338 

339 def _get_fractional_occupancies(self): 

340 return self.get('_atom_site_occupancy') 

341 

342 def _get_setting(self) -> Optional[int]: 

343 setting_str = self.get('_symmetry_space_group_setting') 

344 if setting_str is None: 

345 return None 

346 

347 setting = int(setting_str) 

348 if setting not in [1, 2]: 

349 raise ValueError( 

350 f'Spacegroup setting must be 1 or 2, not {setting}') 

351 return setting 

352 

353 def get_spacegroup(self, subtrans_included) -> Spacegroup: 

354 # XXX The logic in this method needs serious cleaning up! 

355 # The setting needs to be passed as either 1 or two, not None (default) 

356 no = self._get_spacegroup_number() 

357 hm_symbol = self._get_spacegroup_name() 

358 sitesym = self._get_sitesym() 

359 

360 setting = 1 

361 spacegroup = 1 

362 if sitesym: 

363 # Special cases: sitesym can be None or an empty list. 

364 # The empty list could be replaced with just the identity 

365 # function, but it seems more correct to try to get the 

366 # spacegroup number and derive the symmetries for that. 

367 subtrans = [(0.0, 0.0, 0.0)] if subtrans_included else None 

368 spacegroup = spacegroup_from_data( 

369 no=no, symbol=hm_symbol, sitesym=sitesym, subtrans=subtrans, 

370 setting=setting) 

371 elif no is not None: 

372 spacegroup = no 

373 elif hm_symbol is not None: 

374 spacegroup = hm_symbol 

375 else: 

376 spacegroup = 1 

377 

378 setting_std = self._get_setting() 

379 

380 setting_name = None 

381 if '_symmetry_space_group_setting' in self: 

382 assert setting_std is not None 

383 setting = setting_std 

384 elif '_space_group_crystal_system' in self: 

385 setting_name = self['_space_group_crystal_system'] 

386 elif '_symmetry_cell_setting' in self: 

387 setting_name = self['_symmetry_cell_setting'] 

388 

389 if setting_name: 

390 no = Spacegroup(spacegroup).no 

391 if no in rhombohedral_spacegroups: 

392 if setting_name == 'hexagonal': 

393 setting = 1 

394 elif setting_name in ('trigonal', 'rhombohedral'): 

395 setting = 2 

396 else: 

397 warnings.warn( 

398 f'unexpected crystal system {repr(setting_name)} ' 

399 f'for space group {repr(spacegroup)}') 

400 # FIXME - check for more crystal systems... 

401 else: 

402 warnings.warn( 

403 f'crystal system {repr(setting_name)} is not ' 

404 f'interpreted for space group {repr(spacegroup)}. ' 

405 'This may result in wrong setting!') 

406 

407 spg = Spacegroup(spacegroup, setting) 

408 if no is not None: 

409 assert int(spg) == no, (int(spg), no) 

410 return spg 

411 

412 def get_unsymmetrized_structure(self) -> Atoms: 

413 """Return Atoms without symmetrizing coordinates. 

414 

415 This returns a (normally) unphysical Atoms object 

416 corresponding only to those coordinates included 

417 in the CIF file, useful for e.g. debugging. 

418 

419 This method may change behaviour in the future.""" 

420 symbols = self.get_symbols() 

421 coordtype, coords = self._get_site_coordinates() 

422 

423 atoms = Atoms(symbols=symbols, 

424 cell=self.get_cell(), 

425 masses=self._get_masses()) 

426 

427 if coordtype == 'scaled': 

428 atoms.set_scaled_positions(coords) 

429 else: 

430 assert coordtype == 'cartesian' 

431 atoms.positions[:] = coords 

432 

433 return atoms 

434 

435 def has_structure(self): 

436 """Whether this CIF block has an atomic configuration.""" 

437 try: 

438 self.get_symbols() 

439 self._get_site_coordinates() 

440 except NoStructureData: 

441 return False 

442 else: 

443 return True 

444 

445 def get_atoms(self, store_tags=False, primitive_cell=False, 

446 subtrans_included=True, fractional_occupancies=True) -> Atoms: 

447 """Returns an Atoms object from a cif tags dictionary. See read_cif() 

448 for a description of the arguments.""" 

449 if primitive_cell and subtrans_included: 

450 raise RuntimeError( 

451 'Primitive cell cannot be determined when sublattice ' 

452 'translations are included in the symmetry operations listed ' 

453 'in the CIF file, i.e. when `subtrans_included` is True.') 

454 

455 cell = self.get_cell() 

456 assert cell.rank in [0, 3] 

457 

458 kwargs: Dict[str, Any] = {} 

459 if store_tags: 

460 kwargs['info'] = self._tags.copy() 

461 

462 if fractional_occupancies: 

463 occupancies = self._get_fractional_occupancies() 

464 else: 

465 occupancies = None 

466 

467 if occupancies is not None: 

468 # no warnings in this case 

469 kwargs['onduplicates'] = 'keep' 

470 

471 # The unsymmetrized_structure is not the asymmetric unit 

472 # because the asymmetric unit should have (in general) a smaller cell, 

473 # whereas we have the full cell. 

474 unsymmetrized_structure = self.get_unsymmetrized_structure() 

475 

476 if cell.rank == 3: 

477 spacegroup = self.get_spacegroup(subtrans_included) 

478 atoms = crystal(unsymmetrized_structure, 

479 spacegroup=spacegroup, 

480 setting=spacegroup.setting, 

481 occupancies=occupancies, 

482 primitive_cell=primitive_cell, 

483 **kwargs) 

484 else: 

485 atoms = unsymmetrized_structure 

486 if kwargs.get('info') is not None: 

487 atoms.info.update(kwargs['info']) 

488 if occupancies is not None: 

489 # Compile an occupancies dictionary 

490 occ_dict = {} 

491 for i, sym in enumerate(atoms.symbols): 

492 occ_dict[str(i)] = {sym: occupancies[i]} 

493 atoms.info['occupancy'] = occ_dict 

494 

495 return atoms 

496 

497 

498def parse_block(lines: List[str], line: str) -> CIFBlock: 

499 assert line.lower().startswith('data_') 

500 blockname = line.split('_', 1)[1].rstrip() 

501 tags = parse_items(lines, line) 

502 return CIFBlock(blockname, tags) 

503 

504 

505def parse_cif(fileobj, reader='ase') -> Iterator[CIFBlock]: 

506 if reader == 'ase': 

507 return parse_cif_ase(fileobj) 

508 elif reader == 'pycodcif': 

509 return parse_cif_pycodcif(fileobj) 

510 else: 

511 raise ValueError(f'No such reader: {reader}') 

512 

513 

514def parse_cif_ase(fileobj) -> Iterator[CIFBlock]: 

515 """Parse a CIF file using ase CIF parser.""" 

516 

517 if isinstance(fileobj, str): 

518 with open(fileobj, 'rb') as fileobj: 

519 data = fileobj.read() 

520 else: 

521 data = fileobj.read() 

522 

523 if isinstance(data, bytes): 

524 data = data.decode('latin1') 

525 data = format_unicode(data) 

526 lines = [e for e in data.split('\n') if len(e) > 0] 

527 if len(lines) > 0 and lines[0].rstrip() == '#\\#CIF_2.0': 

528 warnings.warn('CIF v2.0 file format detected; `ase` CIF reader might ' 

529 'incorrectly interpret some syntax constructions, use ' 

530 '`pycodcif` reader instead') 

531 lines = [''] + lines[::-1] # all lines (reversed) 

532 

533 while lines: 

534 line = lines.pop().strip() 

535 if not line or line.startswith('#'): 

536 continue 

537 

538 yield parse_block(lines, line) 

539 

540 

541def parse_cif_pycodcif(fileobj) -> Iterator[CIFBlock]: 

542 """Parse a CIF file using pycodcif CIF parser.""" 

543 if not isinstance(fileobj, str): 

544 fileobj = fileobj.name 

545 

546 try: 

547 from pycodcif import parse 

548 except ImportError: 

549 raise ImportError( 

550 'parse_cif_pycodcif requires pycodcif ' + 

551 '(http://wiki.crystallography.net/cod-tools/pycodcif/)') 

552 

553 data, _, _ = parse(fileobj) 

554 

555 for datablock in data: 

556 tags = datablock['values'] 

557 for tag in tags.keys(): 

558 values = [convert_value(x) for x in tags[tag]] 

559 if len(values) == 1: 

560 tags[tag] = values[0] 

561 else: 

562 tags[tag] = values 

563 yield CIFBlock(datablock['name'], tags) 

564 

565 

566def iread_cif( 

567 fileobj, 

568 index=-1, 

569 store_tags: bool = False, 

570 primitive_cell: bool = False, 

571 subtrans_included: bool = True, 

572 fractional_occupancies: bool = True, 

573 reader: str = 'ase', 

574) -> Iterator[Atoms]: 

575 # Find all CIF blocks with valid crystal data 

576 # TODO: return Atoms of the block name ``index`` if it is a string. 

577 images = [] 

578 for block in parse_cif(fileobj, reader): 

579 if not block.has_structure(): 

580 continue 

581 

582 atoms = block.get_atoms( 

583 store_tags, primitive_cell, 

584 subtrans_included, 

585 fractional_occupancies=fractional_occupancies) 

586 images.append(atoms) 

587 

588 if index is None or index == ':': 

589 index = slice(None, None, None) 

590 

591 if not isinstance(index, (slice, str)): 

592 index = slice(index, (index + 1) or None) 

593 

594 for atoms in images[index]: 

595 yield atoms 

596 

597 

598def read_cif( 

599 fileobj, 

600 index=-1, 

601 *, 

602 store_tags: bool = False, 

603 primitive_cell: bool = False, 

604 subtrans_included: bool = True, 

605 fractional_occupancies: bool = True, 

606 reader: str = 'ase', 

607) -> Union[Atoms, List[Atoms]]: 

608 """Read Atoms object from CIF file. 

609 

610 Parameters 

611 ---------- 

612 store_tags : bool 

613 If true, the *info* attribute of the returned Atoms object will be 

614 populated with all tags in the corresponding cif data block. 

615 primitive_cell : bool 

616 If true, the primitive cell is built instead of the conventional cell. 

617 subtrans_included : bool 

618 If true, sublattice translations are assumed to be included among the 

619 symmetry operations listed in the CIF file (seems to be the common 

620 behaviour of CIF files). 

621 Otherwise the sublattice translations are determined from setting 1 of 

622 the extracted space group. A result of setting this flag to true, is 

623 that it will not be possible to determine the primitive cell. 

624 fractional_occupancies : bool 

625 If true, the resulting atoms object will be tagged equipped with a 

626 dictionary `occupancy`. The keys of this dictionary will be integers 

627 converted to strings. The conversion to string is done in order to 

628 avoid troubles with JSON encoding/decoding of the dictionaries with 

629 non-string keys. 

630 Also, in case of mixed occupancies, the atom's chemical symbol will be 

631 that of the most dominant species. 

632 reader : str 

633 Select CIF reader. 

634 

635 * ``ase`` : built-in CIF reader (default) 

636 * ``pycodcif`` : CIF reader based on ``pycodcif`` package 

637 

638 Notes 

639 ----- 

640 Only blocks with valid crystal data will be included. 

641 """ 

642 g = iread_cif( 

643 fileobj, 

644 index, 

645 store_tags, 

646 primitive_cell, 

647 subtrans_included, 

648 fractional_occupancies, 

649 reader, 

650 ) 

651 if isinstance(index, (slice, str)): 

652 # Return list of atoms 

653 return list(g) 

654 else: 

655 # Return single atoms object 

656 return next(g) 

657 

658 

659def format_cell(cell: Cell) -> str: 

660 assert cell.rank == 3 

661 lines = [] 

662 for name, value in zip(CIFBlock.cell_tags, cell.cellpar()): 

663 line = f'{name:20} {value}\n' 

664 lines.append(line) 

665 assert len(lines) == 6 

666 return ''.join(lines) 

667 

668 

669def format_generic_spacegroup_info() -> str: 

670 # We assume no symmetry whatsoever 

671 return '\n'.join([ 

672 '_space_group_name_H-M_alt "P 1"', 

673 '_space_group_IT_number 1', 

674 '', 

675 'loop_', 

676 ' _space_group_symop_operation_xyz', 

677 " 'x, y, z'", 

678 '', 

679 ]) 

680 

681 

682class CIFLoop: 

683 def __init__(self): 

684 self.names = [] 

685 self.formats = [] 

686 self.arrays = [] 

687 

688 def add(self, name, array, fmt): 

689 assert name.startswith('_') 

690 self.names.append(name) 

691 self.formats.append(fmt) 

692 self.arrays.append(array) 

693 if len(self.arrays[0]) != len(self.arrays[-1]): 

694 raise ValueError(f'Loop data "{name}" has {len(array)} ' 

695 'elements, expected {len(self.arrays[0])}') 

696 

697 def tostring(self): 

698 lines = [] 

699 append = lines.append 

700 append('loop_') 

701 for name in self.names: 

702 append(f' {name}') 

703 

704 template = ' ' + ' '.join(self.formats) 

705 

706 ncolumns = len(self.arrays) 

707 nrows = len(self.arrays[0]) if ncolumns > 0 else 0 

708 for row in range(nrows): 

709 arraydata = [array[row] for array in self.arrays] 

710 line = template.format(*arraydata) 

711 append(line) 

712 append('') 

713 return '\n'.join(lines) 

714 

715 

716@iofunction('wb') 

717def write_cif(fd, images, cif_format=None, 

718 wrap=True, labels=None, loop_keys=None) -> None: 

719 """Write *images* to CIF file. 

720 

721 wrap: bool 

722 Wrap atoms into unit cell. 

723 

724 labels: list 

725 Use this list (shaped list[i_frame][i_atom] = string) for the 

726 '_atom_site_label' section instead of automatically generating 

727 it from the element symbol. 

728 

729 loop_keys: dict 

730 Add the information from this dictionary to the `loop_` 

731 section. Keys are printed to the `loop_` section preceeded by 

732 ' _'. dict[key] should contain the data printed for each atom, 

733 so it needs to have the setup `dict[key][i_frame][i_atom] = 

734 string`. The strings are printed as they are, so take care of 

735 formating. Information can be re-read using the `store_tags` 

736 option of the cif reader. 

737 

738 """ 

739 

740 if cif_format is not None: 

741 warnings.warn('The cif_format argument is deprecated and may be ' 

742 'removed in the future. Use loop_keys to customize ' 

743 'data written in loop.', FutureWarning) 

744 

745 if loop_keys is None: 

746 loop_keys = {} 

747 

748 if hasattr(images, 'get_positions'): 

749 images = [images] 

750 

751 fd = io.TextIOWrapper(fd, encoding='latin-1') 

752 try: 

753 for i, atoms in enumerate(images): 

754 blockname = f'data_image{i}\n' 

755 image_loop_keys = {key: loop_keys[key][i] for key in loop_keys} 

756 

757 write_cif_image(blockname, atoms, fd, 

758 wrap=wrap, 

759 labels=None if labels is None else labels[i], 

760 loop_keys=image_loop_keys) 

761 

762 finally: 

763 # Using the TextIOWrapper somehow causes the file to close 

764 # when this function returns. 

765 # Detach in order to circumvent this highly illogical problem: 

766 fd.detach() 

767 

768 

769def autolabel(symbols: Sequence[str]) -> List[str]: 

770 no: Dict[str, int] = {} 

771 labels = [] 

772 for symbol in symbols: 

773 if symbol in no: 

774 no[symbol] += 1 

775 else: 

776 no[symbol] = 1 

777 labels.append('%s%d' % (symbol, no[symbol])) 

778 return labels 

779 

780 

781def chemical_formula_header(atoms): 

782 counts = atoms.symbols.formula.count() 

783 formula_sum = ' '.join(f'{sym}{count}' for sym, count 

784 in counts.items()) 

785 return (f'_chemical_formula_structural {atoms.symbols}\n' 

786 f'_chemical_formula_sum "{formula_sum}"\n') 

787 

788 

789class BadOccupancies(ValueError): 

790 pass 

791 

792 

793def expand_kinds(atoms, coords): 

794 # try to fetch occupancies // spacegroup_kinds - occupancy mapping 

795 symbols = list(atoms.symbols) 

796 coords = list(coords) 

797 occupancies = [1] * len(symbols) 

798 occ_info = atoms.info.get('occupancy') 

799 kinds = atoms.arrays.get('spacegroup_kinds') 

800 if occ_info is not None and kinds is not None: 

801 for i, kind in enumerate(kinds): 

802 occ_info_kind = occ_info[str(kind)] 

803 symbol = symbols[i] 

804 if symbol not in occ_info_kind: 

805 raise BadOccupancies('Occupancies present but no occupancy ' 

806 'info for "{symbol}"') 

807 occupancies[i] = occ_info_kind[symbol] 

808 # extend the positions array in case of mixed occupancy 

809 for sym, occ in occ_info[str(kind)].items(): 

810 if sym != symbols[i]: 

811 symbols.append(sym) 

812 coords.append(coords[i]) 

813 occupancies.append(occ) 

814 return symbols, coords, occupancies 

815 

816 

817def atoms_to_loop_data(atoms, wrap, labels, loop_keys): 

818 if atoms.cell.rank == 3: 

819 coord_type = 'fract' 

820 coords = atoms.get_scaled_positions(wrap).tolist() 

821 else: 

822 coord_type = 'Cartn' 

823 coords = atoms.get_positions(wrap).tolist() 

824 

825 try: 

826 symbols, coords, occupancies = expand_kinds(atoms, coords) 

827 except BadOccupancies as err: 

828 warnings.warn(str(err)) 

829 occupancies = [1] * len(atoms) 

830 symbols = list(atoms.symbols) 

831 

832 if labels is None: 

833 labels = autolabel(symbols) 

834 

835 coord_headers = [f'_atom_site_{coord_type}_{axisname}' 

836 for axisname in 'xyz'] 

837 

838 loopdata = {} 

839 loopdata['_atom_site_label'] = (labels, '{:<8s}') 

840 loopdata['_atom_site_occupancy'] = (occupancies, '{:6.4f}') 

841 

842 _coords = np.array(coords) 

843 for i, key in enumerate(coord_headers): 

844 loopdata[key] = (_coords[:, i], '{}') 

845 

846 loopdata['_atom_site_type_symbol'] = (symbols, '{:<2s}') 

847 loopdata['_atom_site_symmetry_multiplicity'] = ( 

848 [1.0] * len(symbols), '{}') 

849 

850 for key in loop_keys: 

851 # Should expand the loop_keys like we expand the occupancy stuff. 

852 # Otherwise user will never figure out how to do this. 

853 values = [loop_keys[key][i] for i in range(len(symbols))] 

854 loopdata['_' + key] = (values, '{}') 

855 

856 return loopdata, coord_headers 

857 

858 

859def write_cif_image(blockname, atoms, fd, *, wrap, 

860 labels, loop_keys): 

861 fd.write(blockname) 

862 fd.write(chemical_formula_header(atoms)) 

863 

864 rank = atoms.cell.rank 

865 if rank == 3: 

866 fd.write(format_cell(atoms.cell)) 

867 fd.write('\n') 

868 fd.write(format_generic_spacegroup_info()) 

869 fd.write('\n') 

870 elif rank != 0: 

871 raise ValueError('CIF format can only represent systems with ' 

872 f'0 or 3 lattice vectors. Got {rank}.') 

873 

874 loopdata, coord_headers = atoms_to_loop_data(atoms, wrap, labels, 

875 loop_keys) 

876 

877 headers = [ 

878 '_atom_site_type_symbol', 

879 '_atom_site_label', 

880 '_atom_site_symmetry_multiplicity', 

881 *coord_headers, 

882 '_atom_site_occupancy', 

883 ] 

884 

885 headers += ['_' + key for key in loop_keys] 

886 

887 loop = CIFLoop() 

888 for header in headers: 

889 array, fmt = loopdata[header] 

890 loop.add(header, array, fmt) 

891 

892 fd.write(loop.tostring())