Coverage for /builds/kinetik161/ase/ase/db/mysql.py: 98.51%

134 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import json 

2import sys 

3from copy import deepcopy 

4 

5import numpy as np 

6from pymysql import connect 

7from pymysql.err import ProgrammingError 

8 

9import ase.io.jsonio 

10from ase.db.postgresql import insert_nan_and_inf, remove_nan_and_inf 

11from ase.db.sqlite import VERSION, SQLite3Database, init_statements 

12 

13 

14class Connection: 

15 """ 

16 Wrapper for the MySQL connection 

17 

18 Arguments 

19 ========= 

20 host: str 

21 Hostname. For a local database this is localhost. 

22 user: str 

23 Username. 

24 passwd: str 

25 Password 

26 db_name: str 

27 Name of the database 

28 port: int 

29 Port 

30 binary_prefix: bool 

31 MySQL checks if an argument can be interpreted as a UTF-8 string. This 

32 check fails for binary values. Binary values need to have _binary 

33 prefix in MySQL. By setting this to True, the prefix is automatically 

34 added for binary values. 

35 """ 

36 

37 def __init__(self, host=None, user=None, passwd=None, port=3306, 

38 db_name=None, binary_prefix=False): 

39 self.con = connect(host=host, user=user, passwd=passwd, db=db_name, 

40 binary_prefix=binary_prefix, port=port) 

41 

42 def cursor(self): 

43 return MySQLCursor(self.con.cursor()) 

44 

45 def commit(self): 

46 self.con.commit() 

47 

48 def close(self): 

49 self.con.close() 

50 

51 def rollback(self): 

52 self.con.rollback() 

53 

54 

55class MySQLCursor: 

56 """ 

57 Wrapper for the MySQL cursor. The most important task performed by this 

58 class is to translate SQLite queries to MySQL. Translation is needed 

59 because ASE DB uses some field names that are reserved words in MySQL. 

60 Thus, these has to mapped onto other field names. 

61 """ 

62 sql_replace = [ 

63 (' key TEXT', ' attribute_key TEXT'), 

64 ('(key TEXT', '(attribute_key TEXT'), 

65 ('SELECT key FROM', 'SELECT attribute_key FROM'), 

66 ('SELECT DISTINCT key FROM keys', 

67 'SELECT DISTINCT attribute_key FROM attribute_keys'), 

68 ('?', '%s'), 

69 (' keys ', ' attribute_keys '), 

70 (' key=', ' attribute_key='), 

71 ('table.key', 'table.attribute_key'), 

72 (' IF NOT EXISTS', '')] 

73 

74 def __init__(self, cur): 

75 self.cur = cur 

76 

77 def execute(self, sql, params=None): 

78 

79 # Replace external table key -> attribute_key 

80 for substibution in self.sql_replace: 

81 sql = sql.replace(substibution[0], substibution[1]) 

82 

83 if params is None: 

84 params = () 

85 

86 self.cur.execute(sql, params) 

87 

88 def fetchone(self): 

89 return self.cur.fetchone() 

90 

91 def fetchall(self): 

92 return self.cur.fetchall() 

93 

94 def _replace_nan_inf_kvp(self, values): 

95 for item in values: 

96 if not np.isfinite(item[1]): 

97 item[1] = sys.float_info.max / 2 

98 return values 

99 

100 def executemany(self, sql, values): 

101 if 'number_key_values' in sql: 

102 values = self._replace_nan_inf_kvp(values) 

103 

104 for substibution in self.sql_replace: 

105 sql = sql.replace(substibution[0], substibution[1]) 

106 self.cur.executemany(sql, values) 

107 

108 

109class MySQLDatabase(SQLite3Database): 

110 """ 

111 ASE interface to a MySQL database (via pymysql package). 

112 

113 Arguments 

114 ========== 

115 url: str 

116 URL to the database. It should have the form 

117 mysql://username:password@host:port/database_name. 

118 Example URL with the following credentials 

119 username: john 

120 password: johnspasswd 

121 host: localhost (i.e. server is running locally) 

122 database: johns_calculations 

123 port: 3306 

124 mysql://john:johnspasswd@localhost:3306/johns_calculations 

125 create_indices: bool 

126 Carried over from parent class. Currently indices are not 

127 created for MySQL, as TEXT fields cannot be hashed by MySQL. 

128 use_lock_file: bool 

129 See SQLite 

130 serial: bool 

131 See SQLite 

132 """ 

133 type = 'mysql' 

134 default = 'DEFAULT' 

135 

136 def __init__(self, url=None, create_indices=True, 

137 use_lock_file=False, serial=False): 

138 super().__init__( 

139 url, create_indices, use_lock_file, serial) 

140 

141 self.host = None 

142 self.username = None 

143 self.passwd = None 

144 self.db_name = None 

145 self.port = 3306 

146 self._parse_url(url) 

147 

148 def _parse_url(self, url): 

149 """ 

150 Parse the URL 

151 """ 

152 url = url.replace('mysql://', '') 

153 url = url.replace('mariadb://', '') 

154 

155 splitted = url.split(':', 1) 

156 self.username = splitted[0] 

157 

158 splitted = splitted[1].split('@') 

159 self.passwd = splitted[0] 

160 

161 splitted = splitted[1].split('/') 

162 host_and_port = splitted[0].split(':') 

163 self.host = host_and_port[0] 

164 self.port = int(host_and_port[1]) 

165 self.db_name = splitted[1] 

166 

167 def _connect(self): 

168 return Connection(host=self.host, user=self.username, 

169 passwd=self.passwd, db_name=self.db_name, 

170 port=self.port, binary_prefix=True) 

171 

172 def _initialize(self, con): 

173 if self.initialized: 

174 return 

175 

176 cur = con.cursor() 

177 

178 information_exists = True 

179 self._metadata = {} 

180 try: 

181 cur.execute("SELECT 1 FROM information") 

182 except ProgrammingError: 

183 information_exists = False 

184 

185 if not information_exists: 

186 # We need to initialize the DB 

187 # MySQL require that id is explicitly set as primary key 

188 # in the systems table 

189 init_statements_cpy = deepcopy(init_statements) 

190 init_statements_cpy[0] = init_statements_cpy[0][:-1] + \ 

191 ', PRIMARY KEY(id))' 

192 

193 statements = schema_update(init_statements_cpy) 

194 for statement in statements: 

195 cur.execute(statement) 

196 con.commit() 

197 self.version = VERSION 

198 else: 

199 cur.execute('select * from information') 

200 

201 for name, value in cur.fetchall(): 

202 if name == 'version': 

203 self.version = int(value) 

204 elif name == 'metadata': 

205 self._metadata = json.loads(value) 

206 

207 self.initialized = True 

208 

209 def blob(self, array): 

210 if array is None: 

211 return None 

212 return super().blob(array).tobytes() 

213 

214 def get_offset_string(self, offset, limit=None): 

215 sql = '' 

216 if not limit: 

217 # mysql does not allow for setting limit to -1 so 

218 # instead we set a large number 

219 sql += '\nLIMIT 10000000000' 

220 sql += f'\nOFFSET {offset}' 

221 return sql 

222 

223 def get_last_id(self, cur): 

224 cur.execute('select max(id) as ID from systems') 

225 last_id = cur.fetchone()[0] 

226 return last_id 

227 

228 def create_select_statement(self, keys, cmps, 

229 sort=None, order=None, sort_table=None, 

230 what='systems.*'): 

231 sql, value = super().create_select_statement( 

232 keys, cmps, sort, order, sort_table, what) 

233 

234 for subst in MySQLCursor.sql_replace: 

235 sql = sql.replace(subst[0], subst[1]) 

236 return sql, value 

237 

238 def encode(self, obj, binary=False): 

239 return ase.io.jsonio.encode(remove_nan_and_inf(obj)) 

240 

241 def decode(self, obj, lazy=False): 

242 return insert_nan_and_inf(ase.io.jsonio.decode(obj)) 

243 

244 

245def schema_update(statements): 

246 for i, statement in enumerate(statements): 

247 for a, b in [('REAL', 'DOUBLE'), 

248 ('INTEGER PRIMARY KEY AUTOINCREMENT', 

249 'INT NOT NULL AUTO_INCREMENT')]: 

250 statements[i] = statement.replace(a, b) 

251 

252 # MySQL does not support UNIQUE constraint on TEXT 

253 # need to use VARCHAR. The unique_id is generated with 

254 # randint(16**31, 16**32-1) so it will contain 32 

255 # hex-characters 

256 statements[0] = statements[0].replace('TEXT UNIQUE', 'VARCHAR(32) UNIQUE') 

257 

258 # keys is a reserved word in MySQL redefine this table name to 

259 # attribute_keys 

260 statements[2] = statements[2].replace('keys', 'attribute_keys') 

261 

262 txt2jsonb = ['calculator_parameters', 'key_value_pairs'] 

263 

264 for column in txt2jsonb: 

265 statements[0] = statements[0].replace( 

266 f'{column} TEXT,', 

267 f'{column} JSON,') 

268 

269 statements[0] = statements[0].replace('data BLOB,', 'data JSON,') 

270 

271 tab_with_key_field = ['attribute_keys', 'number_key_values', 

272 'text_key_values'] 

273 

274 # key is a reserved word in MySQL redefine this to attribute_key 

275 for i, statement in enumerate(statements): 

276 for tab in tab_with_key_field: 

277 if tab in statement: 

278 statements[i] = statement.replace( 

279 'key TEXT', 'attribute_key TEXT') 

280 return statements