Coverage for /builds/kinetik161/ase/ase/ga/element_crossovers.py: 77.46%

71 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1"""Crossover classes, that cross the elements in the supplied 

2atoms objects. 

3 

4""" 

5import numpy as np 

6 

7from ase.ga.offspring_creator import OffspringCreator 

8 

9 

10class ElementCrossover(OffspringCreator): 

11 """Base class for all operators where the elements of 

12 the atoms objects cross. 

13 

14 """ 

15 

16 def __init__(self, element_pool, max_diff_elements, 

17 min_percentage_elements, verbose, rng=np.random): 

18 OffspringCreator.__init__(self, verbose, rng=rng) 

19 

20 if not isinstance(element_pool[0], (list, np.ndarray)): 

21 self.element_pools = [element_pool] 

22 else: 

23 self.element_pools = element_pool 

24 

25 if max_diff_elements is None: 

26 self.max_diff_elements = [None for _ in self.element_pools] 

27 elif isinstance(max_diff_elements, int): 

28 self.max_diff_elements = [max_diff_elements] 

29 else: 

30 self.max_diff_elements = max_diff_elements 

31 assert len(self.max_diff_elements) == len(self.element_pools) 

32 

33 if min_percentage_elements is None: 

34 self.min_percentage_elements = [0 for _ in self.element_pools] 

35 elif isinstance(min_percentage_elements, (int, float)): 

36 self.min_percentage_elements = [min_percentage_elements] 

37 else: 

38 self.min_percentage_elements = min_percentage_elements 

39 assert len(self.min_percentage_elements) == len(self.element_pools) 

40 

41 self.min_inputs = 2 

42 

43 def get_new_individual(self, parents): 

44 raise NotImplementedError 

45 

46 

47class OnePointElementCrossover(ElementCrossover): 

48 """Crossover of the elements in the atoms objects. Point of cross 

49 is chosen randomly. 

50 

51 Parameters: 

52 

53 element_pool: List of elements in the phase space. The elements can be 

54 grouped if the individual consist of different types of elements. 

55 The list should then be a list of lists e.g. [[list1], [list2]] 

56 

57 max_diff_elements: The maximum number of different elements in the 

58 individual. Default is infinite. If the elements are grouped 

59 max_diff_elements should be supplied as a list with each input 

60 corresponding to the elements specified in the same input in 

61 element_pool. 

62 

63 min_percentage_elements: The minimum percentage of any element in 

64 the individual. Default is any number is allowed. If the elements 

65 are grouped min_percentage_elements should be supplied as a list 

66 with each input corresponding to the elements specified in the 

67 same input in element_pool. 

68 

69 Example: element_pool=[[A,B,C,D],[x,y,z]], max_diff_elements=[3,2], 

70 min_percentage_elements=[.25, .5] 

71 An individual could be "D,B,B,C,x,x,x,x,z,z,z,z" 

72 

73 rng: Random number generator 

74 By default numpy.random. 

75 """ 

76 

77 def __init__(self, element_pool, max_diff_elements=None, 

78 min_percentage_elements=None, verbose=False, rng=np.random): 

79 ElementCrossover.__init__(self, element_pool, max_diff_elements, 

80 min_percentage_elements, verbose, rng=rng) 

81 self.descriptor = 'OnePointElementCrossover' 

82 

83 def get_new_individual(self, parents): 

84 f, m = parents 

85 

86 indi = self.initialize_individual(f) 

87 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

88 

89 cut_choices = [i for i in range(1, len(f) - 1)] 

90 self.rng.shuffle(cut_choices) 

91 for cut in cut_choices: 

92 fsyms = f.get_chemical_symbols() 

93 msyms = m.get_chemical_symbols() 

94 syms = fsyms[:cut] + msyms[cut:] 

95 ok = True 

96 for i, e in enumerate(self.element_pools): 

97 elems = e[:] 

98 elems_in, indices_in = zip(*[(a.symbol, a.index) for a in f 

99 if a.symbol in elems]) 

100 max_diff_elem = self.max_diff_elements[i] 

101 min_percent_elem = self.min_percentage_elements[i] 

102 if min_percent_elem == 0: 

103 min_percent_elem = 1. / len(elems_in) 

104 if max_diff_elem is None: 

105 max_diff_elem = len(elems_in) 

106 

107 syms_in = [syms[i] for i in indices_in] 

108 for s in set(syms_in): 

109 percentage = syms_in.count(s) / float(len(syms_in)) 

110 if percentage < min_percent_elem: 

111 ok = False 

112 break 

113 num_diff = len(set(syms_in)) 

114 if num_diff > max_diff_elem: 

115 ok = False 

116 break 

117 if not ok: 

118 break 

119 if ok: 

120 break 

121 

122 # Sufficient or does some individuals appear 

123 # below min_percentage_elements 

124 

125 for a in f[:cut] + m[cut:]: 

126 indi.append(a) 

127 

128 parent_message = ':Parents {} {}'.format(f.info['confid'], 

129 m.info['confid']) 

130 return (self.finalize_individual(indi), 

131 self.descriptor + parent_message) 

132 

133 

134class TwoPointElementCrossover(ElementCrossover): 

135 """Crosses two individuals by choosing two cross points 

136 at random""" 

137 

138 def __init__(self, element_pool, max_diff_elements=None, 

139 min_percentage_elements=None, verbose=False): 

140 ElementCrossover.__init__(self, element_pool, 

141 max_diff_elements, 

142 min_percentage_elements, verbose) 

143 self.descriptor = 'TwoPointElementCrossover' 

144 

145 def get_new_individual(self, parents): 

146 raise NotImplementedError