Coverage for /builds/kinetik161/ase/ase/db/mysql.py: 98.51%
134 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import json
2import sys
3from copy import deepcopy
5import numpy as np
6from pymysql import connect
7from pymysql.err import ProgrammingError
9import ase.io.jsonio
10from ase.db.postgresql import insert_nan_and_inf, remove_nan_and_inf
11from ase.db.sqlite import VERSION, SQLite3Database, init_statements
14class Connection:
15 """
16 Wrapper for the MySQL connection
18 Arguments
19 =========
20 host: str
21 Hostname. For a local database this is localhost.
22 user: str
23 Username.
24 passwd: str
25 Password
26 db_name: str
27 Name of the database
28 port: int
29 Port
30 binary_prefix: bool
31 MySQL checks if an argument can be interpreted as a UTF-8 string. This
32 check fails for binary values. Binary values need to have _binary
33 prefix in MySQL. By setting this to True, the prefix is automatically
34 added for binary values.
35 """
37 def __init__(self, host=None, user=None, passwd=None, port=3306,
38 db_name=None, binary_prefix=False):
39 self.con = connect(host=host, user=user, passwd=passwd, db=db_name,
40 binary_prefix=binary_prefix, port=port)
42 def cursor(self):
43 return MySQLCursor(self.con.cursor())
45 def commit(self):
46 self.con.commit()
48 def close(self):
49 self.con.close()
51 def rollback(self):
52 self.con.rollback()
55class MySQLCursor:
56 """
57 Wrapper for the MySQL cursor. The most important task performed by this
58 class is to translate SQLite queries to MySQL. Translation is needed
59 because ASE DB uses some field names that are reserved words in MySQL.
60 Thus, these has to mapped onto other field names.
61 """
62 sql_replace = [
63 (' key TEXT', ' attribute_key TEXT'),
64 ('(key TEXT', '(attribute_key TEXT'),
65 ('SELECT key FROM', 'SELECT attribute_key FROM'),
66 ('SELECT DISTINCT key FROM keys',
67 'SELECT DISTINCT attribute_key FROM attribute_keys'),
68 ('?', '%s'),
69 (' keys ', ' attribute_keys '),
70 (' key=', ' attribute_key='),
71 ('table.key', 'table.attribute_key'),
72 (' IF NOT EXISTS', '')]
74 def __init__(self, cur):
75 self.cur = cur
77 def execute(self, sql, params=None):
79 # Replace external table key -> attribute_key
80 for substibution in self.sql_replace:
81 sql = sql.replace(substibution[0], substibution[1])
83 if params is None:
84 params = ()
86 self.cur.execute(sql, params)
88 def fetchone(self):
89 return self.cur.fetchone()
91 def fetchall(self):
92 return self.cur.fetchall()
94 def _replace_nan_inf_kvp(self, values):
95 for item in values:
96 if not np.isfinite(item[1]):
97 item[1] = sys.float_info.max / 2
98 return values
100 def executemany(self, sql, values):
101 if 'number_key_values' in sql:
102 values = self._replace_nan_inf_kvp(values)
104 for substibution in self.sql_replace:
105 sql = sql.replace(substibution[0], substibution[1])
106 self.cur.executemany(sql, values)
109class MySQLDatabase(SQLite3Database):
110 """
111 ASE interface to a MySQL database (via pymysql package).
113 Arguments
114 ==========
115 url: str
116 URL to the database. It should have the form
117 mysql://username:password@host:port/database_name.
118 Example URL with the following credentials
119 username: john
120 password: johnspasswd
121 host: localhost (i.e. server is running locally)
122 database: johns_calculations
123 port: 3306
124 mysql://john:johnspasswd@localhost:3306/johns_calculations
125 create_indices: bool
126 Carried over from parent class. Currently indices are not
127 created for MySQL, as TEXT fields cannot be hashed by MySQL.
128 use_lock_file: bool
129 See SQLite
130 serial: bool
131 See SQLite
132 """
133 type = 'mysql'
134 default = 'DEFAULT'
136 def __init__(self, url=None, create_indices=True,
137 use_lock_file=False, serial=False):
138 super().__init__(
139 url, create_indices, use_lock_file, serial)
141 self.host = None
142 self.username = None
143 self.passwd = None
144 self.db_name = None
145 self.port = 3306
146 self._parse_url(url)
148 def _parse_url(self, url):
149 """
150 Parse the URL
151 """
152 url = url.replace('mysql://', '')
153 url = url.replace('mariadb://', '')
155 splitted = url.split(':', 1)
156 self.username = splitted[0]
158 splitted = splitted[1].split('@')
159 self.passwd = splitted[0]
161 splitted = splitted[1].split('/')
162 host_and_port = splitted[0].split(':')
163 self.host = host_and_port[0]
164 self.port = int(host_and_port[1])
165 self.db_name = splitted[1]
167 def _connect(self):
168 return Connection(host=self.host, user=self.username,
169 passwd=self.passwd, db_name=self.db_name,
170 port=self.port, binary_prefix=True)
172 def _initialize(self, con):
173 if self.initialized:
174 return
176 cur = con.cursor()
178 information_exists = True
179 self._metadata = {}
180 try:
181 cur.execute("SELECT 1 FROM information")
182 except ProgrammingError:
183 information_exists = False
185 if not information_exists:
186 # We need to initialize the DB
187 # MySQL require that id is explicitly set as primary key
188 # in the systems table
189 init_statements_cpy = deepcopy(init_statements)
190 init_statements_cpy[0] = init_statements_cpy[0][:-1] + \
191 ', PRIMARY KEY(id))'
193 statements = schema_update(init_statements_cpy)
194 for statement in statements:
195 cur.execute(statement)
196 con.commit()
197 self.version = VERSION
198 else:
199 cur.execute('select * from information')
201 for name, value in cur.fetchall():
202 if name == 'version':
203 self.version = int(value)
204 elif name == 'metadata':
205 self._metadata = json.loads(value)
207 self.initialized = True
209 def blob(self, array):
210 if array is None:
211 return None
212 return super().blob(array).tobytes()
214 def get_offset_string(self, offset, limit=None):
215 sql = ''
216 if not limit:
217 # mysql does not allow for setting limit to -1 so
218 # instead we set a large number
219 sql += '\nLIMIT 10000000000'
220 sql += f'\nOFFSET {offset}'
221 return sql
223 def get_last_id(self, cur):
224 cur.execute('select max(id) as ID from systems')
225 last_id = cur.fetchone()[0]
226 return last_id
228 def create_select_statement(self, keys, cmps,
229 sort=None, order=None, sort_table=None,
230 what='systems.*'):
231 sql, value = super().create_select_statement(
232 keys, cmps, sort, order, sort_table, what)
234 for subst in MySQLCursor.sql_replace:
235 sql = sql.replace(subst[0], subst[1])
236 return sql, value
238 def encode(self, obj, binary=False):
239 return ase.io.jsonio.encode(remove_nan_and_inf(obj))
241 def decode(self, obj, lazy=False):
242 return insert_nan_and_inf(ase.io.jsonio.decode(obj))
245def schema_update(statements):
246 for i, statement in enumerate(statements):
247 for a, b in [('REAL', 'DOUBLE'),
248 ('INTEGER PRIMARY KEY AUTOINCREMENT',
249 'INT NOT NULL AUTO_INCREMENT')]:
250 statements[i] = statement.replace(a, b)
252 # MySQL does not support UNIQUE constraint on TEXT
253 # need to use VARCHAR. The unique_id is generated with
254 # randint(16**31, 16**32-1) so it will contain 32
255 # hex-characters
256 statements[0] = statements[0].replace('TEXT UNIQUE', 'VARCHAR(32) UNIQUE')
258 # keys is a reserved word in MySQL redefine this table name to
259 # attribute_keys
260 statements[2] = statements[2].replace('keys', 'attribute_keys')
262 txt2jsonb = ['calculator_parameters', 'key_value_pairs']
264 for column in txt2jsonb:
265 statements[0] = statements[0].replace(
266 f'{column} TEXT,',
267 f'{column} JSON,')
269 statements[0] = statements[0].replace('data BLOB,', 'data JSON,')
271 tab_with_key_field = ['attribute_keys', 'number_key_values',
272 'text_key_values']
274 # key is a reserved word in MySQL redefine this to attribute_key
275 for i, statement in enumerate(statements):
276 for tab in tab_with_key_field:
277 if tab in statement:
278 statements[i] = statement.replace(
279 'key TEXT', 'attribute_key TEXT')
280 return statements