Skip to content

Commit 11da6ad

Browse files
committed
Implement Tree proposal with Dice
1 parent 279c4ce commit 11da6ad

3 files changed

Lines changed: 47 additions & 42 deletions

File tree

discretesampling/base/random.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from random import randint
22
import random
3+
from typing import Callable
4+
import numpy as np
35

46

57
class Random(object):
@@ -35,6 +37,32 @@ def eval(self):
3537
return random.choices(population=self.population, weights=self.weights, cum_weights=self.cum_weights, k=self.k)
3638

3739

40+
class Dice(object):
41+
def __init__(self, probabilities, outcomes):
42+
assert len(outcomes) == len(probabilities), "Invalid PMF specified, x and p" +\
43+
" of different lengths"
44+
probabilities = np.array(probabilities)
45+
tolerance = np.sqrt(np.finfo(np.float64).eps)
46+
assert abs(1 - sum(probabilities)) < tolerance, "Invalid PMF specified," +\
47+
" sum of probabilities !~= 1.0"
48+
assert all(probabilities >= 0.), "Invalid PMF specified, all probabilities" +\
49+
" must be > 0"
50+
self.probabilities = probabilities
51+
self.outcomes = outcomes
52+
self.pmf = probabilities
53+
self.cmf = np.cumsum(probabilities)
54+
self.randomiser = Random()
55+
56+
def eval(self):
57+
q = self.randomiser.eval()
58+
x = self.outcomes[np.argmax(self.cmf >= q)]
59+
60+
while callable(x):
61+
x = x()
62+
return x
63+
64+
65+
3866
def set_seed(seed):
3967
"""
4068
:param seed: random seed

discretesampling/base/types.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,9 @@ def getOptimalLKernelType(self):
2828

2929

3030
class DiscreteVariableProposal:
31-
def __init__(self, values, probs):
32-
# Check dims and probs are valid
33-
assert len(values) == len(probs), "Invalid PMF specified, x and p" +\
34-
" of different lengths"
35-
probs = np.array(probs)
36-
tolerance = np.sqrt(np.finfo(np.float64).eps)
37-
assert abs(1 - sum(probs)) < tolerance, "Invalid PMF specified," +\
38-
" sum of probabilities !~= 1.0"
39-
assert all(probs > 0), "Invalid PMF specified, all probabilities" +\
40-
" must be > 0"
41-
self.x = values
42-
self.pmf = probs
43-
self.cmf = np.cumsum(probs)
31+
def __init__(self, moves_dice):
32+
self.moves_dice = moves_dice
33+
4434

4535
@classmethod
4636
def norm(self, x):
@@ -54,8 +44,7 @@ def heuristic(self, x, y):
5444
return True
5545

5646
def sample(self):
57-
q = random.random() # random unif(0,1)
58-
return self.x[np.argmax(self.cmf >= q)]
47+
return self.moves_dice.eval()
5948

6049
def eval(self, y):
6150
try:

discretesampling/domain/decision_tree/tree_distribution.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import numpy as np
22
from math import log, inf
33
import copy
4-
from ...base.random import Random
4+
from ...base.random import Dice, Random
55
from ...base import types
66

77

88
class TreeProposal(types.DiscreteVariableProposal):
9+
moves_prob = [0.4, 0.1, 0.1, 0.4]
10+
moves = ["prune", "swap", "change", "grow"] # noqa
911
def __init__(self, tree):
1012
self.X_train = tree.X_train
1113
self.y_train = tree.y_train
1214
self.tree = copy.deepcopy(tree)
15+
1316

1417
@classmethod
1518
def norm(self, tree):
@@ -21,40 +24,25 @@ def norm(self, tree):
2124
def heuristic(self, x, y):
2225
return y < x or abs(x-y) < 2
2326

24-
def sample(self):
25-
# initialise the probabilities of each move
26-
moves = ["prune", "swap", "change", "grow"] # noqa
27-
moves_prob = [0.4, 0.1, 0.1, 0.4]
27+
def get_moves_prob(self):
2828
if len(self.tree.tree) == 1:
2929
moves_prob = [0.0, 0.0, 0.5, 0.5]
30-
moves_probabilities = np.cumsum(moves_prob)
31-
random_number = Random().eval()
32-
newTree = copy.deepcopy(self.tree)
33-
if random_number < moves_probabilities[0]:
34-
# prune
35-
newTree = newTree.prune()
36-
37-
elif random_number < moves_probabilities[1]:
38-
# swap
39-
newTree = newTree.swap()
40-
41-
elif random_number < moves_probabilities[2]:
42-
# change
43-
newTree = newTree.change()
44-
4530
else:
46-
# grow
47-
newTree = newTree.grow()
31+
moves_prob = self.moves_prob
32+
return moves_prob
4833

49-
return newTree
34+
def sample(self):
35+
# initialise the probabilities of each move
36+
moves_prob = self.get_moves_prob()
37+
newTree = copy.deepcopy(self.tree)
38+
moves_dice = Dice(moves_prob, [newTree.prune, newTree.swap, newTree.change, newTree.grow])
39+
40+
return moves_dice.eval()
5041

5142
def eval(self, sampledTree):
5243
initialTree = self.tree
53-
moves_prob = [0.4, 0.1, 0.1, 0.4]
5444
logprobability = -inf
55-
if len(initialTree.tree) == 1:
56-
moves_prob = [0.0, 0.0, 0.5, 0.5]
57-
45+
moves_prob = self.get_moves_prob()
5846
nodes_differences = [i for i in sampledTree.tree + initialTree.tree
5947
if i not in sampledTree.tree or
6048
i not in initialTree.tree]

0 commit comments

Comments
 (0)