11import numpy as np
22from math import log , inf
33import copy
4- from ...base .random import Random
4+ from ...base .random import Dice , Random
55from ...base import types
66
77
88class TreeProposal (types .DiscreteVariableProposal ):
9+ moves_prob = [0.4 , 0.1 , 0.1 , 0.4 ]
10+ moves = ["prune" , "swap" , "change" , "grow" ] # noqa
911 def __init__ (self , tree ):
1012 self .X_train = tree .X_train
1113 self .y_train = tree .y_train
1214 self .tree = copy .deepcopy (tree )
15+
1316
1417 @classmethod
1518 def norm (self , tree ):
@@ -21,40 +24,25 @@ def norm(self, tree):
2124 def heuristic (self , x , y ):
2225 return y < x or abs (x - y ) < 2
2326
24- def sample (self ):
25- # initialise the probabilities of each move
26- moves = ["prune" , "swap" , "change" , "grow" ] # noqa
27- moves_prob = [0.4 , 0.1 , 0.1 , 0.4 ]
27+ def get_moves_prob (self ):
2828 if len (self .tree .tree ) == 1 :
2929 moves_prob = [0.0 , 0.0 , 0.5 , 0.5 ]
30- moves_probabilities = np .cumsum (moves_prob )
31- random_number = Random ().eval ()
32- newTree = copy .deepcopy (self .tree )
33- if random_number < moves_probabilities [0 ]:
34- # prune
35- newTree = newTree .prune ()
36-
37- elif random_number < moves_probabilities [1 ]:
38- # swap
39- newTree = newTree .swap ()
40-
41- elif random_number < moves_probabilities [2 ]:
42- # change
43- newTree = newTree .change ()
44-
4530 else :
46- # grow
47- newTree = newTree . grow ()
31+ moves_prob = self . moves_prob
32+ return moves_prob
4833
49- return newTree
34+ def sample (self ):
35+ # initialise the probabilities of each move
36+ moves_prob = self .get_moves_prob ()
37+ newTree = copy .deepcopy (self .tree )
38+ moves_dice = Dice (moves_prob , [newTree .prune , newTree .swap , newTree .change , newTree .grow ])
39+
40+ return moves_dice .eval ()
5041
5142 def eval (self , sampledTree ):
5243 initialTree = self .tree
53- moves_prob = [0.4 , 0.1 , 0.1 , 0.4 ]
5444 logprobability = - inf
55- if len (initialTree .tree ) == 1 :
56- moves_prob = [0.0 , 0.0 , 0.5 , 0.5 ]
57-
45+ moves_prob = self .get_moves_prob ()
5846 nodes_differences = [i for i in sampledTree .tree + initialTree .tree
5947 if i not in sampledTree .tree or
6048 i not in initialTree .tree ]
0 commit comments