-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
To replicate
import numpy as np
from contextualized.easy import ContextualizedBayesianNetworks
n_samples = 100
n_contexts = 5
n_features = 10
n_bootstraps = 3
C = np.random.uniform(-1, 1, size=(n_samples, n_contexts))
X = np.random.normal(0, 1, size=(n_samples, n_features))
cbn = ContextualizedBayesianNetworks(n_bootstraps=n_bootstraps)
cbn.fit(C, X, max_epochs=1)
y_pred = cbn.predict(C, X) # Problematic line
# Extra tests
assert y_pred.shape == (n_samples, n_features)
y_pred_avg = cbn.predict(C, X, individual_preds=False)
assert y_pred_avg.shape == (n_samples, n_features)
y_pred_individual = cbn.predict(C, X, individual_preds=True)
assert y_pred_individual.shape == (n_bootstraps, n_samples, n_features)
Metadata
Metadata
Assignees
Labels
No labels