Coverage for /builds/kinetik161/ase/ase/db/core.py: 85.68%

391 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import functools 

2import json 

3import numbers 

4import operator 

5import os 

6import re 

7import warnings 

8from time import time 

9from typing import Any, Dict, List 

10 

11import numpy as np 

12 

13from ase.atoms import Atoms 

14from ase.calculators.calculator import all_changes, all_properties 

15from ase.data import atomic_numbers 

16from ase.db.row import AtomsRow 

17from ase.formula import Formula 

18from ase.io.jsonio import create_ase_object 

19from ase.parallel import DummyMPI, parallel_function, parallel_generator, world 

20from ase.utils import Lock, PurePath 

21 

22T2000 = 946681200.0 # January 1. 2000 

23YEAR = 31557600.0 # 365.25 days 

24 

25 

26@functools.total_ordering 

27class KeyDescription: 

28 _subscript = re.compile(r'`(.)_(.)`') 

29 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`') 

30 

31 def __init__(self, key, shortdesc=None, longdesc=None, unit=''): 

32 self.key = key 

33 

34 if shortdesc is None: 

35 shortdesc = key 

36 

37 if longdesc is None: 

38 longdesc = shortdesc 

39 

40 self.shortdesc = shortdesc 

41 self.longdesc = longdesc 

42 

43 # Somewhat arbitrary that we do this conversion. Can we avoid that? 

44 # Previously done in create_key_descriptions(). 

45 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit) 

46 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit) 

47 unit = unit.replace(r'\text{', '').replace('}', '') 

48 

49 self.unit = unit 

50 

51 def __repr__(self): 

52 cls = type(self).__name__ 

53 return (f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, ' 

54 f'unit={self.unit!r})') 

55 

56 # The templates like to sort key descriptions by shortdesc. 

57 def __eq__(self, other): 

58 return self.shortdesc == getattr(other, 'shortdesc', None) 

59 

60 def __lt__(self, other): 

61 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc) 

62 

63 

64def get_key_descriptions(): 

65 KD = KeyDescription 

66 return {keydesc.key: keydesc for keydesc in [ 

67 KD('id', 'ID', 'Uniqe row ID'), 

68 KD('age', 'Age', 'Time since creation'), 

69 KD('formula', 'Formula', 'Chemical formula'), 

70 KD('pbc', 'PBC', 'Periodic boundary conditions'), 

71 KD('user', 'Username'), 

72 KD('calculator', 'Calculator', 'ASE-calculator name'), 

73 KD('energy', 'Energy', 'Total energy', unit='eV'), 

74 KD('natoms', 'Number of atoms'), 

75 KD('fmax', 'Maximum force', unit='eV/Å'), 

76 KD('smax', 'Maximum stress', 'Maximum stress on unit cell', 

77 unit='eV/ų'), 

78 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'), 

79 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'), 

80 KD('magmom', 'Magnetic moment', unit='μ_B'), 

81 KD('unique_id', 'Unique ID', 'Random (unique) ID'), 

82 KD('volume', 'Volume', 'Volume of unit cell', unit='ų') 

83 ]} 

84 

85 

86def now(): 

87 """Return time since January 1. 2000 in years.""" 

88 return (time() - T2000) / YEAR 

89 

90 

91seconds = {'s': 1, 

92 'm': 60, 

93 'h': 3600, 

94 'd': 86400, 

95 'w': 604800, 

96 'M': 2629800, 

97 'y': YEAR} 

98 

99longwords = {'s': 'second', 

100 'm': 'minute', 

101 'h': 'hour', 

102 'd': 'day', 

103 'w': 'week', 

104 'M': 'month', 

105 'y': 'year'} 

106 

107ops = {'<': operator.lt, 

108 '<=': operator.le, 

109 '=': operator.eq, 

110 '>=': operator.ge, 

111 '>': operator.gt, 

112 '!=': operator.ne} 

113 

114invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='} 

115 

116word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$') 

117 

118reserved_keys = set(all_properties + 

119 all_changes + 

120 list(atomic_numbers) + 

121 ['id', 'unique_id', 'ctime', 'mtime', 'user', 

122 'fmax', 'smax', 

123 'momenta', 'constraints', 'natoms', 'formula', 'age', 

124 'calculator', 'calculator_parameters', 

125 'key_value_pairs', 'data']) 

126 

127numeric_keys = {'id', 'energy', 'magmom', 'charge', 'natoms'} 

128 

129 

130def check(key_value_pairs): 

131 for key, value in key_value_pairs.items(): 

132 if key == "external_tables": 

133 # Checks for external_tables are not 

134 # performed 

135 continue 

136 

137 if not word.match(key) or key in reserved_keys: 

138 raise ValueError(f'Bad key: {key}') 

139 try: 

140 Formula(key, strict=True) 

141 except ValueError: 

142 pass 

143 else: 

144 warnings.warn( 

145 'It is best not to use keys ({0}) that are also a ' 

146 'chemical formula. If you do a "db.select({0!r})",' 

147 'you will not find rows with your key. Instead, you wil get ' 

148 'rows containing the atoms in the formula!'.format(key)) 

149 if not isinstance(value, (numbers.Real, str, np.bool_)): 

150 raise ValueError(f'Bad value for {key!r}: {value}') 

151 if isinstance(value, str): 

152 for t in [int, float]: 

153 if str_represents(value, t): 

154 raise ValueError( 

155 'Value ' + value + ' is put in as string ' + 

156 'but can be interpreted as ' + 

157 f'{t.__name__}! Please convert ' + 

158 f'to {t.__name__} using ' + 

159 f'{t.__name__}(value) before ' + 

160 'writing to the database OR change ' + 

161 'to a different string.') 

162 

163 

164def str_represents(value, t=int): 

165 try: 

166 t(value) 

167 except ValueError: 

168 return False 

169 return True 

170 

171 

172def connect(name, type='extract_from_name', create_indices=True, 

173 use_lock_file=True, append=True, serial=False): 

174 """Create connection to database. 

175 

176 name: str 

177 Filename or address of database. 

178 type: str 

179 One of 'json', 'db', 'postgresql', 

180 (JSON, SQLite, PostgreSQL). 

181 Default is 'extract_from_name', which will guess the type 

182 from the name. 

183 use_lock_file: bool 

184 You can turn this off if you know what you are doing ... 

185 append: bool 

186 Use append=False to start a new database. 

187 """ 

188 

189 if isinstance(name, PurePath): 

190 name = str(name) 

191 

192 if type == 'extract_from_name': 

193 if name is None: 

194 type = None 

195 elif not isinstance(name, str): 

196 type = 'json' 

197 elif (name.startswith('postgresql://') or 

198 name.startswith('postgres://')): 

199 type = 'postgresql' 

200 elif name.startswith('mysql://') or name.startswith('mariadb://'): 

201 type = 'mysql' 

202 else: 

203 type = os.path.splitext(name)[1][1:] 

204 if type == '': 

205 raise ValueError('No file extension or database type given') 

206 

207 if type is None: 

208 return Database() 

209 

210 if not append and world.rank == 0: 

211 if isinstance(name, str) and os.path.isfile(name): 

212 os.remove(name) 

213 

214 if type not in ['postgresql', 'mysql'] and isinstance(name, str): 

215 name = os.path.abspath(name) 

216 

217 if type == 'json': 

218 from ase.db.jsondb import JSONDatabase 

219 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial) 

220 if type == 'db': 

221 from ase.db.sqlite import SQLite3Database 

222 return SQLite3Database(name, create_indices, use_lock_file, 

223 serial=serial) 

224 if type == 'postgresql': 

225 from ase.db.postgresql import PostgreSQLDatabase 

226 return PostgreSQLDatabase(name) 

227 

228 if type == 'mysql': 

229 from ase.db.mysql import MySQLDatabase 

230 return MySQLDatabase(name) 

231 raise ValueError('Unknown database type: ' + type) 

232 

233 

234def lock(method): 

235 """Decorator for using a lock-file.""" 

236 @functools.wraps(method) 

237 def new_method(self, *args, **kwargs): 

238 if self.lock is None: 

239 return method(self, *args, **kwargs) 

240 else: 

241 with self.lock: 

242 return method(self, *args, **kwargs) 

243 return new_method 

244 

245 

246def convert_str_to_int_float_or_str(value): 

247 """Safe eval()""" 

248 try: 

249 return int(value) 

250 except ValueError: 

251 try: 

252 value = float(value) 

253 except ValueError: 

254 value = {'True': True, 'False': False}.get(value, value) 

255 return value 

256 

257 

258def parse_selection(selection, **kwargs): 

259 if selection is None or selection == '': 

260 expressions = [] 

261 elif isinstance(selection, int): 

262 expressions = [('id', '=', selection)] 

263 elif isinstance(selection, list): 

264 expressions = selection 

265 else: 

266 expressions = [w.strip() for w in selection.split(',')] 

267 keys = [] 

268 comparisons = [] 

269 for expression in expressions: 

270 if isinstance(expression, (list, tuple)): 

271 comparisons.append(expression) 

272 continue 

273 if expression.count('<') == 2: 

274 value, expression = expression.split('<', 1) 

275 if expression[0] == '=': 

276 op = '>=' 

277 expression = expression[1:] 

278 else: 

279 op = '>' 

280 key = expression.split('<', 1)[0] 

281 comparisons.append((key, op, value)) 

282 for op in ['!=', '<=', '>=', '<', '>', '=']: 

283 if op in expression: 

284 break 

285 else: # no break 

286 if expression in atomic_numbers: 

287 comparisons.append((expression, '>', 0)) 

288 else: 

289 try: 

290 count = Formula(expression).count() 

291 except ValueError: 

292 keys.append(expression) 

293 else: 

294 comparisons.extend((symbol, '>', n - 1) 

295 for symbol, n in count.items()) 

296 continue 

297 key, value = expression.split(op) 

298 comparisons.append((key, op, value)) 

299 

300 cmps = [] 

301 for key, value in kwargs.items(): 

302 comparisons.append((key, '=', value)) 

303 

304 for key, op, value in comparisons: 

305 if key == 'age': 

306 key = 'ctime' 

307 op = invop[op] 

308 value = now() - time_string_to_float(value) 

309 elif key == 'formula': 

310 if op != '=': 

311 raise ValueError('Use fomula=...') 

312 f = Formula(value) 

313 count = f.count() 

314 cmps.extend((atomic_numbers[symbol], '=', n) 

315 for symbol, n in count.items()) 

316 key = 'natoms' 

317 value = len(f) 

318 elif key in atomic_numbers: 

319 key = atomic_numbers[key] 

320 value = int(value) 

321 elif isinstance(value, str): 

322 value = convert_str_to_int_float_or_str(value) 

323 if key in numeric_keys and not isinstance(value, (int, float)): 

324 msg = 'Wrong type for "{}{}{}" - must be a number' 

325 raise ValueError(msg.format(key, op, value)) 

326 cmps.append((key, op, value)) 

327 

328 return keys, cmps 

329 

330 

331class Database: 

332 """Base class for all databases.""" 

333 

334 def __init__(self, filename=None, create_indices=True, 

335 use_lock_file=False, serial=False): 

336 """Database object. 

337 

338 serial: bool 

339 Let someone else handle parallelization. Default behavior is 

340 to interact with the database on the master only and then 

341 distribute results to all slaves. 

342 """ 

343 if isinstance(filename, str): 

344 filename = os.path.expanduser(filename) 

345 self.filename = filename 

346 self.create_indices = create_indices 

347 if use_lock_file and isinstance(filename, str): 

348 self.lock = Lock(filename + '.lock', world=DummyMPI()) 

349 else: 

350 self.lock = None 

351 self.serial = serial 

352 

353 # Decription of columns and other stuff: 

354 self._metadata: Dict[str, Any] = None 

355 

356 @property 

357 def metadata(self) -> Dict[str, Any]: 

358 raise NotImplementedError 

359 

360 @parallel_function 

361 @lock 

362 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs): 

363 """Write atoms to database with key-value pairs. 

364 

365 atoms: Atoms object 

366 Write atomic numbers, positions, unit cell and boundary 

367 conditions. If a calculator is attached, write also already 

368 calculated properties such as the energy and forces. 

369 key_value_pairs: dict 

370 Dictionary of key-value pairs. Values must be strings or numbers. 

371 data: dict 

372 Extra stuff (not for searching). 

373 id: int 

374 Overwrite existing row. 

375 

376 Key-value pairs can also be set using keyword arguments:: 

377 

378 connection.write(atoms, name='ABC', frequency=42.0) 

379 

380 Returns integer id of the new row. 

381 """ 

382 

383 if atoms is None: 

384 atoms = Atoms() 

385 

386 kvp = dict(key_value_pairs) # modify a copy 

387 kvp.update(kwargs) 

388 

389 id = self._write(atoms, kvp, data, id) 

390 return id 

391 

392 def _write(self, atoms, key_value_pairs, data, id=None): 

393 check(key_value_pairs) 

394 return 1 

395 

396 @parallel_function 

397 @lock 

398 def reserve(self, **key_value_pairs): 

399 """Write empty row if not already present. 

400 

401 Usage:: 

402 

403 id = conn.reserve(key1=value1, key2=value2, ...) 

404 

405 Write an empty row with the given key-value pairs and 

406 return the integer id. If such a row already exists, don't write 

407 anything and return None. 

408 """ 

409 

410 for dct in self._select([], 

411 [(key, '=', value) 

412 for key, value in key_value_pairs.items()]): 

413 return None 

414 

415 atoms = Atoms() 

416 

417 calc_name = key_value_pairs.pop('calculator', None) 

418 

419 if calc_name: 

420 # Allow use of calculator key 

421 assert calc_name.lower() == calc_name 

422 

423 # Fake calculator class: 

424 class Fake: 

425 name = calc_name 

426 

427 def todict(self): 

428 return {} 

429 

430 def check_state(self, atoms): 

431 return ['positions'] 

432 

433 atoms.calc = Fake() 

434 

435 id = self._write(atoms, key_value_pairs, {}, None) 

436 

437 return id 

438 

439 def __delitem__(self, id): 

440 self.delete([id]) 

441 

442 def get_atoms(self, selection=None, 

443 add_additional_information=False, **kwargs): 

444 """Get Atoms object. 

445 

446 selection: int, str or list 

447 See the select() method. 

448 add_additional_information: bool 

449 Put key-value pairs and data into Atoms.info dictionary. 

450 

451 In addition, one can use keyword arguments to select specific 

452 key-value pairs. 

453 """ 

454 

455 row = self.get(selection, **kwargs) 

456 return row.toatoms(add_additional_information) 

457 

458 def __getitem__(self, selection): 

459 return self.get(selection) 

460 

461 def get(self, selection=None, **kwargs): 

462 """Select a single row and return it as a dictionary. 

463 

464 selection: int, str or list 

465 See the select() method. 

466 """ 

467 rows = list(self.select(selection, limit=2, **kwargs)) 

468 if not rows: 

469 raise KeyError('no match') 

470 assert len(rows) == 1, 'more than one row matched' 

471 return rows[0] 

472 

473 @parallel_generator 

474 def select(self, selection=None, filter=None, explain=False, 

475 verbosity=1, limit=None, offset=0, sort=None, 

476 include_data=True, columns='all', **kwargs): 

477 """Select rows. 

478 

479 Return AtomsRow iterator with results. Selection is done 

480 using key-value pairs and the special keys: 

481 

482 formula, age, user, calculator, natoms, energy, magmom 

483 and/or charge. 

484 

485 selection: int, str or list 

486 Can be: 

487 

488 * an integer id 

489 * a string like 'key=value', where '=' can also be one of 

490 '<=', '<', '>', '>=' or '!='. 

491 * a string like 'key' 

492 * comma separated strings like 'key1<value1,key2=value2,key' 

493 * list of strings or tuples: [('charge', '=', 1)]. 

494 filter: function 

495 A function that takes as input a row and returns True or False. 

496 explain: bool 

497 Explain query plan. 

498 verbosity: int 

499 Possible values: 0, 1 or 2. 

500 limit: int or None 

501 Limit selection. 

502 offset: int 

503 Offset into selected rows. 

504 sort: str 

505 Sort rows after key. Prepend with minus sign for a decending sort. 

506 include_data: bool 

507 Use include_data=False to skip reading data from rows. 

508 columns: 'all' or list of str 

509 Specify which columns from the SQL table to include. 

510 For example, if only the row id and the energy is needed, 

511 queries can be speeded up by setting columns=['id', 'energy']. 

512 """ 

513 

514 if sort: 

515 if sort == 'age': 

516 sort = '-ctime' 

517 elif sort == '-age': 

518 sort = 'ctime' 

519 elif sort.lstrip('-') == 'user': 

520 sort += 'name' 

521 

522 keys, cmps = parse_selection(selection, **kwargs) 

523 for row in self._select(keys, cmps, explain=explain, 

524 verbosity=verbosity, 

525 limit=limit, offset=offset, sort=sort, 

526 include_data=include_data, 

527 columns=columns): 

528 if filter is None or filter(row): 

529 yield row 

530 

531 def count(self, selection=None, **kwargs): 

532 """Count rows. 

533 

534 See the select() method for the selection syntax. Use db.count() or 

535 len(db) to count all rows. 

536 """ 

537 n = 0 

538 for row in self.select(selection, **kwargs): 

539 n += 1 

540 return n 

541 

542 def __len__(self): 

543 return self.count() 

544 

545 @parallel_function 

546 @lock 

547 def update(self, id, atoms=None, delete_keys=[], data=None, 

548 **add_key_value_pairs): 

549 """Update and/or delete key-value pairs of row(s). 

550 

551 id: int 

552 ID of row to update. 

553 atoms: Atoms object 

554 Optionally update the Atoms data (positions, cell, ...). 

555 data: dict 

556 Data dict to be added to the existing data. 

557 delete_keys: list of str 

558 Keys to remove. 

559 

560 Use keyword arguments to add new key-value pairs. 

561 

562 Returns number of key-value pairs added and removed. 

563 """ 

564 

565 if not isinstance(id, numbers.Integral): 

566 if isinstance(id, list): 

567 err = ('First argument must be an int and not a list.\n' 

568 'Do something like this instead:\n\n' 

569 'with db:\n' 

570 ' for id in ids:\n' 

571 ' db.update(id, ...)') 

572 raise ValueError(err) 

573 raise TypeError('id must be an int') 

574 

575 check(add_key_value_pairs) 

576 

577 row = self._get_row(id) 

578 kvp = row.key_value_pairs 

579 

580 n = len(kvp) 

581 for key in delete_keys: 

582 kvp.pop(key, None) 

583 n -= len(kvp) 

584 m = -len(kvp) 

585 kvp.update(add_key_value_pairs) 

586 m += len(kvp) 

587 

588 moredata = data 

589 data = row.get('data', {}) 

590 if moredata: 

591 data.update(moredata) 

592 if not data: 

593 data = None 

594 

595 if atoms: 

596 oldrow = row 

597 row = AtomsRow(atoms) 

598 # Copy over data, kvp, ctime, user and id 

599 row._data = oldrow._data 

600 row.__dict__.update(kvp) 

601 row._keys = list(kvp) 

602 row.ctime = oldrow.ctime 

603 row.user = oldrow.user 

604 row.id = id 

605 

606 if atoms or os.path.splitext(self.filename)[1] == '.json': 

607 self._write(row, kvp, data, row.id) 

608 else: 

609 self._update(row.id, kvp, data) 

610 return m, n 

611 

612 def delete(self, ids): 

613 """Delete rows.""" 

614 raise NotImplementedError 

615 

616 

617def time_string_to_float(s): 

618 if isinstance(s, (float, int)): 

619 return s 

620 s = s.replace(' ', '') 

621 if '+' in s: 

622 return sum(time_string_to_float(x) for x in s.split('+')) 

623 if s[-2].isalpha() and s[-1] == 's': 

624 s = s[:-1] 

625 i = 1 

626 while s[i].isdigit(): 

627 i += 1 

628 return seconds[s[i:]] * int(s[:i]) / YEAR 

629 

630 

631def float_to_time_string(t, long=False): 

632 t *= YEAR 

633 for s in 'yMwdhms': 

634 x = t / seconds[s] 

635 if x > 5: 

636 break 

637 if long: 

638 return f'{x:.3f} {longwords[s]}s' 

639 else: 

640 return f'{round(x):.0f}{s}' 

641 

642 

643def object_to_bytes(obj: Any) -> bytes: 

644 """Serialize Python object to bytes.""" 

645 parts = [b'12345678'] 

646 obj = o2b(obj, parts) 

647 offset = sum(len(part) for part in parts) 

648 x = np.array(offset, np.int64) 

649 if not np.little_endian: 

650 x.byteswap(True) 

651 parts[0] = x.tobytes() 

652 parts.append(json.dumps(obj, separators=(',', ':')).encode()) 

653 return b''.join(parts) 

654 

655 

656def bytes_to_object(b: bytes) -> Any: 

657 """Deserialize bytes to Python object.""" 

658 x = np.frombuffer(b[:8], np.int64) 

659 if not np.little_endian: 

660 x = x.byteswap() 

661 offset = x.item() 

662 obj = json.loads(b[offset:].decode()) 

663 return b2o(obj, b) 

664 

665 

666def o2b(obj: Any, parts: List[bytes]): 

667 if isinstance(obj, (int, float, bool, str, type(None))): 

668 return obj 

669 if isinstance(obj, dict): 

670 return {key: o2b(value, parts) for key, value in obj.items()} 

671 if isinstance(obj, (list, tuple)): 

672 return [o2b(value, parts) for value in obj] 

673 if isinstance(obj, np.ndarray): 

674 assert obj.dtype != object, \ 

675 'Cannot convert ndarray of type "object" to bytes.' 

676 offset = sum(len(part) for part in parts) 

677 if not np.little_endian: 

678 obj = obj.byteswap() 

679 parts.append(obj.tobytes()) 

680 return {'__ndarray__': [obj.shape, 

681 obj.dtype.name, 

682 offset]} 

683 if isinstance(obj, complex): 

684 return {'__complex__': [obj.real, obj.imag]} 

685 objtype = getattr(obj, 'ase_objtype') 

686 if objtype: 

687 dct = o2b(obj.todict(), parts) 

688 dct['__ase_objtype__'] = objtype 

689 return dct 

690 raise ValueError('Objects of type {type} not allowed' 

691 .format(type=type(obj))) 

692 

693 

694def b2o(obj: Any, b: bytes) -> Any: 

695 if isinstance(obj, (int, float, bool, str, type(None))): 

696 return obj 

697 

698 if isinstance(obj, list): 

699 return [b2o(value, b) for value in obj] 

700 

701 assert isinstance(obj, dict) 

702 

703 x = obj.get('__complex__') 

704 if x is not None: 

705 return complex(*x) 

706 

707 x = obj.get('__ndarray__') 

708 if x is not None: 

709 shape, name, offset = x 

710 dtype = np.dtype(name) 

711 size = dtype.itemsize * np.prod(shape).astype(int) 

712 a = np.frombuffer(b[offset:offset + size], dtype) 

713 a.shape = shape 

714 if not np.little_endian: 

715 a = a.byteswap() 

716 return a 

717 

718 dct = {key: b2o(value, b) for key, value in obj.items()} 

719 objtype = dct.pop('__ase_objtype__', None) 

720 if objtype is None: 

721 return dct 

722 return create_ase_object(objtype, dct)