Coverage for /builds/kinetik161/ase/ase/db/table.py: 90.14%

142 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1from typing import List, Optional 

2 

3import numpy as np 

4 

5from ase.db.core import float_to_time_string, now 

6 

7all_columns = ('id', 'age', 'user', 'formula', 'calculator', 

8 'energy', 'natoms', 'fmax', 'pbc', 'volume', 

9 'charge', 'mass', 'smax', 'magmom') 

10 

11 

12def get_sql_columns(columns): 

13 """ Map the names of table columns to names of columns in 

14 the SQL tables""" 

15 sql_columns = list(columns) 

16 if 'age' in columns: 

17 sql_columns.remove('age') 

18 sql_columns += ['mtime', 'ctime'] 

19 if 'user' in columns: 

20 sql_columns[sql_columns.index('user')] = 'username' 

21 if 'formula' in columns: 

22 sql_columns[sql_columns.index('formula')] = 'numbers' 

23 if 'fmax' in columns: 

24 sql_columns[sql_columns.index('fmax')] = 'forces' 

25 if 'smax' in columns: 

26 sql_columns[sql_columns.index('smax')] = 'stress' 

27 if 'volume' in columns: 

28 sql_columns[sql_columns.index('volume')] = 'cell' 

29 if 'mass' in columns: 

30 sql_columns[sql_columns.index('mass')] = 'masses' 

31 if 'charge' in columns: 

32 sql_columns[sql_columns.index('charge')] = 'charges' 

33 

34 sql_columns.append('key_value_pairs') 

35 sql_columns.append('constraints') 

36 if 'id' not in sql_columns: 

37 sql_columns.append('id') 

38 

39 return sql_columns 

40 

41 

42def plural(n, word): 

43 if n == 1: 

44 return '1 ' + word 

45 return '%d %ss' % (n, word) 

46 

47 

48def cut(txt, length): 

49 if len(txt) <= length or length == 0: 

50 return txt 

51 return txt[:length - 3] + '...' 

52 

53 

54def cutlist(lst, length): 

55 if len(lst) <= length or length == 0: 

56 return lst 

57 return lst[:9] + [f'... ({len(lst) - 9} more)'] 

58 

59 

60class Table: 

61 def __init__(self, connection, unique_key='id', verbosity=1, cut=35): 

62 self.connection = connection 

63 self.verbosity = verbosity 

64 self.cut = cut 

65 self.rows = [] 

66 self.columns = None 

67 self.id = None 

68 self.right = None 

69 self.keys = None 

70 self.unique_key = unique_key 

71 self.addcolumns: Optional[List[str]] = None 

72 

73 def select(self, query, columns, sort, limit, offset, 

74 show_empty_columns=False): 

75 """Query datatbase and create rows.""" 

76 sql_columns = get_sql_columns(columns) 

77 self.limit = limit 

78 self.offset = offset 

79 self.rows = [Row(row, columns, self.unique_key) 

80 for row in self.connection.select( 

81 query, verbosity=self.verbosity, 

82 limit=limit, offset=offset, sort=sort, 

83 include_data=False, columns=sql_columns)] 

84 

85 self.columns = list(columns) 

86 

87 if not show_empty_columns: 

88 delete = set(range(len(columns))) 

89 for row in self.rows: 

90 for n in delete.copy(): 

91 if row.values[n] is not None: 

92 delete.remove(n) 

93 delete = sorted(delete, reverse=True) 

94 for row in self.rows: 

95 for n in delete: 

96 del row.values[n] 

97 

98 for n in delete: 

99 del self.columns[n] 

100 

101 def format(self, subscript=None): 

102 right = set() # right-adjust numbers 

103 allkeys = set() 

104 for row in self.rows: 

105 numbers = row.format(self.columns, subscript) 

106 right.update(numbers) 

107 allkeys.update(row.dct.get('key_value_pairs', {})) 

108 

109 right.add('age') 

110 self.right = [column in right for column in self.columns] 

111 

112 self.keys = sorted(allkeys) 

113 

114 def write(self, query=None): 

115 self.format() 

116 L = [[len(s) for s in row.strings] 

117 for row in self.rows] 

118 L.append([len(c) for c in self.columns]) 

119 N = np.max(L, axis=0) 

120 

121 fmt = '{:{align}{width}}' 

122 if self.verbosity > 0: 

123 print('|'.join(fmt.format(c, align='<>'[a], width=w) 

124 for c, a, w in zip(self.columns, self.right, N))) 

125 for row in self.rows: 

126 print('|'.join(fmt.format(c, align='<>'[a], width=w) 

127 for c, a, w in 

128 zip(row.strings, self.right, N))) 

129 

130 if self.verbosity == 0: 

131 return 

132 

133 nrows = len(self.rows) 

134 

135 if self.limit and nrows == self.limit: 

136 n = self.connection.count(query) 

137 print('Rows:', n, f'(showing first {self.limit})') 

138 else: 

139 print('Rows:', nrows) 

140 

141 if self.keys: 

142 print('Keys:', ', '.join(cutlist(self.keys, self.cut))) 

143 

144 def write_csv(self): 

145 if self.verbosity > 0: 

146 print(', '.join(self.columns)) 

147 for row in self.rows: 

148 print(', '.join(str(val) for val in row.values)) 

149 

150 

151class Row: 

152 def __init__(self, dct, columns, unique_key='id'): 

153 self.dct = dct 

154 self.values = None 

155 self.strings = None 

156 self.set_columns(columns) 

157 self.uid = getattr(dct, unique_key) 

158 

159 def set_columns(self, columns): 

160 self.values = [] 

161 for c in columns: 

162 if c == 'age': 

163 value = float_to_time_string(now() - self.dct.ctime) 

164 elif c == 'pbc': 

165 value = ''.join('FT'[int(p)] for p in self.dct.pbc) 

166 else: 

167 value = getattr(self.dct, c, None) 

168 self.values.append(value) 

169 

170 def format(self, columns, subscript=None): 

171 self.strings = [] 

172 numbers = set() 

173 for value, column in zip(self.values, columns): 

174 if column == 'formula' and subscript: 

175 value = subscript.sub(r'<sub>\1</sub>', value) 

176 elif isinstance(value, dict): 

177 value = str(value) 

178 elif isinstance(value, list): 

179 value = str(value) 

180 elif isinstance(value, np.ndarray): 

181 value = str(value.tolist()) 

182 elif isinstance(value, int): 

183 value = str(value) 

184 numbers.add(column) 

185 elif isinstance(value, float): 

186 numbers.add(column) 

187 value = f'{value:.3f}' 

188 elif value is None: 

189 value = '' 

190 self.strings.append(value) 

191 

192 return numbers