From bed787202a3e6baa066544bea444f23670279ced Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:25:16 +0200 Subject: [PATCH 1/9] add cov type to blp test --- doubleml/utils/tests/test_blp.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/doubleml/utils/tests/test_blp.py b/doubleml/utils/tests/test_blp.py index 9201b79a..25df6bc4 100644 --- a/doubleml/utils/tests/test_blp.py +++ b/doubleml/utils/tests/test_blp.py @@ -19,8 +19,14 @@ def ci_level(request): return request.param +@pytest.fixture(scope='module', + params=["nonrobust", "HC0", "HC1", "HC2", "HC3"]) +def cov_type(request): + return request.param + + @pytest.fixture(scope='module') -def dml_blp_fixture(ci_joint, ci_level): +def dml_blp_fixture(ci_joint, ci_level, cov_type): n = 50 np.random.seed(42) random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 3))) @@ -30,7 +36,7 @@ def dml_blp_fixture(ci_joint, ci_level): blp_obj = copy.copy(blp) blp.fit() - blp_manual = fit_blp(random_signal, random_basis) + blp_manual = fit_blp(random_signal, random_basis, cov_type) np.random.seed(42) ci_1 = blp.confint(random_basis, joint=ci_joint, level=ci_level, n_rep_boot=1000) @@ -49,7 +55,7 @@ def dml_blp_fixture(ci_joint, ci_level): 'values': blp.blp_model.fittedvalues, 'values_manual': blp_manual.fittedvalues, 'omega': blp.blp_omega, - 'omega_manual': blp_manual.cov_HC0, + 'omega_manual': blp_manual.cov_params().to_numpy(), 'basis': blp.basis, 'signal': blp.orth_signal, 'ci_1': ci_1, From 636479614f57faa29b4cc21ed07e0397471ead55 Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:25:42 +0200 Subject: [PATCH 2/9] add cov_type to manual implmentation --- doubleml/utils/tests/_utils_blp_manual.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doubleml/utils/tests/_utils_blp_manual.py b/doubleml/utils/tests/_utils_blp_manual.py index 3d9b5721..c64545aa 100644 --- a/doubleml/utils/tests/_utils_blp_manual.py +++ b/doubleml/utils/tests/_utils_blp_manual.py @@ -5,8 +5,8 @@ import pandas as pd -def fit_blp(orth_signal, basis): - blp_model = sm.OLS(orth_signal, basis).fit() +def fit_blp(orth_signal, basis, cov_type, **kwargs): + blp_model = sm.OLS(orth_signal, basis).fit(cov_type=cov_type, **kwargs) return blp_model @@ -15,7 +15,7 @@ def blp_confint(blp_model, basis, joint=False, level=0.95, n_rep_boot=500): alpha = 1 - level g_hat = blp_model.predict(basis) - blp_omega = blp_model.cov_HC0 + blp_omega = blp_model.cov_params().to_numpy() blp_se = np.sqrt((basis.dot(blp_omega) * basis).sum(axis=1)) From 7049c1d3806407a5e6a201a6e953f96ec07ed884 Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:30:52 +0200 Subject: [PATCH 3/9] add cov_type to blp --- doubleml/utils/blp.py | 15 ++++++++++++--- doubleml/utils/tests/test_blp.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/doubleml/utils/blp.py b/doubleml/utils/blp.py index bfdf7671..24bd807b 100644 --- a/doubleml/utils/blp.py +++ b/doubleml/utils/blp.py @@ -110,18 +110,27 @@ def summary(self): columns=col_names) return df_summary - def fit(self): + def fit(self, cov_type='HC0', **kwargs): """ Estimate DoubleMLBLP models. + Parameters + ---------- + cov_type : str + The covariance type to be used in the estimation. Default is ``'HC0'``. + See :meth:`statsmodels.regression.linear_model.OLS.fit` for more information. + + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit`. + Returns ------- self : object """ # fit the best-linear-predictor of the orthogonal signal with respect to the grid - self._blp_model = sm.OLS(self._orth_signal, self._basis).fit() - self._blp_omega = self._blp_model.cov_HC0 + self._blp_model = sm.OLS(self._orth_signal, self._basis).fit(cov_type=cov_type, **kwargs) + self._blp_omega = self._blp_model.cov_params().to_numpy() return self diff --git a/doubleml/utils/tests/test_blp.py b/doubleml/utils/tests/test_blp.py index 25df6bc4..3ae8e85e 100644 --- a/doubleml/utils/tests/test_blp.py +++ b/doubleml/utils/tests/test_blp.py @@ -35,7 +35,7 @@ def dml_blp_fixture(ci_joint, ci_level, cov_type): blp = dml.DoubleMLBLP(random_signal, random_basis) blp_obj = copy.copy(blp) - blp.fit() + blp.fit(cov_type=cov_type) blp_manual = fit_blp(random_signal, random_basis, cov_type) np.random.seed(42) From f15d73e3d0f2691057963b63736ce8f323e2447a Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:40:34 +0200 Subject: [PATCH 4/9] add kwargs check --- doubleml/utils/tests/test_blp.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/doubleml/utils/tests/test_blp.py b/doubleml/utils/tests/test_blp.py index 3ae8e85e..ba19470b 100644 --- a/doubleml/utils/tests/test_blp.py +++ b/doubleml/utils/tests/test_blp.py @@ -25,9 +25,17 @@ def cov_type(request): return request.param +@pytest.fixture(scope='module', + params=[True, False]) +def use_t(request): + return request.param + + @pytest.fixture(scope='module') -def dml_blp_fixture(ci_joint, ci_level, cov_type): +def dml_blp_fixture(ci_joint, ci_level, cov_type, use_t): n = 50 + kwargs = {'cov_type': cov_type, 'use_t': use_t} + np.random.seed(42) random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 3))) random_signal = np.random.normal(0, 1, size=(n, )) @@ -35,8 +43,8 @@ def dml_blp_fixture(ci_joint, ci_level, cov_type): blp = dml.DoubleMLBLP(random_signal, random_basis) blp_obj = copy.copy(blp) - blp.fit(cov_type=cov_type) - blp_manual = fit_blp(random_signal, random_basis, cov_type) + blp.fit(**kwargs) + blp_manual = fit_blp(random_signal, random_basis, **kwargs) np.random.seed(42) ci_1 = blp.confint(random_basis, joint=ci_joint, level=ci_level, n_rep_boot=1000) From c093205363a1ca870b0c33bd52e39b48eff6b9a3 Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:44:08 +0200 Subject: [PATCH 5/9] add default test to blp --- doubleml/utils/tests/test_blp.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/doubleml/utils/tests/test_blp.py b/doubleml/utils/tests/test_blp.py index ba19470b..38c1fff4 100644 --- a/doubleml/utils/tests/test_blp.py +++ b/doubleml/utils/tests/test_blp.py @@ -118,6 +118,23 @@ def test_dml_blp_return_types(dml_blp_fixture): assert isinstance(dml_blp_fixture['unfitted_blp_model'].summary, pd.DataFrame) +@pytest.mark.ci +def test_dml_blp_defaults(): + n = 50 + np.random.seed(42) + random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 3))) + random_signal = np.random.normal(0, 1, size=(n, )) + + blp = dml.DoubleMLBLP(random_signal, random_basis) + blp.fit() + + assert np.allclose(blp.blp_omega, + blp.blp_model.cov_HC0, + rtol=1e-9, atol=1e-4) + + assert blp._is_gate is False + + @pytest.mark.ci def test_doubleml_exception_blp(): random_basis = pd.DataFrame(np.random.normal(0, 1, size=(2, 3))) From 88c8ec85b8c6e53f29da46d03c5f1121084290bd Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:59:06 +0200 Subject: [PATCH 6/9] add kwargs to plr cate and gate --- doubleml/plm/plr.py | 15 +++++++++++---- doubleml/plm/tests/test_plr.py | 14 ++++++++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/doubleml/plm/plr.py b/doubleml/plm/plr.py index d5810b97..3f0a26ea 100644 --- a/doubleml/plm/plr.py +++ b/doubleml/plm/plr.py @@ -341,7 +341,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_ return res - def cate(self, basis, is_gate=False): + def cate(self, basis, is_gate=False, **kwargs): """ Calculate conditional average treatment effects (CATE) for a given basis. @@ -350,10 +350,14 @@ def cate(self, basis, is_gate=False): basis : :class:`pandas.DataFrame` The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of predictors. + is_gate : bool Indicates whether the basis is constructed for GATEs (dummy-basis). Default is ``False``. + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``. + Returns ------- model : :class:`doubleML.DoubleMLBLP` @@ -374,10 +378,10 @@ def cate(self, basis, is_gate=False): basis=D_basis, is_gate=is_gate, ) - model.fit() + model.fit(**kwargs) return model - def gate(self, groups): + def gate(self, groups, **kwargs): """ Calculate group average treatment effects (GATE) for groups. @@ -388,6 +392,9 @@ def gate(self, groups): Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str). + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``. + Returns ------- model : :class:`doubleML.DoubleMLBLP` @@ -407,7 +414,7 @@ def gate(self, groups): if any(groups.sum(0) <= 5): warnings.warn('At least one group effect is estimated with less than 6 observations.') - model = self.cate(groups, is_gate=True) + model = self.cate(groups, is_gate=True, **kwargs) return model def _partial_out(self): diff --git a/doubleml/plm/tests/test_plr.py b/doubleml/plm/tests/test_plr.py index 46cbba2e..43e0e216 100644 --- a/doubleml/plm/tests/test_plr.py +++ b/doubleml/plm/tests/test_plr.py @@ -301,8 +301,14 @@ def test_dml_plr_ols_manual_boot(dml_plr_ols_manual_fixture): rtol=1e-9, atol=1e-4) +@pytest.fixture(scope='module', + params=["nonrobust", "HC0", "HC1", "HC2", "HC3"]) +def cov_type(request): + return request.param + + @pytest.mark.ci -def test_dml_plr_cate_gate(score): +def test_dml_plr_cate_gate(score, cov_type): n = 9 # collect data @@ -318,7 +324,7 @@ def test_dml_plr_cate_gate(score): score=score) dml_plr_obj.fit() random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5))) - cate = dml_plr_obj.cate(random_basis) + cate = dml_plr_obj.cate(random_basis, cov_type=cov_type) assert isinstance(cate, dml.DoubleMLBLP) assert isinstance(cate.confint(), pd.DataFrame) @@ -328,7 +334,7 @@ def test_dml_plr_cate_gate(score): columns=['Group 1', 'Group 2']) msg = ('At least one group effect is estimated with less than 6 observations.') with pytest.warns(UserWarning, match=msg): - gate_1 = dml_plr_obj.gate(groups_1) + gate_1 = dml_plr_obj.gate(groups_1, cov_type=cov_type) assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_1.confint(), pd.DataFrame) assert all(gate_1.confint().index == groups_1.columns.tolist()) @@ -337,7 +343,7 @@ def test_dml_plr_cate_gate(score): groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n)) msg = ('At least one group effect is estimated with less than 6 observations.') with pytest.warns(UserWarning, match=msg): - gate_2 = dml_plr_obj.gate(groups_2) + gate_2 = dml_plr_obj.gate(groups_2, cov_type=cov_type) assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_2.confint(), pd.DataFrame) assert all(gate_2.confint().index == ["Group_1", "Group_2"]) From cb02bca189efdff8f8d685610901feeef220f77a Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:04:24 +0200 Subject: [PATCH 7/9] add kwargs to cate and gate irm --- doubleml/irm/irm.py | 15 +++++++++++---- doubleml/irm/tests/test_irm.py | 14 ++++++++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/doubleml/irm/irm.py b/doubleml/irm/irm.py index 1b1695c6..e5acd45d 100644 --- a/doubleml/irm/irm.py +++ b/doubleml/irm/irm.py @@ -431,7 +431,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_ return res - def cate(self, basis, is_gate=False): + def cate(self, basis, is_gate=False, **kwargs): """ Calculate conditional average treatment effects (CATE) for a given basis. @@ -440,10 +440,14 @@ def cate(self, basis, is_gate=False): basis : :class:`pandas.DataFrame` The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of predictors. + is_gate : bool Indicates whether the basis is constructed for GATEs (dummy-basis). Default is ``False``. + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``. + Returns ------- model : :class:`doubleML.DoubleMLBLP` @@ -462,10 +466,10 @@ def cate(self, basis, is_gate=False): orth_signal = self.psi_elements['psi_b'].reshape(-1) # fit the best linear predictor model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate) - model.fit() + model.fit(**kwargs) return model - def gate(self, groups): + def gate(self, groups, **kwargs): """ Calculate group average treatment effects (GATE) for groups. @@ -476,6 +480,9 @@ def gate(self, groups): Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str). + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``. + Returns ------- model : :class:`doubleML.DoubleMLBLP` @@ -495,7 +502,7 @@ def gate(self, groups): if any(groups.sum(0) <= 5): warnings.warn('At least one group effect is estimated with less than 6 observations.') - model = self.cate(groups, is_gate=True) + model = self.cate(groups, is_gate=True, **kwargs) return model def policy_tree(self, features, depth=2, **tree_params): diff --git a/doubleml/irm/tests/test_irm.py b/doubleml/irm/tests/test_irm.py index a28cb3bb..dc6f0cf2 100644 --- a/doubleml/irm/tests/test_irm.py +++ b/doubleml/irm/tests/test_irm.py @@ -187,8 +187,14 @@ def test_dml_irm_sensitivity_rho0(dml_irm_fixture): rtol=1e-9, atol=1e-4) +@pytest.fixture(scope='module', + params=["nonrobust", "HC0", "HC1", "HC2", "HC3"]) +def cov_type(request): + return request.param + + @pytest.mark.ci -def test_dml_irm_cate_gate(): +def test_dml_irm_cate_gate(cov_type): n = 9 # collect data np.random.seed(42) @@ -207,7 +213,7 @@ def test_dml_irm_cate_gate(): dml_irm_obj.fit() # create a random basis random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5))) - cate = dml_irm_obj.cate(random_basis) + cate = dml_irm_obj.cate(random_basis, cov_type=cov_type) assert isinstance(cate, dml.utils.blp.DoubleMLBLP) assert isinstance(cate.confint(), pd.DataFrame) @@ -216,7 +222,7 @@ def test_dml_irm_cate_gate(): columns=['Group 1', 'Group 2']) msg = ('At least one group effect is estimated with less than 6 observations.') with pytest.warns(UserWarning, match=msg): - gate_1 = dml_irm_obj.gate(groups_1) + gate_1 = dml_irm_obj.gate(groups_1, cov_type=cov_type) assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_1.confint(), pd.DataFrame) assert all(gate_1.confint().index == groups_1.columns.to_list()) @@ -225,7 +231,7 @@ def test_dml_irm_cate_gate(): groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n)) msg = ('At least one group effect is estimated with less than 6 observations.') with pytest.warns(UserWarning, match=msg): - gate_2 = dml_irm_obj.gate(groups_2) + gate_2 = dml_irm_obj.gate(groups_2, cov_type=cov_type) assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_2.confint(), pd.DataFrame) assert all(gate_2.confint().index == ["Group_1", "Group_2"]) From 67c4c58017f9a453940e9fc2afb741b42b2bc307 Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:13:01 +0200 Subject: [PATCH 8/9] add apo kwargs for cate and gate --- doubleml/irm/apo.py | 15 +++++++++++---- doubleml/irm/tests/test_apo.py | 17 +++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/doubleml/irm/apo.py b/doubleml/irm/apo.py index 93c3c0df..91e028d1 100644 --- a/doubleml/irm/apo.py +++ b/doubleml/irm/apo.py @@ -389,7 +389,7 @@ def _check_data(self, obj_dml_data): return - def capo(self, basis, is_gate=False): + def capo(self, basis, is_gate=False, **kwargs): """ Calculate conditional average potential outcomes (CAPO) for a given basis. @@ -398,10 +398,14 @@ def capo(self, basis, is_gate=False): basis : :class:`pandas.DataFrame` The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of predictors. + is_gate : bool Indicates whether the basis is constructed for GATE/GAPOs (dummy-basis). Default is ``False``. + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``. + Returns ------- model : :class:`doubleML.DoubleMLBLP` @@ -420,10 +424,10 @@ def capo(self, basis, is_gate=False): orth_signal = self.psi_elements['psi_b'].reshape(-1) # fit the best linear predictor model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate) - model.fit() + model.fit(**kwargs) return model - def gapo(self, groups): + def gapo(self, groups, **kwargs): """ Calculate group average potential outcomes (GAPO) for groups. @@ -434,6 +438,9 @@ def gapo(self, groups): Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str). + **kwargs: dict + Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``. + Returns ------- model : :class:`doubleML.DoubleMLBLP` @@ -453,5 +460,5 @@ def gapo(self, groups): if any(groups.sum(0) <= 5): warnings.warn('At least one group effect is estimated with less than 6 observations.') - model = self.capo(groups, is_gate=True) + model = self.capo(groups, is_gate=True, **kwargs) return model diff --git a/doubleml/irm/tests/test_apo.py b/doubleml/irm/tests/test_apo.py index 7082e399..ad962e8e 100644 --- a/doubleml/irm/tests/test_apo.py +++ b/doubleml/irm/tests/test_apo.py @@ -200,8 +200,14 @@ def test_dml_apo_sensitivity(dml_apo_fixture): rtol=1e-9, atol=1e-4) +@pytest.fixture(scope='module', + params=["nonrobust", "HC0", "HC1", "HC2", "HC3"]) +def cov_type(request): + return request.param + + @pytest.mark.ci -def test_dml_apo_capo_gapo(treatment_level): +def test_dml_apo_capo_gapo(treatment_level, cov_type): n = 20 # collect data np.random.seed(42) @@ -221,25 +227,28 @@ def test_dml_apo_capo_gapo(treatment_level): dml_obj.fit() # create a random basis random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5))) - capo = dml_obj.capo(random_basis) + capo = dml_obj.capo(random_basis, cov_type=cov_type) assert isinstance(capo, dml.utils.blp.DoubleMLBLP) assert isinstance(capo.confint(), pd.DataFrame) + assert capo.blp_model.cov_type == cov_type groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= -1.0, obj_dml_data.data['X1'] > 0.2]), columns=['Group 1', 'Group 2']) msg = ('At least one group effect is estimated with less than 6 observations.') with pytest.warns(UserWarning, match=msg): - gapo_1 = dml_obj.gapo(groups_1) + gapo_1 = dml_obj.gapo(groups_1, cov_type=cov_type) assert isinstance(gapo_1, dml.utils.blp.DoubleMLBLP) assert isinstance(gapo_1.confint(), pd.DataFrame) assert all(gapo_1.confint().index == groups_1.columns.to_list()) + assert gapo_1.blp_model.cov_type == cov_type np.random.seed(42) groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n, p=[0.1, 0.9])) msg = ('At least one group effect is estimated with less than 6 observations.') with pytest.warns(UserWarning, match=msg): - gapo_2 = dml_obj.gapo(groups_2) + gapo_2 = dml_obj.gapo(groups_2, cov_type=cov_type) assert isinstance(gapo_2, dml.utils.blp.DoubleMLBLP) assert isinstance(gapo_2.confint(), pd.DataFrame) assert all(gapo_2.confint().index == ["Group_1", "Group_2"]) + assert gapo_2.blp_model.cov_type == cov_type From 0e64d8480e56a7e5c4b7190bd7c8db26c3688a27 Mon Sep 17 00:00:00 2001 From: Sven Klaassen <47529404+SvenKlaassen@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:14:08 +0200 Subject: [PATCH 9/9] test cov_type gate and cate irm and plr --- doubleml/irm/tests/test_irm.py | 3 +++ doubleml/plm/tests/test_plr.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/doubleml/irm/tests/test_irm.py b/doubleml/irm/tests/test_irm.py index dc6f0cf2..fde402b3 100644 --- a/doubleml/irm/tests/test_irm.py +++ b/doubleml/irm/tests/test_irm.py @@ -216,6 +216,7 @@ def test_dml_irm_cate_gate(cov_type): cate = dml_irm_obj.cate(random_basis, cov_type=cov_type) assert isinstance(cate, dml.utils.blp.DoubleMLBLP) assert isinstance(cate.confint(), pd.DataFrame) + assert cate.blp_model.cov_type == cov_type groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= 0, obj_dml_data.data['X1'] > 0.2]), @@ -226,6 +227,7 @@ def test_dml_irm_cate_gate(cov_type): assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_1.confint(), pd.DataFrame) assert all(gate_1.confint().index == groups_1.columns.to_list()) + assert gate_1.blp_model.cov_type == cov_type np.random.seed(42) groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n)) @@ -235,6 +237,7 @@ def test_dml_irm_cate_gate(cov_type): assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_2.confint(), pd.DataFrame) assert all(gate_2.confint().index == ["Group_1", "Group_2"]) + assert gate_2.blp_model.cov_type == cov_type @pytest.fixture(scope='module', diff --git a/doubleml/plm/tests/test_plr.py b/doubleml/plm/tests/test_plr.py index 43e0e216..43de605a 100644 --- a/doubleml/plm/tests/test_plr.py +++ b/doubleml/plm/tests/test_plr.py @@ -327,6 +327,7 @@ def test_dml_plr_cate_gate(score, cov_type): cate = dml_plr_obj.cate(random_basis, cov_type=cov_type) assert isinstance(cate, dml.DoubleMLBLP) assert isinstance(cate.confint(), pd.DataFrame) + assert cate.blp_model.cov_type == cov_type groups_1 = pd.DataFrame( np.column_stack([obj_dml_data.data['X1'] <= 0, @@ -338,6 +339,7 @@ def test_dml_plr_cate_gate(score, cov_type): assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_1.confint(), pd.DataFrame) assert all(gate_1.confint().index == groups_1.columns.tolist()) + assert gate_1.blp_model.cov_type == cov_type np.random.seed(42) groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n)) @@ -347,3 +349,4 @@ def test_dml_plr_cate_gate(score, cov_type): assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP) assert isinstance(gate_2.confint(), pd.DataFrame) assert all(gate_2.confint().index == ["Group_1", "Group_2"]) + assert gate_2.blp_model.cov_type == cov_type