Coverage for /builds/kinetik161/ase/ase/db/postgresql.py: 95.65%

138 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import json 

2 

3import numpy as np 

4from psycopg2 import connect 

5from psycopg2.extras import execute_values 

6 

7from ase.db.sqlite import (VERSION, SQLite3Database, index_statements, 

8 init_statements) 

9from ase.io.jsonio import create_ase_object, create_ndarray 

10from ase.io.jsonio import encode as ase_encode 

11 

12jsonb_indices = [ 

13 'CREATE INDEX idxkeys ON systems USING GIN (key_value_pairs);', 

14 'CREATE INDEX idxcalc ON systems USING GIN (calculator_parameters);'] 

15 

16 

17def remove_nan_and_inf(obj): 

18 if isinstance(obj, float) and not np.isfinite(obj): 

19 return {'__special_number__': str(obj)} 

20 if isinstance(obj, list): 

21 return [remove_nan_and_inf(x) for x in obj] 

22 if isinstance(obj, dict): 

23 return {key: remove_nan_and_inf(value) for key, value in obj.items()} 

24 if isinstance(obj, np.ndarray) and not np.isfinite(obj).all(): 

25 return remove_nan_and_inf(obj.tolist()) 

26 return obj 

27 

28 

29def insert_nan_and_inf(obj): 

30 if isinstance(obj, dict) and '__special_number__' in obj: 

31 return float(obj['__special_number__']) 

32 if isinstance(obj, list): 

33 return [insert_nan_and_inf(x) for x in obj] 

34 if isinstance(obj, dict): 

35 return {key: insert_nan_and_inf(value) for key, value in obj.items()} 

36 return obj 

37 

38 

39class Connection: 

40 def __init__(self, con): 

41 self.con = con 

42 

43 def cursor(self): 

44 return Cursor(self.con.cursor()) 

45 

46 def commit(self): 

47 self.con.commit() 

48 

49 def close(self): 

50 self.con.close() 

51 

52 

53class Cursor: 

54 def __init__(self, cur): 

55 self.cur = cur 

56 

57 def fetchone(self): 

58 return self.cur.fetchone() 

59 

60 def fetchall(self): 

61 return self.cur.fetchall() 

62 

63 def execute(self, statement, *args): 

64 self.cur.execute(statement.replace('?', '%s'), *args) 

65 

66 def executemany(self, statement, *args): 

67 if len(args[0]) > 0: 

68 N = len(args[0][0]) 

69 else: 

70 return 

71 if 'INSERT INTO systems' in statement: 

72 q = 'DEFAULT' + ', ' + ', '.join('?' * N) # DEFAULT for id 

73 else: 

74 q = ', '.join('?' * N) 

75 statement = statement.replace(f'({q})', '%s') 

76 q = '({})'.format(q.replace('?', '%s')) 

77 

78 execute_values(self.cur, statement.replace('?', '%s'), 

79 argslist=args[0], template=q, page_size=len(args[0])) 

80 

81 

82def insert_ase_and_ndarray_objects(obj): 

83 if isinstance(obj, dict): 

84 objtype = obj.pop('__ase_objtype__', None) 

85 if objtype is not None: 

86 return create_ase_object(objtype, 

87 insert_ase_and_ndarray_objects(obj)) 

88 data = obj.get('__ndarray__') 

89 if data is not None: 

90 return create_ndarray(*data) 

91 return {key: insert_ase_and_ndarray_objects(value) 

92 for key, value in obj.items()} 

93 if isinstance(obj, list): 

94 return [insert_ase_and_ndarray_objects(value) for value in obj] 

95 return obj 

96 

97 

98class PostgreSQLDatabase(SQLite3Database): 

99 type = 'postgresql' 

100 default = 'DEFAULT' 

101 

102 def encode(self, obj, binary=False): 

103 return ase_encode(remove_nan_and_inf(obj)) 

104 

105 def decode(self, obj, lazy=False): 

106 return insert_ase_and_ndarray_objects(insert_nan_and_inf(obj)) 

107 

108 def blob(self, array): 

109 """Convert array to blob/buffer object.""" 

110 

111 if array is None: 

112 return None 

113 if len(array) == 0: 

114 array = np.zeros(0) 

115 if array.dtype == np.int64: 

116 array = array.astype(np.int32) 

117 return array.tolist() 

118 

119 def deblob(self, buf, dtype=float, shape=None): 

120 """Convert blob/buffer object to ndarray of correct dtype and shape. 

121 

122 (without creating an extra view).""" 

123 if buf is None: 

124 return None 

125 return np.array(buf, dtype=dtype) 

126 

127 def _connect(self): 

128 return Connection(connect(self.filename)) 

129 

130 def _initialize(self, con): 

131 if self.initialized: 

132 return 

133 

134 self._metadata = {} 

135 

136 cur = con.cursor() 

137 cur.execute("show search_path;") 

138 schema = cur.fetchone()[0].split(', ') 

139 if schema[0] == '"$user"': 

140 schema = schema[1] 

141 else: 

142 schema = schema[0] 

143 

144 cur.execute(""" 

145 SELECT EXISTS(select * from information_schema.tables where 

146 table_name='information' and table_schema='{}'); 

147 """.format(schema)) 

148 

149 if not cur.fetchone()[0]: # information schema doesn't exist. 

150 # Initialize database: 

151 sql = ';\n'.join(init_statements) 

152 sql = schema_update(sql) 

153 cur.execute(sql) 

154 if self.create_indices: 

155 cur.execute(';\n'.join(index_statements)) 

156 cur.execute(';\n'.join(jsonb_indices)) 

157 con.commit() 

158 self.version = VERSION 

159 else: 

160 cur.execute('select * from information;') 

161 for name, value in cur.fetchall(): 

162 if name == 'version': 

163 self.version = int(value) 

164 elif name == 'metadata': 

165 self._metadata = json.loads(value) 

166 

167 assert 5 < self.version <= VERSION 

168 

169 self.initialized = True 

170 

171 def get_offset_string(self, offset, limit=None): 

172 # postgresql allows you to set offset without setting limit; 

173 # very practical 

174 return f'\nOFFSET {offset}' 

175 

176 def get_last_id(self, cur): 

177 cur.execute('SELECT last_value FROM systems_id_seq') 

178 id = cur.fetchone()[0] 

179 return int(id) 

180 

181 

182def schema_update(sql): 

183 for a, b in [('REAL', 'DOUBLE PRECISION'), 

184 ('INTEGER PRIMARY KEY AUTOINCREMENT', 

185 'SERIAL PRIMARY KEY')]: 

186 sql = sql.replace(a, b) 

187 

188 arrays_1D = ['numbers', 'initial_magmoms', 'initial_charges', 'masses', 

189 'tags', 'momenta', 'stress', 'dipole', 'magmoms', 'charges'] 

190 

191 arrays_2D = ['positions', 'cell', 'forces'] 

192 

193 txt2jsonb = ['calculator_parameters', 'key_value_pairs'] 

194 

195 for column in arrays_1D: 

196 if column in ['numbers', 'tags']: 

197 dtype = 'INTEGER' 

198 else: 

199 dtype = 'DOUBLE PRECISION' 

200 sql = sql.replace(f'{column} BLOB,', 

201 f'{column} {dtype}[],') 

202 for column in arrays_2D: 

203 sql = sql.replace(f'{column} BLOB,', 

204 f'{column} DOUBLE PRECISION[][],') 

205 for column in txt2jsonb: 

206 sql = sql.replace(f'{column} TEXT,', 

207 f'{column} JSONB,') 

208 

209 sql = sql.replace('data BLOB,', 'data JSONB,') 

210 

211 return sql