@@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample():
19581958 atol = 0.6 / np .sqrt (10000 ),
19591959 )
19601960 assert np .all (np .abs (vect_obs - x_posterior [..., None ]) < 1 )
1961+
1962+
1963+ def test_vectorize_over_posterior_with_intermediate_rvs ():
1964+ with pm .Model () as model :
1965+ a = pm .Normal ("a" )
1966+ b = pm .Normal .dist (a )
1967+ c = b + 1
1968+ d = pm .Normal .dist (c )
1969+ idata = pm .sample_prior_predictive (100 , var_names = ["a" ])
1970+ idata .add_groups ({"posterior" : idata .prior })
1971+ _ , _ , vectorized_no_intermediate = vectorize_over_posterior (
1972+ outputs = [b , c , d ],
1973+ posterior = idata .posterior ,
1974+ input_rvs = [a ],
1975+ allow_rvs_in_graph = True ,
1976+ )
1977+ [vectorized_intermediate_rvs ] = vectorize_over_posterior (
1978+ outputs = [d ],
1979+ posterior = idata .posterior ,
1980+ input_rvs = [a ],
1981+ allow_rvs_in_graph = True ,
1982+ )
1983+ assert vectorized_no_intermediate .type .shape == (1 , 100 )
1984+ assert vectorized_no_intermediate .type .shape == vectorized_intermediate_rvs .type .shape
1985+ a_ancestor1 = get_var_by_name ([vectorized_no_intermediate ], "a" )[0 ]
1986+ a_ancestor2 = get_var_by_name ([vectorized_intermediate_rvs ], "a" )[0 ]
1987+ assert isinstance (a_ancestor1 , TensorConstant )
1988+ assert np .array_equiv (a_ancestor1 .eval (), idata .posterior .a .data )
1989+ assert isinstance (a_ancestor2 , TensorConstant )
1990+ assert np .array_equiv (a_ancestor2 .eval (), idata .posterior .a .data )
0 commit comments