Coverage for /builds/kinetik161/ase/ase/io/vtkxml.py: 5.33%
75 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import numpy as np
3fast = False
6def write_vti(filename, atoms, data=None):
7 from vtk import vtkDoubleArray, vtkStructuredPoints, vtkXMLImageDataWriter
9 # if isinstance(fileobj, str):
10 # fileobj = paropen(fileobj, 'w')
12 if isinstance(atoms, list):
13 if len(atoms) > 1:
14 raise ValueError('Can only write one configuration to a VTI file!')
15 atoms = atoms[0]
17 if data is None:
18 raise ValueError('VTK XML Image Data (VTI) format requires data!')
20 data = np.asarray(data)
22 if data.dtype == complex:
23 data = np.abs(data)
25 cell = atoms.get_cell()
27 if not np.all(cell == np.diag(np.diag(cell))):
28 raise ValueError('Unit cell must be orthogonal')
30 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel()
32 # Create a VTK grid of structured points
33 spts = vtkStructuredPoints()
34 spts.SetWholeBoundingBox(bbox)
35 spts.SetDimensions(data.shape)
36 spts.SetSpacing(cell.diagonal() / data.shape)
37 # spts.SetSpacing(paw.gd.h_c * Bohr)
39 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr)
40 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape)
41 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape)
43 # s = paw.wfs.kpt_u[0].psit_nG[0].copy()
44 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False)
45 # spts.point_data.scalars = data.swapaxes(0,2).flatten()
46 # spts.point_data.scalars.name = 'scalars'
48 # Allocate a VTK array of type double and copy data
49 da = vtkDoubleArray()
50 da.SetName('scalars')
51 da.SetNumberOfComponents(1)
52 da.SetNumberOfTuples(np.prod(data.shape))
54 for i, d in enumerate(data.swapaxes(0, 2).flatten()):
55 da.SetTuple1(i, d)
57 # Assign the VTK array as point data of the grid
58 spd = spts.GetPointData() # type(spd) is vtkPointData
59 spd.SetScalars(da)
61 """
62 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray
63 iia = vtkImageImportFromArray()
64 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten()))
65 iia.SetArray(Numeric_asarray(data))
66 ida = iia.GetOutput()
67 ipd = ida.GetPointData()
68 ipd.SetName('scalars')
69 spd.SetScalars(ipd.GetScalars())
70 """
72 # Save the ImageData dataset to a VTK XML file.
73 w = vtkXMLImageDataWriter()
75 if fast:
76 w.SetDataModeToAppend()
77 w.EncodeAppendedDataOff()
78 else:
79 w.SetDataModeToAscii()
81 w.SetFileName(filename)
82 w.SetInput(spts)
83 w.Write()
86def write_vtu(filename, atoms, data=None):
87 from vtk import (VTK_MAJOR_VERSION, vtkPoints, vtkUnstructuredGrid,
88 vtkXMLUnstructuredGridWriter)
89 from vtk.util.numpy_support import numpy_to_vtk
91 if isinstance(atoms, list):
92 if len(atoms) > 1:
93 raise ValueError('Can only write one configuration to a VTI file!')
94 atoms = atoms[0]
96 # Create a VTK grid of structured points
97 ugd = vtkUnstructuredGrid()
99 # add atoms as vtk Points
100 p = vtkPoints()
101 p.SetNumberOfPoints(len(atoms))
102 p.SetDataTypeToDouble()
103 for i, pos in enumerate(atoms.get_positions()):
104 p.InsertPoint(i, *pos)
105 ugd.SetPoints(p)
107 # add atomic numbers
108 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1)
109 ugd.GetPointData().AddArray(numbers)
110 numbers.SetName("atomic numbers")
112 # add tags
113 tags = numpy_to_vtk(atoms.get_tags(), deep=1)
114 ugd.GetPointData().AddArray(tags)
115 tags.SetName("tags")
117 # add covalent radii
118 from ase.data import covalent_radii
119 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1)
120 ugd.GetPointData().AddArray(radii)
121 radii.SetName("radii")
123 # Save the UnstructuredGrid dataset to a VTK XML file.
124 w = vtkXMLUnstructuredGridWriter()
126 if fast:
127 w.SetDataModeToAppend()
128 w.EncodeAppendedDataOff()
129 else:
130 w.GetCompressor().SetCompressionLevel(0)
131 w.SetDataModeToAscii()
133 if isinstance(filename, str):
134 w.SetFileName(filename)
135 else:
136 w.SetFileName(filename.name)
137 if VTK_MAJOR_VERSION <= 5:
138 w.SetInput(ugd)
139 else:
140 w.SetInputData(ugd)
141 w.Write()