Coverage for /builds/kinetik161/ase/ase/io/jsonio.py: 89.09%
110 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import datetime
2import json
4import numpy as np
6from ase.utils import reader, writer
8# Note: We are converting JSON classes to the recommended mechanisms
9# by the json module. That means instead of classes, we will use the
10# functions default() and object_hook().
11#
12# The encoder classes are to be deprecated (but maybe not removed, if
13# widely used).
16def default(obj):
17 if hasattr(obj, 'todict'):
18 dct = obj.todict()
20 if not isinstance(dct, dict):
21 raise RuntimeError('todict() of {} returned object of type {} '
22 'but should have returned dict'
23 .format(obj, type(dct)))
24 if hasattr(obj, 'ase_objtype'):
25 # We modify the dictionary, so it is wise to take a copy.
26 dct = dct.copy()
27 dct['__ase_objtype__'] = obj.ase_objtype
29 return dct
30 if isinstance(obj, np.ndarray):
31 flatobj = obj.ravel()
32 if np.iscomplexobj(obj):
33 flatobj.dtype = obj.real.dtype
34 # We use str(obj.dtype) here instead of obj.dtype.name, because
35 # they are not always the same (e.g. for numpy arrays of strings).
36 # Using obj.dtype.name can break the ability to recursively decode/
37 # encode such arrays.
38 return {'__ndarray__': (obj.shape,
39 str(obj.dtype),
40 flatobj.tolist())}
41 if isinstance(obj, np.integer):
42 return int(obj)
43 if isinstance(obj, np.bool_):
44 return bool(obj)
45 if isinstance(obj, datetime.datetime):
46 return {'__datetime__': obj.isoformat()}
47 if isinstance(obj, complex):
48 return {'__complex__': (obj.real, obj.imag)}
50 raise TypeError(f'Cannot convert object of type {type(obj)} to '
51 'dictionary for JSON')
54class MyEncoder(json.JSONEncoder):
55 def default(self, obj):
56 # (Note the name "default" comes from the outer namespace, so
57 # not actually recursive)
58 return default(obj)
61encode = MyEncoder().encode
64def object_hook(dct):
65 if '__datetime__' in dct:
66 return datetime.datetime.strptime(dct['__datetime__'],
67 '%Y-%m-%dT%H:%M:%S.%f')
69 if '__complex__' in dct:
70 return complex(*dct['__complex__'])
72 if '__ndarray__' in dct:
73 return create_ndarray(*dct['__ndarray__'])
75 # No longer used (only here for backwards compatibility):
76 if '__complex_ndarray__' in dct:
77 r, i = (np.array(x) for x in dct['__complex_ndarray__'])
78 return r + i * 1j
80 if '__ase_objtype__' in dct:
81 objtype = dct.pop('__ase_objtype__')
82 dct = numpyfy(dct)
83 return create_ase_object(objtype, dct)
85 return dct
88def create_ndarray(shape, dtype, data):
89 """Create ndarray from shape, dtype and flattened data."""
90 array = np.empty(shape, dtype=dtype)
91 flatbuf = array.ravel()
92 if np.iscomplexobj(array):
93 flatbuf.dtype = array.real.dtype
94 flatbuf[:] = data
95 return array
98def create_ase_object(objtype, dct):
99 # We just try each object type one after another and instantiate
100 # them manually, depending on which kind it is.
101 # We can formalize this later if it ever becomes necessary.
102 if objtype == 'cell':
103 from ase.cell import Cell
104 dct.pop('pbc', None) # compatibility; we once had pbc
105 obj = Cell(**dct)
106 elif objtype == 'bandstructure':
107 from ase.spectrum.band_structure import BandStructure
108 obj = BandStructure(**dct)
109 elif objtype == 'bandpath':
110 from ase.dft.kpoints import BandPath
111 obj = BandPath(path=dct.pop('labelseq'), **dct)
112 elif objtype == 'atoms':
113 from ase import Atoms
114 obj = Atoms.fromdict(dct)
115 elif objtype == 'vibrationsdata':
116 from ase.vibrations import VibrationsData
117 obj = VibrationsData.fromdict(dct)
118 else:
119 raise ValueError('Do not know how to decode object type {} '
120 'into an actual object'.format(objtype))
121 assert obj.ase_objtype == objtype
122 return obj
125mydecode = json.JSONDecoder(object_hook=object_hook).decode
128def intkey(key):
129 """Convert str to int if possible."""
130 try:
131 return int(key)
132 except ValueError:
133 return key
136def fix_int_keys_in_dicts(obj):
137 """Convert "int" keys: "1" -> 1.
139 The json.dump() function will convert int keys in dicts to str keys.
140 This function goes the other way.
141 """
142 if isinstance(obj, dict):
143 return {intkey(key): fix_int_keys_in_dicts(value)
144 for key, value in obj.items()}
145 return obj
148def numpyfy(obj):
149 if isinstance(obj, dict):
150 if '__complex_ndarray__' in obj:
151 r, i = (np.array(x) for x in obj['__complex_ndarray__'])
152 return r + i * 1j
153 if isinstance(obj, list) and len(obj) > 0:
154 try:
155 a = np.array(obj)
156 except ValueError:
157 pass
158 else:
159 if a.dtype in [bool, int, float]:
160 return a
161 obj = [numpyfy(value) for value in obj]
162 return obj
165def decode(txt, always_array=True):
166 obj = mydecode(txt)
167 obj = fix_int_keys_in_dicts(obj)
168 if always_array:
169 obj = numpyfy(obj)
170 return obj
173@reader
174def read_json(fd, always_array=True):
175 dct = decode(fd.read(), always_array=always_array)
176 return dct
179@writer
180def write_json(fd, obj):
181 fd.write(encode(obj))