Coverage for /builds/kinetik161/ase/ase/db/table.py: 90.14%
142 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1from typing import List, Optional
3import numpy as np
5from ase.db.core import float_to_time_string, now
7all_columns = ('id', 'age', 'user', 'formula', 'calculator',
8 'energy', 'natoms', 'fmax', 'pbc', 'volume',
9 'charge', 'mass', 'smax', 'magmom')
12def get_sql_columns(columns):
13 """ Map the names of table columns to names of columns in
14 the SQL tables"""
15 sql_columns = list(columns)
16 if 'age' in columns:
17 sql_columns.remove('age')
18 sql_columns += ['mtime', 'ctime']
19 if 'user' in columns:
20 sql_columns[sql_columns.index('user')] = 'username'
21 if 'formula' in columns:
22 sql_columns[sql_columns.index('formula')] = 'numbers'
23 if 'fmax' in columns:
24 sql_columns[sql_columns.index('fmax')] = 'forces'
25 if 'smax' in columns:
26 sql_columns[sql_columns.index('smax')] = 'stress'
27 if 'volume' in columns:
28 sql_columns[sql_columns.index('volume')] = 'cell'
29 if 'mass' in columns:
30 sql_columns[sql_columns.index('mass')] = 'masses'
31 if 'charge' in columns:
32 sql_columns[sql_columns.index('charge')] = 'charges'
34 sql_columns.append('key_value_pairs')
35 sql_columns.append('constraints')
36 if 'id' not in sql_columns:
37 sql_columns.append('id')
39 return sql_columns
42def plural(n, word):
43 if n == 1:
44 return '1 ' + word
45 return '%d %ss' % (n, word)
48def cut(txt, length):
49 if len(txt) <= length or length == 0:
50 return txt
51 return txt[:length - 3] + '...'
54def cutlist(lst, length):
55 if len(lst) <= length or length == 0:
56 return lst
57 return lst[:9] + [f'... ({len(lst) - 9} more)']
60class Table:
61 def __init__(self, connection, unique_key='id', verbosity=1, cut=35):
62 self.connection = connection
63 self.verbosity = verbosity
64 self.cut = cut
65 self.rows = []
66 self.columns = None
67 self.id = None
68 self.right = None
69 self.keys = None
70 self.unique_key = unique_key
71 self.addcolumns: Optional[List[str]] = None
73 def select(self, query, columns, sort, limit, offset,
74 show_empty_columns=False):
75 """Query datatbase and create rows."""
76 sql_columns = get_sql_columns(columns)
77 self.limit = limit
78 self.offset = offset
79 self.rows = [Row(row, columns, self.unique_key)
80 for row in self.connection.select(
81 query, verbosity=self.verbosity,
82 limit=limit, offset=offset, sort=sort,
83 include_data=False, columns=sql_columns)]
85 self.columns = list(columns)
87 if not show_empty_columns:
88 delete = set(range(len(columns)))
89 for row in self.rows:
90 for n in delete.copy():
91 if row.values[n] is not None:
92 delete.remove(n)
93 delete = sorted(delete, reverse=True)
94 for row in self.rows:
95 for n in delete:
96 del row.values[n]
98 for n in delete:
99 del self.columns[n]
101 def format(self, subscript=None):
102 right = set() # right-adjust numbers
103 allkeys = set()
104 for row in self.rows:
105 numbers = row.format(self.columns, subscript)
106 right.update(numbers)
107 allkeys.update(row.dct.get('key_value_pairs', {}))
109 right.add('age')
110 self.right = [column in right for column in self.columns]
112 self.keys = sorted(allkeys)
114 def write(self, query=None):
115 self.format()
116 L = [[len(s) for s in row.strings]
117 for row in self.rows]
118 L.append([len(c) for c in self.columns])
119 N = np.max(L, axis=0)
121 fmt = '{:{align}{width}}'
122 if self.verbosity > 0:
123 print('|'.join(fmt.format(c, align='<>'[a], width=w)
124 for c, a, w in zip(self.columns, self.right, N)))
125 for row in self.rows:
126 print('|'.join(fmt.format(c, align='<>'[a], width=w)
127 for c, a, w in
128 zip(row.strings, self.right, N)))
130 if self.verbosity == 0:
131 return
133 nrows = len(self.rows)
135 if self.limit and nrows == self.limit:
136 n = self.connection.count(query)
137 print('Rows:', n, f'(showing first {self.limit})')
138 else:
139 print('Rows:', nrows)
141 if self.keys:
142 print('Keys:', ', '.join(cutlist(self.keys, self.cut)))
144 def write_csv(self):
145 if self.verbosity > 0:
146 print(', '.join(self.columns))
147 for row in self.rows:
148 print(', '.join(str(val) for val in row.values))
151class Row:
152 def __init__(self, dct, columns, unique_key='id'):
153 self.dct = dct
154 self.values = None
155 self.strings = None
156 self.set_columns(columns)
157 self.uid = getattr(dct, unique_key)
159 def set_columns(self, columns):
160 self.values = []
161 for c in columns:
162 if c == 'age':
163 value = float_to_time_string(now() - self.dct.ctime)
164 elif c == 'pbc':
165 value = ''.join('FT'[int(p)] for p in self.dct.pbc)
166 else:
167 value = getattr(self.dct, c, None)
168 self.values.append(value)
170 def format(self, columns, subscript=None):
171 self.strings = []
172 numbers = set()
173 for value, column in zip(self.values, columns):
174 if column == 'formula' and subscript:
175 value = subscript.sub(r'<sub>\1</sub>', value)
176 elif isinstance(value, dict):
177 value = str(value)
178 elif isinstance(value, list):
179 value = str(value)
180 elif isinstance(value, np.ndarray):
181 value = str(value.tolist())
182 elif isinstance(value, int):
183 value = str(value)
184 numbers.add(column)
185 elif isinstance(value, float):
186 numbers.add(column)
187 value = f'{value:.3f}'
188 elif value is None:
189 value = ''
190 self.strings.append(value)
192 return numbers