Coverage for /builds/kinetik161/ase/ase/io/jsonio.py: 89.09%

110 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import datetime 

2import json 

3 

4import numpy as np 

5 

6from ase.utils import reader, writer 

7 

8# Note: We are converting JSON classes to the recommended mechanisms 

9# by the json module. That means instead of classes, we will use the 

10# functions default() and object_hook(). 

11# 

12# The encoder classes are to be deprecated (but maybe not removed, if 

13# widely used). 

14 

15 

16def default(obj): 

17 if hasattr(obj, 'todict'): 

18 dct = obj.todict() 

19 

20 if not isinstance(dct, dict): 

21 raise RuntimeError('todict() of {} returned object of type {} ' 

22 'but should have returned dict' 

23 .format(obj, type(dct))) 

24 if hasattr(obj, 'ase_objtype'): 

25 # We modify the dictionary, so it is wise to take a copy. 

26 dct = dct.copy() 

27 dct['__ase_objtype__'] = obj.ase_objtype 

28 

29 return dct 

30 if isinstance(obj, np.ndarray): 

31 flatobj = obj.ravel() 

32 if np.iscomplexobj(obj): 

33 flatobj.dtype = obj.real.dtype 

34 # We use str(obj.dtype) here instead of obj.dtype.name, because 

35 # they are not always the same (e.g. for numpy arrays of strings). 

36 # Using obj.dtype.name can break the ability to recursively decode/ 

37 # encode such arrays. 

38 return {'__ndarray__': (obj.shape, 

39 str(obj.dtype), 

40 flatobj.tolist())} 

41 if isinstance(obj, np.integer): 

42 return int(obj) 

43 if isinstance(obj, np.bool_): 

44 return bool(obj) 

45 if isinstance(obj, datetime.datetime): 

46 return {'__datetime__': obj.isoformat()} 

47 if isinstance(obj, complex): 

48 return {'__complex__': (obj.real, obj.imag)} 

49 

50 raise TypeError(f'Cannot convert object of type {type(obj)} to ' 

51 'dictionary for JSON') 

52 

53 

54class MyEncoder(json.JSONEncoder): 

55 def default(self, obj): 

56 # (Note the name "default" comes from the outer namespace, so 

57 # not actually recursive) 

58 return default(obj) 

59 

60 

61encode = MyEncoder().encode 

62 

63 

64def object_hook(dct): 

65 if '__datetime__' in dct: 

66 return datetime.datetime.strptime(dct['__datetime__'], 

67 '%Y-%m-%dT%H:%M:%S.%f') 

68 

69 if '__complex__' in dct: 

70 return complex(*dct['__complex__']) 

71 

72 if '__ndarray__' in dct: 

73 return create_ndarray(*dct['__ndarray__']) 

74 

75 # No longer used (only here for backwards compatibility): 

76 if '__complex_ndarray__' in dct: 

77 r, i = (np.array(x) for x in dct['__complex_ndarray__']) 

78 return r + i * 1j 

79 

80 if '__ase_objtype__' in dct: 

81 objtype = dct.pop('__ase_objtype__') 

82 dct = numpyfy(dct) 

83 return create_ase_object(objtype, dct) 

84 

85 return dct 

86 

87 

88def create_ndarray(shape, dtype, data): 

89 """Create ndarray from shape, dtype and flattened data.""" 

90 array = np.empty(shape, dtype=dtype) 

91 flatbuf = array.ravel() 

92 if np.iscomplexobj(array): 

93 flatbuf.dtype = array.real.dtype 

94 flatbuf[:] = data 

95 return array 

96 

97 

98def create_ase_object(objtype, dct): 

99 # We just try each object type one after another and instantiate 

100 # them manually, depending on which kind it is. 

101 # We can formalize this later if it ever becomes necessary. 

102 if objtype == 'cell': 

103 from ase.cell import Cell 

104 dct.pop('pbc', None) # compatibility; we once had pbc 

105 obj = Cell(**dct) 

106 elif objtype == 'bandstructure': 

107 from ase.spectrum.band_structure import BandStructure 

108 obj = BandStructure(**dct) 

109 elif objtype == 'bandpath': 

110 from ase.dft.kpoints import BandPath 

111 obj = BandPath(path=dct.pop('labelseq'), **dct) 

112 elif objtype == 'atoms': 

113 from ase import Atoms 

114 obj = Atoms.fromdict(dct) 

115 elif objtype == 'vibrationsdata': 

116 from ase.vibrations import VibrationsData 

117 obj = VibrationsData.fromdict(dct) 

118 else: 

119 raise ValueError('Do not know how to decode object type {} ' 

120 'into an actual object'.format(objtype)) 

121 assert obj.ase_objtype == objtype 

122 return obj 

123 

124 

125mydecode = json.JSONDecoder(object_hook=object_hook).decode 

126 

127 

128def intkey(key): 

129 """Convert str to int if possible.""" 

130 try: 

131 return int(key) 

132 except ValueError: 

133 return key 

134 

135 

136def fix_int_keys_in_dicts(obj): 

137 """Convert "int" keys: "1" -> 1. 

138 

139 The json.dump() function will convert int keys in dicts to str keys. 

140 This function goes the other way. 

141 """ 

142 if isinstance(obj, dict): 

143 return {intkey(key): fix_int_keys_in_dicts(value) 

144 for key, value in obj.items()} 

145 return obj 

146 

147 

148def numpyfy(obj): 

149 if isinstance(obj, dict): 

150 if '__complex_ndarray__' in obj: 

151 r, i = (np.array(x) for x in obj['__complex_ndarray__']) 

152 return r + i * 1j 

153 if isinstance(obj, list) and len(obj) > 0: 

154 try: 

155 a = np.array(obj) 

156 except ValueError: 

157 pass 

158 else: 

159 if a.dtype in [bool, int, float]: 

160 return a 

161 obj = [numpyfy(value) for value in obj] 

162 return obj 

163 

164 

165def decode(txt, always_array=True): 

166 obj = mydecode(txt) 

167 obj = fix_int_keys_in_dicts(obj) 

168 if always_array: 

169 obj = numpyfy(obj) 

170 return obj 

171 

172 

173@reader 

174def read_json(fd, always_array=True): 

175 dct = decode(fd.read(), always_array=always_array) 

176 return dct 

177 

178 

179@writer 

180def write_json(fd, obj): 

181 fd.write(encode(obj))