diff --git a/align_system/algorithms/alignment_adm_component.py b/align_system/algorithms/alignment_adm_component.py index cc2382a3..dc5941e7 100644 --- a/align_system/algorithms/alignment_adm_component.py +++ b/align_system/algorithms/alignment_adm_component.py @@ -1,4 +1,5 @@ import math +import numpy as np from align_system.utils import call_with_coerced_args, logging from align_system.algorithms.abstracts import ADMComponent @@ -316,12 +317,15 @@ def _midpoint_eqn(self, kdma, opt_a_value, medical_delta, attr_delta): class RandomEffectsModelAlignmentADMComponent(ADMComponent): def __init__( self, - attributes=None + attributes=None, + probabilistic: bool=False, ): if attributes is None: attributes = {} self.attributes = attributes + self.probabilistic = probabilistic + def run_returns(self): return ('chosen_choice', 'best_sample_idx', 'alignment_info') @@ -451,7 +455,14 @@ def run( "p_choose_a": p_choose_a, } - if p_choose_a >= 0.5: - return (opt_a["choice"], best_sample_idx, alignment_info) + if not self.probabilistic: + if p_choose_a >= 0.5: + return (opt_a["choice"], best_sample_idx, alignment_info) + else: + return (opt_b["choice"], best_sample_idx, alignment_info) else: - return (opt_b["choice"], best_sample_idx, alignment_info) + choices = [opt_a["choice"], opt_b["choice"]] + probs = [p_choose_a, 1-p_choose_a] + + choice = np.random.choice(choices, p=probs) + return (choice, best_sample_idx, alignment_info) diff --git a/align_system/configs/adm_component/alignment/probabilistic_random_effects_tuple.yaml b/align_system/configs/adm_component/alignment/probabilistic_random_effects_tuple.yaml new file mode 100644 index 00000000..208653ac --- /dev/null +++ b/align_system/configs/adm_component/alignment/probabilistic_random_effects_tuple.yaml @@ -0,0 +1,4 @@ +_target_: align_system.algorithms.alignment_adm_component.RandomEffectsModelAlignmentADMComponent + +attributes: ${ref:adm.attribute_definitions} +probabilistic: true