Coverage for /builds/kinetik161/ase/ase/db/postgresql.py: 95.65%
138 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import json
3import numpy as np
4from psycopg2 import connect
5from psycopg2.extras import execute_values
7from ase.db.sqlite import (VERSION, SQLite3Database, index_statements,
8 init_statements)
9from ase.io.jsonio import create_ase_object, create_ndarray
10from ase.io.jsonio import encode as ase_encode
12jsonb_indices = [
13 'CREATE INDEX idxkeys ON systems USING GIN (key_value_pairs);',
14 'CREATE INDEX idxcalc ON systems USING GIN (calculator_parameters);']
17def remove_nan_and_inf(obj):
18 if isinstance(obj, float) and not np.isfinite(obj):
19 return {'__special_number__': str(obj)}
20 if isinstance(obj, list):
21 return [remove_nan_and_inf(x) for x in obj]
22 if isinstance(obj, dict):
23 return {key: remove_nan_and_inf(value) for key, value in obj.items()}
24 if isinstance(obj, np.ndarray) and not np.isfinite(obj).all():
25 return remove_nan_and_inf(obj.tolist())
26 return obj
29def insert_nan_and_inf(obj):
30 if isinstance(obj, dict) and '__special_number__' in obj:
31 return float(obj['__special_number__'])
32 if isinstance(obj, list):
33 return [insert_nan_and_inf(x) for x in obj]
34 if isinstance(obj, dict):
35 return {key: insert_nan_and_inf(value) for key, value in obj.items()}
36 return obj
39class Connection:
40 def __init__(self, con):
41 self.con = con
43 def cursor(self):
44 return Cursor(self.con.cursor())
46 def commit(self):
47 self.con.commit()
49 def close(self):
50 self.con.close()
53class Cursor:
54 def __init__(self, cur):
55 self.cur = cur
57 def fetchone(self):
58 return self.cur.fetchone()
60 def fetchall(self):
61 return self.cur.fetchall()
63 def execute(self, statement, *args):
64 self.cur.execute(statement.replace('?', '%s'), *args)
66 def executemany(self, statement, *args):
67 if len(args[0]) > 0:
68 N = len(args[0][0])
69 else:
70 return
71 if 'INSERT INTO systems' in statement:
72 q = 'DEFAULT' + ', ' + ', '.join('?' * N) # DEFAULT for id
73 else:
74 q = ', '.join('?' * N)
75 statement = statement.replace(f'({q})', '%s')
76 q = '({})'.format(q.replace('?', '%s'))
78 execute_values(self.cur, statement.replace('?', '%s'),
79 argslist=args[0], template=q, page_size=len(args[0]))
82def insert_ase_and_ndarray_objects(obj):
83 if isinstance(obj, dict):
84 objtype = obj.pop('__ase_objtype__', None)
85 if objtype is not None:
86 return create_ase_object(objtype,
87 insert_ase_and_ndarray_objects(obj))
88 data = obj.get('__ndarray__')
89 if data is not None:
90 return create_ndarray(*data)
91 return {key: insert_ase_and_ndarray_objects(value)
92 for key, value in obj.items()}
93 if isinstance(obj, list):
94 return [insert_ase_and_ndarray_objects(value) for value in obj]
95 return obj
98class PostgreSQLDatabase(SQLite3Database):
99 type = 'postgresql'
100 default = 'DEFAULT'
102 def encode(self, obj, binary=False):
103 return ase_encode(remove_nan_and_inf(obj))
105 def decode(self, obj, lazy=False):
106 return insert_ase_and_ndarray_objects(insert_nan_and_inf(obj))
108 def blob(self, array):
109 """Convert array to blob/buffer object."""
111 if array is None:
112 return None
113 if len(array) == 0:
114 array = np.zeros(0)
115 if array.dtype == np.int64:
116 array = array.astype(np.int32)
117 return array.tolist()
119 def deblob(self, buf, dtype=float, shape=None):
120 """Convert blob/buffer object to ndarray of correct dtype and shape.
122 (without creating an extra view)."""
123 if buf is None:
124 return None
125 return np.array(buf, dtype=dtype)
127 def _connect(self):
128 return Connection(connect(self.filename))
130 def _initialize(self, con):
131 if self.initialized:
132 return
134 self._metadata = {}
136 cur = con.cursor()
137 cur.execute("show search_path;")
138 schema = cur.fetchone()[0].split(', ')
139 if schema[0] == '"$user"':
140 schema = schema[1]
141 else:
142 schema = schema[0]
144 cur.execute("""
145 SELECT EXISTS(select * from information_schema.tables where
146 table_name='information' and table_schema='{}');
147 """.format(schema))
149 if not cur.fetchone()[0]: # information schema doesn't exist.
150 # Initialize database:
151 sql = ';\n'.join(init_statements)
152 sql = schema_update(sql)
153 cur.execute(sql)
154 if self.create_indices:
155 cur.execute(';\n'.join(index_statements))
156 cur.execute(';\n'.join(jsonb_indices))
157 con.commit()
158 self.version = VERSION
159 else:
160 cur.execute('select * from information;')
161 for name, value in cur.fetchall():
162 if name == 'version':
163 self.version = int(value)
164 elif name == 'metadata':
165 self._metadata = json.loads(value)
167 assert 5 < self.version <= VERSION
169 self.initialized = True
171 def get_offset_string(self, offset, limit=None):
172 # postgresql allows you to set offset without setting limit;
173 # very practical
174 return f'\nOFFSET {offset}'
176 def get_last_id(self, cur):
177 cur.execute('SELECT last_value FROM systems_id_seq')
178 id = cur.fetchone()[0]
179 return int(id)
182def schema_update(sql):
183 for a, b in [('REAL', 'DOUBLE PRECISION'),
184 ('INTEGER PRIMARY KEY AUTOINCREMENT',
185 'SERIAL PRIMARY KEY')]:
186 sql = sql.replace(a, b)
188 arrays_1D = ['numbers', 'initial_magmoms', 'initial_charges', 'masses',
189 'tags', 'momenta', 'stress', 'dipole', 'magmoms', 'charges']
191 arrays_2D = ['positions', 'cell', 'forces']
193 txt2jsonb = ['calculator_parameters', 'key_value_pairs']
195 for column in arrays_1D:
196 if column in ['numbers', 'tags']:
197 dtype = 'INTEGER'
198 else:
199 dtype = 'DOUBLE PRECISION'
200 sql = sql.replace(f'{column} BLOB,',
201 f'{column} {dtype}[],')
202 for column in arrays_2D:
203 sql = sql.replace(f'{column} BLOB,',
204 f'{column} DOUBLE PRECISION[][],')
205 for column in txt2jsonb:
206 sql = sql.replace(f'{column} TEXT,',
207 f'{column} JSONB,')
209 sql = sql.replace('data BLOB,', 'data JSONB,')
211 return sql