Coverage for /builds/kinetik161/ase/ase/io/vtkxml.py: 5.33%

75 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import numpy as np 

2 

3fast = False 

4 

5 

6def write_vti(filename, atoms, data=None): 

7 from vtk import vtkDoubleArray, vtkStructuredPoints, vtkXMLImageDataWriter 

8 

9 # if isinstance(fileobj, str): 

10 # fileobj = paropen(fileobj, 'w') 

11 

12 if isinstance(atoms, list): 

13 if len(atoms) > 1: 

14 raise ValueError('Can only write one configuration to a VTI file!') 

15 atoms = atoms[0] 

16 

17 if data is None: 

18 raise ValueError('VTK XML Image Data (VTI) format requires data!') 

19 

20 data = np.asarray(data) 

21 

22 if data.dtype == complex: 

23 data = np.abs(data) 

24 

25 cell = atoms.get_cell() 

26 

27 if not np.all(cell == np.diag(np.diag(cell))): 

28 raise ValueError('Unit cell must be orthogonal') 

29 

30 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel() 

31 

32 # Create a VTK grid of structured points 

33 spts = vtkStructuredPoints() 

34 spts.SetWholeBoundingBox(bbox) 

35 spts.SetDimensions(data.shape) 

36 spts.SetSpacing(cell.diagonal() / data.shape) 

37 # spts.SetSpacing(paw.gd.h_c * Bohr) 

38 

39 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr) 

40 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape) 

41 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape) 

42 

43 # s = paw.wfs.kpt_u[0].psit_nG[0].copy() 

44 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False) 

45 # spts.point_data.scalars = data.swapaxes(0,2).flatten() 

46 # spts.point_data.scalars.name = 'scalars' 

47 

48 # Allocate a VTK array of type double and copy data 

49 da = vtkDoubleArray() 

50 da.SetName('scalars') 

51 da.SetNumberOfComponents(1) 

52 da.SetNumberOfTuples(np.prod(data.shape)) 

53 

54 for i, d in enumerate(data.swapaxes(0, 2).flatten()): 

55 da.SetTuple1(i, d) 

56 

57 # Assign the VTK array as point data of the grid 

58 spd = spts.GetPointData() # type(spd) is vtkPointData 

59 spd.SetScalars(da) 

60 

61 """ 

62 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray 

63 iia = vtkImageImportFromArray() 

64 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten())) 

65 iia.SetArray(Numeric_asarray(data)) 

66 ida = iia.GetOutput() 

67 ipd = ida.GetPointData() 

68 ipd.SetName('scalars') 

69 spd.SetScalars(ipd.GetScalars()) 

70 """ 

71 

72 # Save the ImageData dataset to a VTK XML file. 

73 w = vtkXMLImageDataWriter() 

74 

75 if fast: 

76 w.SetDataModeToAppend() 

77 w.EncodeAppendedDataOff() 

78 else: 

79 w.SetDataModeToAscii() 

80 

81 w.SetFileName(filename) 

82 w.SetInput(spts) 

83 w.Write() 

84 

85 

86def write_vtu(filename, atoms, data=None): 

87 from vtk import (VTK_MAJOR_VERSION, vtkPoints, vtkUnstructuredGrid, 

88 vtkXMLUnstructuredGridWriter) 

89 from vtk.util.numpy_support import numpy_to_vtk 

90 

91 if isinstance(atoms, list): 

92 if len(atoms) > 1: 

93 raise ValueError('Can only write one configuration to a VTI file!') 

94 atoms = atoms[0] 

95 

96 # Create a VTK grid of structured points 

97 ugd = vtkUnstructuredGrid() 

98 

99 # add atoms as vtk Points 

100 p = vtkPoints() 

101 p.SetNumberOfPoints(len(atoms)) 

102 p.SetDataTypeToDouble() 

103 for i, pos in enumerate(atoms.get_positions()): 

104 p.InsertPoint(i, *pos) 

105 ugd.SetPoints(p) 

106 

107 # add atomic numbers 

108 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1) 

109 ugd.GetPointData().AddArray(numbers) 

110 numbers.SetName("atomic numbers") 

111 

112 # add tags 

113 tags = numpy_to_vtk(atoms.get_tags(), deep=1) 

114 ugd.GetPointData().AddArray(tags) 

115 tags.SetName("tags") 

116 

117 # add covalent radii 

118 from ase.data import covalent_radii 

119 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1) 

120 ugd.GetPointData().AddArray(radii) 

121 radii.SetName("radii") 

122 

123 # Save the UnstructuredGrid dataset to a VTK XML file. 

124 w = vtkXMLUnstructuredGridWriter() 

125 

126 if fast: 

127 w.SetDataModeToAppend() 

128 w.EncodeAppendedDataOff() 

129 else: 

130 w.GetCompressor().SetCompressionLevel(0) 

131 w.SetDataModeToAscii() 

132 

133 if isinstance(filename, str): 

134 w.SetFileName(filename) 

135 else: 

136 w.SetFileName(filename.name) 

137 if VTK_MAJOR_VERSION <= 5: 

138 w.SetInput(ugd) 

139 else: 

140 w.SetInputData(ugd) 

141 w.Write()