Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ Tests mirror the source modules:
- `tests/test_pretrends.py` - Tests for pre-trends power analysis
- `tests/test_datasets.py` - Tests for dataset loading functions

Session-scoped `ci_params` fixture in `conftest.py` scales bootstrap iterations and TROP grid sizes in pure Python mode — use `ci_params.bootstrap(n)` and `ci_params.grid(values)` in new tests with `n_bootstrap >= 20`.

### Test Writing Guidelines

**For fallback/error handling paths:**
Expand Down
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
to avoid import-time subprocess latency.
"""

import math
import os
import subprocess

Expand Down Expand Up @@ -81,3 +82,44 @@ def test_comparison_with_r(require_r):
"""
if not r_available:
pytest.skip("R or did package not available")


# =============================================================================
# CI Performance: Backend-Aware Parameter Scaling
# =============================================================================

from diff_diff._backend import HAS_RUST_BACKEND

_PURE_PYTHON_MODE = (
os.environ.get("DIFF_DIFF_BACKEND", "auto").lower() == "python"
or not HAS_RUST_BACKEND
)


class CIParams:
"""Scale test parameters in pure Python mode for CI performance.

When Rust backend is available, all values pass through unchanged.
In pure Python mode, bootstrap iterations and LOOCV grids are scaled
down to reduce CI time while preserving code path coverage.
"""

@staticmethod
def bootstrap(n: int) -> int:
"""Scale bootstrap iterations. Guaranteed monotonic: bootstrap(n+1) >= bootstrap(n)."""
if not _PURE_PYTHON_MODE or n <= 10:
return n
return max(11, int(math.sqrt(n) * 1.6))

@staticmethod
def grid(values: list) -> list:
"""Scale TROP lambda grids. Keeps first, middle, last for grids > 3 elements."""
if not _PURE_PYTHON_MODE or len(values) <= 3:
return values
return [values[0], values[len(values) // 2], values[-1]]


@pytest.fixture(scope="session")
def ci_params():
"""Backend-aware parameter scaling for CI performance."""
return CIParams()
79 changes: 47 additions & 32 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,9 +1957,10 @@ def single_treated_unit_data(self):

return pd.DataFrame(data)

def test_basic_fit(self, sdid_panel_data):
def test_basic_fit(self, sdid_panel_data, ci_params):
"""Test basic SDID model fitting."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -1975,9 +1976,10 @@ def test_basic_fit(self, sdid_panel_data):
assert results.n_treated == 5
assert results.n_control == 25

def test_att_direction(self, sdid_panel_data):
def test_att_direction(self, sdid_panel_data, ci_params):
"""Test that ATT is estimated in correct direction."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand Down Expand Up @@ -2036,9 +2038,10 @@ def test_unit_weights_nonnegative(self, sdid_panel_data):
for w in results.unit_weights.values():
assert w >= 0

def test_single_treated_unit(self, single_treated_unit_data):
def test_single_treated_unit(self, single_treated_unit_data, ci_params):
"""Test SDID with a single treated unit (classic SC scenario)."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
single_treated_unit_data,
outcome="outcome",
Expand Down Expand Up @@ -2101,9 +2104,10 @@ def test_placebo_inference(self, sdid_panel_data):
assert len(results.placebo_effects) > 0
assert results.se > 0

def test_bootstrap_inference(self, sdid_panel_data):
def test_bootstrap_inference(self, sdid_panel_data, ci_params):
"""Test bootstrap-based inference."""
sdid = SyntheticDiD(variance_method="bootstrap", n_bootstrap=100, seed=42)
n_boot = ci_params.bootstrap(100)
sdid = SyntheticDiD(variance_method="bootstrap", n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2114,7 +2118,7 @@ def test_bootstrap_inference(self, sdid_panel_data):
)

assert results.variance_method == "bootstrap"
assert results.n_bootstrap == 100
assert results.n_bootstrap == n_boot
assert results.se > 0
assert results.conf_int[0] < results.att < results.conf_int[1]

Expand Down Expand Up @@ -2174,9 +2178,10 @@ def test_pre_treatment_fit(self, sdid_panel_data):
assert results.pre_treatment_fit is not None
assert results.pre_treatment_fit >= 0

def test_summary_output(self, sdid_panel_data):
def test_summary_output(self, sdid_panel_data, ci_params):
"""Test that summary produces string output."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2192,9 +2197,10 @@ def test_summary_output(self, sdid_panel_data):
assert "ATT" in summary
assert "Unit Weights" in summary

def test_to_dict(self, sdid_panel_data):
def test_to_dict(self, sdid_panel_data, ci_params):
"""Test conversion to dictionary."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2211,9 +2217,10 @@ def test_to_dict(self, sdid_panel_data):
assert "n_post_periods" in result_dict
assert "pre_treatment_fit" in result_dict

def test_to_dataframe(self, sdid_panel_data):
def test_to_dataframe(self, sdid_panel_data, ci_params):
"""Test conversion to DataFrame."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2228,9 +2235,10 @@ def test_to_dataframe(self, sdid_panel_data):
assert len(df) == 1
assert "att" in df.columns

def test_repr(self, sdid_panel_data):
def test_repr(self, sdid_panel_data, ci_params):
"""Test string representation."""
sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2244,9 +2252,10 @@ def test_repr(self, sdid_panel_data):
assert "SyntheticDiDResults" in repr_str
assert "ATT=" in repr_str

def test_is_significant_property(self, sdid_panel_data):
def test_is_significant_property(self, sdid_panel_data, ci_params):
"""Test is_significant property."""
sdid = SyntheticDiD(n_bootstrap=100, seed=42)
n_boot = ci_params.bootstrap(100)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand Down Expand Up @@ -2356,12 +2365,13 @@ def test_auto_infer_post_periods(self, sdid_panel_data):
assert results.pre_periods == [0, 1, 2, 3]
assert results.post_periods == [4, 5, 6, 7]

def test_with_covariates(self, sdid_panel_data):
def test_with_covariates(self, sdid_panel_data, ci_params):
"""Test SDID with covariates."""
# Add a covariate
sdid_panel_data["size"] = np.random.normal(100, 10, len(sdid_panel_data))

sdid = SyntheticDiD(n_bootstrap=50, seed=42)
n_boot = ci_params.bootstrap(50)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2375,9 +2385,10 @@ def test_with_covariates(self, sdid_panel_data):
assert results is not None
assert sdid.is_fitted_

def test_confidence_interval_contains_estimate(self, sdid_panel_data):
def test_confidence_interval_contains_estimate(self, sdid_panel_data, ci_params):
"""Test that confidence interval contains the estimate."""
sdid = SyntheticDiD(n_bootstrap=100, seed=42)
n_boot = ci_params.bootstrap(100)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)
results = sdid.fit(
sdid_panel_data,
outcome="outcome",
Expand All @@ -2390,9 +2401,10 @@ def test_confidence_interval_contains_estimate(self, sdid_panel_data):
lower, upper = results.conf_int
assert lower < results.att < upper

def test_reproducibility_with_seed(self, sdid_panel_data):
def test_reproducibility_with_seed(self, sdid_panel_data, ci_params):
"""Test that results are reproducible with the same seed."""
results1 = SyntheticDiD(n_bootstrap=50, seed=42).fit(
n_boot = ci_params.bootstrap(50)
results1 = SyntheticDiD(n_bootstrap=n_boot, seed=42).fit(
sdid_panel_data,
outcome="outcome",
treatment="treated",
Expand All @@ -2401,7 +2413,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data):
post_periods=[4, 5, 6, 7]
)

results2 = SyntheticDiD(n_bootstrap=50, seed=42).fit(
results2 = SyntheticDiD(n_bootstrap=n_boot, seed=42).fit(
sdid_panel_data,
outcome="outcome",
treatment="treated",
Expand All @@ -2413,7 +2425,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data):
assert results1.att == results2.att
assert results1.se == results2.se

def test_insufficient_pre_periods_warning(self):
def test_insufficient_pre_periods_warning(self, ci_params):
"""Test that SDID warns with very few pre-treatment periods."""
np.random.seed(42)

Expand Down Expand Up @@ -2448,7 +2460,8 @@ def test_insufficient_pre_periods_warning(self):

df = pd.DataFrame(data)

sdid = SyntheticDiD(n_bootstrap=30, seed=42)
n_boot = ci_params.bootstrap(30)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)

# Should work but may warn about few pre-periods
# (Depending on implementation - some may warn, some may not)
Expand All @@ -2465,7 +2478,7 @@ def test_insufficient_pre_periods_warning(self):
assert np.isfinite(results.att)
assert results.se > 0

def test_single_pre_period_edge_case(self):
def test_single_pre_period_edge_case(self, ci_params):
"""Test SDID with single pre-treatment period (extreme edge case)."""
np.random.seed(42)

Expand Down Expand Up @@ -2499,7 +2512,8 @@ def test_single_pre_period_edge_case(self):

df = pd.DataFrame(data)

sdid = SyntheticDiD(n_bootstrap=30, seed=42)
n_boot = ci_params.bootstrap(30)
sdid = SyntheticDiD(n_bootstrap=n_boot, seed=42)

# With single pre-period, time weights will be trivially [1.0]
results = sdid.fit(
Expand All @@ -2516,7 +2530,7 @@ def test_single_pre_period_edge_case(self):
# Time weights should have single entry
assert len(results.time_weights) == 1

def test_more_pre_periods_than_control_units(self):
def test_more_pre_periods_than_control_units(self, ci_params):
"""Test SDID when n_pre_periods > n_control_units (underdetermined)."""
np.random.seed(42)

Expand Down Expand Up @@ -2551,7 +2565,8 @@ def test_more_pre_periods_than_control_units(self):
df = pd.DataFrame(data)

# Use regularization to help with underdetermined system
sdid = SyntheticDiD(lambda_reg=1.0, n_bootstrap=30, seed=42)
n_boot = ci_params.bootstrap(30)
sdid = SyntheticDiD(lambda_reg=1.0, n_bootstrap=n_boot, seed=42)

results = sdid.fit(
df,
Expand Down
10 changes: 6 additions & 4 deletions tests/test_methodology_callaway.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ class TestSEFormulas:
"""Tests for standard error formula verification."""

@pytest.mark.slow
def test_analytical_se_close_to_bootstrap_se(self):
def test_analytical_se_close_to_bootstrap_se(self, ci_params):
"""
Analytical and bootstrap SEs should be within 20%.

Expand All @@ -812,6 +812,7 @@ def test_analytical_se_close_to_bootstrap_se(self):
This test is marked slow because it uses 499 bootstrap iterations
for thorough validation of SE convergence.
"""
n_boot = ci_params.bootstrap(499)
data = generate_staggered_data(
n_units=300,
n_periods=8,
Expand All @@ -821,7 +822,7 @@ def test_analytical_se_close_to_bootstrap_se(self):
)

cs_anal = CallawaySantAnna(n_bootstrap=0)
cs_boot = CallawaySantAnna(n_bootstrap=499, seed=42)
cs_boot = CallawaySantAnna(n_bootstrap=n_boot, seed=42)

results_anal = cs_anal.fit(
data, outcome='outcome', unit='unit',
Expand Down Expand Up @@ -893,12 +894,13 @@ def test_bootstrap_weight_moments_webb(self):
var_w = np.var(weights)
assert abs(var_w - 1.0) < 0.05, f"Webb Var(w) should be ~1.0, got {var_w}"

def test_bootstrap_produces_valid_inference(self):
def test_bootstrap_produces_valid_inference(self, ci_params):
"""Test that bootstrap produces valid inference with p-values and CIs.

Uses 99 bootstrap iterations - sufficient to verify the mechanism works
without being slow for CI runs.
"""
n_boot = ci_params.bootstrap(99)
data = generate_staggered_data(
n_units=100,
n_periods=6,
Expand All @@ -907,7 +909,7 @@ def test_bootstrap_produces_valid_inference(self):
seed=42
)

cs = CallawaySantAnna(n_bootstrap=99, seed=42)
cs = CallawaySantAnna(n_bootstrap=n_boot, seed=42)
results = cs.fit(
data, outcome='outcome', unit='unit',
time='period', first_treat='first_treat'
Expand Down
Loading
Loading