diff --git a/econml/score/ensemble_cate.py b/econml/score/ensemble_cate.py index de4d52fa3..d61690ddd 100644 --- a/econml/score/ensemble_cate.py +++ b/econml/score/ensemble_cate.py @@ -1,70 +1,254 @@ # Copyright (c) PyWhy contributors. All rights reserved. -# Licensed under the MIT License. +# Licensed under the MIT License import numpy as np from sklearn.utils.validation import check_array from .._cate_estimator import BaseCateEstimator, LinearCateEstimator -class EnsembleCateEstimator: +class EnsembleCateEstimator(BaseCateEstimator): """ A CATE estimator that represents a weighted ensemble of many CATE estimators. - Returns their weighted effect prediction. + Predicts treatment effects as the weighted average of predictions from base estimators. Parameters ---------- - cate_models : list of BaseCateEstimator objects - A list of fitted cate estimator objects that will be used in the ensemble. - The models are passed by reference, and not copied internally, because we - need the fitted objects, so any change to the passed models will affect - the internal predictions (e.g. if the input models are refitted). - weights : np.ndarray of shape (len(cate_models),) - The weight placed on each model. Weights must be non-positive. The - ensemble will predict effects based on the weighted average predictions - of the cate_models estiamtors, weighted by the corresponding weight in `weights`. - """ + cate_models : list of BaseCateEstimator + List of *fitted* CATE estimators. Models are held by reference — changes to them affect ensemble predictions. + All models must implement the methods being called (e.g., `effect`, `const_marginal_effect`). - def __init__(self, *, cate_models, weights): - self.cate_models = cate_models - self.weights = weights + weights : array-like of shape (n_models,) + Non-negative weights for each model. Must sum to > 0. If not normalized, will be normalized internally. + Weights determine contribution of each model to the ensemble prediction. - def effect(self, X=None, *, T0=0, T1=1): - return np.average([mdl.effect(X=X, T0=T0, T1=T1) for mdl in self.cate_models], - weights=self.weights, axis=0) - effect.__doc__ = BaseCateEstimator.effect.__doc__ + normalize_weights : bool, default=True + If True, weights are normalized to sum to 1. If False, raw weights are used. - def marginal_effect(self, T, X=None): - return np.average([mdl.marginal_effect(T, X=X) for mdl in self.cate_models], - weights=self.weights, axis=0) - marginal_effect.__doc__ = BaseCateEstimator.marginal_effect.__doc__ + Attributes + ---------- + n_models_ : int + Number of base models in the ensemble. - def const_marginal_effect(self, X=None): - if np.any([not hasattr(mdl, 'const_marginal_effect') for mdl in self.cate_models]): - raise ValueError("One of the base CATE models in parameter `cate_models` does not support " - "the `const_marginal_effect` method.") - return np.average([mdl.const_marginal_effect(X=X) for mdl in self.cate_models], - weights=self.weights, axis=0) - const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__ + d_t_ : int or None + Dimensionality of treatment (inferred from first model supporting `marginal_effect` or `const_marginal_effect`). + + d_y_ : int or None + Dimensionality of outcome (inferred similarly). + + Notes + ----- + - This class inherits from `BaseCateEstimator` to ensure compatibility with EconML APIs. + - Lazy inference of `d_t_`, `d_y_` avoids forcing all models to expose these unless needed. + - Supports heterogeneous models: some may support `effect`, others only `const_marginal_effect`. + """ + + def __init__(self, *, cate_models, weights, normalize_weights=True): + self.cate_models = cate_models + self.weights = weights + self.normalize_weights = normalize_weights @property def cate_models(self): + """List of base CATE estimators.""" return self._cate_models @cate_models.setter def cate_models(self, value): - if (not isinstance(value, list)) or (not np.all([isinstance(model, BaseCateEstimator) for model in value])): - raise ValueError('Parameter `cate_models` should be a list of `BaseCateEstimator` objects.') + if not isinstance(value, list) or len(value) == 0: + raise ValueError("`cate_models` must be a non-empty list.") + if not all(isinstance(model, BaseCateEstimator) for model in value): + raise ValueError("All elements in `cate_models` must be instances of `BaseCateEstimator`.") self._cate_models = value + # Invalidate cached metadata + self._d_t = None + self._d_y = None @property def weights(self): + """Weights assigned to each base model.""" return self._weights @weights.setter def weights(self, value): - weights = check_array(value, accept_sparse=False, ensure_2d=False, allow_nd=False, dtype='numeric', - force_all_finite=True) + weights = check_array(value, accept_sparse=False, ensure_2d=False, dtype='numeric', + force_all_finite=True, copy=True).ravel() + if weights.shape[0] != len(self.cate_models): + raise ValueError(f"Length of `weights` ({weights.shape[0]}) must match " + f"number of models ({len(self.cate_models)}).") if np.any(weights < 0): - raise ValueError("All weights in parameter `weights` must be non-negative.") + raise ValueError("All weights must be non-negative.") + if np.sum(weights) <= 0: + raise ValueError("Sum of weights must be positive.") + + if getattr(self, 'normalize_weights', True): + weights = weights / np.sum(weights) + self._weights = weights + + @property + def d_t(self): + """Treatment dimensionality (lazy inference).""" + if self._d_t is None: + self._infer_shapes() + return self._d_t + + @property + def d_y(self): + """Outcome dimensionality (lazy inference).""" + if self._d_y is None: + self._infer_shapes() + return self._d_y + + def _infer_shapes(self): + """Infer d_t and d_y from first model that supports const_marginal_effect or marginal_effect.""" + for mdl in self.cate_models: + if hasattr(mdl, 'const_marginal_effect'): + try: + # Try dummy call to infer shapes + dummy_X = np.zeros((1, 1)) # minimal shape + eff = mdl.const_marginal_effect(X=dummy_X) + if eff.ndim == 3: + _, d_y, d_t = eff.shape + self._d_t = d_t + self._d_y = d_y + return + elif eff.ndim == 2: + # Assume (n, d_t) and d_y=1 + self._d_t = eff.shape[1] + self._d_y = 1 + return + except Exception: + continue + elif hasattr(mdl, 'marginal_effect'): + try: + dummy_T = np.zeros((1, 1)) + dummy_X = np.zeros((1, 1)) + meff = mdl.marginal_effect(T=dummy_T, X=dummy_X) + if meff.ndim == 3: + _, d_y, d_t = meff.shape + self._d_t = d_t + self._d_y = d_y + return + except Exception: + continue + # Fallback: unknown + self._d_t = None + self._d_y = None + + def effect(self, X=None, *, T0=0, T1=1): + """ + Calculate the average treatment effect. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), optional + Features for each sample. + T0 : array-like or scalar, default=0 + Baseline treatment. + T1 : array-like or scalar, default=1 + Target treatment. + + Returns + ------- + τ : array-like of shape (n_samples,) or (n_samples, d_y) + Estimated treatment effects. + """ + if not self.cate_models: + raise ValueError("No models in ensemble.") + + predictions = [] + for mdl in self.cate_models: + if not hasattr(mdl, 'effect'): + raise AttributeError(f"Model {type(mdl).__name__} does not implement 'effect' method.") + pred = mdl.effect(X=X, T0=T0, T1=T1) + predictions.append(np.asarray(pred)) + + # Stack and validate shapes + stacked = np.stack(predictions, axis=0) # (n_models, n_samples, ...) + return np.average(stacked, weights=self.weights, axis=0) + + effect.__doc__ = BaseCateEstimator.effect.__doc__ + + def marginal_effect(self, T, X=None): + """ + Calculate the heterogeneous marginal effect. + + Parameters + ---------- + T : array-like of shape (n_samples, d_t) + Treatment values at which to calculate the effect. + X : array-like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + τ : array-like of shape (n_samples, d_y, d_t) + Estimated marginal effects. + """ + if not self.cate_models: + raise ValueError("No models in ensemble.") + + predictions = [] + for mdl in self.cate_models: + if not hasattr(mdl, 'marginal_effect'): + raise AttributeError(f"Model {type(mdl).__name__} does not implement 'marginal_effect' method.") + pred = mdl.marginal_effect(T=T, X=X) + pred = np.asarray(pred) + # Ensure 3D: (n, d_y, d_t) + if pred.ndim == 2: + pred = pred[:, None, :] # assume d_y=1 + elif pred.ndim != 3: + raise ValueError(f"Unexpected shape {pred.shape} from {type(mdl).__name__}.marginal_effect") + predictions.append(pred) + + stacked = np.stack(predictions, axis=0) # (n_models, n, d_y, d_t) + return np.average(stacked, weights=self.weights, axis=0) + + marginal_effect.__doc__ = BaseCateEstimator.marginal_effect.__doc__ + + def const_marginal_effect(self, X=None): + """ + Calculate the constant marginal CATE. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + τ : array-like of shape (n_samples, d_y, d_t) + Estimated constant marginal effects. + """ + if not self.cate_models: + raise ValueError("No models in ensemble.") + + predictions = [] + for mdl in self.cate_models: + if not hasattr(mdl, 'const_marginal_effect'): + raise AttributeError( + f"Model {type(mdl).__name__} does not implement 'const_marginal_effect' method." + ) + pred = mdl.const_marginal_effect(X=X) + pred = np.asarray(pred) + if pred.ndim == 2: + pred = pred[:, None, :] # assume d_y=1 + elif pred.ndim != 3: + raise ValueError(f"Unexpected shape {pred.shape} from {type(mdl).__name__}.const_marginal_effect") + predictions.append(pred) + + stacked = np.stack(predictions, axis=0) # (n_models, n, d_y, d_t) + return np.average(stacked, weights=self.weights, axis=0) + + const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__ + + def __repr__(self): + return (f"{self.__class__.__name__}(n_models={len(self.cate_models)}, " + f"normalize_weights={getattr(self, 'normalize_weights', True)})") + + def __str__(self): + model_types = [type(mdl).__name__ for mdl in self.cate_models] + return (f"Ensemble of {len(self.cate_models)} models: {model_types}\n" + f"Weights: {self.weights}") + diff --git a/econml/score/rscorer.py b/econml/score/rscorer.py index cf04ceb1a..77d9ae860 100644 --- a/econml/score/rscorer.py +++ b/econml/score/rscorer.py @@ -1,8 +1,9 @@ # Copyright (c) PyWhy contributors. All rights reserved. # Licensed under the MIT License. +from typing import List, Optional, Tuple, Union, Any from ..dml import LinearDML -from sklearn.base import clone +from sklearn.base import clone, BaseEstimator import numpy as np from scipy.special import softmax from .ensemble_cate import EnsembleCateEstimator @@ -32,222 +33,285 @@ class RScorer: This corresponds to the extra variance of the outcome explained by introducing heterogeneity in the effect as captured by the cate model, as opposed to always predicting a constant effect. - A negative score, means that the cate model performs even worse than a constant effect model - and hints at overfitting during training of the cate model. - - This method was also advocated in recent work of [Schuleretal2018]_ when compared among several alternatives - for causal model selection and introduced in the work of [NieWager2017]_. + A negative score means that the cate model performs worse than a constant effect model + and may indicate overfitting. Parameters ---------- model_y: estimator - The estimator for fitting the response to the features. Must implement - `fit` and `predict` methods. + The estimator for fitting the response to the features. Must implement `fit` and `predict`. model_t: estimator - The estimator for fitting the treatment to the features. Must implement - `fit` and `predict` methods. - - discrete_treatment: bool, default ``False`` - Whether the treatment values should be treated as categorical, rather than continuous, quantities - - discrete_outcome: bool, default ``False`` - Whether the outcome should be treated as binary - - categories: 'auto' or list, default 'auto' - The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values). - The first category will be treated as the control treatment. + The estimator for fitting the treatment to the features. Must implement `fit` and `predict`. - cv: int, cross-validation generator or an iterable, default 2 - Determines the cross-validation splitting strategy. - Possible inputs for cv are: + discrete_treatment: bool, default=False + Whether the treatment values should be treated as categorical. - - None, to use the default 3-fold cross-validation, - - integer, to specify the number of folds. - - :term:`CV splitter` - - An iterable yielding (train, test) splits as arrays of indices. + discrete_outcome: bool, default=False + Whether the outcome should be treated as binary. - For integer/None inputs, if the treatment is discrete - :class:`~sklearn.model_selection.StratifiedKFold` is used, else, - :class:`~sklearn.model_selection.KFold` is used - (with a random shuffle in either case). + categories: 'auto' or list, default='auto' + Categories to use when encoding discrete treatments. 'auto' uses unique sorted values. + The first category is treated as the control. - Unless an iterable is used, we call `split(concat[W, X], T)` to generate the splits. If all - W, X are None, then we call `split(ones((T.shape[0], 1)), T)`. + cv: int, cross-validation generator or iterable, default=2 + Determines the cross-validation splitting strategy. See sklearn docs for options. mc_iters: int, optional - The number of times to rerun the first stage models to reduce the variance of the nuisances. + Number of Monte Carlo iterations to reduce nuisance variance. - mc_agg: {'mean', 'median'}, default 'mean' - How to aggregate the nuisance value for each sample across the `mc_iters` monte carlo iterations of - cross-fitting. - - random_state : int, RandomState instance, or None, default None - - If int, random_state is the seed used by the random number generator; - If :class:`~numpy.random.mtrand.RandomState` instance, random_state is the random number generator; - If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used - by :mod:`np.random`. - - References - ---------- - .. [NieWager2017] X. Nie and S. Wager. - Quasi-Oracle Estimation of Heterogeneous Treatment Effects. - arXiv preprint arXiv:1712.04912, 2017. - ``_ - - .. [Schuleretal2018] Alejandro Schuler, Michael Baiocchi, Robert Tibshirani, Nigam Shah. - "A comparison of methods for model selection when estimating individual treatment effects." - Arxiv, 2018 - ``_ + mc_agg: {'mean', 'median'}, default='mean' + How to aggregate nuisance values across MC iterations. + random_state: int, RandomState instance or None, default=None + Controls randomness for reproducibility. """ def __init__(self, *, - model_y, - model_t, - discrete_treatment=False, - discrete_outcome=False, - categories='auto', - cv=2, - mc_iters=None, - mc_agg='mean', - random_state=None): + model_y: BaseEstimator, + model_t: BaseEstimator, + discrete_treatment: bool = False, + discrete_outcome: bool = False, + categories: Union[str, List] = 'auto', + cv: Union[int, Any] = 2, + mc_iters: Optional[int] = None, + mc_agg: str = 'mean', + random_state: Optional[Union[int, np.random.RandomState]] = None): self.model_y = clone(model_y, safe=False) self.model_t = clone(model_t, safe=False) self.discrete_treatment = discrete_treatment self.discrete_outcome = discrete_outcome - self.cv = cv self.categories = categories - self.random_state = random_state + self.cv = cv self.mc_iters = mc_iters self.mc_agg = mc_agg + self.random_state = random_state - def fit(self, y, T, X=None, W=None, sample_weight=None, groups=None): + # Internal state + self.lineardml_: Optional[LinearDML] = None + self.base_score_: Optional[float] = None + self.dx_: Optional[int] = None + + def fit(self, + y: np.ndarray, + T: np.ndarray, + X: Optional[np.ndarray] = None, + W: Optional[np.ndarray] = None, + sample_weight: Optional[np.ndarray] = None, + groups: Optional[np.ndarray] = None) -> 'RScorer': """ - Fit a baseline model to the data. + Fit residual models and compute baseline score. Parameters ---------- - Y: (n × d_y) matrix or vector of length n - Outcomes for each sample - T: (n × dₜ) matrix or vector of length n - Treatments for each sample - X: (n × dₓ) matrix, optional - Features for each sample - W: (n × d_w) matrix, optional - Controls for each sample - sample_weight: (n,) vector, optional - Weights for each row - groups: (n,) vector, optional - All rows corresponding to the same group will be kept together during splitting. - If groups is not None, the `cv` argument passed to this class's initializer - must support a 'groups' argument to its split method. + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + Outcome(s) for each sample. + T : array-like of shape (n_samples,) or (n_samples, n_treatments) + Treatment(s) for each sample. + X : array-like of shape (n_samples, n_features), optional + Features for heterogeneity. + W : array-like of shape (n_samples, n_controls), optional + Control variables. + sample_weight : array-like of shape (n_samples,), optional + Sample weights. + groups : array-like of shape (n_samples,), optional + Group labels for grouped CV splits. Returns ------- - self + self : RScorer + Fitted scorer. """ if X is None: raise ValueError("X cannot be None for the RScorer!") - self.lineardml_ = LinearDML(model_y=self.model_y, - model_t=self.model_t, - cv=self.cv, - discrete_treatment=self.discrete_treatment, - discrete_outcome=self.discrete_outcome, - categories=self.categories, - random_state=self.random_state, - mc_iters=self.mc_iters, - mc_agg=self.mc_agg) - self.lineardml_.fit(y, T, X=None, W=np.hstack([v for v in [X, W] if v is not None]), - sample_weight=sample_weight, groups=groups, cache_values=True) + # Combine X and W for controls in DML + W_full = np.hstack([v for v in [X, W] if v is not None]) if W is not None or X is not None else None + + self.lineardml_ = LinearDML( + model_y=self.model_y, + model_t=self.model_t, + cv=self.cv, + discrete_treatment=self.discrete_treatment, + discrete_outcome=self.discrete_outcome, + categories=self.categories, + random_state=self.random_state, + mc_iters=self.mc_iters, + mc_agg=self.mc_agg + ) + + self.lineardml_.fit( + y, T, X=None, W=W_full, + sample_weight=sample_weight, groups=groups, cache_values=True + ) + + if not hasattr(self.lineardml_, '_cached_values') or self.lineardml_._cached_values is None: + raise RuntimeError("LinearDML did not cache values. Ensure cache_values=True.") + self.base_score_ = self.lineardml_.score_ + if self.base_score_ <= 0: + raise ValueError(f"Base score must be positive. Got {self.base_score_}.") self.dx_ = X.shape[1] + return self - def score(self, cate_model): + def _get_X_from_cached_W(self) -> np.ndarray: + """Extract X from cached W (first dx_ columns).""" + if self.lineardml_ is None or self.dx_ is None: + raise RuntimeError("Must call fit() before score().") + W_cached = self.lineardml_._cached_values.W + return W_cached[:, :self.dx_] + + def _compute_loss(self, Y_res: np.ndarray, T_res: np.ndarray, effects: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> float: + """ + Compute mean squared error: E[(Yres - )^2] + + Parameters + ---------- + Y_res : (n, d_y) + T_res : (n, d_t) + effects : (n, d_y, d_t) + sample_weight : (n,), optional + + Returns + ------- + loss : float + """ + # Predicted residuals: sum over treatment dimension + # einsum: 'ijk,ik->ij' => for each sample i, output j: sum_k effects[i,j,k] * T_res[i,k] + Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res) + + sq_errors = (Y_res - Y_res_pred) ** 2 # (n, d_y) + + if sample_weight is not None: + # Weighted average over samples, then mean over outputs + loss = np.mean(np.average(sq_errors, weights=sample_weight, axis=0)) + else: + loss = np.mean(sq_errors) + + return loss + + def score(self, cate_model: Any) -> float: """ Score a CATE model against the baseline. Parameters ---------- - cate_model : instance of fitted BaseCateEstimator + cate_model : fitted estimator + Must have `const_marginal_effect(X)` method returning (n, d_y, d_t) array. Returns ------- - score : double - An analogue of the R-square loss for the causal setting. + score : float + R-squared style score. Higher is better. Can be negative. """ + if self.lineardml_ is None or self.base_score_ is None: + raise RuntimeError("Must call fit() before score().") + + # Validate cate_model interface + if not hasattr(cate_model, 'const_marginal_effect'): + raise ValueError("cate_model must implement 'const_marginal_effect(X)' method.") + Y_res, T_res = self.lineardml_._cached_values.nuisances - X = self.lineardml_._cached_values.W[:, :self.dx_] + X = self._get_X_from_cached_W() sample_weight = self.lineardml_._cached_values.sample_weight + + # Ensure 2D if Y_res.ndim == 1: - Y_res = Y_res.reshape((-1, 1)) + Y_res = Y_res.reshape(-1, 1) if T_res.ndim == 1: - T_res = T_res.reshape((-1, 1)) - effects = cate_model.const_marginal_effect(X).reshape((-1, Y_res.shape[1], T_res.shape[1])) - Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res).reshape(Y_res.shape) - if sample_weight is not None: - return 1 - np.mean(np.average((Y_res - Y_res_pred)**2, weights=sample_weight, axis=0)) / self.base_score_ - else: - return 1 - np.mean((Y_res - Y_res_pred) ** 2) / self.base_score_ + T_res = T_res.reshape(-1, 1) + + effects = cate_model.const_marginal_effect(X) + if effects.ndim != 3: + raise ValueError(f"Expected 3D effects (n, d_y, d_t), got shape {effects.shape}") + + loss = self._compute_loss(Y_res, T_res, effects, sample_weight) - def best_model(self, cate_models, return_scores=False): + # Guard against division by zero (shouldn't happen due to fit() check, but still) + if self.base_score_ <= 0: + return -np.inf if loss > 0 else 1.0 + + return 1 - loss / self.base_score_ + + def best_model(self, + cate_models: List[Any], + return_scores: bool = False + ) -> Union[Tuple[Any, float], Tuple[Any, float, List[float]]]: """ - Choose the best among a list of models. + Select the best model based on R-scores. Parameters ---------- - cate_models : list of instance of fitted BaseCateEstimator - return_scores : bool, default False - Whether to return the list scores of each model + cate_models : list of fitted estimators + return_scores : bool, default=False + If True, also return list of scores. Returns ------- - best_model : instance of fitted BaseCateEstimator - The model that achieves the best score - best_score : double - The score of the best model - scores : list of double - The list of scores for each of the input models. Returned only if `return_scores=True`. + best_model : estimator + best_score : float + scores : list of float, optional """ + if not cate_models: + raise ValueError("cate_models list is empty.") + rscores = [self.score(mdl) for mdl in cate_models] - best = np.nanargmax(rscores) + + # Handle all-NaN case + finite_scores = [s for s in rscores if np.isfinite(s)] + if not finite_scores: + raise ValueError("All model scores are invalid (NaN or inf).") + + best_idx = np.nanargmax(rscores) # nanargmax ignores NaNs + best_model = cate_models[best_idx] + best_score = rscores[best_idx] + if return_scores: - return cate_models[best], rscores[best], rscores + return best_model, best_score, rscores else: - return cate_models[best], rscores[best] - - def ensemble(self, cate_models, eta=1000.0, return_scores=False): + return best_model, best_score + + def ensemble(self, + cate_models: List[Any], + eta: float = 1000.0, + return_scores: bool = False + ) -> Union[Tuple[EnsembleCateEstimator, float], + Tuple[EnsembleCateEstimator, float, np.ndarray]]: """ - Ensemble a list of models based on their performance. + Create a weighted ensemble of models using softmax weights based on scores. Parameters ---------- - cate_models : list of instance of fitted BaseCateEstimator - eta : double, default 1000 - The soft-max parameter for the ensemble - return_scores : bool, default False - Whether to return the list scores of each model + cate_models : list of fitted estimators + eta : float, default=1000.0 + Temperature parameter for softmax weighting. + return_scores : bool, default=False + If True, also return raw scores. Returns ------- - ensemble_model : instance of fitted EnsembleCateEstimator - A fitted ensemble cate model that calculates effects based on a weighted - version of the input cate models, weighted by a softmax of their score - performance - ensemble_score : double - The score of the ensemble model - scores : list of double - The list of scores for each of the input models. Returned only if `return_scores=True`. + ensemble : EnsembleCateEstimator + ensemble_score : float + scores : array, optional """ + if not cate_models: + raise ValueError("cate_models list is empty.") + rscores = np.array([self.score(mdl) for mdl in cate_models]) goodinds = np.isfinite(rscores) + + if not np.any(goodinds): + raise ValueError("No valid (finite) scores to ensemble.") + + # Softmax weights on finite scores weights = softmax(eta * rscores[goodinds]) - goodmodels = [mdl for mdl, good in zip(cate_models, goodinds) if good] + goodmodels = [mdl for mdl, keep in zip(cate_models, goodinds) if keep] + ensemble = EnsembleCateEstimator(cate_models=goodmodels, weights=weights) ensemble_score = self.score(ensemble) + if return_scores: return ensemble, ensemble_score, rscores else: diff --git a/prototypes/orthogonal_forests/comparison_plots.py b/prototypes/orthogonal_forests/comparison_plots.py index 9fa77c961..1a238eb17 100644 --- a/prototypes/orthogonal_forests/comparison_plots.py +++ b/prototypes/orthogonal_forests/comparison_plots.py @@ -1,332 +1,558 @@ +""" +Treatment Effect Estimation Results Analysis and Visualization + +This module analyzes and visualizes results from various treatment effect estimation methods, +including bias, variance, RMSE, and R² comparisons across different experimental conditions. +""" + import argparse import copy import itertools +import os +import re +import sys +from pathlib import Path +from typing import Dict, List, Tuple, Optional, Any + import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np -import os import pandas as pd -import re -import sys -import time from joblib import Parallel, delayed -from matplotlib import rcParams, cm, rc +from matplotlib import rcParams from sklearn.metrics import r2_score +# Configure matplotlib matplotlib.rcParams['font.family'] = "serif" ################### -# Global settings # +# Constants # ################### -# Global plotting controls -# Control for support size, can control for more -plot_controls = ["support"] -label_order = ["ORF-CV", "ORF", "GRF-xW", "GRF-x", "GRF-Res", "HeteroDML-Lasso", "HeteroDML-RF"] -corresponding_str = ["OrthoForestCV", "OrthoForest", "GRF_Wx", "GRF_x", - "GRF_res_Wx", "HeteroDML", "ForestHeteroDML"] -################## -# File utilities # -################## -def has_plot_controls(fname, control_combination): - for c in control_combination: - if "_{0}_".format(c) not in fname: - return False - return True +PLOT_CONTROLS = ["support"] +LABEL_ORDER = ["ORF-CV", "ORF", "GRF-xW", "GRF-x", "GRF-Res", "HeteroDML-Lasso", "HeteroDML-RF"] +METHOD_MAPPING = { + "OrthoForestCV": "ORF-CV", + "OrthoForest": "ORF", + "GRF_Wx": "GRF-xW", + "GRF_x": "GRF-x", + "GRF_res_Wx": "GRF-Res", + "HeteroDML": "HeteroDML-Lasso", + "ForestHeteroDML": "HeteroDML-RF" +} +CORRESPONDING_STR = list(METHOD_MAPPING.keys()) -def get_file_key(fname): - if "GRF" in fname: - return "_" + "_".join(re.split("GRF_", fname)[0].split("_")[1:]) - else: - return "_" + "_".join(re.split("results", fname)[0].split("_")[1:]) +# Plot configuration +FIGURE_SIZE_JOINT = (10, 5) +FIGURE_SIZE_METRICS = (12, 3) +DPI_HIGH_RES = 300 +PERCENTILE_UPPER = 95 +PERCENTILE_LOWER = 5 -def sort_fnames(file_names): - sorted_file_names = [] - label_indices = [] - for i, s in enumerate(corresponding_str): - for f in file_names: - if ((f.split("_")[0]==s and "GRF" not in f) - or ("_{0}_".format(s) in f and "GRF" in f)): - sorted_file_names.append(f) - label_indices.append(i) - break - return sorted_file_names, np.array(label_order)[label_indices] - -def get_file_groups(agg_fnames, plot_controls): - all_file_names = {} - control_values = [] - for control in plot_controls: - vals = set() - for fname in agg_fnames: - control_prefix = control + '_' - val = re.search(control_prefix + '(\d+)', fname).group(1) - vals.add(control_prefix + val) - control_values.append(list(vals)) - control_combinations = list(itertools.product(*control_values)) - for control_combination in control_combinations: - file_names = [f for f in agg_fnames if has_plot_controls(f, control_combination)] - file_key = get_file_key(file_names[0]) - all_file_names[file_key], final_labels = sort_fnames(file_names) - return all_file_names, final_labels +# Color schemes +COLOR_INDICES = [0, 3, 12, 14, 15, 4, 6] -def merge_results(sf, input_dir, output_dir, split_files_seeds): - name_template = "{0}seed_{1}_{2}" - seeds = split_files_seeds[sf] - df = pd.read_csv(os.path.join(input_dir, name_template.format(sf[0], seeds[0], sf[1]))) - te_idx = len([c for c in df.columns if bool(re.search("TE_[0-9]", c))]) - for i, seed in enumerate(seeds[1:]): - new_df = pd.read_csv(os.path.join(input_dir, name_template.format(sf[0], seed, sf[1]))) - te_cols = [c for c in new_df.columns if bool(re.search("TE_[0-9]", c))] - for te_col in te_cols: - df["TE_"+str(te_idx)] = new_df[te_col] - te_idx += 1 - agg_fname = os.path.join(output_dir, sf[0]+sf[1]) - df.to_csv(agg_fname, index=False) +# Output directories +OUTPUT_SUBDIRS = ["jpg_low_res", "jpg_high_res", "pdf_low_res"] -def get_results(fname, dir_name): - df = pd.read_csv(os.path.join(dir_name, fname)) - return df[[c for c in df.columns if "x" in c]+[c for c in df.columns if "TE_" in c]] +################### +# Data Classes # +################### -def save_plots(fig, fname, lgd=None): - jpg_low_res_path = os.path.join(output_dir, "jpg_low_res") - if not os.path.exists(jpg_low_res_path): - os.makedirs(jpg_low_res_path) - jpg_high_res_path = os.path.join(output_dir, "jpg_high_res") - if not os.path.exists(jpg_high_res_path): - os.makedirs(jpg_high_res_path) - pdf_low_res_path = os.path.join(output_dir, "pdf_low_res") - if not os.path.exists(pdf_low_res_path): - os.makedirs(pdf_low_res_path) - if lgd is None: - fig.savefig(os.path.join(jpg_low_res_path, "{0}.png".format(fname)), bbox_inches='tight') - fig.savefig(os.path.join(jpg_high_res_path, "{0}.png".format(fname)), dpi=300, bbox_inches='tight') - fig.savefig(os.path.join(pdf_low_res_path, "{0}.pdf".format(fname)), bbox_inches='tight') - else: - fig.savefig(os.path.join(jpg_low_res_path, "{0}.png".format(fname)), bbox_inches='tight', bbox_extra_artists=(lgd,)) - fig.savefig(os.path.join(jpg_high_res_path, "{0}.png".format(fname)), dpi=300, bbox_inches='tight', bbox_extra_artists=(lgd,)) - fig.savefig(os.path.join(pdf_low_res_path, "{0}.pdf".format(fname)), bbox_inches='tight', bbox_extra_artists=(lgd,)) +class MetricResults: + """Container for metric calculation results.""" + + def __init__(self, mean: np.ndarray, std: np.ndarray): + self.mean = mean + self.std = std -################## -# Plotting utils # -################## -def get_r2(df): - r2_scores = np.array([r2_score(df["TE_hat"], df[c]) for c in df.columns if bool(re.search('TE_[0-9]+', c))]) - return r2_scores +class ExperimentResults: + """Container for all experimental results.""" + + def __init__(self): + self.bias = None + self.variance = None + self.rmse = None + self.r2 = None -def get_metrics(dfs): - biases = np.zeros((len(dfs[0]), len(dfs))) - variances = np.zeros((len(dfs[0]), len(dfs))) - rmses = np.zeros((len(dfs[0]), len(dfs))) - r2_scores = [] - for i, df in enumerate(dfs): - # bias - treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]] - bias = np.abs(np.mean(treatment_effects, axis=1) - df["TE_hat"]) - biases[:, i] = np.abs(np.mean(treatment_effects, axis=1) - df["TE_hat"]) - # var - variance = np.std(treatment_effects, axis=1) - variances[:, i] = np.std(treatment_effects, axis=1) - # rmse - rmse = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1) - rmses[:, i] = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1) - # r2 - r2_scores.append(get_r2(df)) - bias_lims = {"std": np.std(biases, axis=0), "mean": np.mean(biases, axis=0)} - var_lims = {"std": np.std(variances, axis=0), "mean": np.mean(variances, axis=0)} - rmse_lims = {"std": np.std(rmses, axis=0), "mean": np.mean(rmses, axis=0)} - print(r2_scores) - r2_lims = {"std": [np.std(r2_scores[i]) for i in range(len(r2_scores))], "mean": [np.mean(r2_scores[i]) for i in range(len(r2_scores))]} - return {"bias": bias_lims, "var": var_lims, "rmse": rmse_lims, "r2": r2_lims} +################### +# File Operations # +################### -def generic_joint_plots(file_key, dfs, labels, file_name_prefix): - m = min(4, len(dfs)) - n = np.ceil((len(dfs)) / m) - fig = plt.figure(figsize=(10, 5)) - ymax = max([max(df["TE_hat"]) for df in dfs])+1 - print(file_key) - print(len(dfs)) - print(labels) - for i, df in enumerate(dfs): - ax = fig.add_subplot(n, m, i+1) - treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]] - y = np.mean(treatment_effects, axis=1) - err_up = np.percentile(treatment_effects, 95, axis=1) - err_bottom = np.percentile(treatment_effects, 5, axis=1) - ax.fill_between(df["x0"], err_up, err_bottom, alpha=0.5) - if i == 0: - ax.plot(df["x0"], y, label='Mean estimate') - ax.plot(df["x0"], df["TE_hat"].values, 'b--', label='True effect') +class FileProcessor: + """Handles file operations and data loading.""" + + def __init__(self, input_dir: str, output_dir: str): + self.input_dir = Path(input_dir) + self.output_dir = Path(output_dir) + self._ensure_output_dirs() + + def _ensure_output_dirs(self) -> None: + """Create necessary output directories.""" + for subdir in OUTPUT_SUBDIRS: + (self.output_dir / subdir).mkdir(parents=True, exist_ok=True) + + def has_plot_controls(self, fname: str, control_combination: List[str]) -> bool: + """Check if filename contains all required control parameters.""" + return all(f"_{control}_" in fname for control in control_combination) + + def get_file_key(self, fname: str) -> str: + """Extract file key for grouping related files.""" + if "GRF" in fname: + return "_" + "_".join(re.split("GRF_", fname)[0].split("_")[1:]) + else: + return "_" + "_".join(re.split("results", fname)[0].split("_")[1:]) + + def sort_filenames(self, file_names: List[str]) -> Tuple[List[str], np.ndarray]: + """Sort filenames according to predefined method order.""" + sorted_file_names = [] + label_indices = [] + + for i, method_str in enumerate(CORRESPONDING_STR): + for fname in file_names: + if self._matches_method(fname, method_str): + sorted_file_names.append(fname) + label_indices.append(i) + break + + return sorted_file_names, np.array(LABEL_ORDER)[label_indices] + + def _matches_method(self, fname: str, method_str: str) -> bool: + """Check if filename matches a specific method.""" + if "GRF" not in fname: + return fname.split("_")[0] == method_str else: - ax.plot(df["x0"], y) - ax.plot(df["x0"], df["TE_hat"].values, 'b--', label=None) - if i%m==0: - ax.set_ylabel("Treatment effect") - ax.set_ylim(ymax=ymax) - ax.set_title(labels[i]) - if i + 1 > m*(n-1): - ax.set_xlabel("x") - fig.legend(loc=(0.8, 0.25)) - fig.tight_layout() - save_plots(fig, file_name_prefix) - plt.clf() + return f"_{method_str}_" in fname + + def get_file_groups(self, agg_fnames: List[str]) -> Tuple[Dict[str, List[str]], np.ndarray]: + """Group files by experimental conditions.""" + all_file_names = {} + control_values = self._extract_control_values(agg_fnames) + control_combinations = list(itertools.product(*control_values)) + + final_labels = None + for control_combination in control_combinations: + file_names = [f for f in agg_fnames + if self.has_plot_controls(f, control_combination)] + + if file_names: + file_key = self.get_file_key(file_names[0]) + sorted_names, labels = self.sort_filenames(file_names) + all_file_names[file_key] = sorted_names + if final_labels is None: + final_labels = labels + + return all_file_names, final_labels + + def _extract_control_values(self, agg_fnames: List[str]) -> List[List[str]]: + """Extract unique control parameter values from filenames.""" + control_values = [] + for control in PLOT_CONTROLS: + vals = set() + control_prefix = f"{control}_" + + for fname in agg_fnames: + match = re.search(f"{control_prefix}(\\d+)", fname) + if match: + vals.add(f"{control_prefix}{match.group(1)}") + + control_values.append(list(vals)) + + return control_values + + def merge_results(self, sf: Tuple[str, str], split_files_seeds: Dict) -> None: + """Merge results from multiple seed runs.""" + name_template = "{0}seed_{1}_{2}" + seeds = split_files_seeds[sf] + + try: + # Load first file + first_file = self.input_dir / name_template.format(sf[0], seeds[0], sf[1]) + df = pd.read_csv(first_file) + + te_idx = len([c for c in df.columns if re.search("TE_[0-9]", c)]) + + # Merge additional seeds + for seed in seeds[1:]: + seed_file = self.input_dir / name_template.format(sf[0], seed, sf[1]) + new_df = pd.read_csv(seed_file) + te_cols = [c for c in new_df.columns if re.search("TE_[0-9]", c)] + + for te_col in te_cols: + df[f"TE_{te_idx}"] = new_df[te_col] + te_idx += 1 + + # Save merged results + agg_fname = self.output_dir / f"{sf[0]}{sf[1]}" + df.to_csv(agg_fname, index=False) + + except Exception as e: + print(f"Error merging results for {sf}: {e}") + raise + + def get_results(self, fname: str) -> pd.DataFrame: + """Load and filter results data.""" + try: + df = pd.read_csv(self.output_dir / fname) + x_cols = [c for c in df.columns if "x" in c] + te_cols = [c for c in df.columns if "TE_" in c] + return df[x_cols + te_cols] + except Exception as e: + print(f"Error loading results from {fname}: {e}") + raise + +################### +# Analysis # +################### + +class MetricsCalculator: + """Calculates performance metrics for treatment effect estimation.""" + + @staticmethod + def calculate_r2(df: pd.DataFrame) -> np.ndarray: + """Calculate R² scores for all treatment effect columns.""" + te_cols = [c for c in df.columns if re.search('TE_[0-9]+', c)] + return np.array([r2_score(df["TE_hat"], df[col]) for col in te_cols]) + + @staticmethod + def calculate_metrics(dfs: List[pd.DataFrame]) -> ExperimentResults: + """Calculate bias, variance, RMSE, and R² for all dataframes.""" + n_obs = len(dfs[0]) + n_methods = len(dfs) + + biases = np.zeros((n_obs, n_methods)) + variances = np.zeros((n_obs, n_methods)) + rmses = np.zeros((n_obs, n_methods)) + r2_scores = [] + + for i, df in enumerate(dfs): + te_cols = [c for c in df.columns if re.search('TE_[0-9]+', c)] + treatment_effects = df[te_cols] + + # Calculate metrics + mean_te = np.mean(treatment_effects, axis=1) + biases[:, i] = np.abs(mean_te - df["TE_hat"]) + variances[:, i] = np.std(treatment_effects, axis=1) + rmses[:, i] = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1) + r2_scores.append(MetricsCalculator.calculate_r2(df)) + + # Create results object + results = ExperimentResults() + results.bias = MetricResults(np.mean(biases, axis=0), np.std(biases, axis=0)) + results.variance = MetricResults(np.mean(variances, axis=0), np.std(variances, axis=0)) + results.rmse = MetricResults(np.mean(rmses, axis=0), np.std(rmses, axis=0)) + results.r2 = MetricResults( + np.array([np.mean(r2_scores[i]) for i in range(len(r2_scores))]), + np.array([np.std(r2_scores[i]) for i in range(len(r2_scores))]) + ) + + return results + +################### +# Visualization # +################### -def metrics_subfig(dfs, ax, metric, c_scheme=0): - if c_scheme == 0: +class PlotGenerator: + """Generates various types of plots for results visualization.""" + + def __init__(self, output_dir: str): + self.output_dir = Path(output_dir) + + def save_plots(self, fig: plt.Figure, fname: str, lgd: Optional[Any] = None) -> None: + """Save figure in multiple formats.""" + save_kwargs = {'bbox_inches': 'tight'} + if lgd is not None: + save_kwargs['bbox_extra_artists'] = (lgd,) + + # Save in different formats and resolutions + fig.savefig(self.output_dir / "jpg_low_res" / f"{fname}.png", **save_kwargs) + fig.savefig(self.output_dir / "jpg_high_res" / f"{fname}.png", + dpi=DPI_HIGH_RES, **save_kwargs) + fig.savefig(self.output_dir / "pdf_low_res" / f"{fname}.pdf", **save_kwargs) + + def create_joint_plots(self, file_key: str, dfs: List[pd.DataFrame], + labels: List[str], file_name_prefix: str) -> None: + """Create joint treatment effect plots.""" + n_methods = len(dfs) + n_cols = min(4, n_methods) + n_rows = int(np.ceil(n_methods / n_cols)) + + fig = plt.figure(figsize=FIGURE_SIZE_JOINT) + ymax = max([df["TE_hat"].max() for df in dfs]) + 1 + + for i, df in enumerate(dfs): + ax = fig.add_subplot(n_rows, n_cols, i + 1) + + te_cols = [c for c in df.columns if re.search('TE_[0-9]+', c)] + treatment_effects = df[te_cols] + + y_mean = np.mean(treatment_effects, axis=1) + err_up = np.percentile(treatment_effects, PERCENTILE_UPPER, axis=1) + err_bottom = np.percentile(treatment_effects, PERCENTILE_LOWER, axis=1) + + ax.fill_between(df["x0"], err_up, err_bottom, alpha=0.5) + + if i == 0: + ax.plot(df["x0"], y_mean, label='Mean estimate') + ax.plot(df["x0"], df["TE_hat"], 'b--', label='True effect') + else: + ax.plot(df["x0"], y_mean) + ax.plot(df["x0"], df["TE_hat"], 'b--') + + if i % n_cols == 0: + ax.set_ylabel("Treatment effect") + + ax.set_ylim(ymax=ymax) + ax.set_title(labels[i]) + + if i + 1 > n_cols * (n_rows - 1): + ax.set_xlabel("x") + + fig.legend(loc=(0.8, 0.25)) + fig.tight_layout() + self.save_plots(fig, file_name_prefix) + plt.close(fig) + + def create_metrics_plots(self, file_key: str, dfs: List[pd.DataFrame], + labels: List[str], file_name_prefix: str) -> None: + """Create violin plots for bias, variance, and RMSE metrics.""" + metrics = ["bias", "variance", "rmse"] + fig = plt.figure(figsize=FIGURE_SIZE_METRICS) + + violin_bodies = [] + for i, metric in enumerate(metrics): + ax = fig.add_subplot(1, len(metrics), i + 1) + bodies = self._create_metric_subplot(dfs, ax, metric) + if i == 0: + violin_bodies = bodies + + lgd = fig.legend(violin_bodies, labels, ncol=len(labels), + loc='lower center', bbox_to_anchor=(0.5, 0), frameon=False) + fig.tight_layout() + fig.subplots_adjust(bottom=0.15) + self.save_plots(fig, file_name_prefix, lgd) + plt.close(fig) + + def _create_metric_subplot(self, dfs: List[pd.DataFrame], ax: plt.Axes, + metric: str) -> List[Any]: + """Create subplot for a specific metric.""" palette = plt.get_cmap('Set1') - else: - palette = plt.get_cmap('tab20b') - if metric == "bias": + + if metric == "bias": + data = self._calculate_bias_data(dfs) + ax.set_title("Bias") + elif metric == "variance": + data = self._calculate_variance_data(dfs) + ax.set_title("Variance") + elif metric == "rmse": + data = self._calculate_rmse_data(dfs) + ax.set_title("RMSE") + else: + raise ValueError(f"Unknown metric: {metric}") + + vparts = ax.violinplot(data, showmedians=True) + ax.set_xticks([]) + + # Style violin plots + for i, body in enumerate(vparts['bodies']): + color_idx = i if i < 5 else i + 1 + body.set_facecolor(palette(color_idx)) + body.set_edgecolor(palette(color_idx)) + body.set_alpha(0.9) + + # Style other violin plot elements + for element in ['cbars', 'cmins', 'cmaxes', 'cmedians']: + if element in vparts: + vparts[element].set_color('black') + vparts[element].set_alpha(0.7 if element != 'cbars' else 0.3) + if element == 'cbars': + vparts[element].set_linestyle('--') + + return vparts['bodies'] + + def _calculate_bias_data(self, dfs: List[pd.DataFrame]) -> np.ndarray: + """Calculate bias data for violin plot.""" biases = np.zeros((len(dfs[0]), len(dfs))) for i, df in enumerate(dfs): - treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]] - bias = np.abs(np.mean(treatment_effects, axis=1) - df["TE_hat"]) + te_cols = [c for c in df.columns if re.search('TE_[0-9]+', c)] + treatment_effects = df[te_cols] biases[:, i] = np.abs(np.mean(treatment_effects, axis=1) - df["TE_hat"]) - vparts = ax.violinplot(biases, showmedians=True) - ax.set_title("Bias") - elif metric=="variance": + return biases + + def _calculate_variance_data(self, dfs: List[pd.DataFrame]) -> np.ndarray: + """Calculate variance data for violin plot.""" variances = np.zeros((len(dfs[0]), len(dfs))) for i, df in enumerate(dfs): - treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]] - variance = np.std(treatment_effects, axis=1) + te_cols = [c for c in df.columns if re.search('TE_[0-9]+', c)] + treatment_effects = df[te_cols] variances[:, i] = np.std(treatment_effects, axis=1) - vparts = ax.violinplot(variances, showmedians=True) - ax.set_title("Variance") - elif metric=="rmse": + return variances + + def _calculate_rmse_data(self, dfs: List[pd.DataFrame]) -> np.ndarray: + """Calculate RMSE data for violin plot.""" rmses = np.zeros((len(dfs[0]), len(dfs))) for i, df in enumerate(dfs): - treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]] - rmse = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1) + te_cols = [c for c in df.columns if re.search('TE_[0-9]+', c)] + treatment_effects = df[te_cols] rmses[:, i] = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1) - vparts = ax.violinplot(rmses, showmedians=True) - ax.set_title("RMSE") - elif metric == "R2": - r2_scores = [] - for i, df in enumerate(dfs): - r2_scores.append(get_r2(df)) - vparts = ax.violinplot(r2_scores, showmedians=True) - ax.set_title("$R^2$") - else: - print("No such metric") - return 0 - cs = [0, 3, 12, 14, 15, 4, 6] - ax.set_xticks([]) - for i, pc in enumerate(vparts['bodies']): - if i < 5: - c = i - else: - c = i+1 - if c_scheme == 1: - c = cs[i] - pc.set_facecolor(palette(c)) - pc.set_edgecolor(palette(c)) - pc.set_alpha(0.9) - - alpha = 0.7 - vparts['cbars'].set_color('black') - vparts['cbars'].set_alpha(0.3) - vparts['cbars'].set_linestyle('--') - - vparts['cmins'].set_color('black') - vparts['cmins'].set_alpha(alpha) - - vparts['cmaxes'].set_color('black') - vparts['cmaxes'].set_alpha(alpha) - - vparts['cmedians'].set_color('black') - vparts['cmedians'].set_alpha(alpha) - return vparts['bodies'] - -def metrics_plots(file_key, dfs, labels, c_scheme, file_name_prefix): - metrics = ["bias", "variance", "rmse"] - m = 1 - n = len(metrics) - fig = plt.figure(figsize=(12*n/3, 3)) - for i, metric in enumerate(metrics): - ax = fig.add_subplot(m, n, i+1) - vbodies = metrics_subfig(dfs, ax, metric, c_scheme) - lgd = fig.legend(vbodies, labels, ncol=len(labels), loc='lower center', bbox_to_anchor=(0.5, 0), frameon=False) - fig.tight_layout() - fig.subplots_adjust(bottom=0.15) - save_plots(fig, file_name_prefix, lgd) - plt.clf() + return rmses + + def create_support_plots(self, all_metrics: Dict, labels: List[str], + file_name_prefix: str) -> None: + """Create plots showing metrics vs support size.""" + palette = plt.get_cmap('Set1') + x_values = sorted(all_metrics.keys()) + metrics = ["bias", "variance", "rmse"] + titles = ["Bias", "Variance", "RMSE"] + + fig = plt.figure(figsize=FIGURE_SIZE_METRICS) + plot_objects = [] + + for metric_idx, metric in enumerate(metrics): + ax = fig.add_subplot(1, len(metrics), metric_idx + 1) + + for i, label in enumerate(labels): + color_idx = i if i < 5 else i + 1 + + # Extract metric values across support sizes + err_values = np.array([all_metrics[x][metric].std[i] for x in x_values]) + mean_values = np.array([all_metrics[x][metric].mean[i] for x in x_values]) + + # Plot with error bands + fill = ax.fill_between(x_values, mean_values - err_values/6, + mean_values + err_values/6, + alpha=0.5, color=palette(color_idx)) + ax.plot(x_values, mean_values, label=label, color=palette(color_idx)) + + if metric_idx == 0: + plot_obj = copy.copy(fill) + plot_obj.set_alpha(1.0) + plot_objects.append(plot_obj) + + ax.set_title(titles[metric_idx]) + ax.set_xlabel("Support size") + + lgd = fig.legend(plot_objects, labels, ncol=len(labels), + loc='lower center', bbox_to_anchor=(0.5, 0), frameon=False) + fig.tight_layout() + fig.subplots_adjust(bottom=0.25) + self.save_plots(fig, file_name_prefix, lgd) + plt.close(fig) -def support_plots(all_metrics, labels, file_name_prefix): - palette = plt.get_cmap('Set1') - x = sorted(list(all_metrics.keys())) - metrics = ["bias", "var", "rmse"] - titles = ["Bias", "Variance", "RMSE"] - m = 1 - n = len(metrics) - fig = plt.figure(figsize=(12*n/3, 3)) - all_plots = [] - for it, metric in enumerate(metrics): - ax = fig.add_subplot(m, n, it+1) - for i, l in enumerate(labels): - if i < 5: - c = i - else: - c = i+1 - err = np.array([all_metrics[j][metric]["std"][i] for j in x]) - mid = np.array([all_metrics[j][metric]["mean"][i] for j in x]) - p = ax.fill_between(x, mid-err/6, mid+err/6, alpha=0.5, color=palette(c)) - ax.plot(x, mid, label=labels[i], color=palette(c)) - if it == 0: - p1 = copy.copy(p) - p1.set_alpha(1.0) - all_plots.append(p1) - ax.set_title(titles[it]) - ax.set_xlabel("Support size") - fig.legend(all_plots, labels, ncol=len(labels), loc='lower center', bbox_to_anchor=(0.5, 0), frameon=False) - fig.tight_layout() - fig.subplots_adjust(bottom=0.25) - save_plots(fig, file_name_prefix) - plt.clf() +################### +# Main Analysis # +################### -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--output_dir", type=str, help="Directory for saving results", default=".") - parser.add_argument("--input_dir", type=str, help="", default=".") - parser.add_argument("-merge", action='store_true') +def main(): + """Main analysis pipeline.""" + parser = argparse.ArgumentParser(description="Analyze treatment effect estimation results") + parser.add_argument("--output_dir", type=str, default=".", + help="Directory for saving results") + parser.add_argument("--input_dir", type=str, default=".", + help="Directory containing input files") + parser.add_argument("--merge", action='store_true', + help="Merge results from multiple seeds") + + args = parser.parse_args() - args = parser.parse_args(sys.argv[1:]) - input_dir = args.input_dir - output_dir = args.output_dir + # Initialize processors + file_processor = FileProcessor(args.input_dir, args.output_dir) + plot_generator = PlotGenerator(args.output_dir) - all_files = os.listdir(input_dir) + # Process files + all_files = os.listdir(args.input_dir) results_files = [f for f in all_files if f.endswith("results.csv") and "seed" in f] - split_files = set([(re.split("seed_[0-9]+_", f)[0], re.split("seed_[0-9]+_", f)[1]) for f in results_files]) - split_files_seeds = {k:[int(re.search("seed_(\d+)_", f).group(1)) for f in results_files if f.startswith(k[0]) and f.endswith(k[1])] for k in split_files} - name_template = "{0}seed_{1}_{2}" - agg_fnames = [sf[0] + sf[1] for sf in split_files] + + split_files = set([ + (re.split("seed_[0-9]+_", f)[0], re.split("seed_[0-9]+_", f)[1]) + for f in results_files + ]) + + split_files_seeds = { + k: [int(re.search("seed_(\\d+)_", f).group(1)) + for f in results_files if f.startswith(k[0]) and f.endswith(k[1])] + for k in split_files + } + + agg_fnames = [f"{sf[0]}{sf[1]}" for sf in split_files] + + # Merge results if requested if args.merge: - Parallel(n_jobs=-1, verbose=3)(delayed(merge_results)(sf, input_dir, output_dir, split_files_seeds) for sf in split_files) + print("Merging results from multiple seeds...") + Parallel(n_jobs=-1, verbose=3)( + delayed(file_processor.merge_results)(sf, split_files_seeds) + for sf in split_files + ) + + # Group files and generate plots + agg_file_groups, labels = file_processor.get_file_groups(agg_fnames) - agg_file_groups, labels = get_file_groups(agg_fnames, plot_controls) - print(agg_fnames) - print(agg_file_groups) all_metrics = {} metrics_by_xgroup = [{}, {}] - for g in agg_file_groups: - agg_file_group = agg_file_groups[g] - dfs = [get_results(fname, output_dir) for fname in agg_file_group] - all_metrics[int(re.search("support_" + '(\d+)', g).group(1))] = get_metrics(dfs) - # Infer feature dimension - n_x = len([c for c in dfs[0].columns if bool(re.search("x[0-9]", c))]) - if n_x == 1: - generic_joint_plots(g, dfs, labels, "{0}{1}".format("Example", g)) - metrics_plots(g, dfs, labels, 0, "{0}{1}".format("Metrics", g)) + + for group_key in agg_file_groups: + print(f"Processing group: {group_key}") + agg_file_group = agg_file_groups[group_key] + dfs = [file_processor.get_results(fname) for fname in agg_file_group] + + # Calculate metrics + metrics = MetricsCalculator.calculate_metrics(dfs) + support_size = int(re.search("support_(\\d+)", group_key).group(1)) + all_metrics[support_size] = { + "bias": metrics.bias, + "variance": metrics.variance, + "rmse": metrics.rmse, + "r2": metrics.r2 + } + + # Determine feature dimensionality + n_features = len([c for c in dfs[0].columns if re.search("x[0-9]", c)]) + + if n_features == 1: + # Single feature case + plot_generator.create_joint_plots( + group_key, dfs, labels, f"Example{group_key}" + ) + plot_generator.create_metrics_plots( + group_key, dfs, labels, f"Metrics{group_key}" + ) else: - metrics_plots(g, dfs, labels, 0, "{0}_x1={2}{1}".format("Metrics", g, "all")) + # Multiple feature case + plot_generator.create_metrics_plots( + group_key, dfs, labels, f"Metrics{group_key}_x1=all" + ) + + # Create plots for each feature group for i in range(2): - dfs1 = [df[df["x1"]==i] for df in dfs] - generic_joint_plots(g, dfs1, labels, "{0}_x1={2}{1}".format("Example", g, str(i))) - metrics_plots(g, dfs1, labels, 0, "{0}_x1={2}{1}".format("Metrics", g, str(i))) - metrics_by_xgroup[i][int(re.search("support_" + '(\d+)', g).group(1))] = get_metrics(dfs1) - # Metrics by support plots - if n_x == 1: - support_plots(all_metrics, labels, "{0}".format("Metrics_by_support")) + dfs_subset = [df[df["x1"] == i] for df in dfs] + plot_generator.create_joint_plots( + group_key, dfs_subset, labels, f"Example{group_key}_x1={i}" + ) + plot_generator.create_metrics_plots( + group_key, dfs_subset, labels, f"Metrics{group_key}_x1={i}" + ) + + # Store metrics for support plots + subset_metrics = MetricsCalculator.calculate_metrics(dfs_subset) + metrics_by_xgroup[i][support_size] = { + "bias": subset_metrics.bias, + "variance": subset_metrics.variance, + "rmse": subset_metrics.rmse, + "r2": subset_metrics.r2 + } + + # Generate support size comparison plots + if n_features == 1: + plot_generator.create_support_plots(all_metrics, labels, "Metrics_by_support") else: - support_plots(all_metrics, labels, "{0}_x1={1}".format("Metrics_by_support", "all")) + plot_generator.create_support_plots(all_metrics, labels, "Metrics_by_support_x1=all") for i in range(2): - support_plots(metrics_by_xgroup[i], labels, "{0}_x1={1}".format("Metrics_by_support", str(i))) \ No newline at end of file + plot_generator.create_support_plots( + metrics_by_xgroup[i], labels, f"Metrics_by_support_x1={i}" + ) + + print("Analysis complete!") + +if __name__ == "__main__": + main()