diff --git a/book/ate/propensity_score_and_dml.ipynb b/book/ate/propensity_score_and_dml.ipynb index c2e7639..2cff6cd 100644 --- a/book/ate/propensity_score_and_dml.ipynb +++ b/book/ate/propensity_score_and_dml.ipynb @@ -11,9 +11,4268 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Matching\n", - "- IPW, AIPW, Doubly Robust Estimator\n", - "- Double Machine Learning (비모수 버전의 Regression 처럼 활용 가능)" + "- IPW, AIPW, Doubly Robust Estimator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### **Propensity Score 추정**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "출처: https://matheusfacure.github.io/python-causality-handbook/11-Propensity-Score.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "IPW와 AIPW, Doubly Robust 모두 Propensity Score를 활용한 개념들이기 때문에 먼저 Propensity Score부터 간단하게 구해보겠습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "from causalinference import CausalModel" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
| \n", + " | schoolid | \n", + "intervention | \n", + "achievement_score | \n", + "success_expect | \n", + "ethnicity | \n", + "gender | \n", + "frst_in_family | \n", + "school_urbanicity | \n", + "school_mindset | \n", + "school_achievement | \n", + "school_ethnic_minority | \n", + "school_poverty | \n", + "school_size | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 259 | \n", + "73 | \n", + "1 | \n", + "1.480828 | \n", + "5 | \n", + "1 | \n", + "2 | \n", + "0 | \n", + "1 | \n", + "-0.462945 | \n", + "0.652608 | \n", + "-0.515202 | \n", + "-0.169849 | \n", + "0.173954 | \n", + "
| 3435 | \n", + "76 | \n", + "0 | \n", + "-0.987277 | \n", + "5 | \n", + "13 | \n", + "1 | \n", + "1 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "
| 9963 | \n", + "4 | \n", + "0 | \n", + "-0.152340 | \n", + "5 | \n", + "2 | \n", + "2 | \n", + "1 | \n", + "0 | \n", + "-2.289636 | \n", + "0.190797 | \n", + "0.875012 | \n", + "-0.724801 | \n", + "0.761781 | \n", + "
| 4488 | \n", + "67 | \n", + "0 | \n", + "0.358336 | \n", + "6 | \n", + "14 | \n", + "1 | \n", + "0 | \n", + "4 | \n", + "-1.115337 | \n", + "1.053089 | \n", + "0.315755 | \n", + "0.054586 | \n", + "1.862187 | \n", + "
| 2637 | \n", + "16 | \n", + "1 | \n", + "1.360920 | \n", + "6 | \n", + "4 | \n", + "1 | \n", + "0 | \n", + "1 | \n", + "-0.538975 | \n", + "1.433826 | \n", + "-0.033161 | \n", + "-0.982274 | \n", + "1.591641 | \n", + "
| \n", + " | schoolid | \n", + "intervention | \n", + "achievement_score | \n", + "success_expect | \n", + "ethnicity | \n", + "gender | \n", + "frst_in_family | \n", + "school_urbanicity | \n", + "school_mindset | \n", + "school_achievement | \n", + "school_ethnic_minority | \n", + "school_poverty | \n", + "school_size | \n", + "propensity_score | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "76 | \n", + "1 | \n", + "0.277359 | \n", + "6 | \n", + "4 | \n", + "2 | \n", + "1 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "0.315271 | \n", + "
| 1 | \n", + "76 | \n", + "1 | \n", + "-0.449646 | \n", + "4 | \n", + "12 | \n", + "2 | \n", + "1 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "0.263482 | \n", + "
| 2 | \n", + "76 | \n", + "1 | \n", + "0.769703 | \n", + "6 | \n", + "4 | \n", + "2 | \n", + "0 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "0.343781 | \n", + "
| 3 | \n", + "76 | \n", + "1 | \n", + "-0.121763 | \n", + "6 | \n", + "4 | \n", + "2 | \n", + "0 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "0.343781 | \n", + "
| 4 | \n", + "76 | \n", + "1 | \n", + "1.526147 | \n", + "6 | \n", + "4 | \n", + "1 | \n", + "0 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "0.367474 | \n", + "
| ... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
| 10386 | \n", + "1 | \n", + "0 | \n", + "0.808867 | \n", + "7 | \n", + "4 | \n", + "2 | \n", + "1 | \n", + "3 | \n", + "1.185986 | \n", + "-1.129889 | \n", + "1.009875 | \n", + "1.005063 | \n", + "-1.174702 | \n", + "0.324195 | \n", + "
| 10387 | \n", + "1 | \n", + "0 | \n", + "-0.156063 | \n", + "7 | \n", + "4 | \n", + "2 | \n", + "1 | \n", + "3 | \n", + "1.185986 | \n", + "-1.129889 | \n", + "1.009875 | \n", + "1.005063 | \n", + "-1.174702 | \n", + "0.324195 | \n", + "
| 10388 | \n", + "1 | \n", + "0 | \n", + "0.370820 | \n", + "2 | \n", + "15 | \n", + "1 | \n", + "1 | \n", + "3 | \n", + "1.185986 | \n", + "-1.129889 | \n", + "1.009875 | \n", + "1.005063 | \n", + "-1.174702 | \n", + "0.248792 | \n", + "
| 10389 | \n", + "1 | \n", + "0 | \n", + "-0.396297 | \n", + "5 | \n", + "4 | \n", + "1 | \n", + "1 | \n", + "3 | \n", + "1.185986 | \n", + "-1.129889 | \n", + "1.009875 | \n", + "1.005063 | \n", + "-1.174702 | \n", + "0.303049 | \n", + "
| 10390 | \n", + "1 | \n", + "0 | \n", + "0.478970 | \n", + "5 | \n", + "1 | \n", + "2 | \n", + "1 | \n", + "3 | \n", + "1.185986 | \n", + "-1.129889 | \n", + "1.009875 | \n", + "1.005063 | \n", + "-1.174702 | \n", + "0.256371 | \n", + "
10391 rows × 14 columns
\n", + "AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=False, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=False, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge())Ridge()
Ridge()
IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000))LogisticRegression(max_iter=1000)
LogisticRegression(max_iter=1000)
AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=True, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=True, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge())Ridge()
Ridge()
IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000))LogisticRegression(max_iter=1000)
LogisticRegression(max_iter=1000)
| \n", + " | temp | \n", + "weekday | \n", + "cost | \n", + "price | \n", + "sales | \n", + "
|---|---|---|---|---|---|
| 0 | \n", + "17.3 | \n", + "6 | \n", + "1.5 | \n", + "5.6 | \n", + "173 | \n", + "
| 1 | \n", + "25.4 | \n", + "3 | \n", + "0.3 | \n", + "4.9 | \n", + "196 | \n", + "
| 2 | \n", + "23.3 | \n", + "5 | \n", + "1.5 | \n", + "7.6 | \n", + "207 | \n", + "
| 3 | \n", + "26.9 | \n", + "1 | \n", + "0.3 | \n", + "5.3 | \n", + "241 | \n", + "
| 4 | \n", + "20.2 | \n", + "1 | \n", + "1.0 | \n", + "7.2 | \n", + "227 | \n", + "
| coef | std err | t | P>|t| | [0.025 | 0.975] | \n", + "|
|---|---|---|---|---|---|---|
| Intercept | 0.0106 | 0.072 | 0.148 | 0.883 | -0.131 | 0.152 | \n", + "
| price_res | -3.9228 | 0.071 | -54.962 | 0.000 | -4.063 | -3.783 | \n", + "
| coef | std err | t | P>|t| | [0.025 | 0.975] | \n", + "|
|---|---|---|---|---|---|---|
| Intercept | 192.9679 | 1.013 | 190.414 | 0.000 | 190.981 | 194.954 | \n", + "
| price | 1.2294 | 0.162 | 7.575 | 0.000 | 0.911 | 1.547 | \n", + "
| \n", + " | nifa | \n", + "net_tfa | \n", + "tw | \n", + "age | \n", + "inc | \n", + "fsize | \n", + "educ | \n", + "db | \n", + "marr | \n", + "twoearn | \n", + "e401 | \n", + "p401 | \n", + "pira | \n", + "hown | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "0.0 | \n", + "0.0 | \n", + "4500.0 | \n", + "47 | \n", + "6765.0 | \n", + "2 | \n", + "8 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 1 | \n", + "6215.0 | \n", + "1015.0 | \n", + "22390.0 | \n", + "36 | \n", + "28452.0 | \n", + "1 | \n", + "16 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 2 | \n", + "0.0 | \n", + "-2000.0 | \n", + "-2000.0 | \n", + "37 | \n", + "3300.0 | \n", + "6 | \n", + "12 | \n", + "1 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "
| 3 | \n", + "15000.0 | \n", + "15000.0 | \n", + "155000.0 | \n", + "58 | \n", + "52590.0 | \n", + "2 | \n", + "16 | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 4 | \n", + "0.0 | \n", + "0.0 | \n", + "58000.0 | \n", + "32 | \n", + "21804.0 | \n", + "1 | \n", + "11 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| e401 | \n", + "6141.513231 | \n", + "1458.804562 | \n", + "4.209963 | \n", + "0.000026 | \n", + "3285.053486 | \n", + "9000.717633 | \n", + "
| \n", + " | theta.lower | \n", + "theta.upper | \n", + "
|---|---|---|
| 0 | \n", + "[2362.9817493722544] | \n", + "[9922.984266571242] | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| e401 | \n", + "6141.513231 | \n", + "1458.804562 | \n", + "4.209963 | \n", + "0.000026 | \n", + "3285.053486 | \n", + "9000.717633 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| p401 | \n", + "7197.609515 | \n", + "1952.711128 | \n", + "3.685957 | \n", + "0.000228 | \n", + "3370.366032 | \n", + "11024.852997 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| p401 | \n", + "6825.693722 | \n", + "2798.752919 | \n", + "2.438834 | \n", + "0.014735 | \n", + "1269.817391 | \n", + "12311.146167 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| p401 | \n", + "2081.049819 | \n", + "8891.713712 | \n", + "0.234044 | \n", + "0.814951 | \n", + "-15487.24316 | \n", + "19508.491103 | \n", + "
| \n", + " | X1 | \n", + "X2 | \n", + "X3 | \n", + "X4 | \n", + "X5 | \n", + "X6 | \n", + "X7 | \n", + "X8 | \n", + "X9 | \n", + "X10 | \n", + "y | \n", + "d | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "-0.368577 | \n", + "-0.688886 | \n", + "0.793315 | \n", + "-0.934066 | \n", + "-1.015731 | \n", + "-1.630031 | \n", + "-1.474996 | \n", + "-0.477593 | \n", + "-0.676716 | \n", + "0.552010 | \n", + "-1.063518 | \n", + "0.0 | \n", + "
| 1 | \n", + "0.078426 | \n", + "-1.028731 | \n", + "0.755885 | \n", + "-0.223044 | \n", + "-0.311049 | \n", + "-0.059540 | \n", + "-0.123388 | \n", + "-0.508408 | \n", + "-0.094020 | \n", + "0.292126 | \n", + "-0.327716 | \n", + "1.0 | \n", + "
| 2 | \n", + "-2.899021 | \n", + "-1.294123 | \n", + "-0.884821 | \n", + "0.421903 | \n", + "-0.290983 | \n", + "-0.740970 | \n", + "-2.104354 | \n", + "-0.020588 | \n", + "0.710170 | \n", + "0.135842 | \n", + "-0.150180 | \n", + "0.0 | \n", + "
| 3 | \n", + "0.502005 | \n", + "0.902920 | \n", + "-0.158726 | \n", + "0.529506 | \n", + "0.012832 | \n", + "0.987503 | \n", + "-0.935100 | \n", + "0.523039 | \n", + "0.016426 | \n", + "0.363400 | \n", + "0.529042 | \n", + "1.0 | \n", + "
| 4 | \n", + "-1.843018 | \n", + "0.170705 | \n", + "0.712846 | \n", + "0.118400 | \n", + "-0.128942 | \n", + "0.128420 | \n", + "0.333902 | \n", + "-1.666221 | \n", + "-0.455829 | \n", + "1.606466 | \n", + "-0.154075 | \n", + "0.0 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| d | \n", + "-0.127561 | \n", + "0.129644 | \n", + "-0.983928 | \n", + "0.325151 | \n", + "-0.381659 | \n", + "0.126538 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| d | \n", + "0.552643 | \n", + "0.115694 | \n", + "4.776756 | \n", + "0.000002 | \n", + "0.325886 | \n", + "0.779399 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| 0 | \n", + "-0.079156 | \n", + "0.118008 | \n", + "-0.670763 | \n", + "0.502371 | \n", + "-0.304181 | \n", + "0.152136 | \n", + "
| 1 | \n", + "0.498375 | \n", + "0.125547 | \n", + "3.969641 | \n", + "0.000072 | \n", + "0.252308 | \n", + "0.744441 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| 1 vs 0 | \n", + "0.772675 | \n", + "0.204353 | \n", + "3.781072 | \n", + "0.001665 | \n", + "0.267597 | \n", + "1.173201 | \n", + "