Coverage for /builds/kinetik161/ase/ase/db/sqlite.py: 90.13%
557 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""SQLite3 backend.
3Versions:
51) Added 3 more columns.
62) Changed "user" to "username".
73) Now adding keys to keyword table and added an "information" table containing
8 a version number.
94) Got rid of keywords.
105) Add fmax, smax, mass, volume, charge
116) Use REAL for magmom and drop possibility for non-collinear spin
127) Volume can be None
138) Added name='metadata' row to "information" table
149) Row data is now stored in binary format.
15"""
17import json
18import numbers
19import os
20import sqlite3
21import sys
22from contextlib import contextmanager
24import numpy as np
26import ase.io.jsonio
27from ase.calculators.calculator import all_properties
28from ase.data import atomic_numbers
29from ase.db.core import (Database, bytes_to_object, invop, lock, now,
30 object_to_bytes, ops, parse_selection)
31from ase.db.row import AtomsRow
32from ase.parallel import parallel_function
34VERSION = 9
36init_statements = [
37 """CREATE TABLE systems (
38 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name
39 unique_id TEXT UNIQUE,
40 ctime REAL,
41 mtime REAL,
42 username TEXT,
43 numbers BLOB, -- stuff that defines an Atoms object
44 positions BLOB,
45 cell BLOB,
46 pbc INTEGER,
47 initial_magmoms BLOB,
48 initial_charges BLOB,
49 masses BLOB,
50 tags BLOB,
51 momenta BLOB,
52 constraints TEXT, -- constraints and calculator
53 calculator TEXT,
54 calculator_parameters TEXT,
55 energy REAL, -- calculated properties
56 free_energy REAL,
57 forces BLOB,
58 stress BLOB,
59 dipole BLOB,
60 magmoms BLOB,
61 magmom REAL,
62 charges BLOB,
63 key_value_pairs TEXT, -- key-value pairs and data as json
64 data BLOB,
65 natoms INTEGER, -- stuff for making queries faster
66 fmax REAL,
67 smax REAL,
68 volume REAL,
69 mass REAL,
70 charge REAL)""",
72 """CREATE TABLE species (
73 Z INTEGER,
74 n INTEGER,
75 id INTEGER,
76 FOREIGN KEY (id) REFERENCES systems(id))""",
78 """CREATE TABLE keys (
79 key TEXT,
80 id INTEGER,
81 FOREIGN KEY (id) REFERENCES systems(id))""",
83 """CREATE TABLE text_key_values (
84 key TEXT,
85 value TEXT,
86 id INTEGER,
87 FOREIGN KEY (id) REFERENCES systems(id))""",
89 """CREATE TABLE number_key_values (
90 key TEXT,
91 value REAL,
92 id INTEGER,
93 FOREIGN KEY (id) REFERENCES systems(id))""",
95 """CREATE TABLE information (
96 name TEXT,
97 value TEXT)""",
99 f"INSERT INTO information VALUES ('version', '{VERSION}')"]
101index_statements = [
102 'CREATE INDEX unique_id_index ON systems(unique_id)',
103 'CREATE INDEX ctime_index ON systems(ctime)',
104 'CREATE INDEX username_index ON systems(username)',
105 'CREATE INDEX calculator_index ON systems(calculator)',
106 'CREATE INDEX species_index ON species(Z)',
107 'CREATE INDEX key_index ON keys(key)',
108 'CREATE INDEX text_index ON text_key_values(key)',
109 'CREATE INDEX number_index ON number_key_values(key)']
111all_tables = ['systems', 'species', 'keys',
112 'text_key_values', 'number_key_values']
115def float_if_not_none(x):
116 """Convert numpy.float64 to float - old db-interfaces need that."""
117 if x is not None:
118 return float(x)
121class SQLite3Database(Database):
122 type = 'db'
123 initialized = False
124 _allow_reading_old_format = False
125 default = 'NULL' # used for autoincrement id
126 connection = None
127 version = None
128 columnnames = [line.split()[0].lstrip()
129 for line in init_statements[0].splitlines()[1:]]
131 def encode(self, obj, binary=False):
132 if binary:
133 return object_to_bytes(obj)
134 return ase.io.jsonio.encode(obj)
136 def decode(self, txt, lazy=False):
137 if lazy:
138 return txt
139 if isinstance(txt, str):
140 return ase.io.jsonio.decode(txt)
141 return bytes_to_object(txt)
143 def blob(self, array):
144 """Convert array to blob/buffer object."""
146 if array is None:
147 return None
148 if len(array) == 0:
149 array = np.zeros(0)
150 if array.dtype == np.int64:
151 array = array.astype(np.int32)
152 if not np.little_endian:
153 array = array.byteswap()
154 return memoryview(np.ascontiguousarray(array))
156 def deblob(self, buf, dtype=float, shape=None):
157 """Convert blob/buffer object to ndarray of correct dtype and shape.
159 (without creating an extra view)."""
160 if buf is None:
161 return None
162 if len(buf) == 0:
163 array = np.zeros(0, dtype)
164 else:
165 array = np.frombuffer(buf, dtype)
166 if not np.little_endian:
167 array = array.byteswap()
168 if shape is not None:
169 array.shape = shape
170 return array
172 def _connect(self):
173 return sqlite3.connect(self.filename, timeout=20)
175 def __enter__(self):
176 assert self.connection is None
177 self.change_count = 0
178 self.connection = self._connect()
179 return self
181 def __exit__(self, exc_type, exc_value, tb):
182 if exc_type is None:
183 self.connection.commit()
184 else:
185 self.connection.rollback()
186 self.connection.close()
187 self.connection = None
189 @contextmanager
190 def managed_connection(self, commit_frequency=5000):
191 try:
192 con = self.connection or self._connect()
193 self._initialize(con)
194 yield con
195 except ValueError as exc:
196 if self.connection is None:
197 con.close()
198 raise exc
199 else:
200 if self.connection is None:
201 con.commit()
202 con.close()
203 else:
204 self.change_count += 1
205 if self.change_count % commit_frequency == 0:
206 con.commit()
208 def _initialize(self, con):
209 if self.initialized:
210 return
212 self._metadata = {}
214 cur = con.execute(
215 'SELECT COUNT(*) FROM sqlite_master WHERE name="systems"')
217 if cur.fetchone()[0] == 0:
218 for statement in init_statements:
219 con.execute(statement)
220 if self.create_indices:
221 for statement in index_statements:
222 con.execute(statement)
223 con.commit()
224 self.version = VERSION
225 else:
226 cur = con.execute(
227 'SELECT COUNT(*) FROM sqlite_master WHERE name="user_index"')
228 if cur.fetchone()[0] == 1:
229 # Old version with "user" instead of "username" column
230 self.version = 1
231 else:
232 try:
233 cur = con.execute(
234 'SELECT value FROM information WHERE name="version"')
235 except sqlite3.OperationalError:
236 self.version = 2
237 else:
238 self.version = int(cur.fetchone()[0])
240 cur = con.execute(
241 'SELECT value FROM information WHERE name="metadata"')
242 results = cur.fetchall()
243 if results:
244 self._metadata = json.loads(results[0][0])
246 if self.version > VERSION:
247 raise OSError('Can not read new ase.db format '
248 '(version {}). Please update to latest ASE.'
249 .format(self.version))
250 if self.version < 5 and not self._allow_reading_old_format:
251 raise OSError('Please convert to new format. ' +
252 'Use: python -m ase.db.convert ' + self.filename)
254 self.initialized = True
256 def _write(self, atoms, key_value_pairs, data, id):
257 ext_tables = key_value_pairs.pop("external_tables", {})
258 Database._write(self, atoms, key_value_pairs, data)
260 mtime = now()
262 encode = self.encode
263 blob = self.blob
265 if not isinstance(atoms, AtomsRow):
266 row = AtomsRow(atoms)
267 row.ctime = mtime
268 row.user = os.getenv('USER')
269 else:
270 row = atoms
271 # Extract the external tables from AtomsRow
272 names = self._get_external_table_names()
273 for name in names:
274 new_table = row.get(name, {})
275 if new_table:
276 ext_tables[name] = new_table
278 if not id and not key_value_pairs and not ext_tables:
279 key_value_pairs = row.key_value_pairs
281 for k, v in ext_tables.items():
282 dtype = self._guess_type(v)
283 self._create_table_if_not_exists(k, dtype)
285 constraints = row._constraints
286 if constraints:
287 if isinstance(constraints, list):
288 constraints = encode(constraints)
289 else:
290 constraints = None
292 values = (row.unique_id,
293 row.ctime,
294 mtime,
295 row.user,
296 blob(row.numbers),
297 blob(row.positions),
298 blob(row.cell),
299 int(np.dot(row.pbc, [1, 2, 4])),
300 blob(row.get('initial_magmoms')),
301 blob(row.get('initial_charges')),
302 blob(row.get('masses')),
303 blob(row.get('tags')),
304 blob(row.get('momenta')),
305 constraints)
307 if 'calculator' in row:
308 values += (row.calculator, encode(row.calculator_parameters))
309 else:
310 values += (None, None)
312 if not data:
313 data = row._data
315 with self.managed_connection() as con:
316 if not isinstance(data, (str, bytes)):
317 data = encode(data, binary=self.version >= 9)
319 values += (row.get('energy'),
320 row.get('free_energy'),
321 blob(row.get('forces')),
322 blob(row.get('stress')),
323 blob(row.get('dipole')),
324 blob(row.get('magmoms')),
325 row.get('magmom'),
326 blob(row.get('charges')),
327 encode(key_value_pairs),
328 data,
329 len(row.numbers),
330 float_if_not_none(row.get('fmax')),
331 float_if_not_none(row.get('smax')),
332 float_if_not_none(row.get('volume')),
333 float(row.mass),
334 float(row.charge))
336 cur = con.cursor()
337 if id is None:
338 q = self.default + ', ' + ', '.join('?' * len(values))
339 cur.execute(f'INSERT INTO systems VALUES ({q})',
340 values)
341 id = self.get_last_id(cur)
342 else:
343 self._delete(cur, [id], ['keys', 'text_key_values',
344 'number_key_values', 'species'])
345 q = ', '.join(name + '=?' for name in self.columnnames[1:])
346 cur.execute(f'UPDATE systems SET {q} WHERE id=?',
347 values + (id,))
349 count = row.count_atoms()
350 if count:
351 species = [(atomic_numbers[symbol], n, id)
352 for symbol, n in count.items()]
353 cur.executemany('INSERT INTO species VALUES (?, ?, ?)',
354 species)
356 text_key_values = []
357 number_key_values = []
358 for key, value in key_value_pairs.items():
359 if isinstance(value, (numbers.Real, np.bool_)):
360 number_key_values.append([key, float(value), id])
361 else:
362 assert isinstance(value, str)
363 text_key_values.append([key, value, id])
365 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
366 text_key_values)
367 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
368 number_key_values)
369 cur.executemany('INSERT INTO keys VALUES (?, ?)',
370 [(key, id) for key in key_value_pairs])
372 # Insert entries in the valid tables
373 for tabname in ext_tables.keys():
374 entries = ext_tables[tabname]
375 entries['id'] = id
376 self._insert_in_external_table(
377 cur, name=tabname, entries=ext_tables[tabname])
379 return id
381 def _update(self, id, key_value_pairs, data=None):
382 """Update key_value_pairs and data for a single row """
383 encode = self.encode
384 ext_tables = key_value_pairs.pop('external_tables', {})
386 for k, v in ext_tables.items():
387 dtype = self._guess_type(v)
388 self._create_table_if_not_exists(k, dtype)
390 mtime = now()
391 with self.managed_connection() as con:
392 cur = con.cursor()
393 cur.execute(
394 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?',
395 (mtime, encode(key_value_pairs), id))
396 if data:
397 if not isinstance(data, (str, bytes)):
398 data = encode(data, binary=self.version >= 9)
399 cur.execute('UPDATE systems set data=? where id=?', (data, id))
401 self._delete(cur, [id], ['keys', 'text_key_values',
402 'number_key_values'])
404 text_key_values = []
405 number_key_values = []
406 for key, value in key_value_pairs.items():
407 if isinstance(value, (numbers.Real, np.bool_)):
408 number_key_values.append([key, float(value), id])
409 else:
410 assert isinstance(value, str)
411 text_key_values.append([key, value, id])
413 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
414 text_key_values)
415 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
416 number_key_values)
417 cur.executemany('INSERT INTO keys VALUES (?, ?)',
418 [(key, id) for key in key_value_pairs])
420 # Insert entries in the valid tables
421 for tabname in ext_tables.keys():
422 entries = ext_tables[tabname]
423 entries['id'] = id
424 self._insert_in_external_table(
425 cur, name=tabname, entries=ext_tables[tabname])
427 return id
429 def get_last_id(self, cur):
430 cur.execute('SELECT seq FROM sqlite_sequence WHERE name="systems"')
431 result = cur.fetchone()
432 if result is not None:
433 id = result[0]
434 return id
435 else:
436 return 0
438 def _get_row(self, id):
439 with self.managed_connection() as con:
440 cur = con.cursor()
441 if id is None:
442 cur.execute('SELECT COUNT(*) FROM systems')
443 assert cur.fetchone()[0] == 1
444 cur.execute('SELECT * FROM systems')
445 else:
446 cur.execute('SELECT * FROM systems WHERE id=?', (id,))
447 values = cur.fetchone()
449 return self._convert_tuple_to_row(values)
451 def _convert_tuple_to_row(self, values):
452 deblob = self.deblob
453 decode = self.decode
455 values = self._old2new(values)
456 dct = {'id': values[0],
457 'unique_id': values[1],
458 'ctime': values[2],
459 'mtime': values[3],
460 'user': values[4],
461 'numbers': deblob(values[5], np.int32),
462 'positions': deblob(values[6], shape=(-1, 3)),
463 'cell': deblob(values[7], shape=(3, 3))}
465 if values[8] is not None:
466 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool)
467 if values[9] is not None:
468 dct['initial_magmoms'] = deblob(values[9])
469 if values[10] is not None:
470 dct['initial_charges'] = deblob(values[10])
471 if values[11] is not None:
472 dct['masses'] = deblob(values[11])
473 if values[12] is not None:
474 dct['tags'] = deblob(values[12], np.int32)
475 if values[13] is not None:
476 dct['momenta'] = deblob(values[13], shape=(-1, 3))
477 if values[14] is not None:
478 dct['constraints'] = values[14]
479 if values[15] is not None:
480 dct['calculator'] = values[15]
481 if values[16] is not None:
482 dct['calculator_parameters'] = decode(values[16])
483 if values[17] is not None:
484 dct['energy'] = values[17]
485 if values[18] is not None:
486 dct['free_energy'] = values[18]
487 if values[19] is not None:
488 dct['forces'] = deblob(values[19], shape=(-1, 3))
489 if values[20] is not None:
490 dct['stress'] = deblob(values[20])
491 if values[21] is not None:
492 dct['dipole'] = deblob(values[21])
493 if values[22] is not None:
494 dct['magmoms'] = deblob(values[22])
495 if values[23] is not None:
496 dct['magmom'] = values[23]
497 if values[24] is not None:
498 dct['charges'] = deblob(values[24])
499 if values[25] != '{}':
500 dct['key_value_pairs'] = decode(values[25])
501 if len(values) >= 27 and values[26] != 'null':
502 dct['data'] = decode(values[26], lazy=True)
504 # Now we need to update with info from the external tables
505 external_tab = self._get_external_table_names()
506 tables = {}
507 for tab in external_tab:
508 row = self._read_external_table(tab, dct["id"])
509 tables[tab] = row
511 dct.update(tables)
512 return AtomsRow(dct)
514 def _old2new(self, values):
515 if self.type == 'postgresql':
516 assert self.version >= 8, 'Your db-version is too old!'
517 assert self.version >= 4, 'Your db-file is too old!'
518 if self.version < 5:
519 pass # should be ok for reading by convert.py script
520 if self.version < 6:
521 m = values[23]
522 if m is not None and not isinstance(m, float):
523 magmom = float(self.deblob(m, shape=()))
524 values = values[:23] + (magmom,) + values[24:]
525 return values
527 def create_select_statement(self, keys, cmps,
528 sort=None, order=None, sort_table=None,
529 what='systems.*'):
530 tables = ['systems']
531 where = []
532 args = []
533 for key in keys:
534 if key == 'forces':
535 where.append('systems.fmax IS NOT NULL')
536 elif key == 'strain':
537 where.append('systems.smax IS NOT NULL')
538 elif key in ['energy', 'fmax', 'smax',
539 'constraints', 'calculator']:
540 where.append(f'systems.{key} IS NOT NULL')
541 else:
542 if '-' not in key:
543 q = 'systems.id in (select id from keys where key=?)'
544 else:
545 key = key.replace('-', '')
546 q = 'systems.id not in (select id from keys where key=?)'
547 where.append(q)
548 args.append(key)
550 # Special handling of "H=0" and "H<2" type of selections:
551 bad = {}
552 for key, op, value in cmps:
553 if isinstance(key, int):
554 bad[key] = bad.get(key, True) and ops[op](0, value)
556 for key, op, value in cmps:
557 if key in ['id', 'energy', 'magmom', 'ctime', 'user',
558 'calculator', 'natoms', 'pbc', 'unique_id',
559 'fmax', 'smax', 'volume', 'mass', 'charge']:
560 if key == 'user':
561 key = 'username'
562 elif key == 'pbc':
563 assert op in ['=', '!=']
564 value = int(np.dot([x == 'T' for x in value], [1, 2, 4]))
565 elif key == 'magmom':
566 assert self.version >= 6, 'Update your db-file'
567 where.append(f'systems.{key}{op}?')
568 args.append(value)
569 elif isinstance(key, int):
570 if self.type == 'postgresql':
571 where.append(
572 'cardinality(array_positions(' +
573 f'numbers::int[], ?)){op}?')
574 args += [key, value]
575 else:
576 if bad[key]:
577 where.append(
578 'systems.id not in (select id from species ' +
579 f'where Z=? and n{invop[op]}?)')
580 args += [key, value]
581 else:
582 where.append('systems.id in (select id from species ' +
583 f'where Z=? and n{op}?)')
584 args += [key, value]
586 elif self.type == 'postgresql':
587 jsonop = '->'
588 if isinstance(value, str):
589 jsonop = '->>'
590 elif isinstance(value, bool):
591 jsonop = '->>'
592 value = str(value).lower()
593 where.append("systems.key_value_pairs {} '{}'{}?"
594 .format(jsonop, key, op))
595 args.append(str(value))
597 elif isinstance(value, str):
598 where.append('systems.id in (select id from text_key_values ' +
599 f'where key=? and value{op}?)')
600 args += [key, value]
601 else:
602 where.append(
603 'systems.id in (select id from number_key_values ' +
604 f'where key=? and value{op}?)')
605 args += [key, float(value)]
607 if sort:
608 if sort_table != 'systems':
609 tables.append(f'{sort_table} AS sort_table')
610 where.append('systems.id=sort_table.id AND '
611 'sort_table.key=?')
612 args.append(sort)
613 sort_table = 'sort_table'
614 sort = 'value'
616 sql = f'SELECT {what} FROM\n ' + ', '.join(tables)
617 if where:
618 sql += '\n WHERE\n ' + ' AND\n '.join(where)
619 if sort:
620 # XXX use "?" instead of "{}"
621 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format(
622 sort_table, sort, order)
624 return sql, args
626 def _select(self, keys, cmps, explain=False, verbosity=0,
627 limit=None, offset=0, sort=None, include_data=True,
628 columns='all'):
630 values = np.array([None for i in range(27)])
631 values[25] = '{}'
632 values[26] = 'null'
634 if columns == 'all':
635 columnindex = list(range(26))
636 else:
637 columnindex = [c for c in range(0, 26)
638 if self.columnnames[c] in columns]
639 if include_data:
640 columnindex.append(26)
642 if sort:
643 if sort[0] == '-':
644 order = 'DESC'
645 sort = sort[1:]
646 else:
647 order = 'ASC'
648 if sort in ['id', 'energy', 'username', 'calculator',
649 'ctime', 'mtime', 'magmom', 'pbc',
650 'fmax', 'smax', 'volume', 'mass', 'charge', 'natoms']:
651 sort_table = 'systems'
652 else:
653 for dct in self._select(keys + [sort], cmps=[], limit=1,
654 include_data=False,
655 columns=['key_value_pairs']):
656 if isinstance(dct['key_value_pairs'][sort], str):
657 sort_table = 'text_key_values'
658 else:
659 sort_table = 'number_key_values'
660 break
661 else:
662 # No rows. Just pick a table:
663 sort_table = 'number_key_values'
665 else:
666 order = None
667 sort_table = None
669 what = ', '.join('systems.' + name
670 for name in
671 np.array(self.columnnames)[np.array(columnindex)])
673 sql, args = self.create_select_statement(keys, cmps, sort, order,
674 sort_table, what)
676 if explain:
677 sql = 'EXPLAIN QUERY PLAN ' + sql
679 if limit:
680 sql += f'\nLIMIT {limit}'
682 if offset:
683 sql += self.get_offset_string(offset, limit=limit)
685 if verbosity == 2:
686 print(sql, args)
688 with self.managed_connection() as con:
689 cur = con.cursor()
690 cur.execute(sql, args)
691 if explain:
692 for row in cur.fetchall():
693 yield {'explain': row}
694 else:
695 n = 0
696 for shortvalues in cur.fetchall():
697 values[columnindex] = shortvalues
698 yield self._convert_tuple_to_row(tuple(values))
699 n += 1
701 if sort and sort_table != 'systems':
702 # Yield rows without sort key last:
703 if limit is not None:
704 if n == limit:
705 return
706 limit -= n
707 for row in self._select(keys + ['-' + sort], cmps,
708 limit=limit, offset=offset,
709 include_data=include_data,
710 columns=columns):
711 yield row
713 def get_offset_string(self, offset, limit=None):
714 sql = ''
715 if not limit:
716 # In sqlite you cannot have offset without limit, so we
717 # set it to -1 meaning no limit
718 sql += '\nLIMIT -1'
719 sql += f'\nOFFSET {offset}'
720 return sql
722 @parallel_function
723 def count(self, selection=None, **kwargs):
724 keys, cmps = parse_selection(selection, **kwargs)
725 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)')
727 with self.managed_connection() as con:
728 cur = con.cursor()
729 cur.execute(sql, args)
730 return cur.fetchone()[0]
732 def analyse(self):
733 with self.managed_connection() as con:
734 con.execute('ANALYZE')
736 @parallel_function
737 @lock
738 def delete(self, ids):
739 if len(ids) == 0:
740 return
741 table_names = self._get_external_table_names() + all_tables[::-1]
742 with self.managed_connection() as con:
743 self._delete(con.cursor(), ids,
744 tables=table_names)
745 self.vacuum()
747 def _delete(self, cur, ids, tables=None):
748 tables = tables or all_tables[::-1]
749 for table in tables:
750 cur.execute('DELETE FROM {} WHERE id in ({});'.
751 format(table, ', '.join([str(id) for id in ids])))
753 def vacuum(self):
754 if not self.type == 'db':
755 return
757 with self.managed_connection() as con:
758 con.commit()
759 con.cursor().execute("VACUUM")
761 @property
762 def metadata(self):
763 if self._metadata is None:
764 self._initialize(self._connect())
765 return self._metadata.copy()
767 @metadata.setter
768 def metadata(self, dct):
769 self._metadata = dct
770 md = json.dumps(dct)
771 with self.managed_connection() as con:
772 cur = con.cursor()
773 cur.execute(
774 "SELECT COUNT(*) FROM information WHERE name='metadata'")
776 if cur.fetchone()[0]:
777 cur.execute(
778 "UPDATE information SET value=? WHERE name='metadata'",
779 [md])
780 else:
781 cur.execute('INSERT INTO information VALUES (?, ?)',
782 ('metadata', md))
784 def _get_external_table_names(self, db_con=None):
785 """Return a list with the external table names."""
786 sql = "SELECT value FROM information WHERE name='external_table_name'"
787 with self.managed_connection() as con:
788 cur = con.cursor()
789 cur.execute(sql)
790 ext_tab_names = [x[0] for x in cur.fetchall()]
791 return ext_tab_names
793 def _external_table_exists(self, name):
794 """Return True if an external table name exists."""
795 return name in self._get_external_table_names()
797 def _create_table_if_not_exists(self, name, dtype):
798 """Create a new table if it does not exits.
800 Arguments
801 ==========
802 name: str
803 Name of the new table
804 dtype: str
805 Datatype of the value field (typically REAL, INTEGER, TEXT etc.)
806 """
808 taken_names = set(all_tables + all_properties + self.columnnames)
809 if name in taken_names:
810 raise ValueError("External table can not be any of {}"
811 "".format(taken_names))
813 if self._external_table_exists(name):
814 return
816 sql = f"CREATE TABLE IF NOT EXISTS {name} "
817 sql += f"(key TEXT, value {dtype}, id INTEGER, "
818 sql += "FOREIGN KEY (id) REFERENCES systems(id))"
819 sql2 = "INSERT INTO information VALUES (?, ?)"
820 with self.managed_connection() as con:
821 cur = con.cursor()
822 cur.execute(sql)
823 # Insert an entry saying that there is a new external table
824 # present and an entry with the datatype
825 cur.execute(sql2, ("external_table_name", name))
826 cur.execute(sql2, (name + "_dtype", dtype))
828 def delete_external_table(self, name):
829 """Delete an external table."""
830 if not self._external_table_exists(name):
831 return
833 with self.managed_connection() as con:
834 cur = con.cursor()
836 sql = f"DROP TABLE {name}"
837 cur.execute(sql)
839 sql = "DELETE FROM information WHERE value=?"
840 cur.execute(sql, (name,))
841 sql = "DELETE FROM information WHERE name=?"
842 cur.execute(sql, (name + "_dtype",))
844 def _convert_to_recognized_types(self, value):
845 """Convert Numpy types to python types."""
846 if np.issubdtype(type(value), np.integer):
847 return int(value)
848 elif np.issubdtype(type(value), np.floating):
849 return float(value)
850 return value
852 def _insert_in_external_table(self, cursor, name=None, entries=None):
853 """Insert into external table"""
854 if name is None or entries is None:
855 # There is nothing to do
856 return
858 id = entries.pop("id")
859 dtype = self._guess_type(entries)
860 expected_dtype = self._get_value_type_of_table(cursor, name)
861 if dtype != expected_dtype:
862 raise ValueError("The provided data type for table {} "
863 "is {}, while it is initialized to "
864 "be of type {}"
865 "".format(name, dtype, expected_dtype))
867 # First we check if entries already exists
868 cursor.execute(f"SELECT key FROM {name} WHERE id=?", (id,))
869 updates = []
870 for item in cursor.fetchall():
871 value = entries.pop(item[0], None)
872 if value is not None:
873 updates.append(
874 (value, id, self._convert_to_recognized_types(item[0])))
876 # Update entry if key and ID already exists
877 sql = f"UPDATE {name} SET value=? WHERE id=? AND key=?"
878 cursor.executemany(sql, updates)
880 # Insert the ones that does not already exist
881 inserts = [(k, self._convert_to_recognized_types(v), id)
882 for k, v in entries.items()]
883 sql = f"INSERT INTO {name} VALUES (?, ?, ?)"
884 cursor.executemany(sql, inserts)
886 def _guess_type(self, entries):
887 """Guess the type based on the first entry."""
888 values = [v for _, v in entries.items()]
890 # Check if all datatypes are the same
891 all_types = [type(v) for v in values]
892 if any(t != all_types[0] for t in all_types):
893 typenames = [t.__name__ for t in all_types]
894 raise ValueError("Inconsistent datatypes in the table. "
895 "given types: {}".format(typenames))
897 val = values[0]
898 if isinstance(val, int) or np.issubdtype(type(val), np.integer):
899 return "INTEGER"
900 if isinstance(val, float) or np.issubdtype(type(val), np.floating):
901 return "REAL"
902 if isinstance(val, str):
903 return "TEXT"
904 raise ValueError("Unknown datatype!")
906 def _get_value_type_of_table(self, cursor, tab_name):
907 """Return the expected value name."""
908 sql = "SELECT value FROM information WHERE name=?"
909 cursor.execute(sql, (tab_name + "_dtype",))
910 return cursor.fetchone()[0]
912 def _read_external_table(self, name, id):
913 """Read row from external table."""
915 with self.managed_connection() as con:
916 cur = con.cursor()
917 cur.execute(f"SELECT * FROM {name} WHERE id=?", (id,))
918 items = cur.fetchall()
919 dictionary = {item[0]: item[1] for item in items}
921 return dictionary
923 def get_all_key_names(self):
924 """Create set of all key names."""
925 with self.managed_connection() as con:
926 cur = con.cursor()
927 cur.execute('SELECT DISTINCT key FROM keys;')
928 all_keys = {row[0] for row in cur.fetchall()}
929 return all_keys
932if __name__ == '__main__':
933 from ase.db import connect
934 con = connect(sys.argv[1])
935 con._initialize(con._connect())
936 print('Version:', con.version)