From 1f5127e58a299db915724e9bd994c7c129e6b4f6 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 18:32:29 -0500 Subject: [PATCH] Add backend-aware test parameter scaling for pure Python CI performance Introduce ci_params session-scoped fixture that scales bootstrap iterations and TROP grid sizes when running without Rust backend, reducing CI time while preserving code path coverage. Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 2 + tests/conftest.py | 42 +++++++++ tests/test_estimators.py | 79 ++++++++++------- tests/test_methodology_callaway.py | 10 ++- tests/test_methodology_did.py | 20 +++-- tests/test_staggered.py | 133 +++++++++++++++++------------ tests/test_sun_abraham.py | 43 ++++++---- tests/test_trop.py | 86 +++++++++++-------- tests/test_wild_bootstrap.py | 131 ++++++++++++++++------------ 9 files changed, 341 insertions(+), 205 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index eab6ab1..2739dd2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -375,6 +375,8 @@ Tests mirror the source modules: - `tests/test_pretrends.py` - Tests for pre-trends power analysis - `tests/test_datasets.py` - Tests for dataset loading functions +Session-scoped `ci_params` fixture in `conftest.py` scales bootstrap iterations and TROP grid sizes in pure Python mode — use `ci_params.bootstrap(n)` and `ci_params.grid(values)` in new tests with `n_bootstrap >= 20`. + ### Test Writing Guidelines **For fallback/error handling paths:** diff --git a/tests/conftest.py b/tests/conftest.py index c744808..5b98b34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ to avoid import-time subprocess latency. """ +import math import os import subprocess @@ -81,3 +82,44 @@ def test_comparison_with_r(require_r): """ if not r_available: pytest.skip("R or did package not available") + + +# ============================================================================= +# CI Performance: Backend-Aware Parameter Scaling +# ============================================================================= + +from diff_diff._backend import HAS_RUST_BACKEND + +_PURE_PYTHON_MODE = ( + os.environ.get("DIFF_DIFF_BACKEND", "auto").lower() == "python" + or not HAS_RUST_BACKEND +) + + +class CIParams: + """Scale test parameters in pure Python mode for CI performance. + + When Rust backend is available, all values pass through unchanged. + In pure Python mode, bootstrap iterations and LOOCV grids are scaled + down to reduce CI time while preserving code path coverage. + """ + + @staticmethod + def bootstrap(n: int) -> int: + """Scale bootstrap iterations. Guaranteed monotonic: bootstrap(n+1) >= bootstrap(n).""" + if not _PURE_PYTHON_MODE or n <= 10: + return n + return max(11, int(math.sqrt(n) * 1.6)) + + @staticmethod + def grid(values: list) -> list: + """Scale TROP lambda grids. Keeps first, middle, last for grids > 3 elements.""" + if not _PURE_PYTHON_MODE or len(values) <= 3: + return values + return [values[0], values[len(values) // 2], values[-1]] + + +@pytest.fixture(scope="session") +def ci_params(): + """Backend-aware parameter scaling for CI performance.""" + return CIParams() diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 141bd4e..7558231 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -1957,9 +1957,10 @@ def single_treated_unit_data(self): return pd.DataFrame(data) - def test_basic_fit(self, sdid_panel_data): + def test_basic_fit(self, sdid_panel_data, ci_params): """Test basic SDID model fitting.""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -1975,9 +1976,10 @@ def test_basic_fit(self, sdid_panel_data): assert results.n_treated == 5 assert results.n_control == 25 - def test_att_direction(self, sdid_panel_data): + def test_att_direction(self, sdid_panel_data, ci_params): """Test that ATT is estimated in correct direction.""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2036,9 +2038,10 @@ def test_unit_weights_nonnegative(self, sdid_panel_data): for w in results.unit_weights.values(): assert w >= 0 - def test_single_treated_unit(self, single_treated_unit_data): + def test_single_treated_unit(self, single_treated_unit_data, ci_params): """Test SDID with a single treated unit (classic SC scenario).""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( single_treated_unit_data, outcome="outcome", @@ -2101,9 +2104,10 @@ def test_placebo_inference(self, sdid_panel_data): assert len(results.placebo_effects) > 0 assert results.se > 0 - def test_bootstrap_inference(self, sdid_panel_data): + def test_bootstrap_inference(self, sdid_panel_data, ci_params): """Test bootstrap-based inference.""" - sdid = SyntheticDiD(variance_method="bootstrap", n_bootstrap=100, seed=42) + n_boot = ci_params.bootstrap(100) + sdid = SyntheticDiD(variance_method="bootstrap", n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2114,7 +2118,7 @@ def test_bootstrap_inference(self, sdid_panel_data): ) assert results.variance_method == "bootstrap" - assert results.n_bootstrap == 100 + assert results.n_bootstrap == n_boot assert results.se > 0 assert results.conf_int[0] < results.att < results.conf_int[1] @@ -2174,9 +2178,10 @@ def test_pre_treatment_fit(self, sdid_panel_data): assert results.pre_treatment_fit is not None assert results.pre_treatment_fit >= 0 - def test_summary_output(self, sdid_panel_data): + def test_summary_output(self, sdid_panel_data, ci_params): """Test that summary produces string output.""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2192,9 +2197,10 @@ def test_summary_output(self, sdid_panel_data): assert "ATT" in summary assert "Unit Weights" in summary - def test_to_dict(self, sdid_panel_data): + def test_to_dict(self, sdid_panel_data, ci_params): """Test conversion to dictionary.""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2211,9 +2217,10 @@ def test_to_dict(self, sdid_panel_data): assert "n_post_periods" in result_dict assert "pre_treatment_fit" in result_dict - def test_to_dataframe(self, sdid_panel_data): + def test_to_dataframe(self, sdid_panel_data, ci_params): """Test conversion to DataFrame.""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2228,9 +2235,10 @@ def test_to_dataframe(self, sdid_panel_data): assert len(df) == 1 assert "att" in df.columns - def test_repr(self, sdid_panel_data): + def test_repr(self, sdid_panel_data, ci_params): """Test string representation.""" - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2244,9 +2252,10 @@ def test_repr(self, sdid_panel_data): assert "SyntheticDiDResults" in repr_str assert "ATT=" in repr_str - def test_is_significant_property(self, sdid_panel_data): + def test_is_significant_property(self, sdid_panel_data, ci_params): """Test is_significant property.""" - sdid = SyntheticDiD(n_bootstrap=100, seed=42) + n_boot = ci_params.bootstrap(100) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2356,12 +2365,13 @@ def test_auto_infer_post_periods(self, sdid_panel_data): assert results.pre_periods == [0, 1, 2, 3] assert results.post_periods == [4, 5, 6, 7] - def test_with_covariates(self, sdid_panel_data): + def test_with_covariates(self, sdid_panel_data, ci_params): """Test SDID with covariates.""" # Add a covariate sdid_panel_data["size"] = np.random.normal(100, 10, len(sdid_panel_data)) - sdid = SyntheticDiD(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2375,9 +2385,10 @@ def test_with_covariates(self, sdid_panel_data): assert results is not None assert sdid.is_fitted_ - def test_confidence_interval_contains_estimate(self, sdid_panel_data): + def test_confidence_interval_contains_estimate(self, sdid_panel_data, ci_params): """Test that confidence interval contains the estimate.""" - sdid = SyntheticDiD(n_bootstrap=100, seed=42) + n_boot = ci_params.bootstrap(100) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) results = sdid.fit( sdid_panel_data, outcome="outcome", @@ -2390,9 +2401,10 @@ def test_confidence_interval_contains_estimate(self, sdid_panel_data): lower, upper = results.conf_int assert lower < results.att < upper - def test_reproducibility_with_seed(self, sdid_panel_data): + def test_reproducibility_with_seed(self, sdid_panel_data, ci_params): """Test that results are reproducible with the same seed.""" - results1 = SyntheticDiD(n_bootstrap=50, seed=42).fit( + n_boot = ci_params.bootstrap(50) + results1 = SyntheticDiD(n_bootstrap=n_boot, seed=42).fit( sdid_panel_data, outcome="outcome", treatment="treated", @@ -2401,7 +2413,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data): post_periods=[4, 5, 6, 7] ) - results2 = SyntheticDiD(n_bootstrap=50, seed=42).fit( + results2 = SyntheticDiD(n_bootstrap=n_boot, seed=42).fit( sdid_panel_data, outcome="outcome", treatment="treated", @@ -2413,7 +2425,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data): assert results1.att == results2.att assert results1.se == results2.se - def test_insufficient_pre_periods_warning(self): + def test_insufficient_pre_periods_warning(self, ci_params): """Test that SDID warns with very few pre-treatment periods.""" np.random.seed(42) @@ -2448,7 +2460,8 @@ def test_insufficient_pre_periods_warning(self): df = pd.DataFrame(data) - sdid = SyntheticDiD(n_bootstrap=30, seed=42) + n_boot = ci_params.bootstrap(30) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) # Should work but may warn about few pre-periods # (Depending on implementation - some may warn, some may not) @@ -2465,7 +2478,7 @@ def test_insufficient_pre_periods_warning(self): assert np.isfinite(results.att) assert results.se > 0 - def test_single_pre_period_edge_case(self): + def test_single_pre_period_edge_case(self, ci_params): """Test SDID with single pre-treatment period (extreme edge case).""" np.random.seed(42) @@ -2499,7 +2512,8 @@ def test_single_pre_period_edge_case(self): df = pd.DataFrame(data) - sdid = SyntheticDiD(n_bootstrap=30, seed=42) + n_boot = ci_params.bootstrap(30) + sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42) # With single pre-period, time weights will be trivially [1.0] results = sdid.fit( @@ -2516,7 +2530,7 @@ def test_single_pre_period_edge_case(self): # Time weights should have single entry assert len(results.time_weights) == 1 - def test_more_pre_periods_than_control_units(self): + def test_more_pre_periods_than_control_units(self, ci_params): """Test SDID when n_pre_periods > n_control_units (underdetermined).""" np.random.seed(42) @@ -2551,7 +2565,8 @@ def test_more_pre_periods_than_control_units(self): df = pd.DataFrame(data) # Use regularization to help with underdetermined system - sdid = SyntheticDiD(lambda_reg=1.0, n_bootstrap=30, seed=42) + n_boot = ci_params.bootstrap(30) + sdid = SyntheticDiD(lambda_reg=1.0, n_bootstrap=n_boot, seed=42) results = sdid.fit( df, diff --git a/tests/test_methodology_callaway.py b/tests/test_methodology_callaway.py index 42b0cea..ceea641 100644 --- a/tests/test_methodology_callaway.py +++ b/tests/test_methodology_callaway.py @@ -801,7 +801,7 @@ class TestSEFormulas: """Tests for standard error formula verification.""" @pytest.mark.slow - def test_analytical_se_close_to_bootstrap_se(self): + def test_analytical_se_close_to_bootstrap_se(self, ci_params): """ Analytical and bootstrap SEs should be within 20%. @@ -812,6 +812,7 @@ def test_analytical_se_close_to_bootstrap_se(self): This test is marked slow because it uses 499 bootstrap iterations for thorough validation of SE convergence. """ + n_boot = ci_params.bootstrap(499) data = generate_staggered_data( n_units=300, n_periods=8, @@ -821,7 +822,7 @@ def test_analytical_se_close_to_bootstrap_se(self): ) cs_anal = CallawaySantAnna(n_bootstrap=0) - cs_boot = CallawaySantAnna(n_bootstrap=499, seed=42) + cs_boot = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results_anal = cs_anal.fit( data, outcome='outcome', unit='unit', @@ -893,12 +894,13 @@ def test_bootstrap_weight_moments_webb(self): var_w = np.var(weights) assert abs(var_w - 1.0) < 0.05, f"Webb Var(w) should be ~1.0, got {var_w}" - def test_bootstrap_produces_valid_inference(self): + def test_bootstrap_produces_valid_inference(self, ci_params): """Test that bootstrap produces valid inference with p-values and CIs. Uses 99 bootstrap iterations - sufficient to verify the mechanism works without being slow for CI runs. """ + n_boot = ci_params.bootstrap(99) data = generate_staggered_data( n_units=100, n_periods=6, @@ -907,7 +909,7 @@ def test_bootstrap_produces_valid_inference(self): seed=42 ) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', unit='unit', time='period', first_treat='first_treat' diff --git a/tests/test_methodology_did.py b/tests/test_methodology_did.py index b3cdf00..7d91a6a 100644 --- a/tests/test_methodology_did.py +++ b/tests/test_methodology_did.py @@ -787,8 +787,9 @@ def test_vcov_positive_semidefinite(self): class TestWildBootstrapInference: """Tests for wild cluster bootstrap inference.""" - def test_wild_bootstrap_produces_valid_se(self): + def test_wild_bootstrap_produces_valid_se(self, ci_params): """Test wild bootstrap produces finite, positive SE.""" + n_boot = ci_params.bootstrap(199) data = generate_clustered_did_data( n_clusters=15, cluster_size=10, @@ -799,7 +800,7 @@ def test_wild_bootstrap_produces_valid_se(self): did = DifferenceInDifferences( inference='wild_bootstrap', cluster='cluster_id', - n_bootstrap=199, + n_bootstrap=n_boot, bootstrap_weights='rademacher', seed=42 ) @@ -811,8 +812,9 @@ def test_wild_bootstrap_produces_valid_se(self): assert results.se > 0, "Bootstrap SE should be positive" assert results.inference_method == 'wild_bootstrap' - def test_wild_bootstrap_pvalue_in_valid_range(self): + def test_wild_bootstrap_pvalue_in_valid_range(self, ci_params): """Test wild bootstrap p-value is in [0, 1].""" + n_boot = ci_params.bootstrap(199) data = generate_clustered_did_data( n_clusters=15, cluster_size=10, @@ -823,7 +825,7 @@ def test_wild_bootstrap_pvalue_in_valid_range(self): did = DifferenceInDifferences( inference='wild_bootstrap', cluster='cluster_id', - n_bootstrap=199, + n_bootstrap=n_boot, seed=42 ) results = did.fit( @@ -833,8 +835,9 @@ def test_wild_bootstrap_pvalue_in_valid_range(self): assert 0 <= results.p_value <= 1, \ f"P-value {results.p_value} not in [0, 1]" - def test_wild_bootstrap_ci_contains_point_estimate(self): + def test_wild_bootstrap_ci_contains_point_estimate(self, ci_params): """Test wild bootstrap CI contains point estimate.""" + n_boot = ci_params.bootstrap(199) data = generate_clustered_did_data( n_clusters=15, cluster_size=10, @@ -845,7 +848,7 @@ def test_wild_bootstrap_ci_contains_point_estimate(self): did = DifferenceInDifferences( inference='wild_bootstrap', cluster='cluster_id', - n_bootstrap=199, + n_bootstrap=n_boot, seed=42 ) results = did.fit( @@ -859,8 +862,9 @@ def test_wild_bootstrap_ci_contains_point_estimate(self): f"CI [{lower}, {upper}] should approximately contain ATT {results.att}" @pytest.mark.parametrize("weight_type", ["rademacher", "mammen", "webb"]) - def test_wild_bootstrap_weight_types(self, weight_type): + def test_wild_bootstrap_weight_types(self, weight_type, ci_params): """Test all wild bootstrap weight types work.""" + n_boot = ci_params.bootstrap(99) data = generate_clustered_did_data( n_clusters=15, cluster_size=10, @@ -871,7 +875,7 @@ def test_wild_bootstrap_weight_types(self, weight_type): did = DifferenceInDifferences( inference='wild_bootstrap', cluster='cluster_id', - n_bootstrap=99, + n_bootstrap=n_boot, bootstrap_weights=weight_type, seed=42 ) diff --git a/tests/test_staggered.py b/tests/test_staggered.py index 382fea8..4901ba1 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -681,7 +681,7 @@ def test_extreme_propensity_scores(self): assert np.isfinite(results.overall_se), "SE should be finite" assert results.overall_se > 0, "SE should be positive" - def test_extreme_weights_warning(self): + def test_extreme_weights_warning(self, ci_params): """Test that extreme weights produce warnings and methodology-aligned behavior. Tests that: @@ -691,6 +691,7 @@ def test_extreme_weights_warning(self): """ import warnings np.random.seed(42) + n_boot = ci_params.bootstrap(100) # Minimal dataset: very small sample with unbalanced groups n_units, n_periods = 20, 4 @@ -729,7 +730,7 @@ def test_extreme_weights_warning(self): "SE should be finite or NaN (not inf)" # Test with bootstrap - should drop invalid samples with warning - cs_boot = CallawaySantAnna(n_bootstrap=100, seed=42) + cs_boot = CallawaySantAnna(n_bootstrap=n_boot, seed=42) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -920,11 +921,12 @@ def test_rank_deficient_action_silent_no_warning(self): class TestCallawaySantAnnaBootstrap: """Tests for Callaway-Sant'Anna multiplier bootstrap inference.""" - def test_bootstrap_basic(self): + def test_bootstrap_basic(self, ci_params): """Test basic bootstrap functionality.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -934,20 +936,21 @@ def test_bootstrap_basic(self): ) assert results.bootstrap_results is not None - assert results.bootstrap_results.n_bootstrap == 99 + assert results.bootstrap_results.n_bootstrap == n_boot assert results.bootstrap_results.weight_type == "rademacher" assert results.overall_se > 0 assert results.overall_conf_int[0] < results.overall_att < results.overall_conf_int[1] - def test_bootstrap_weight_types(self): + def test_bootstrap_weight_types(self, ci_params): """Test different bootstrap weight types.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(49) weight_types = ["rademacher", "mammen", "webb"] for wt in weight_types: cs = CallawaySantAnna( - n_bootstrap=49, + n_bootstrap=n_boot, bootstrap_weight_type=wt, seed=42 ) @@ -963,11 +966,12 @@ def test_bootstrap_weight_types(self): assert results.bootstrap_results.weight_type == wt assert results.overall_se > 0 - def test_bootstrap_event_study(self): + def test_bootstrap_event_study(self, ci_params): """Test bootstrap with event study aggregation.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -987,11 +991,12 @@ def test_bootstrap_event_study(self): assert effect['se'] > 0 assert effect['conf_int'][0] < effect['conf_int'][1] - def test_bootstrap_group_aggregation(self): + def test_bootstrap_group_aggregation(self, ci_params): """Test bootstrap with group aggregation.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1011,11 +1016,12 @@ def test_bootstrap_group_aggregation(self): assert effect['se'] > 0 assert effect['conf_int'][0] < effect['conf_int'][1] - def test_bootstrap_all_aggregations(self): + def test_bootstrap_all_aggregations(self, ci_params): """Test bootstrap with all aggregations.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1029,11 +1035,12 @@ def test_bootstrap_all_aggregations(self): assert results.bootstrap_results.event_study_ses is not None assert results.bootstrap_results.group_effect_ses is not None - def test_bootstrap_reproducibility(self): + def test_bootstrap_reproducibility(self, ci_params): """Test that bootstrap is reproducible with same seed.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs1 = CallawaySantAnna(n_bootstrap=99, seed=123) + cs1 = CallawaySantAnna(n_bootstrap=n_boot, seed=123) results1 = cs1.fit( data, outcome='outcome', @@ -1042,7 +1049,7 @@ def test_bootstrap_reproducibility(self): first_treat='first_treat' ) - cs2 = CallawaySantAnna(n_bootstrap=99, seed=123) + cs2 = CallawaySantAnna(n_bootstrap=n_boot, seed=123) results2 = cs2.fit( data, outcome='outcome', @@ -1055,11 +1062,12 @@ def test_bootstrap_reproducibility(self): assert results1.overall_se == results2.overall_se assert results1.overall_conf_int == results2.overall_conf_int - def test_bootstrap_different_seeds(self): + def test_bootstrap_different_seeds(self, ci_params): """Test that different seeds give different results.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs1 = CallawaySantAnna(n_bootstrap=99, seed=123) + cs1 = CallawaySantAnna(n_bootstrap=n_boot, seed=123) results1 = cs1.fit( data, outcome='outcome', @@ -1068,7 +1076,7 @@ def test_bootstrap_different_seeds(self): first_treat='first_treat' ) - cs2 = CallawaySantAnna(n_bootstrap=99, seed=456) + cs2 = CallawaySantAnna(n_bootstrap=n_boot, seed=456) results2 = cs2.fit( data, outcome='outcome', @@ -1080,15 +1088,16 @@ def test_bootstrap_different_seeds(self): # Results should differ with different seeds assert results1.overall_se != results2.overall_se - def test_bootstrap_p_value_significance(self): + def test_bootstrap_p_value_significance(self, ci_params): """Test that strong effect has significant p-value with bootstrap.""" data = generate_staggered_data( n_units=100, treatment_effect=5.0, seed=42 ) + n_boot = ci_params.bootstrap(199) - cs = CallawaySantAnna(n_bootstrap=199, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1101,15 +1110,16 @@ def test_bootstrap_p_value_significance(self): assert results.overall_p_value < 0.05 assert results.is_significant - def test_bootstrap_zero_effect_not_significant(self): + def test_bootstrap_zero_effect_not_significant(self, ci_params): """Test that zero effect is not significant with bootstrap.""" data = generate_staggered_data( n_units=50, treatment_effect=0.0, seed=42 ) + n_boot = ci_params.bootstrap(199) - cs = CallawaySantAnna(n_bootstrap=199, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1122,11 +1132,12 @@ def test_bootstrap_zero_effect_not_significant(self): # (using 0.01 to be more conservative with finite sample) assert results.overall_p_value > 0.01 or abs(results.overall_att) < 2 * results.overall_se - def test_bootstrap_distribution_stored(self): + def test_bootstrap_distribution_stored(self, ci_params): """Test that bootstrap distribution is stored in results.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1136,13 +1147,14 @@ def test_bootstrap_distribution_stored(self): ) assert results.bootstrap_results.bootstrap_distribution is not None - assert len(results.bootstrap_results.bootstrap_distribution) == 99 + assert len(results.bootstrap_results.bootstrap_distribution) == n_boot - def test_bootstrap_with_covariates(self): + def test_bootstrap_with_covariates(self, ci_params): """Test bootstrap with covariate adjustment.""" data = generate_staggered_data_with_covariates(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1155,9 +1167,10 @@ def test_bootstrap_with_covariates(self): assert results.bootstrap_results is not None assert results.overall_se > 0 - def test_bootstrap_group_time_effects(self): + def test_bootstrap_group_time_effects(self, ci_params): """Test that bootstrap updates group-time effect SEs.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) # Without bootstrap cs1 = CallawaySantAnna(n_bootstrap=0) @@ -1170,7 +1183,7 @@ def test_bootstrap_group_time_effects(self): ) # With bootstrap - cs2 = CallawaySantAnna(n_bootstrap=99, seed=42) + cs2 = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results2 = cs2.fit( data, outcome='outcome', @@ -1209,13 +1222,14 @@ def test_bootstrap_get_params(self): assert params['bootstrap_weight_type'] == "mammen" assert params['seed'] == 42 - def test_bootstrap_with_not_yet_treated(self): + def test_bootstrap_with_not_yet_treated(self, ci_params): """Test bootstrap with not_yet_treated control group.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(99) cs = CallawaySantAnna( control_group="not_yet_treated", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) results = cs.fit( @@ -1229,16 +1243,17 @@ def test_bootstrap_with_not_yet_treated(self): assert results.bootstrap_results is not None assert results.overall_se > 0 - def test_bootstrap_estimation_methods(self): + def test_bootstrap_estimation_methods(self, ci_params): """Test bootstrap with different estimation methods.""" data = generate_staggered_data(n_units=50, seed=42) + n_boot = ci_params.bootstrap(49) methods = ["reg", "ipw", "dr"] for method in methods: cs = CallawaySantAnna( estimation_method=method, - n_bootstrap=49, + n_bootstrap=n_boot, seed=42 ) results = cs.fit( @@ -1252,11 +1267,12 @@ def test_bootstrap_estimation_methods(self): assert results.bootstrap_results is not None assert results.overall_se > 0, f"Failed for method {method}" - def test_bootstrap_with_balanced_event_study(self): + def test_bootstrap_with_balanced_event_study(self, ci_params): """Test bootstrap with balanced event study aggregation.""" data = generate_staggered_data(n_units=100, n_periods=12, seed=42) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -1421,9 +1437,10 @@ def test_single_cohort_event_study(self): post_effects = [results.event_study_effects[e]['effect'] for e in post_periods] assert any(e > 0.5 for e in post_effects), f"Expected positive post-period effects, got {post_effects}" - def test_single_cohort_with_bootstrap(self): + def test_single_cohort_with_bootstrap(self, ci_params): """Test bootstrap inference with single cohort.""" np.random.seed(42) + n_boot = ci_params.bootstrap(99) n_units = 50 n_periods = 6 @@ -1450,7 +1467,7 @@ def test_single_cohort_with_bootstrap(self): df = pd.DataFrame(data) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( df, outcome='outcome', @@ -1526,7 +1543,7 @@ def test_single_cohort_not_yet_treated_control(self): class TestCallawaySantAnnaAnalyticalSE: """Tests for analytical SE using influence function aggregation.""" - def test_analytical_se_vs_bootstrap_se(self): + def test_analytical_se_vs_bootstrap_se(self, ci_params): """Analytical SE should be close to bootstrap SE (within 15%).""" # Generate data with moderate size for stable comparison data = generate_staggered_data( @@ -1537,6 +1554,7 @@ def test_analytical_se_vs_bootstrap_se(self): never_treated_frac=0.3, seed=42 ) + n_boot = ci_params.bootstrap(499) # Run with analytical SE (n_bootstrap=0) cs_analytical = CallawaySantAnna(n_bootstrap=0, seed=42) @@ -1549,7 +1567,7 @@ def test_analytical_se_vs_bootstrap_se(self): ) # Run with bootstrap SE (n_bootstrap=499) - cs_bootstrap = CallawaySantAnna(n_bootstrap=499, seed=42) + cs_bootstrap = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results_bootstrap = cs_bootstrap.fit( data, outcome='outcome', @@ -1670,7 +1688,7 @@ def test_analytical_se_single_gt_pair(self): # Should be close (may not be exact due to normalization) assert abs(results.overall_se - individual_se) < individual_se * 0.01 - def test_event_study_analytical_se(self): + def test_event_study_analytical_se(self, ci_params): """Event study SEs should also use influence function aggregation.""" data = generate_staggered_data( n_units=200, @@ -1680,6 +1698,7 @@ def test_event_study_analytical_se(self): never_treated_frac=0.3, seed=42 ) + n_boot = ci_params.bootstrap(499) # Analytical cs_analytical = CallawaySantAnna(n_bootstrap=0, seed=42) @@ -1693,7 +1712,7 @@ def test_event_study_analytical_se(self): ) # Bootstrap - cs_bootstrap = CallawaySantAnna(n_bootstrap=499, seed=42) + cs_bootstrap = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results_bootstrap = cs_bootstrap.fit( data, outcome='outcome', @@ -1809,14 +1828,15 @@ def test_non_standard_all_column_names(self): assert np.isfinite(results.overall_att) assert results.overall_se > 0 - def test_non_standard_names_with_bootstrap(self): + def test_non_standard_names_with_bootstrap(self, ci_params): """Test non-standard column names with bootstrap inference.""" data = self.generate_data_with_custom_names( first_treat_name='g', # Short name like R's `did` package uses n_units=50 ) + n_boot = ci_params.bootstrap(99) - cs = CallawaySantAnna(n_bootstrap=99, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='y', @@ -2126,7 +2146,7 @@ def test_base_period_in_results(self): assert "Base period:" in summary assert "universal" in summary - def test_pre_treatment_bootstrap(self): + def test_pre_treatment_bootstrap(self, ci_params): """Bootstrap handles pre-treatment effects.""" data = generate_staggered_data( n_units=60, @@ -2135,10 +2155,11 @@ def test_pre_treatment_bootstrap(self): treatment_effect=2.0, seed=42 ) + n_boot = ci_params.bootstrap(99) cs = CallawaySantAnna( base_period="varying", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) results = cs.fit( @@ -2302,9 +2323,10 @@ def test_no_post_treatment_effects_returns_nan_with_warning(self): f"Expected NaN for overall_p_value, got {results.overall_p_value}" ) - def test_no_post_treatment_effects_bootstrap_returns_nan(self): + def test_no_post_treatment_effects_bootstrap_returns_nan(self, ci_params): """Bootstrap returns NaN inference when no post-treatment effects exist.""" import warnings + n_boot = ci_params.bootstrap(99) # Create data where treatment happens after the data ends n_units = 50 @@ -2325,7 +2347,7 @@ def test_no_post_treatment_effects_bootstrap_returns_nan(self): df = pd.DataFrame(data) - cs = CallawaySantAnna(base_period="varying", n_bootstrap=99, seed=42) + cs = CallawaySantAnna(base_period="varying", n_bootstrap=n_boot, seed=42) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -2357,13 +2379,14 @@ def test_no_post_treatment_effects_bootstrap_returns_nan(self): assert np.isnan(results.bootstrap_results.overall_att_se) assert np.isnan(results.bootstrap_results.overall_att_p_value) - def test_bootstrap_runs_for_pretreatment_effects(self): + def test_bootstrap_runs_for_pretreatment_effects(self, ci_params): """Bootstrap computes SEs for pre-treatment effects even when no post-treatment. When all treatment occurs after data ends, the overall ATT should be NaN, but pre-treatment effects should still get bootstrap SEs (not analytical). """ import warnings + n_boot = ci_params.bootstrap(99) # Create data where all treatment happens after the data ends # so we have only pre-treatment effects @@ -2388,7 +2411,7 @@ def test_bootstrap_runs_for_pretreatment_effects(self): df = pd.DataFrame(data) # Fit with bootstrap and base_period="varying" to get pre-treatment effects - cs = CallawaySantAnna(base_period="varying", n_bootstrap=99, seed=42) + cs = CallawaySantAnna(base_period="varying", n_bootstrap=n_boot, seed=42) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -2648,7 +2671,7 @@ def test_group_effects_anticipation_boundary(self): class TestCallawaySantAnnaTStatNaN: """Tests for NaN t_stat when SE is invalid.""" - def test_invalid_se_produces_nan_tstat_overall(self): + def test_invalid_se_produces_nan_tstat_overall(self, ci_params): """Overall t_stat is NaN when SE is non-finite.""" # Create data that will result in no valid post-treatment effects # This should produce NaN for overall statistics @@ -2659,6 +2682,7 @@ def test_invalid_se_produces_nan_tstat_overall(self): treatment_effect=2.0, seed=789 ) + n_boot = ci_params.bootstrap(50) # Modify first_treat so all treatment happens after data ends data['first_treat'] = data['first_treat'].replace( @@ -2669,7 +2693,7 @@ def test_invalid_se_produces_nan_tstat_overall(self): import warnings with warnings.catch_warnings(record=True): warnings.simplefilter("always") - cs = CallawaySantAnna(n_bootstrap=50, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', @@ -2684,7 +2708,7 @@ def test_invalid_se_produces_nan_tstat_overall(self): "overall_t_stat should be NaN when SE is invalid" ) - def test_per_effect_tstat_consistency(self): + def test_per_effect_tstat_consistency(self, ci_params): """Per-effect t_stat uses same NaN logic as overall t_stat. t_stat should be NaN (not 0.0) when SE is non-finite or zero. @@ -2697,8 +2721,9 @@ def test_per_effect_tstat_consistency(self): treatment_effect=2.0, seed=456 ) + n_boot = ci_params.bootstrap(100) - cs = CallawaySantAnna(n_bootstrap=100, seed=42) + cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42) results = cs.fit( data, outcome='outcome', diff --git a/tests/test_sun_abraham.py b/tests/test_sun_abraham.py index c91f266..cbf9079 100644 --- a/tests/test_sun_abraham.py +++ b/tests/test_sun_abraham.py @@ -341,11 +341,12 @@ def test_invalid_level_error(self): class TestSunAbrahamBootstrap: """Tests for Sun-Abraham bootstrap inference.""" - def test_bootstrap_basic(self): + def test_bootstrap_basic(self, ci_params): """Test basic bootstrap functionality.""" data = generate_staggered_data(n_units=50, seed=42) - sa = SunAbraham(n_bootstrap=99, seed=42) + n_boot = ci_params.bootstrap(99) + sa = SunAbraham(n_bootstrap=n_boot, seed=42) results = sa.fit( data, outcome="outcome", @@ -355,7 +356,7 @@ def test_bootstrap_basic(self): ) assert results.bootstrap_results is not None - assert results.bootstrap_results.n_bootstrap == 99 + assert results.bootstrap_results.n_bootstrap == n_boot assert results.bootstrap_results.weight_type == "pairs" assert results.overall_se > 0 assert ( @@ -364,11 +365,12 @@ def test_bootstrap_basic(self): < results.overall_conf_int[1] ) - def test_bootstrap_reproducibility(self): + def test_bootstrap_reproducibility(self, ci_params): """Test that bootstrap is reproducible with same seed.""" data = generate_staggered_data(n_units=50, seed=42) - sa1 = SunAbraham(n_bootstrap=99, seed=123) + n_boot = ci_params.bootstrap(99) + sa1 = SunAbraham(n_bootstrap=n_boot, seed=123) results1 = sa1.fit( data, outcome="outcome", @@ -377,7 +379,7 @@ def test_bootstrap_reproducibility(self): first_treat="first_treat", ) - sa2 = SunAbraham(n_bootstrap=99, seed=123) + sa2 = SunAbraham(n_bootstrap=n_boot, seed=123) results2 = sa2.fit( data, outcome="outcome", @@ -390,11 +392,12 @@ def test_bootstrap_reproducibility(self): assert results1.overall_se == results2.overall_se assert results1.overall_conf_int == results2.overall_conf_int - def test_bootstrap_different_seeds(self): + def test_bootstrap_different_seeds(self, ci_params): """Test that different seeds give different results.""" data = generate_staggered_data(n_units=50, seed=42) - sa1 = SunAbraham(n_bootstrap=99, seed=123) + n_boot = ci_params.bootstrap(99) + sa1 = SunAbraham(n_bootstrap=n_boot, seed=123) results1 = sa1.fit( data, outcome="outcome", @@ -403,7 +406,7 @@ def test_bootstrap_different_seeds(self): first_treat="first_treat", ) - sa2 = SunAbraham(n_bootstrap=99, seed=456) + sa2 = SunAbraham(n_bootstrap=n_boot, seed=456) results2 = sa2.fit( data, outcome="outcome", @@ -415,11 +418,12 @@ def test_bootstrap_different_seeds(self): # Results should differ with different seeds assert results1.overall_se != results2.overall_se - def test_bootstrap_p_value_significance(self): + def test_bootstrap_p_value_significance(self, ci_params): """Test that strong effect has significant p-value with bootstrap.""" data = generate_staggered_data(n_units=100, treatment_effect=5.0, seed=42) - sa = SunAbraham(n_bootstrap=199, seed=42) + n_boot = ci_params.bootstrap(199) + sa = SunAbraham(n_bootstrap=n_boot, seed=42) results = sa.fit( data, outcome="outcome", @@ -432,11 +436,12 @@ def test_bootstrap_p_value_significance(self): assert results.overall_p_value < 0.05 assert results.is_significant - def test_bootstrap_distribution_stored(self): + def test_bootstrap_distribution_stored(self, ci_params): """Test that bootstrap distribution is stored in results.""" data = generate_staggered_data(n_units=50, seed=42) - sa = SunAbraham(n_bootstrap=99, seed=42) + n_boot = ci_params.bootstrap(99) + sa = SunAbraham(n_bootstrap=n_boot, seed=42) results = sa.fit( data, outcome="outcome", @@ -446,13 +451,14 @@ def test_bootstrap_distribution_stored(self): ) assert results.bootstrap_results.bootstrap_distribution is not None - assert len(results.bootstrap_results.bootstrap_distribution) == 99 + assert len(results.bootstrap_results.bootstrap_distribution) == n_boot - def test_bootstrap_event_study_effects(self): + def test_bootstrap_event_study_effects(self, ci_params): """Test that bootstrap updates event study effect SEs.""" data = generate_staggered_data(n_units=50, seed=42) - sa = SunAbraham(n_bootstrap=99, seed=42) + n_boot = ci_params.bootstrap(99) + sa = SunAbraham(n_bootstrap=n_boot, seed=42) results = sa.fit( data, outcome="outcome", @@ -964,7 +970,7 @@ def test_overall_tstat_nan_when_se_invalid(self): f"overall_t_stat should be ATT/SE, expected {expected}, got {t_stat}" ) - def test_bootstrap_tstat_nan_when_se_invalid(self): + def test_bootstrap_tstat_nan_when_se_invalid(self, ci_params): """Bootstrap t_stat uses NaN (not 0.0) when SE is non-finite or zero.""" data = generate_staggered_data( n_units=60, @@ -974,7 +980,8 @@ def test_bootstrap_tstat_nan_when_se_invalid(self): seed=456, ) - sa = SunAbraham(n_bootstrap=50, seed=42) + n_boot = ci_params.bootstrap(50) + sa = SunAbraham(n_bootstrap=n_boot, seed=42) results = sa.fit( data, outcome="outcome", diff --git a/tests/test_trop.py b/tests/test_trop.py index 985fe5d..37082f1 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -118,13 +118,14 @@ def test_basic_fit(self, simple_panel_data): assert results.n_control == 15 assert results.n_treated == 5 - def test_fit_with_factors(self, factor_dgp_data): + def test_fit_with_factors(self, factor_dgp_data, ci_params): """Test fitting with factor structure.""" + n_boot = ci_params.bootstrap(20) trop_est = TROP( lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], - n_bootstrap=20, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -139,15 +140,16 @@ def test_fit_with_factors(self, factor_dgp_data): assert results.effective_rank >= 0 assert results.factor_matrix.shape == (12, 30) # n_periods x n_units - def test_treatment_effect_recovery(self, factor_dgp_data): + def test_treatment_effect_recovery(self, factor_dgp_data, ci_params): """Test that TROP recovers treatment effect direction.""" true_att = 2.0 + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0, 0.5, 1.0], lambda_unit_grid=[0.0, 0.5, 1.0], lambda_nn_grid=[0.0, 0.1], - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -163,10 +165,11 @@ def test_treatment_effect_recovery(self, factor_dgp_data): # Should be reasonably close to true value assert abs(results.att - true_att) < 3.0 - def test_tuning_parameter_selection(self, simple_panel_data): + def test_tuning_parameter_selection(self, simple_panel_data, ci_params): """Test that LOOCV selects tuning parameters.""" + time_grid = ci_params.grid([0.0, 0.5, 1.0, 2.0]) trop_est = TROP( - lambda_time_grid=[0.0, 0.5, 1.0, 2.0], + lambda_time_grid=time_grid, lambda_unit_grid=[0.0, 0.5, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], n_bootstrap=10, @@ -185,13 +188,14 @@ def test_tuning_parameter_selection(self, simple_panel_data): assert results.lambda_unit in trop_est.lambda_unit_grid assert results.lambda_nn in trop_est.lambda_nn_grid - def test_bootstrap_variance(self, simple_panel_data): + def test_bootstrap_variance(self, simple_panel_data, ci_params): """Test bootstrap variance estimation.""" + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -203,17 +207,18 @@ def test_bootstrap_variance(self, simple_panel_data): ) assert results.se > 0 - assert results.n_bootstrap == 30 + assert results.n_bootstrap == n_boot assert results.bootstrap_distribution is not None - def test_confidence_interval(self, simple_panel_data): + def test_confidence_interval(self, simple_panel_data, ci_params): """Test confidence interval properties.""" + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], alpha=0.05, - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -446,14 +451,15 @@ def test_get_time_effects_df(self, simple_panel_data): assert "time" in effects_df.columns assert "effect" in effects_df.columns - def test_is_significant(self, simple_panel_data): + def test_is_significant(self, simple_panel_data, ci_params): """Test significance property.""" + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], alpha=0.05, - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -466,13 +472,14 @@ def test_is_significant(self, simple_panel_data): assert isinstance(results.is_significant, bool) - def test_significance_stars(self, simple_panel_data): + def test_significance_stars(self, simple_panel_data, ci_params): """Test significance stars.""" + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -529,7 +536,7 @@ def test_nan_propagation_when_se_zero(self): class TestTROPvsSDID: """Tests comparing TROP to SDID under different DGPs.""" - def test_trop_handles_factor_dgp(self): + def test_trop_handles_factor_dgp(self, ci_params): """Test that TROP works on factor DGP data.""" data = generate_factor_dgp( n_units=30, @@ -544,11 +551,12 @@ def test_trop_handles_factor_dgp(self): ) # TROP should complete without error + n_boot = ci_params.bootstrap(20) trop_est = TROP( lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], - n_bootstrap=20, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -792,7 +800,7 @@ def test_time_weights_reduce_bias(self): # Check that time weighting was considered assert results.lambda_time in [0.0, 0.5, 1.0] - def test_factor_model_reduces_bias(self): + def test_factor_model_reduces_bias(self, ci_params): """ Test that nuclear norm regularization reduces bias with factor structure. @@ -813,11 +821,13 @@ def test_factor_model_reduces_bias(self): ) # TROP with nuclear norm regularization + n_boot = ci_params.bootstrap(20) + nn_grid = ci_params.grid([0.0, 0.1, 1.0, 5.0]) trop_est = TROP( lambda_time_grid=[0.0, 0.5], lambda_unit_grid=[0.0, 0.5], - lambda_nn_grid=[0.0, 0.1, 1.0, 5.0], - n_bootstrap=20, + lambda_nn_grid=nn_grid, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -835,7 +845,7 @@ def test_factor_model_reduces_bias(self): # Factor matrix should capture some structure assert results.effective_rank > 0, "Factor matrix should have positive rank" - def test_paper_dgp_recovery(self): + def test_paper_dgp_recovery(self, ci_params): """ Test treatment effect recovery using paper's simulation DGP. @@ -893,11 +903,12 @@ def test_paper_dgp_recovery(self): df = pd.DataFrame(data) # TROP estimation + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0, 0.5, 1.0], lambda_unit_grid=[0.0, 0.5, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -1139,12 +1150,13 @@ def test_pivot_vs_iterrows_equivalence(self): assert np.allclose(Y_iterrows, Y_pivot, equal_nan=True) assert np.array_equal(D_iterrows, D_pivot) - def test_reproducibility_with_seed(self, simple_panel_data): + def test_reproducibility_with_seed(self, simple_panel_data, ci_params): """ Test that results are reproducible with the same seed. Running TROP twice with the same seed should produce identical results. """ + n_boot = ci_params.bootstrap(20) results1 = trop( simple_panel_data, outcome="outcome", @@ -1154,7 +1166,7 @@ def test_reproducibility_with_seed(self, simple_panel_data): lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], - n_bootstrap=20, + n_bootstrap=n_boot, seed=42, ) @@ -1167,7 +1179,7 @@ def test_reproducibility_with_seed(self, simple_panel_data): lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], - n_bootstrap=20, + n_bootstrap=n_boot, seed=42, ) @@ -1600,7 +1612,7 @@ def test_issue_c_weighted_nuclear_norm(self): # ATT should recover treatment effect direction assert results.att > 0, f"ATT={results.att:.3f} should be positive" - def test_issue_d_stratified_bootstrap(self): + def test_issue_d_stratified_bootstrap(self, ci_params): """ Test Issue D fix: Bootstrap uses stratified sampling. @@ -1635,11 +1647,12 @@ def test_issue_d_stratified_bootstrap(self): df = pd.DataFrame(data) # Run with bootstrap variance estimation + n_boot = ci_params.bootstrap(30) trop_est = TROP( lambda_time_grid=[0.0], lambda_unit_grid=[0.0], lambda_nn_grid=[0.0], - n_bootstrap=30, + n_bootstrap=n_boot, seed=42 ) results = trop_est.fit( @@ -1652,7 +1665,7 @@ def test_issue_d_stratified_bootstrap(self): # Bootstrap should complete successfully assert results.bootstrap_distribution is not None - assert len(results.bootstrap_distribution) >= 20 # Most iterations succeed + assert len(results.bootstrap_distribution) >= 11 # Most iterations succeed # SE should be positive and finite assert results.se > 0 assert np.isfinite(results.se) @@ -2703,14 +2716,15 @@ def test_joint_no_lowrank(self, simple_panel_data): # Factor matrix should be all zeros assert np.allclose(results.factor_matrix, 0.0) - def test_joint_with_lowrank(self, factor_dgp_data): + def test_joint_with_lowrank(self, factor_dgp_data, ci_params): """Joint method with finite lambda_nn (with low-rank).""" + n_boot = ci_params.bootstrap(20) trop_est = TROP( method="joint", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], - n_bootstrap=20, + n_bootstrap=n_boot, seed=42, ) results = trop_est.fit( @@ -2789,14 +2803,15 @@ def test_method_in_set_params(self): trop_est.set_params(method="joint") assert trop_est.method == "joint" - def test_joint_bootstrap_variance(self, simple_panel_data): + def test_joint_bootstrap_variance(self, simple_panel_data, ci_params): """Joint method bootstrap variance estimation works.""" + n_boot = ci_params.bootstrap(20) trop_est = TROP( method="joint", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], - n_bootstrap=20, + n_bootstrap=n_boot, seed=42, ) results = trop_est.fit( @@ -2808,18 +2823,19 @@ def test_joint_bootstrap_variance(self, simple_panel_data): ) assert results.se > 0 - assert results.n_bootstrap == 20 + assert results.n_bootstrap == n_boot assert results.bootstrap_distribution is not None - def test_joint_confidence_interval(self, simple_panel_data): + def test_joint_confidence_interval(self, simple_panel_data, ci_params): """Joint method produces valid confidence intervals.""" + n_boot = ci_params.bootstrap(30) trop_est = TROP( method="joint", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], alpha=0.05, - n_bootstrap=30, + n_bootstrap=n_boot, seed=42, ) results = trop_est.fit( diff --git a/tests/test_wild_bootstrap.py b/tests/test_wild_bootstrap.py index 77de74a..a9ddae3 100644 --- a/tests/test_wild_bootstrap.py +++ b/tests/test_wild_bootstrap.py @@ -194,73 +194,78 @@ def test_mammen_weights_moments(self): class TestWildBootstrapSE: """Tests for wild_bootstrap_se function.""" - def test_returns_wild_bootstrap_results(self, ols_components): + def test_returns_wild_bootstrap_results(self, ols_components, ci_params): """Test that function returns WildBootstrapResults.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) assert isinstance(results, WildBootstrapResults) - def test_se_is_positive(self, ols_components): + def test_se_is_positive(self, ols_components, ci_params): """Test bootstrap SE is positive.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) assert results.se > 0 - def test_p_value_in_valid_range(self, ols_components): + def test_p_value_in_valid_range(self, ols_components, ci_params): """Test p-value is in [0, 1].""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) assert 0 <= results.p_value <= 1 - def test_ci_contains_reasonable_values(self, ols_components): + def test_ci_contains_reasonable_values(self, ols_components, ci_params): """Test CI bounds are ordered correctly.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(199) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=199, + n_bootstrap=n_boot, seed=42 ) assert results.ci_lower < results.ci_upper - def test_reproducibility_with_seed(self, ols_components): + def test_reproducibility_with_seed(self, ols_components, ci_params): """Test same seed gives same results.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results1 = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) results2 = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -268,36 +273,38 @@ def test_reproducibility_with_seed(self, ols_components): assert results1.p_value == results2.p_value assert results1.ci_lower == results2.ci_lower - def test_different_seeds_different_results(self, ols_components): + def test_different_seeds_different_results(self, ols_components, ci_params): """Test different seeds give different results.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results1 = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) results2 = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=123 ) # Should be different (not exactly equal) assert results1.se != results2.se - def test_different_weight_types(self, ols_components): + def test_different_weight_types(self, ols_components, ci_params): """Test all weight types produce valid results.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) for weight_type in ["rademacher", "webb", "mammen"]: results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, weight_type=weight_type, seed=42 ) @@ -317,9 +324,10 @@ def test_invalid_weight_type_raises(self, ols_components): weight_type="invalid" ) - def test_few_clusters_warning(self, few_cluster_data): + def test_few_clusters_warning(self, few_cluster_data, ci_params): """Test warning when clusters < 5.""" data = few_cluster_data + n_boot = ci_params.bootstrap(99) y = data["outcome"].values.astype(float) d = data["treated"].values.astype(float) @@ -335,7 +343,7 @@ def test_few_clusters_warning(self, few_cluster_data): wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -352,31 +360,33 @@ def test_too_few_clusters_raises(self, ols_components): coefficient_index=3 ) - def test_n_clusters_reported_correctly(self, ols_components): + def test_n_clusters_reported_correctly(self, ols_components, ci_params): """Test n_clusters is reported correctly.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) assert results.n_clusters == 10 - def test_n_bootstrap_reported_correctly(self, ols_components): + def test_n_bootstrap_reported_correctly(self, ols_components, ci_params): """Test n_bootstrap is reported correctly.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(199) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=199, + n_bootstrap=n_boot, seed=42 ) - assert results.n_bootstrap == 199 + assert results.n_bootstrap == n_boot # ============================================================================= @@ -387,12 +397,13 @@ def test_n_bootstrap_reported_correctly(self, ols_components): class TestEstimatorIntegration: """Tests for wild bootstrap integration with DiD estimators.""" - def test_did_with_wild_bootstrap(self, clustered_did_data): + def test_did_with_wild_bootstrap(self, clustered_did_data, ci_params): """Test DifferenceInDifferences with wild bootstrap.""" + n_boot = ci_params.bootstrap(99) did = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -404,23 +415,24 @@ def test_did_with_wild_bootstrap(self, clustered_did_data): ) assert results.inference_method == "wild_bootstrap" - assert results.n_bootstrap == 99 + assert results.n_bootstrap == n_boot assert results.n_clusters == 10 assert results.se > 0 - def test_did_wild_bootstrap_reproducibility(self, clustered_did_data): + def test_did_wild_bootstrap_reproducibility(self, clustered_did_data, ci_params): """Test wild bootstrap results are reproducible with seed.""" + n_boot = ci_params.bootstrap(99) did1 = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) did2 = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -441,13 +453,14 @@ def test_did_wild_bootstrap_reproducibility(self, clustered_did_data): assert results1.se == results2.se assert results1.p_value == results2.p_value - def test_did_analytical_vs_bootstrap_att_same(self, clustered_did_data): + def test_did_analytical_vs_bootstrap_att_same(self, clustered_did_data, ci_params): """Test that ATT is the same regardless of inference method.""" + n_boot = ci_params.bootstrap(99) did_analytical = DifferenceInDifferences(cluster="cluster") did_bootstrap = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -468,12 +481,13 @@ def test_did_analytical_vs_bootstrap_att_same(self, clustered_did_data): # ATT should be identical assert results_analytical.att == results_bootstrap.att - def test_did_wild_bootstrap_with_webb_weights(self, clustered_did_data): + def test_did_wild_bootstrap_with_webb_weights(self, clustered_did_data, ci_params): """Test wild bootstrap with Webb weights.""" + n_boot = ci_params.bootstrap(99) did = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, bootstrap_weights="webb", seed=42 ) @@ -488,11 +502,12 @@ def test_did_wild_bootstrap_with_webb_weights(self, clustered_did_data): assert results.inference_method == "wild_bootstrap" assert results.se > 0 - def test_did_wild_bootstrap_requires_cluster(self, clustered_did_data): + def test_did_wild_bootstrap_requires_cluster(self, clustered_did_data, ci_params): """Test that wild bootstrap is only used when cluster is specified.""" + n_boot = ci_params.bootstrap(99) did = DifferenceInDifferences( inference="wild_bootstrap", # No cluster specified - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -506,12 +521,13 @@ def test_did_wild_bootstrap_requires_cluster(self, clustered_did_data): # Should fall back to analytical since no cluster specified assert results.inference_method == "analytical" - def test_twfe_with_wild_bootstrap(self, clustered_did_data): + def test_twfe_with_wild_bootstrap(self, clustered_did_data, ci_params): """Test TwoWayFixedEffects with wild bootstrap.""" + n_boot = ci_params.bootstrap(99) twfe = TwoWayFixedEffects( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -524,15 +540,16 @@ def test_twfe_with_wild_bootstrap(self, clustered_did_data): ) assert results.inference_method == "wild_bootstrap" - assert results.n_bootstrap == 99 + assert results.n_bootstrap == n_boot assert results.se > 0 - def test_summary_shows_bootstrap_info(self, clustered_did_data): + def test_summary_shows_bootstrap_info(self, clustered_did_data, ci_params): """Test that summary shows bootstrap info.""" + n_boot = ci_params.bootstrap(99) did = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -546,7 +563,7 @@ def test_summary_shows_bootstrap_info(self, clustered_did_data): summary = results.summary() assert "wild_bootstrap" in summary - assert "99" in summary # n_bootstrap + assert str(n_boot) in summary # n_bootstrap assert "10" in summary # n_clusters def test_get_params_includes_bootstrap_params(self): @@ -588,14 +605,15 @@ def test_set_params_for_bootstrap(self): class TestWildBootstrapResults: """Tests for WildBootstrapResults dataclass.""" - def test_summary_format(self, ols_components): + def test_summary_format(self, ols_components, ci_params): """Test summary method produces readable output.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -606,14 +624,15 @@ def test_summary_format(self, ols_components): assert "Bootstrap p-value:" in summary assert "Number of clusters:" in summary - def test_print_summary(self, ols_components, capsys): + def test_print_summary(self, ols_components, capsys, ci_params): """Test print_summary outputs to stdout.""" X, y, residuals, cluster_ids = ols_components + n_boot = ci_params.bootstrap(99) results = wild_bootstrap_se( X, y, residuals, cluster_ids, coefficient_index=3, - n_bootstrap=99, + n_bootstrap=n_boot, seed=42 ) @@ -631,9 +650,10 @@ def test_print_summary(self, ols_components, capsys): class TestFewClustersEdgeCases: """Tests for wild bootstrap behavior with very few clusters.""" - def test_three_clusters_still_works(self): + def test_three_clusters_still_works(self, ci_params): """Test wild bootstrap works with 3 clusters (minimum viable).""" np.random.seed(42) + n_boot = ci_params.bootstrap(99) n_clusters = 3 obs_per_cluster = 40 @@ -666,7 +686,7 @@ def test_three_clusters_still_works(self): did = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, bootstrap_weights="webb", # Webb recommended for few clusters seed=42 ) @@ -684,9 +704,10 @@ def test_three_clusters_still_works(self): assert results.inference_method == "wild_bootstrap" assert results.n_clusters == 3 - def test_two_clusters_minimum(self): + def test_two_clusters_minimum(self, ci_params): """Test wild bootstrap works with exactly 2 clusters (absolute minimum).""" np.random.seed(42) + n_boot = ci_params.bootstrap(99) n_clusters = 2 obs_per_cluster = 50 @@ -719,7 +740,7 @@ def test_two_clusters_minimum(self): did = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=99, + n_bootstrap=n_boot, bootstrap_weights="webb", seed=42 ) @@ -738,12 +759,13 @@ def test_two_clusters_minimum(self): assert np.isfinite(results.att) assert results.n_clusters == 2 - def test_few_clusters_webb_vs_rademacher(self, few_cluster_data): + def test_few_clusters_webb_vs_rademacher(self, few_cluster_data, ci_params): """Test that Webb weights produce different (often more conservative) SEs than Rademacher with few clusters.""" + n_boot = ci_params.bootstrap(199) did_webb = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=199, + n_bootstrap=n_boot, bootstrap_weights="webb", seed=42 ) @@ -751,7 +773,7 @@ def test_few_clusters_webb_vs_rademacher(self, few_cluster_data): did_rademacher = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=199, + n_bootstrap=n_boot, bootstrap_weights="rademacher", seed=42 ) @@ -780,12 +802,13 @@ def test_few_clusters_webb_vs_rademacher(self, few_cluster_data): # SEs will differ due to different weight distributions # (This is expected, not necessarily one > other) - def test_few_clusters_confidence_intervals_valid(self, few_cluster_data): + def test_few_clusters_confidence_intervals_valid(self, few_cluster_data, ci_params): """Test that CIs are valid even with few clusters.""" + n_boot = ci_params.bootstrap(199) did = DifferenceInDifferences( cluster="cluster", inference="wild_bootstrap", - n_bootstrap=199, + n_bootstrap=n_boot, bootstrap_weights="webb", seed=42 )