Coverage for /builds/kinetik161/ase/ase/db/core.py: 85.68%
391 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import functools
2import json
3import numbers
4import operator
5import os
6import re
7import warnings
8from time import time
9from typing import Any, Dict, List
11import numpy as np
13from ase.atoms import Atoms
14from ase.calculators.calculator import all_changes, all_properties
15from ase.data import atomic_numbers
16from ase.db.row import AtomsRow
17from ase.formula import Formula
18from ase.io.jsonio import create_ase_object
19from ase.parallel import DummyMPI, parallel_function, parallel_generator, world
20from ase.utils import Lock, PurePath
22T2000 = 946681200.0 # January 1. 2000
23YEAR = 31557600.0 # 365.25 days
26@functools.total_ordering
27class KeyDescription:
28 _subscript = re.compile(r'`(.)_(.)`')
29 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`')
31 def __init__(self, key, shortdesc=None, longdesc=None, unit=''):
32 self.key = key
34 if shortdesc is None:
35 shortdesc = key
37 if longdesc is None:
38 longdesc = shortdesc
40 self.shortdesc = shortdesc
41 self.longdesc = longdesc
43 # Somewhat arbitrary that we do this conversion. Can we avoid that?
44 # Previously done in create_key_descriptions().
45 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit)
46 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit)
47 unit = unit.replace(r'\text{', '').replace('}', '')
49 self.unit = unit
51 def __repr__(self):
52 cls = type(self).__name__
53 return (f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, '
54 f'unit={self.unit!r})')
56 # The templates like to sort key descriptions by shortdesc.
57 def __eq__(self, other):
58 return self.shortdesc == getattr(other, 'shortdesc', None)
60 def __lt__(self, other):
61 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc)
64def get_key_descriptions():
65 KD = KeyDescription
66 return {keydesc.key: keydesc for keydesc in [
67 KD('id', 'ID', 'Uniqe row ID'),
68 KD('age', 'Age', 'Time since creation'),
69 KD('formula', 'Formula', 'Chemical formula'),
70 KD('pbc', 'PBC', 'Periodic boundary conditions'),
71 KD('user', 'Username'),
72 KD('calculator', 'Calculator', 'ASE-calculator name'),
73 KD('energy', 'Energy', 'Total energy', unit='eV'),
74 KD('natoms', 'Number of atoms'),
75 KD('fmax', 'Maximum force', unit='eV/Å'),
76 KD('smax', 'Maximum stress', 'Maximum stress on unit cell',
77 unit='eV/ų'),
78 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'),
79 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'),
80 KD('magmom', 'Magnetic moment', unit='μ_B'),
81 KD('unique_id', 'Unique ID', 'Random (unique) ID'),
82 KD('volume', 'Volume', 'Volume of unit cell', unit='ų')
83 ]}
86def now():
87 """Return time since January 1. 2000 in years."""
88 return (time() - T2000) / YEAR
91seconds = {'s': 1,
92 'm': 60,
93 'h': 3600,
94 'd': 86400,
95 'w': 604800,
96 'M': 2629800,
97 'y': YEAR}
99longwords = {'s': 'second',
100 'm': 'minute',
101 'h': 'hour',
102 'd': 'day',
103 'w': 'week',
104 'M': 'month',
105 'y': 'year'}
107ops = {'<': operator.lt,
108 '<=': operator.le,
109 '=': operator.eq,
110 '>=': operator.ge,
111 '>': operator.gt,
112 '!=': operator.ne}
114invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}
116word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')
118reserved_keys = set(all_properties +
119 all_changes +
120 list(atomic_numbers) +
121 ['id', 'unique_id', 'ctime', 'mtime', 'user',
122 'fmax', 'smax',
123 'momenta', 'constraints', 'natoms', 'formula', 'age',
124 'calculator', 'calculator_parameters',
125 'key_value_pairs', 'data'])
127numeric_keys = {'id', 'energy', 'magmom', 'charge', 'natoms'}
130def check(key_value_pairs):
131 for key, value in key_value_pairs.items():
132 if key == "external_tables":
133 # Checks for external_tables are not
134 # performed
135 continue
137 if not word.match(key) or key in reserved_keys:
138 raise ValueError(f'Bad key: {key}')
139 try:
140 Formula(key, strict=True)
141 except ValueError:
142 pass
143 else:
144 warnings.warn(
145 'It is best not to use keys ({0}) that are also a '
146 'chemical formula. If you do a "db.select({0!r})",'
147 'you will not find rows with your key. Instead, you wil get '
148 'rows containing the atoms in the formula!'.format(key))
149 if not isinstance(value, (numbers.Real, str, np.bool_)):
150 raise ValueError(f'Bad value for {key!r}: {value}')
151 if isinstance(value, str):
152 for t in [int, float]:
153 if str_represents(value, t):
154 raise ValueError(
155 'Value ' + value + ' is put in as string ' +
156 'but can be interpreted as ' +
157 f'{t.__name__}! Please convert ' +
158 f'to {t.__name__} using ' +
159 f'{t.__name__}(value) before ' +
160 'writing to the database OR change ' +
161 'to a different string.')
164def str_represents(value, t=int):
165 try:
166 t(value)
167 except ValueError:
168 return False
169 return True
172def connect(name, type='extract_from_name', create_indices=True,
173 use_lock_file=True, append=True, serial=False):
174 """Create connection to database.
176 name: str
177 Filename or address of database.
178 type: str
179 One of 'json', 'db', 'postgresql',
180 (JSON, SQLite, PostgreSQL).
181 Default is 'extract_from_name', which will guess the type
182 from the name.
183 use_lock_file: bool
184 You can turn this off if you know what you are doing ...
185 append: bool
186 Use append=False to start a new database.
187 """
189 if isinstance(name, PurePath):
190 name = str(name)
192 if type == 'extract_from_name':
193 if name is None:
194 type = None
195 elif not isinstance(name, str):
196 type = 'json'
197 elif (name.startswith('postgresql://') or
198 name.startswith('postgres://')):
199 type = 'postgresql'
200 elif name.startswith('mysql://') or name.startswith('mariadb://'):
201 type = 'mysql'
202 else:
203 type = os.path.splitext(name)[1][1:]
204 if type == '':
205 raise ValueError('No file extension or database type given')
207 if type is None:
208 return Database()
210 if not append and world.rank == 0:
211 if isinstance(name, str) and os.path.isfile(name):
212 os.remove(name)
214 if type not in ['postgresql', 'mysql'] and isinstance(name, str):
215 name = os.path.abspath(name)
217 if type == 'json':
218 from ase.db.jsondb import JSONDatabase
219 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial)
220 if type == 'db':
221 from ase.db.sqlite import SQLite3Database
222 return SQLite3Database(name, create_indices, use_lock_file,
223 serial=serial)
224 if type == 'postgresql':
225 from ase.db.postgresql import PostgreSQLDatabase
226 return PostgreSQLDatabase(name)
228 if type == 'mysql':
229 from ase.db.mysql import MySQLDatabase
230 return MySQLDatabase(name)
231 raise ValueError('Unknown database type: ' + type)
234def lock(method):
235 """Decorator for using a lock-file."""
236 @functools.wraps(method)
237 def new_method(self, *args, **kwargs):
238 if self.lock is None:
239 return method(self, *args, **kwargs)
240 else:
241 with self.lock:
242 return method(self, *args, **kwargs)
243 return new_method
246def convert_str_to_int_float_or_str(value):
247 """Safe eval()"""
248 try:
249 return int(value)
250 except ValueError:
251 try:
252 value = float(value)
253 except ValueError:
254 value = {'True': True, 'False': False}.get(value, value)
255 return value
258def parse_selection(selection, **kwargs):
259 if selection is None or selection == '':
260 expressions = []
261 elif isinstance(selection, int):
262 expressions = [('id', '=', selection)]
263 elif isinstance(selection, list):
264 expressions = selection
265 else:
266 expressions = [w.strip() for w in selection.split(',')]
267 keys = []
268 comparisons = []
269 for expression in expressions:
270 if isinstance(expression, (list, tuple)):
271 comparisons.append(expression)
272 continue
273 if expression.count('<') == 2:
274 value, expression = expression.split('<', 1)
275 if expression[0] == '=':
276 op = '>='
277 expression = expression[1:]
278 else:
279 op = '>'
280 key = expression.split('<', 1)[0]
281 comparisons.append((key, op, value))
282 for op in ['!=', '<=', '>=', '<', '>', '=']:
283 if op in expression:
284 break
285 else: # no break
286 if expression in atomic_numbers:
287 comparisons.append((expression, '>', 0))
288 else:
289 try:
290 count = Formula(expression).count()
291 except ValueError:
292 keys.append(expression)
293 else:
294 comparisons.extend((symbol, '>', n - 1)
295 for symbol, n in count.items())
296 continue
297 key, value = expression.split(op)
298 comparisons.append((key, op, value))
300 cmps = []
301 for key, value in kwargs.items():
302 comparisons.append((key, '=', value))
304 for key, op, value in comparisons:
305 if key == 'age':
306 key = 'ctime'
307 op = invop[op]
308 value = now() - time_string_to_float(value)
309 elif key == 'formula':
310 if op != '=':
311 raise ValueError('Use fomula=...')
312 f = Formula(value)
313 count = f.count()
314 cmps.extend((atomic_numbers[symbol], '=', n)
315 for symbol, n in count.items())
316 key = 'natoms'
317 value = len(f)
318 elif key in atomic_numbers:
319 key = atomic_numbers[key]
320 value = int(value)
321 elif isinstance(value, str):
322 value = convert_str_to_int_float_or_str(value)
323 if key in numeric_keys and not isinstance(value, (int, float)):
324 msg = 'Wrong type for "{}{}{}" - must be a number'
325 raise ValueError(msg.format(key, op, value))
326 cmps.append((key, op, value))
328 return keys, cmps
331class Database:
332 """Base class for all databases."""
334 def __init__(self, filename=None, create_indices=True,
335 use_lock_file=False, serial=False):
336 """Database object.
338 serial: bool
339 Let someone else handle parallelization. Default behavior is
340 to interact with the database on the master only and then
341 distribute results to all slaves.
342 """
343 if isinstance(filename, str):
344 filename = os.path.expanduser(filename)
345 self.filename = filename
346 self.create_indices = create_indices
347 if use_lock_file and isinstance(filename, str):
348 self.lock = Lock(filename + '.lock', world=DummyMPI())
349 else:
350 self.lock = None
351 self.serial = serial
353 # Decription of columns and other stuff:
354 self._metadata: Dict[str, Any] = None
356 @property
357 def metadata(self) -> Dict[str, Any]:
358 raise NotImplementedError
360 @parallel_function
361 @lock
362 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs):
363 """Write atoms to database with key-value pairs.
365 atoms: Atoms object
366 Write atomic numbers, positions, unit cell and boundary
367 conditions. If a calculator is attached, write also already
368 calculated properties such as the energy and forces.
369 key_value_pairs: dict
370 Dictionary of key-value pairs. Values must be strings or numbers.
371 data: dict
372 Extra stuff (not for searching).
373 id: int
374 Overwrite existing row.
376 Key-value pairs can also be set using keyword arguments::
378 connection.write(atoms, name='ABC', frequency=42.0)
380 Returns integer id of the new row.
381 """
383 if atoms is None:
384 atoms = Atoms()
386 kvp = dict(key_value_pairs) # modify a copy
387 kvp.update(kwargs)
389 id = self._write(atoms, kvp, data, id)
390 return id
392 def _write(self, atoms, key_value_pairs, data, id=None):
393 check(key_value_pairs)
394 return 1
396 @parallel_function
397 @lock
398 def reserve(self, **key_value_pairs):
399 """Write empty row if not already present.
401 Usage::
403 id = conn.reserve(key1=value1, key2=value2, ...)
405 Write an empty row with the given key-value pairs and
406 return the integer id. If such a row already exists, don't write
407 anything and return None.
408 """
410 for dct in self._select([],
411 [(key, '=', value)
412 for key, value in key_value_pairs.items()]):
413 return None
415 atoms = Atoms()
417 calc_name = key_value_pairs.pop('calculator', None)
419 if calc_name:
420 # Allow use of calculator key
421 assert calc_name.lower() == calc_name
423 # Fake calculator class:
424 class Fake:
425 name = calc_name
427 def todict(self):
428 return {}
430 def check_state(self, atoms):
431 return ['positions']
433 atoms.calc = Fake()
435 id = self._write(atoms, key_value_pairs, {}, None)
437 return id
439 def __delitem__(self, id):
440 self.delete([id])
442 def get_atoms(self, selection=None,
443 add_additional_information=False, **kwargs):
444 """Get Atoms object.
446 selection: int, str or list
447 See the select() method.
448 add_additional_information: bool
449 Put key-value pairs and data into Atoms.info dictionary.
451 In addition, one can use keyword arguments to select specific
452 key-value pairs.
453 """
455 row = self.get(selection, **kwargs)
456 return row.toatoms(add_additional_information)
458 def __getitem__(self, selection):
459 return self.get(selection)
461 def get(self, selection=None, **kwargs):
462 """Select a single row and return it as a dictionary.
464 selection: int, str or list
465 See the select() method.
466 """
467 rows = list(self.select(selection, limit=2, **kwargs))
468 if not rows:
469 raise KeyError('no match')
470 assert len(rows) == 1, 'more than one row matched'
471 return rows[0]
473 @parallel_generator
474 def select(self, selection=None, filter=None, explain=False,
475 verbosity=1, limit=None, offset=0, sort=None,
476 include_data=True, columns='all', **kwargs):
477 """Select rows.
479 Return AtomsRow iterator with results. Selection is done
480 using key-value pairs and the special keys:
482 formula, age, user, calculator, natoms, energy, magmom
483 and/or charge.
485 selection: int, str or list
486 Can be:
488 * an integer id
489 * a string like 'key=value', where '=' can also be one of
490 '<=', '<', '>', '>=' or '!='.
491 * a string like 'key'
492 * comma separated strings like 'key1<value1,key2=value2,key'
493 * list of strings or tuples: [('charge', '=', 1)].
494 filter: function
495 A function that takes as input a row and returns True or False.
496 explain: bool
497 Explain query plan.
498 verbosity: int
499 Possible values: 0, 1 or 2.
500 limit: int or None
501 Limit selection.
502 offset: int
503 Offset into selected rows.
504 sort: str
505 Sort rows after key. Prepend with minus sign for a decending sort.
506 include_data: bool
507 Use include_data=False to skip reading data from rows.
508 columns: 'all' or list of str
509 Specify which columns from the SQL table to include.
510 For example, if only the row id and the energy is needed,
511 queries can be speeded up by setting columns=['id', 'energy'].
512 """
514 if sort:
515 if sort == 'age':
516 sort = '-ctime'
517 elif sort == '-age':
518 sort = 'ctime'
519 elif sort.lstrip('-') == 'user':
520 sort += 'name'
522 keys, cmps = parse_selection(selection, **kwargs)
523 for row in self._select(keys, cmps, explain=explain,
524 verbosity=verbosity,
525 limit=limit, offset=offset, sort=sort,
526 include_data=include_data,
527 columns=columns):
528 if filter is None or filter(row):
529 yield row
531 def count(self, selection=None, **kwargs):
532 """Count rows.
534 See the select() method for the selection syntax. Use db.count() or
535 len(db) to count all rows.
536 """
537 n = 0
538 for row in self.select(selection, **kwargs):
539 n += 1
540 return n
542 def __len__(self):
543 return self.count()
545 @parallel_function
546 @lock
547 def update(self, id, atoms=None, delete_keys=[], data=None,
548 **add_key_value_pairs):
549 """Update and/or delete key-value pairs of row(s).
551 id: int
552 ID of row to update.
553 atoms: Atoms object
554 Optionally update the Atoms data (positions, cell, ...).
555 data: dict
556 Data dict to be added to the existing data.
557 delete_keys: list of str
558 Keys to remove.
560 Use keyword arguments to add new key-value pairs.
562 Returns number of key-value pairs added and removed.
563 """
565 if not isinstance(id, numbers.Integral):
566 if isinstance(id, list):
567 err = ('First argument must be an int and not a list.\n'
568 'Do something like this instead:\n\n'
569 'with db:\n'
570 ' for id in ids:\n'
571 ' db.update(id, ...)')
572 raise ValueError(err)
573 raise TypeError('id must be an int')
575 check(add_key_value_pairs)
577 row = self._get_row(id)
578 kvp = row.key_value_pairs
580 n = len(kvp)
581 for key in delete_keys:
582 kvp.pop(key, None)
583 n -= len(kvp)
584 m = -len(kvp)
585 kvp.update(add_key_value_pairs)
586 m += len(kvp)
588 moredata = data
589 data = row.get('data', {})
590 if moredata:
591 data.update(moredata)
592 if not data:
593 data = None
595 if atoms:
596 oldrow = row
597 row = AtomsRow(atoms)
598 # Copy over data, kvp, ctime, user and id
599 row._data = oldrow._data
600 row.__dict__.update(kvp)
601 row._keys = list(kvp)
602 row.ctime = oldrow.ctime
603 row.user = oldrow.user
604 row.id = id
606 if atoms or os.path.splitext(self.filename)[1] == '.json':
607 self._write(row, kvp, data, row.id)
608 else:
609 self._update(row.id, kvp, data)
610 return m, n
612 def delete(self, ids):
613 """Delete rows."""
614 raise NotImplementedError
617def time_string_to_float(s):
618 if isinstance(s, (float, int)):
619 return s
620 s = s.replace(' ', '')
621 if '+' in s:
622 return sum(time_string_to_float(x) for x in s.split('+'))
623 if s[-2].isalpha() and s[-1] == 's':
624 s = s[:-1]
625 i = 1
626 while s[i].isdigit():
627 i += 1
628 return seconds[s[i:]] * int(s[:i]) / YEAR
631def float_to_time_string(t, long=False):
632 t *= YEAR
633 for s in 'yMwdhms':
634 x = t / seconds[s]
635 if x > 5:
636 break
637 if long:
638 return f'{x:.3f} {longwords[s]}s'
639 else:
640 return f'{round(x):.0f}{s}'
643def object_to_bytes(obj: Any) -> bytes:
644 """Serialize Python object to bytes."""
645 parts = [b'12345678']
646 obj = o2b(obj, parts)
647 offset = sum(len(part) for part in parts)
648 x = np.array(offset, np.int64)
649 if not np.little_endian:
650 x.byteswap(True)
651 parts[0] = x.tobytes()
652 parts.append(json.dumps(obj, separators=(',', ':')).encode())
653 return b''.join(parts)
656def bytes_to_object(b: bytes) -> Any:
657 """Deserialize bytes to Python object."""
658 x = np.frombuffer(b[:8], np.int64)
659 if not np.little_endian:
660 x = x.byteswap()
661 offset = x.item()
662 obj = json.loads(b[offset:].decode())
663 return b2o(obj, b)
666def o2b(obj: Any, parts: List[bytes]):
667 if isinstance(obj, (int, float, bool, str, type(None))):
668 return obj
669 if isinstance(obj, dict):
670 return {key: o2b(value, parts) for key, value in obj.items()}
671 if isinstance(obj, (list, tuple)):
672 return [o2b(value, parts) for value in obj]
673 if isinstance(obj, np.ndarray):
674 assert obj.dtype != object, \
675 'Cannot convert ndarray of type "object" to bytes.'
676 offset = sum(len(part) for part in parts)
677 if not np.little_endian:
678 obj = obj.byteswap()
679 parts.append(obj.tobytes())
680 return {'__ndarray__': [obj.shape,
681 obj.dtype.name,
682 offset]}
683 if isinstance(obj, complex):
684 return {'__complex__': [obj.real, obj.imag]}
685 objtype = getattr(obj, 'ase_objtype')
686 if objtype:
687 dct = o2b(obj.todict(), parts)
688 dct['__ase_objtype__'] = objtype
689 return dct
690 raise ValueError('Objects of type {type} not allowed'
691 .format(type=type(obj)))
694def b2o(obj: Any, b: bytes) -> Any:
695 if isinstance(obj, (int, float, bool, str, type(None))):
696 return obj
698 if isinstance(obj, list):
699 return [b2o(value, b) for value in obj]
701 assert isinstance(obj, dict)
703 x = obj.get('__complex__')
704 if x is not None:
705 return complex(*x)
707 x = obj.get('__ndarray__')
708 if x is not None:
709 shape, name, offset = x
710 dtype = np.dtype(name)
711 size = dtype.itemsize * np.prod(shape).astype(int)
712 a = np.frombuffer(b[offset:offset + size], dtype)
713 a.shape = shape
714 if not np.little_endian:
715 a = a.byteswap()
716 return a
718 dct = {key: b2o(value, b) for key, value in obj.items()}
719 objtype = dct.pop('__ase_objtype__', None)
720 if objtype is None:
721 return dct
722 return create_ase_object(objtype, dct)