From 842eefc33c60ef84b8f5af0fee384f0c2101351d Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 12:37:26 -0500 Subject: [PATCH 1/8] MultiPeriodDiD: full event-study specification with pre-period coefficients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BREAKING: Transform MultiPeriodDiD from a post-period-only estimator into a proper event-study estimator with pre-period coefficients for parallel trends assessment. Core changes: - Treatment × period interactions for ALL non-reference periods (pre and post) - Default reference period is last pre-period (e=-1 convention) with FutureWarning - New `unit` parameter for staggered adoption detection warning - `period_effects` now contains both pre and post period effects - `summary()` shows pre-period section with reference period indicator - `to_dataframe()` includes `is_post` column - `interaction_indices` stored on results for robust sub-VCV extraction Bug fix: - HonestDiD/PreTrendsPower VCV extraction uses interaction sub-VCV instead of full regression VCV (via stored interaction_indices) Co-Authored-By: Claude Opus 4.5 --- CHANGELOG.md | 25 + CLAUDE.md | 2 +- README.md | 45 +- diff_diff/estimators.py | 136 +++- diff_diff/honest_did.py | 275 ++++--- diff_diff/pretrends.py | 240 +++--- diff_diff/results.py | 252 ++++--- diff_diff/visualization.py | 419 ++++++----- docs/api/results.rst | 4 + docs/choosing_estimator.rst | 7 +- docs/methodology/REGISTRY.md | 114 ++- tests/test_estimators.py | 1346 ++++++++++++++++++++-------------- tests/test_honest_did.py | 266 ++++--- tests/test_visualization.py | 244 +++--- 14 files changed, 1934 insertions(+), 1441 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73e2e81..2685221 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,11 +7,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- **MultiPeriodDiD: Full event-study specification** (BREAKING) + - Treatment × period interactions now created for ALL periods (pre and post), + not just post-treatment + - Pre-period coefficients available for parallel trends assessment + - Default reference period changed from first to last pre-period (e=-1 convention) + with FutureWarning for one release cycle + - `period_effects` dict now contains both pre and post period effects + - `to_dataframe()` includes `is_post` column + - `summary()` output now shows pre-period effects section + - t_stat uses `np.isfinite(se) and se > 0` guard (consistent with other estimators) + +### Added +- `unit` parameter to `MultiPeriodDiD.fit()` for staggered adoption detection +- `reference_period` and `interaction_indices` attributes on `MultiPeriodDiDResults` +- `pre_period_effects` and `post_period_effects` convenience properties on results +- Pre-period section in `summary()` output with reference period indicator +- Warning when `reference_period` is set to a post-treatment period +- Staggered adoption warning when treatment timing varies across units (with `unit` param) +- Informative KeyError when accessing reference period via `get_effect()` + ### Removed - **TROP `variance_method` parameter** — Jackknife variance estimation removed. Bootstrap (the only method specified in Athey et al. 2025) is now always used. The `variance_method` field has also been removed from `TROPResults`. +### Fixed +- HonestDiD VCV extraction: now uses interaction sub-VCV instead of full regression VCV + (via `interaction_indices` period → column index mapping) + ## [2.2.0] - 2026-01-27 ### Added diff --git a/CLAUDE.md b/CLAUDE.md index eab6ab1..28b4756 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -60,7 +60,7 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - **`diff_diff/estimators.py`** - Core estimator classes implementing DiD methods: - `DifferenceInDifferences` - Basic 2x2 DiD with formula or column-name interface - - `MultiPeriodDiD` - Event-study style DiD with period-specific treatment effects + - `MultiPeriodDiD` - Full event-study DiD with treatment × period interactions for ALL periods (pre and post). Supports `unit` parameter for staggered adoption detection. Default reference period is last pre-period (e=-1 convention). Pre-period coefficients enable parallel trends assessment. `interaction_indices` maps periods to VCV column indices for robust sub-VCV extraction in HonestDiD/PreTrendsPower. - Re-exports `TwoWayFixedEffects` and `SyntheticDiD` for backward compatibility - **`diff_diff/twfe.py`** - Two-Way Fixed Effects estimator: diff --git a/README.md b/README.md index a675e35..a902ea6 100644 --- a/README.md +++ b/README.md @@ -561,12 +561,13 @@ results = twfe.fit( ### Multi-Period DiD (Event Study) -For settings with multiple pre- and post-treatment periods: +For settings with multiple pre- and post-treatment periods. Estimates treatment × period +interactions for ALL periods (pre and post), enabling parallel trends assessment: ```python from diff_diff import MultiPeriodDiD -# Fit with multiple time periods +# Fit full event study with pre and post period effects did = MultiPeriodDiD() results = did.fit( panel_data, @@ -574,18 +575,23 @@ results = did.fit( treatment='treated', time='period', post_periods=[3, 4, 5], # Periods 3-5 are post-treatment - reference_period=0 # Reference period for comparison + reference_period=2, # Last pre-period (e=-1 convention) + unit='unit_id', # Optional: warns if staggered adoption detected ) -# View period-specific treatment effects -for period, effect in results.period_effects.items(): - print(f"Period {period}: {effect.effect:.3f} (SE: {effect.se:.3f})") +# Pre-period effects test parallel trends (should be ≈ 0) +for period, effect in results.pre_period_effects.items(): + print(f"Pre {period}: {effect.effect:.3f} (SE: {effect.se:.3f})") + +# Post-period effects estimate dynamic treatment effects +for period, effect in results.post_period_effects.items(): + print(f"Post {period}: {effect.effect:.3f} (SE: {effect.se:.3f})") # View average treatment effect across post-periods print(f"Average ATT: {results.avg_att:.3f}") print(f"Average SE: {results.avg_se:.3f}") -# Full summary with all period effects +# Full summary with pre and post period effects results.print_summary() ``` @@ -951,10 +957,10 @@ Create publication-ready event study plots: ```python from diff_diff import plot_event_study, MultiPeriodDiD, CallawaySantAnna, SunAbraham -# From MultiPeriodDiD +# From MultiPeriodDiD (full event study with pre and post period effects) did = MultiPeriodDiD() results = did.fit(data, outcome='y', treatment='treated', - time='period', post_periods=[3, 4, 5]) + time='period', post_periods=[3, 4, 5], reference_period=2) plot_event_study(results, title="Treatment Effects Over Time") # From CallawaySantAnna (with event study aggregation) @@ -1413,14 +1419,15 @@ Pre-trends tests have low power and can exacerbate bias. **Honest DiD** (Rambach ```python from diff_diff import HonestDiD, MultiPeriodDiD -# First, fit a standard event study +# First, fit a full event study (pre + post period effects) did = MultiPeriodDiD() event_results = did.fit( data, outcome='outcome', treatment='treated', time='period', - post_periods=[5, 6, 7, 8, 9] + post_periods=[5, 6, 7, 8, 9], + reference_period=4, # Last pre-period (e=-1 convention) ) # Compute honest bounds with relative magnitudes restriction @@ -1488,14 +1495,15 @@ A passing pre-trends test doesn't mean parallel trends holds—it may just mean ```python from diff_diff import PreTrendsPower, MultiPeriodDiD -# First, fit an event study +# First, fit a full event study did = MultiPeriodDiD() event_results = did.fit( data, outcome='outcome', treatment='treated', time='period', - post_periods=[5, 6, 7, 8, 9] + post_periods=[5, 6, 7, 8, 9], + reference_period=4, ) # Analyze pre-trends test power @@ -1764,7 +1772,8 @@ MultiPeriodDiD( | `covariates` | list | Linear control variables | | `fixed_effects` | list | Categorical FE columns (creates dummies) | | `absorb` | list | High-dimensional FE (within-transformation) | -| `reference_period` | any | Omitted period for time dummies | +| `reference_period` | any | Omitted period (default: last pre-period, e=-1 convention) | +| `unit` | str | Unit identifier column (for staggered adoption warning) | ### MultiPeriodDiDResults @@ -1772,8 +1781,8 @@ MultiPeriodDiD( | Attribute | Description | |-----------|-------------| -| `period_effects` | Dict mapping periods to PeriodEffect objects | -| `avg_att` | Average ATT across post-treatment periods | +| `period_effects` | Dict mapping periods to PeriodEffect objects (pre and post, excluding reference) | +| `avg_att` | Average ATT across post-treatment periods only | | `avg_se` | Standard error of average ATT | | `avg_t_stat` | T-statistic for average ATT | | `avg_p_value` | P-value for average ATT | @@ -1781,6 +1790,10 @@ MultiPeriodDiD( | `n_obs` | Number of observations | | `pre_periods` | List of pre-treatment periods | | `post_periods` | List of post-treatment periods | +| `reference_period` | The omitted reference period (coefficient = 0 by construction) | +| `interaction_indices` | Dict mapping period → column index in VCV (for sub-VCV extraction) | +| `pre_period_effects` | Property: pre-period effects only (for parallel trends assessment) | +| `post_period_effects` | Property: post-period effects only | **Methods:** diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 83fc72b..36cfa00 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -12,6 +12,7 @@ For backward compatibility, all estimators are re-exported from this module. """ +import warnings from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -152,7 +153,7 @@ def fit( formula: Optional[str] = None, covariates: Optional[List[str]] = None, fixed_effects: Optional[List[str]] = None, - absorb: Optional[List[str]] = None + absorb: Optional[List[str]] = None, ) -> DiDResults: """ Fit the Difference-in-Differences model. @@ -380,9 +381,7 @@ def _fit_ols( (coefficients, residuals, fitted_values, r_squared) """ # Use unified OLS backend - coefficients, residuals, fitted, _ = solve_ols( - X, y, return_fitted=True, return_vcov=False - ) + coefficients, residuals, fitted, _ = solve_ols(X, y, return_fitted=True, return_vcov=False) r_squared = compute_r_squared(y, residuals) return coefficients, residuals, fitted, r_squared @@ -417,13 +416,16 @@ def _run_wild_bootstrap_inference( (se, p_value, conf_int, t_stat, vcov, bootstrap_results) """ bootstrap_results = wild_bootstrap_se( - X, y, residuals, cluster_ids, + X, + y, + residuals, + cluster_ids, coefficient_index=coefficient_index, n_bootstrap=self.n_bootstrap, weight_type=self.bootstrap_weights, alpha=self.alpha, seed=self.seed, - return_distribution=False + return_distribution=False, ) self._bootstrap_results = bootstrap_results @@ -536,7 +538,7 @@ def _validate_data( outcome: str, treatment: str, time: str, - covariates: Optional[List[str]] = None + covariates: Optional[List[str]] = None, ) -> None: """Validate input data.""" # Check DataFrame @@ -702,15 +704,18 @@ class MultiPeriodDiD(DifferenceInDifferences): ----- The model estimates: - Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_t∈post δ_t*(D_i × Post_t) + ε_it + Y_it = α + β*D_i + Σ_t γ_t*Period_t + Σ_{t≠ref} δ_t*(D_i × 1{t}) + ε_it Where: - D_i is the treatment indicator - - Period_t are time period dummies - - D_i × Post_t are treatment-by-post-period interactions + - Period_t are time period dummies (all non-reference periods) + - D_i × 1{t} are treatment-by-period interactions (all non-reference) - δ_t are the period-specific treatment effects + - The reference period (default: last pre-period) has δ_ref = 0 by construction - The average ATT is computed as the mean of the δ_t coefficients. + Pre-treatment δ_t test the parallel trends assumption (should be ≈ 0). + Post-treatment δ_t estimate dynamic treatment effects. + The average ATT is computed from post-treatment δ_t only. """ def fit( # type: ignore[override] @@ -723,7 +728,8 @@ def fit( # type: ignore[override] covariates: Optional[List[str]] = None, fixed_effects: Optional[List[str]] = None, absorb: Optional[List[str]] = None, - reference_period: Any = None + reference_period: Any = None, + unit: Optional[str] = None, ) -> MultiPeriodDiDResults: """ Fit the Multi-Period Difference-in-Differences model. @@ -749,7 +755,13 @@ def fit( # type: ignore[override] List of categorical column names for high-dimensional fixed effects. reference_period : any, optional The reference (omitted) time period for the period dummies. - Defaults to the first pre-treatment period. + Defaults to the last pre-treatment period (e=-1 convention). + unit : str, optional + Name of the unit identifier column. When provided, checks whether + treatment timing varies across units and warns if staggered adoption + is detected (suggests CallawaySantAnna instead). Does NOT affect + standard error computation -- use the ``cluster`` parameter for + cluster-robust SEs. Returns ------- @@ -763,18 +775,15 @@ def fit( # type: ignore[override] """ # Warn if wild bootstrap is requested but not supported if self.inference == "wild_bootstrap": - import warnings warnings.warn( "Wild bootstrap inference is not yet supported for MultiPeriodDiD. " "Using analytical inference instead.", - UserWarning + UserWarning, ) # Validate basic inputs if outcome is None or treatment is None or time is None: - raise ValueError( - "Must provide 'outcome', 'treatment', and 'time'" - ) + raise ValueError("Must provide 'outcome', 'treatment', and 'time'") # Validate columns exist self._validate_data(data, outcome, treatment, time, covariates) @@ -782,6 +791,25 @@ def fit( # type: ignore[override] # Validate treatment is binary validate_binary(data[treatment].values, "treatment") + # Validate unit column and check for staggered adoption + if unit is not None: + if unit not in data.columns: + raise ValueError(f"Unit column '{unit}' not found in data") + + # Check for staggered treatment timing + treated_mask = data[treatment] == 1 + if treated_mask.any(): + treatment_timing = data.loc[treated_mask].groupby(unit)[time].min() + if treatment_timing.nunique() > 1: + warnings.warn( + "Treatment timing varies across units (staggered adoption " + "detected). MultiPeriodDiD assumes simultaneous adoption " + "and may produce biased estimates with staggered treatment. " + "Consider using CallawaySantAnna or SunAbraham instead.", + UserWarning, + stacklevel=2, + ) + # Get all unique time periods all_periods = sorted(data[time].unique()) @@ -811,10 +839,34 @@ def fit( # type: ignore[override] # Determine reference period (omitted dummy) if reference_period is None: - reference_period = pre_periods[0] + # Default: last pre-period (e=-1 convention, matches fixest) + if len(pre_periods) > 1: + warnings.warn( + f"The default reference_period is changing from the first " + f"pre-period ({pre_periods[0]}) to the last pre-period " + f"({pre_periods[-1]}) to match the standard e=-1 convention. " + f"To silence this warning, pass " + f"reference_period={pre_periods[-1]} explicitly. " + f"In a future version, the default will be the last " + f"pre-period.", + FutureWarning, + stacklevel=2, + ) + reference_period = pre_periods[-1] elif reference_period not in all_periods: raise ValueError(f"Reference period '{reference_period}' not found in time column") + # Warn if reference period is a post-treatment period + if reference_period in post_periods: + warnings.warn( + f"reference_period={reference_period} is a post-treatment period. " + f"The reference period should typically be a pre-treatment period " + f"(e.g., the last pre-period). Post-period references alter the " + f"interpretation of all coefficients.", + UserWarning, + stacklevel=2, + ) + # Validate fixed effects and absorb columns if fixed_effects: for fe in fixed_effects: @@ -857,11 +909,12 @@ def fit( # type: ignore[override] var_names.append(f"period_{period}") period_dummy_indices[period] = X.shape[1] - 1 - # Add treatment × post-period interactions - # These are our coefficients of interest - interaction_indices = {} # Map post-period -> column index in X + # Add treatment × period interactions for ALL non-reference periods + # Pre-period interactions test parallel trends; post-period interactions + # estimate dynamic treatment effects + interaction_indices = {} # Map period -> column index in X - for period in post_periods: + for period in non_ref_periods: interaction = d * (t == period).astype(float) X = np.column_stack([X, interaction]) var_names.append(f"{treatment}:period_{period}") @@ -889,7 +942,8 @@ def fit( # type: ignore[override] # Note: Wild bootstrap for multi-period effects is complex (multiple coefficients) # For now, we use analytical inference even if inference="wild_bootstrap" coefficients, residuals, fitted, vcov = solve_ols( - X, y, + X, + y, return_fitted=True, return_vcov=True, cluster_ids=cluster_ids, @@ -915,21 +969,23 @@ def fit( # type: ignore[override] else: # For rank-deficient case, compute vcov on reduced matrix then expand X_reduced = X[:, identified_mask] - vcov_reduced = np.linalg.solve(X_reduced.T @ X_reduced, mse * np.eye(X_reduced.shape[1])) + vcov_reduced = np.linalg.solve( + X_reduced.T @ X_reduced, mse * np.eye(X_reduced.shape[1]) + ) # Expand to full size with NaN for dropped columns vcov = np.full((X.shape[1], X.shape[1]), np.nan) vcov[np.ix_(identified_mask, identified_mask)] = vcov_reduced - # Extract period-specific treatment effects + # Extract period-specific treatment effects for ALL non-reference periods period_effects = {} - effect_values = [] - effect_indices = [] + post_effect_values = [] + post_effect_indices = [] - for period in post_periods: + for period in non_ref_periods: idx = interaction_indices[period] effect = coefficients[idx] se = np.sqrt(vcov[idx, idx]) - t_stat = effect / se + t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan p_value = compute_p_value(t_stat, df=df) conf_int = compute_confidence_interval(effect, se, self.alpha, df=df) @@ -939,14 +995,16 @@ def fit( # type: ignore[override] se=se, t_stat=t_stat, p_value=p_value, - conf_int=conf_int + conf_int=conf_int, ) - effect_values.append(effect) - effect_indices.append(idx) - # Compute average treatment effect - # R-style NA propagation: if ANY period effect is NaN, average is undefined - effect_arr = np.array(effect_values) + if period in post_periods: + post_effect_values.append(effect) + post_effect_indices.append(idx) + + # Compute average treatment effect (post-periods only) + # R-style NA propagation: if ANY post-period effect is NaN, average is undefined + effect_arr = np.array(post_effect_values) if np.any(np.isnan(effect_arr)): # Some period effects are NaN (unidentified) - cannot compute valid average @@ -962,8 +1020,8 @@ def fit( # type: ignore[override] # Standard error of average: need to account for covariance n_post = len(post_periods) - sub_vcov = vcov[np.ix_(effect_indices, effect_indices)] - avg_var = np.sum(sub_vcov) / (n_post ** 2) + sub_vcov = vcov[np.ix_(post_effect_indices, post_effect_indices)] + avg_var = np.sum(sub_vcov) / (n_post**2) if np.isnan(avg_var) or avg_var < 0: # Vcov has NaN (dropped columns) - propagate NaN @@ -1009,6 +1067,8 @@ def fit( # type: ignore[override] residuals=residuals, fitted_values=fitted, r_squared=r_squared, + reference_period=reference_period, + interaction_indices=interaction_indices, ) self._coefficients = coefficients diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 81c8caf..ddae7ff 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -192,9 +192,7 @@ class HonestDiDResults: ci_method: str = "FLCI" original_results: Optional[Any] = field(default=None, repr=False) # Event study bounds (optional) - event_study_bounds: Optional[Dict[Any, Dict[str, float]]] = field( - default=None, repr=False - ) + event_study_bounds: Optional[Dict[Any, Dict[str, float]]] = field(default=None, repr=False) def __repr__(self) -> str: sig = "" if self.ci_lb <= 0 <= self.ci_ub else "*" @@ -276,11 +274,13 @@ def summary(self) -> str: ] # Interpretation - lines.extend([ - "-" * 70, - "Interpretation".center(70), - "-" * 70, - ]) + lines.extend( + [ + "-" * 70, + "Interpretation".center(70), + "-" * 70, + ] + ) if self.method == "relative_magnitude": lines.append( @@ -294,9 +294,7 @@ def summary(self) -> str: f"Violation curvature (second diff) bounded by {self.M:.4f} per period." ) else: - lines.append( - f"Combined smoothness (M={self.M:.2f}) and relative magnitude bounds." - ) + lines.append(f"Combined smoothness (M={self.M:.2f}) and relative magnitude bounds.") if self.is_significant: if self.ci_lb > 0: @@ -304,9 +302,7 @@ def summary(self) -> str: else: lines.append(f"Effect remains NEGATIVE even with violations up to M={self.M}.") else: - lines.append( - f"Cannot rule out zero effect when allowing violations up to M={self.M}." - ) + lines.append(f"Cannot rule out zero effect when allowing violations up to M={self.M}.") lines.extend(["", "=" * 70]) @@ -378,10 +374,7 @@ class SensitivityResults: def __repr__(self) -> str: breakdown_str = f"{self.breakdown_M:.4f}" if self.breakdown_M else "None" - return ( - f"SensitivityResults(n_M={len(self.M_values)}, " - f"breakdown_M={breakdown_str})" - ) + return f"SensitivityResults(n_M={len(self.M_values)}, " f"breakdown_M={breakdown_str})" @property def has_breakdown(self) -> bool: @@ -405,18 +398,18 @@ def summary(self) -> str: if self.breakdown_M is not None: lines.append(f"{'Breakdown value:':<30} {self.breakdown_M:.4f}") lines.append("") - lines.append( - f"Result is robust to violations up to M = {self.breakdown_M:.4f}" - ) + lines.append(f"Result is robust to violations up to M = {self.breakdown_M:.4f}") else: lines.append(f"{'Breakdown value:':<30} None (always significant)") - lines.extend([ - "", - "-" * 70, - f"{'M':<10} {'Lower Bound':>12} {'Upper Bound':>12} {'CI Lower':>12} {'CI Upper':>12}", - "-" * 70, - ]) + lines.extend( + [ + "", + "-" * 70, + f"{'M':<10} {'Lower Bound':>12} {'Upper Bound':>12} {'CI Lower':>12} {'CI Upper':>12}", + "-" * 70, + ] + ) for i, M in enumerate(self.M_values): lb, ub = self.bounds[i] @@ -437,18 +430,26 @@ def to_dataframe(self) -> pd.DataFrame: for i, M in enumerate(self.M_values): lb, ub = self.bounds[i] ci_lb, ci_ub = self.robust_cis[i] - rows.append({ - "M": M, - "lb": lb, - "ub": ub, - "ci_lb": ci_lb, - "ci_ub": ci_ub, - "is_significant": not (ci_lb <= 0 <= ci_ub), - }) + rows.append( + { + "M": M, + "lb": lb, + "ub": ub, + "ci_lb": ci_lb, + "ci_ub": ci_ub, + "is_significant": not (ci_lb <= 0 <= ci_ub), + } + ) return pd.DataFrame(rows) - def plot(self, ax=None, show_bounds: bool = True, show_ci: bool = True, - breakdown_line: bool = True, **kwargs): + def plot( + self, + ax=None, + show_bounds: bool = True, + show_ci: bool = True, + breakdown_line: bool = True, + **kwargs, + ): """ Plot sensitivity analysis results. @@ -483,28 +484,45 @@ def plot(self, ax=None, show_bounds: bool = True, show_ci: bool = True, ci_arr = np.array(self.robust_cis) # Plot original estimate - ax.axhline(y=self.original_estimate, color='black', linestyle='-', - linewidth=1.5, label='Original estimate', alpha=0.7) + ax.axhline( + y=self.original_estimate, + color="black", + linestyle="-", + linewidth=1.5, + label="Original estimate", + alpha=0.7, + ) # Plot zero line - ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5) + ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5) if show_bounds: - ax.fill_between(M, bounds_arr[:, 0], bounds_arr[:, 1], - alpha=0.3, color='blue', label='Identified set') + ax.fill_between( + M, + bounds_arr[:, 0], + bounds_arr[:, 1], + alpha=0.3, + color="blue", + label="Identified set", + ) if show_ci: - ax.plot(M, ci_arr[:, 0], 'b-', linewidth=1.5, label='Robust CI') - ax.plot(M, ci_arr[:, 1], 'b-', linewidth=1.5) + ax.plot(M, ci_arr[:, 0], "b-", linewidth=1.5, label="Robust CI") + ax.plot(M, ci_arr[:, 1], "b-", linewidth=1.5) if breakdown_line and self.breakdown_M is not None: - ax.axvline(x=self.breakdown_M, color='red', linestyle=':', - linewidth=2, label=f'Breakdown (M={self.breakdown_M:.2f})') + ax.axvline( + x=self.breakdown_M, + color="red", + linestyle=":", + linewidth=2, + label=f"Breakdown (M={self.breakdown_M:.2f})", + ) - ax.set_xlabel('M (restriction parameter)') - ax.set_ylabel('Treatment Effect') - ax.set_title('Sensitivity Analysis: Treatment Effect Bounds') - ax.legend(loc='best') + ax.set_xlabel("M (restriction parameter)") + ax.set_ylabel("Treatment Effect") + ax.set_title("Sensitivity Analysis: Treatment Effect Bounds") + ax.legend(loc="best") return ax @@ -515,7 +533,7 @@ def plot(self, ax=None, show_bounds: bool = True, show_ci: bool = True, def _extract_event_study_params( - results: Union[MultiPeriodDiDResults, Any] + results: Union[MultiPeriodDiDResults, Any], ) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any]]: """ Extract event study parameters from results objects. @@ -545,28 +563,25 @@ def _extract_event_study_params( pre_periods = results.pre_periods post_periods = results.post_periods - # Get coefficients - need to extract from period_effects - # Note: MultiPeriodDiD stores effects for post-periods only in period_effects - # Pre-period effects would be in the coefficients dict if estimated - effects = [] - ses = [] - - # For now, we'll work with post-period effects - # In a full event study, we'd also have pre-period coefficients - for period in post_periods: - pe = results.period_effects[period] - effects.append(pe.effect) - ses.append(pe.se) + # Extract all estimated effects in chronological order + all_estimated = sorted(results.period_effects.keys()) + effects = [results.period_effects[p].effect for p in all_estimated] + ses = [results.period_effects[p].se for p in all_estimated] beta_hat = np.array(effects) - num_post_periods = len(post_periods) - num_pre_periods = len(pre_periods) if pre_periods else 0 - - # Get vcov if available - if results.vcov is not None: - sigma = results.vcov + num_pre_periods = sum(1 for p in all_estimated if p in pre_periods) + num_post_periods = sum(1 for p in all_estimated if p in post_periods) + + # Extract proper sub-VCV for interaction terms + if ( + results.vcov is not None + and hasattr(results, "interaction_indices") + and results.interaction_indices is not None + ): + indices = [results.interaction_indices[p] for p in all_estimated] + sigma = results.vcov[np.ix_(indices, indices)] else: - # Construct diagonal vcov from SEs + # Fallback: diagonal from SEs sigma = np.diag(np.array(ses) ** 2) return beta_hat, sigma, num_pre_periods, num_post_periods, pre_periods, post_periods @@ -575,6 +590,7 @@ def _extract_event_study_params( # Try CallawaySantAnnaResults try: from diff_diff.staggered import CallawaySantAnnaResults + if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects is None: raise ValueError( @@ -586,9 +602,9 @@ def _extract_event_study_params( # Extract event study effects by relative time # Filter out normalization constraints (n_groups=0) and non-finite SEs event_effects = { - t: data for t, data in results.event_study_effects.items() - if data.get('n_groups', 1) > 0 - and np.isfinite(data.get('se', np.nan)) + t: data + for t, data in results.event_study_effects.items() + if data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) } rel_times = sorted(event_effects.keys()) @@ -599,17 +615,13 @@ def _extract_event_study_params( effects = [] ses = [] for t in rel_times: - effects.append(event_effects[t]['effect']) - ses.append(event_effects[t]['se']) + effects.append(event_effects[t]["effect"]) + ses.append(event_effects[t]["se"]) beta_hat = np.array(effects) sigma = np.diag(np.array(ses) ** 2) - return ( - beta_hat, sigma, - len(pre_times), len(post_times), - pre_times, post_times - ) + return (beta_hat, sigma, len(pre_times), len(post_times), pre_times, post_times) except ImportError: pass @@ -644,17 +656,15 @@ def _construct_A_sd(num_periods: int) -> np.ndarray: for i in range(n_constraints): # Second difference: delta_{t+1} - 2*delta_t + delta_{t-1} - A[i, i] = 1 # delta_{t-1} + A[i, i] = 1 # delta_{t-1} A[i, i + 1] = -2 # delta_t - A[i, i + 2] = 1 # delta_{t+1} + A[i, i + 2] = 1 # delta_{t+1} return A def _construct_constraints_sd( - num_pre_periods: int, - num_post_periods: int, - M: float + num_pre_periods: int, num_post_periods: int, M: float ) -> Tuple[np.ndarray, np.ndarray]: """ Construct smoothness constraint matrices. @@ -692,10 +702,7 @@ def _construct_constraints_sd( def _construct_constraints_rm( - num_pre_periods: int, - num_post_periods: int, - Mbar: float, - max_pre_violation: float + num_pre_periods: int, num_post_periods: int, Mbar: float, max_pre_violation: float ) -> Tuple[np.ndarray, np.ndarray]: """ Construct relative magnitudes constraint matrices. @@ -731,7 +738,7 @@ def _construct_constraints_rm( for i in range(num_post_periods): post_idx = num_pre_periods + i - A_ineq[2 * i, post_idx] = 1 # delta <= bound + A_ineq[2 * i, post_idx] = 1 # delta <= bound A_ineq[2 * i + 1, post_idx] = -1 # -delta <= bound return A_ineq, b_ineq @@ -743,7 +750,7 @@ def _solve_bounds_lp( A_ineq: np.ndarray, b_ineq: np.ndarray, num_pre_periods: int, - lp_method: str = 'highs' + lp_method: str = "highs", ) -> Tuple[float, float]: """ Solve for identified set bounds using linear programming. @@ -789,7 +796,7 @@ def _solve_bounds_lp( # where delta_post = delta[num_pre_periods:] c = np.zeros(total_periods) - c[num_pre_periods:num_pre_periods + num_post] = -l_vec # min -l'@delta = max l'@delta + c[num_pre_periods : num_pre_periods + num_post] = -l_vec # min -l'@delta = max l'@delta # For upper bound: max l'@(beta - delta) = l'@beta + max(-l'@delta) # For lower bound: min l'@(beta - delta) = l'@beta + min(-l'@delta) @@ -801,9 +808,7 @@ def _solve_bounds_lp( # Solve for lower bound of -l'@delta (which gives upper bound of theta) try: result_min = optimize.linprog( - c, A_ub=A_ineq, b_ub=b_ineq, - bounds=(None, None), - method=lp_method + c, A_ub=A_ineq, b_ub=b_ineq, bounds=(None, None), method=lp_method ) if result_min.success: min_val = result_min.fun @@ -816,9 +821,7 @@ def _solve_bounds_lp( # Solve for upper bound of -l'@delta (which gives lower bound of theta) try: result_max = optimize.linprog( - -c, A_ub=A_ineq, b_ub=b_ineq, - bounds=(None, None), - method=lp_method + -c, A_ub=A_ineq, b_ub=b_ineq, bounds=(None, None), method=lp_method ) if result_max.success: max_val = -result_max.fun @@ -835,12 +838,7 @@ def _solve_bounds_lp( return lb, ub -def _compute_flci( - lb: float, - ub: float, - se: float, - alpha: float = 0.05 -) -> Tuple[float, float]: +def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple[float, float]: """ Compute Fixed Length Confidence Interval (FLCI). @@ -888,7 +886,7 @@ def _compute_clf_ci( Mbar: float, max_pre_violation: float, alpha: float = 0.05, - n_draws: int = 1000 + n_draws: int = 1000, ) -> Tuple[float, float, float, float]: """ Compute Conditional Least Favorable (C-LF) confidence interval. @@ -1066,8 +1064,9 @@ def fit( M = M if M is not None else self.M # Extract event study parameters - (beta_hat, sigma, num_pre, num_post, - pre_periods, post_periods) = _extract_event_study_params(results) + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + _extract_event_study_params(results) + ) # beta_hat from MultiPeriodDiDResults already contains only post-periods # Check if we have the right number of coefficients @@ -1089,7 +1088,7 @@ def fit( sigma_post = sigma[num_pre:, num_pre:] else: # Construct diagonal from available dimensions - sigma_post = sigma[:len(beta_post), :len(beta_post)] + sigma_post = sigma[: len(beta_post), : len(beta_post)] # Update num_post to match actual data num_post = len(beta_post) @@ -1100,9 +1099,7 @@ def fit( else: l_vec = np.asarray(self.l_vec) if len(l_vec) != num_post: - raise ValueError( - f"l_vec must have length {num_post}, got {len(l_vec)}" - ) + raise ValueError(f"l_vec must have length {num_post}, got {len(l_vec)}") # Compute original estimate and SE original_estimate = np.dot(l_vec, beta_post) @@ -1117,15 +1114,13 @@ def fit( elif self.method == "relative_magnitude": lb, ub, ci_lb, ci_ub = self._compute_rm_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, - pre_periods, results + beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results ) ci_method = "C-LF" else: # combined lb, ub, ci_lb, ci_ub = self._compute_combined_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, - pre_periods, results + beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results ) ci_method = "FLCI" @@ -1150,7 +1145,7 @@ def _compute_smoothness_bounds( l_vec: np.ndarray, num_pre: int, num_post: int, - M: float + M: float, ) -> Tuple[float, float, float, float]: """Compute bounds under smoothness restriction.""" # Construct constraints @@ -1174,7 +1169,7 @@ def _compute_rm_bounds( num_post: int, Mbar: float, pre_periods: List, - results: Any + results: Any, ) -> Tuple[float, float, float, float]: """Compute bounds under relative magnitudes restriction.""" # Estimate max pre-period violation from pre-trends @@ -1204,7 +1199,7 @@ def _compute_combined_bounds( num_post: int, M: float, pre_periods: List, - results: Any + results: Any, ) -> Tuple[float, float, float, float]: """Compute bounds under combined smoothness + RM restriction.""" # Get smoothness bounds @@ -1232,11 +1227,7 @@ def _compute_combined_bounds( return lb, ub, ci_lb, ci_ub - def _estimate_max_pre_violation( - self, - results: Any, - pre_periods: List - ) -> float: + def _estimate_max_pre_violation(self, results: Any, pre_periods: List) -> float: """ Estimate the maximum pre-period violation. @@ -1244,19 +1235,14 @@ def _estimate_max_pre_violation( a default based on the overall SE. """ if isinstance(results, MultiPeriodDiDResults): - # Check if we have pre-period effects - # In a standard event study, pre-period coefficients should be ~0 - # Their magnitude indicates the pre-trend violation - if hasattr(results, 'coefficients') and results.coefficients: - # Look for pre-period coefficients - pre_effects = [] - for period in pre_periods: - key = f"treated:period_{period}" - if key in results.coefficients: - pre_effects.append(abs(results.coefficients[key])) - - if pre_effects: - return max(pre_effects) + # Pre-period effects are now in period_effects directly + pre_effects = [ + abs(results.period_effects[p].effect) + for p in pre_periods + if p in results.period_effects + ] + if pre_effects: + return max(pre_effects) # Fallback: use avg_se as a scale return results.avg_se @@ -1264,14 +1250,14 @@ def _estimate_max_pre_violation( # For CallawaySantAnna, use pre-period event study effects try: from diff_diff.staggered import CallawaySantAnnaResults + if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects: # Filter out normalization constraints (n_groups=0, e.g. reference period) pre_effects = [ - abs(results.event_study_effects[t]['effect']) + abs(results.event_study_effects[t]["effect"]) for t in results.event_study_effects - if t < 0 - and results.event_study_effects[t].get('n_groups', 1) > 0 + if t < 0 and results.event_study_effects[t].get("n_groups", 1) > 0 ] if pre_effects: return max(pre_effects) @@ -1336,10 +1322,7 @@ def sensitivity_analysis( ) def _find_breakdown( - self, - results: Any, - M_values: np.ndarray, - ci_list: List[Tuple[float, float]] + self, results: Any, M_values: np.ndarray, ci_list: List[Tuple[float, float]] ) -> Optional[float]: """ Find the breakdown value where CI first includes zero. @@ -1379,9 +1362,7 @@ def _find_breakdown( return None def breakdown_value( - self, - results: Union[MultiPeriodDiDResults, Any], - tol: float = 0.01 + self, results: Union[MultiPeriodDiDResults, Any], tol: float = 0.01 ) -> Optional[float]: """ Find the breakdown value directly using binary search. @@ -1470,7 +1451,7 @@ def sensitivity_plot( M_grid: Optional[List[float]] = None, alpha: float = 0.05, ax=None, - **kwargs + **kwargs, ): """ Create a sensitivity analysis plot. diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index d9bc1c6..e36bdbd 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -151,27 +151,19 @@ def summary(self) -> str: ] if self.power_adequate: - lines.append( - f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%})." - ) + lines.append(f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%}).") lines.append( f" The pre-trends test would detect violations of magnitude {self.violation_magnitude:.3f}." ) else: - lines.append( - f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%})." - ) + lines.append(f"✗ Power ({self.power:.0%}) below target ({self.target_power:.0%}).") lines.append( f" Would need violations of {self.mdv:.3f} to achieve {self.target_power:.0%} power." ) lines.append("") - lines.append( - f"Minimum detectable violation (MDV): {self.mdv:.4f}" - ) - lines.append( - " → Passing pre-trends test does NOT rule out violations up to this size." - ) + lines.append(f"Minimum detectable violation (MDV): {self.mdv:.4f}") + lines.append(" → Passing pre-trends test does NOT rule out violations up to this size.") lines.extend(["", "=" * 70]) @@ -289,21 +281,27 @@ class PreTrendsPowerCurve: violation_type: str def __repr__(self) -> str: - return ( - f"PreTrendsPowerCurve(n_points={len(self.M_values)}, " - f"mdv={self.mdv:.4f})" - ) + return f"PreTrendsPowerCurve(n_points={len(self.M_values)}, " f"mdv={self.mdv:.4f})" def to_dataframe(self) -> pd.DataFrame: """Convert to DataFrame with M and power columns.""" - return pd.DataFrame({ - "M": self.M_values, - "power": self.powers, - }) - - def plot(self, ax=None, show_mdv: bool = True, show_target: bool = True, - color: str = "#2563eb", mdv_color: str = "#dc2626", - target_color: str = "#22c55e", **kwargs): + return pd.DataFrame( + { + "M": self.M_values, + "power": self.powers, + } + ) + + def plot( + self, + ax=None, + show_mdv: bool = True, + show_target: bool = True, + color: str = "#2563eb", + mdv_color: str = "#dc2626", + target_color: str = "#22c55e", + **kwargs, + ): """ Plot the power curve. @@ -338,26 +336,35 @@ def plot(self, ax=None, show_mdv: bool = True, show_target: bool = True, fig, ax = plt.subplots(figsize=(10, 6)) # Plot power curve - ax.plot(self.M_values, self.powers, color=color, linewidth=2, - label="Power", **kwargs) + ax.plot(self.M_values, self.powers, color=color, linewidth=2, label="Power", **kwargs) # Target power line if show_target: - ax.axhline(y=self.target_power, color=target_color, linestyle="--", - linewidth=1.5, alpha=0.7, - label=f"Target power ({self.target_power:.0%})") + ax.axhline( + y=self.target_power, + color=target_color, + linestyle="--", + linewidth=1.5, + alpha=0.7, + label=f"Target power ({self.target_power:.0%})", + ) # MDV line if show_mdv and self.mdv is not None and np.isfinite(self.mdv): - ax.axvline(x=self.mdv, color=mdv_color, linestyle=":", - linewidth=1.5, alpha=0.7, - label=f"MDV = {self.mdv:.3f}") + ax.axvline( + x=self.mdv, + color=mdv_color, + linestyle=":", + linewidth=1.5, + alpha=0.7, + label=f"MDV = {self.mdv:.3f}", + ) ax.set_xlabel("Violation Magnitude (M)") ax.set_ylabel("Power") ax.set_title("Pre-Trends Test Power Curve") ax.set_ylim(0, 1.05) - ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}")) ax.legend(loc="lower right") ax.grid(True, alpha=0.3) @@ -450,9 +457,7 @@ def __init__( f"got '{violation_type}'" ) if violation_type == "custom" and violation_weights is None: - raise ValueError( - "violation_weights must be provided when violation_type='custom'" - ) + raise ValueError("violation_weights must be provided when violation_type='custom'") self.alpha = alpha self.target_power = power @@ -566,88 +571,41 @@ def _extract_pre_period_params( "parameter to specify which are actually pre-treatment." ) - # Only include periods with actual estimated coefficients - # (excludes the reference period which is omitted from estimation) - if hasattr(results, 'coefficients') and results.coefficients: - # Find which pre-periods have estimated coefficients - estimated_pre_periods = [ - p for p in all_pre_periods - if f"treated:period_{p}" in results.coefficients - ] + # Pre-period effects are in period_effects (excluding reference period) + estimated_pre_periods = [ + p + for p in all_pre_periods + if p in results.period_effects and results.period_effects[p].se > 0 + ] - if len(estimated_pre_periods) == 0: - raise ValueError( - "No estimated pre-period coefficients found. " - "The pre-trends test requires at least one estimated " - "pre-period coefficient (excluding the reference period)." - ) + if len(estimated_pre_periods) == 0: + raise ValueError( + "No estimated pre-period coefficients found. " + "The pre-trends test requires at least one estimated " + "pre-period coefficient (excluding the reference period)." + ) - n_pre = len(estimated_pre_periods) - - # Extract effects for estimated periods only - effects = np.array([ - results.coefficients[f"treated:period_{p}"] - for p in estimated_pre_periods - ]) - - # Extract SEs - try period_effects first, fall back to avg_se - ses = [] - for p in estimated_pre_periods: - if p in results.period_effects: - ses.append(results.period_effects[p].se) - else: - ses.append(results.avg_se) - ses = np.array(ses) - - # Extract vcov for estimated pre-periods - # Build mapping from period to vcov index - if results.vcov is not None: - # Get ordered list of all coefficient keys - coef_keys = list(results.coefficients.keys()) - pre_indices = [ - coef_keys.index(f"treated:period_{p}") - for p in estimated_pre_periods - if f"treated:period_{p}" in coef_keys - ] - if len(pre_indices) == n_pre and results.vcov.shape[0] > max(pre_indices): - vcov = results.vcov[np.ix_(pre_indices, pre_indices)] - else: - # Fall back to diagonal - vcov = np.diag(ses ** 2) - else: - vcov = np.diag(ses ** 2) + n_pre = len(estimated_pre_periods) + effects = np.array([results.period_effects[p].effect for p in estimated_pre_periods]) + ses = np.array([results.period_effects[p].se for p in estimated_pre_periods]) + + # Extract vcov using stored interaction indices for robust extraction + if ( + results.vcov is not None + and hasattr(results, "interaction_indices") + and results.interaction_indices is not None + ): + indices = [results.interaction_indices[p] for p in estimated_pre_periods] + vcov = results.vcov[np.ix_(indices, indices)] else: - # No coefficients available - try period_effects for pre-periods - # Exclude reference period (the one with effect=0 and se=0 or missing) - estimated_pre_periods = [ - p for p in all_pre_periods - if p in results.period_effects - and results.period_effects[p].se > 0 - ] - - if len(estimated_pre_periods) == 0: - raise ValueError( - "No estimated pre-period effects found. " - "The pre-trends test requires at least one estimated " - "pre-period effect (excluding the reference period)." - ) - - n_pre = len(estimated_pre_periods) - effects = np.array([ - results.period_effects[p].effect - for p in estimated_pre_periods - ]) - ses = np.array([ - results.period_effects[p].se - for p in estimated_pre_periods - ]) - vcov = np.diag(ses ** 2) + vcov = np.diag(ses**2) return effects, ses, vcov, n_pre # Try CallawaySantAnnaResults try: from diff_diff.staggered import CallawaySantAnnaResults + if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects is None: raise ValueError( @@ -658,10 +616,9 @@ def _extract_pre_period_params( # Get pre-period effects (negative relative times) # Filter out normalization constraints (n_groups=0) and non-finite SEs pre_effects = { - t: data for t, data in results.event_study_effects.items() - if t < 0 - and data.get('n_groups', 1) > 0 - and np.isfinite(data.get('se', np.nan)) + t: data + for t, data in results.event_study_effects.items() + if t < 0 and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) } if not pre_effects: @@ -670,9 +627,9 @@ def _extract_pre_period_params( pre_periods = sorted(pre_effects.keys()) n_pre = len(pre_periods) - effects = np.array([pre_effects[t]['effect'] for t in pre_periods]) - ses = np.array([pre_effects[t]['se'] for t in pre_periods]) - vcov = np.diag(ses ** 2) + effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) + ses = np.array([pre_effects[t]["se"] for t in pre_periods]) + vcov = np.diag(ses**2) return effects, ses, vcov, n_pre except ImportError: @@ -681,14 +638,14 @@ def _extract_pre_period_params( # Try SunAbrahamResults try: from diff_diff.sun_abraham import SunAbrahamResults + if isinstance(results, SunAbrahamResults): # Get pre-period effects (negative relative times) # Filter out normalization constraints (n_groups=0) and non-finite SEs pre_effects = { - t: data for t, data in results.event_study_effects.items() - if t < 0 - and data.get('n_groups', 1) > 0 - and np.isfinite(data.get('se', np.nan)) + t: data + for t, data in results.event_study_effects.items() + if t < 0 and data.get("n_groups", 1) > 0 and np.isfinite(data.get("se", np.nan)) } if not pre_effects: @@ -697,9 +654,9 @@ def _extract_pre_period_params( pre_periods = sorted(pre_effects.keys()) n_pre = len(pre_periods) - effects = np.array([pre_effects[t]['effect'] for t in pre_periods]) - ses = np.array([pre_effects[t]['se'] for t in pre_periods]) - vcov = np.diag(ses ** 2) + effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) + ses = np.array([pre_effects[t]["se"] for t in pre_periods]) + vcov = np.diag(ses**2) return effects, ses, vcov, n_pre except ImportError: @@ -888,9 +845,7 @@ def fit( M = mdv if np.isfinite(mdv) else np.max(ses) # Compute power at specified M - power, noncentrality, test_stat, critical_value = self._compute_power( - M, weights, vcov - ) + power, noncentrality, test_stat, critical_value = self._compute_power(M, weights, vcov) return PreTrendsPowerResults( power=power, @@ -977,10 +932,7 @@ def power_curve( M_grid = np.asarray(M_grid) # Compute power at each M - powers = np.array([ - self._compute_power(M, weights, vcov)[0] - for M in M_grid - ]) + powers = np.array([self._compute_power(M, weights, vcov)[0] for M in M_grid]) return PreTrendsPowerCurve( M_values=M_grid, @@ -1028,32 +980,20 @@ def sensitivity_to_honest_did( max_pre_se = np.max(pt_results.pre_period_ses) interpretation = [] - interpretation.append( - f"Minimum Detectable Violation (MDV): {mdv:.4f}" - ) - interpretation.append( - f"Max pre-period SE: {max_pre_se:.4f}" - ) + interpretation.append(f"Minimum Detectable Violation (MDV): {mdv:.4f}") + interpretation.append(f"Max pre-period SE: {max_pre_se:.4f}") if np.isfinite(mdv): # Ratio of MDV to max SE - gives sense of how many SEs the MDV is mdv_in_ses = mdv / max_pre_se if max_pre_se > 0 else np.inf - interpretation.append( - f"MDV / max(SE): {mdv_in_ses:.2f}" - ) + interpretation.append(f"MDV / max(SE): {mdv_in_ses:.2f}") if mdv_in_ses < 1: - interpretation.append( - "→ Pre-trends test is fairly sensitive to violations." - ) + interpretation.append("→ Pre-trends test is fairly sensitive to violations.") elif mdv_in_ses < 2: - interpretation.append( - "→ Pre-trends test has moderate sensitivity." - ) + interpretation.append("→ Pre-trends test has moderate sensitivity.") else: - interpretation.append( - "→ Pre-trends test has low power to detect violations." - ) + interpretation.append("→ Pre-trends test has low power to detect violations.") interpretation.append( " Consider using HonestDiD with larger M values for robustness." ) @@ -1061,9 +1001,7 @@ def sensitivity_to_honest_did( interpretation.append( "→ Pre-trends test cannot achieve target power for any violation size." ) - interpretation.append( - " Use HonestDiD sensitivity analysis for inference." - ) + interpretation.append(" Use HonestDiD sensitivity analysis for inference.") return { "mdv": mdv, diff --git a/diff_diff/results.py b/diff_diff/results.py index 213d09f..decdf5b 100644 --- a/diff_diff/results.py +++ b/diff_diff/results.py @@ -106,23 +106,27 @@ def summary(self, alpha: Optional[float] = None) -> str: if self.n_clusters is not None: lines.append(f"{'Number of clusters:':<25} {self.n_clusters:>10}") - lines.extend([ - "", - "-" * 70, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", - "-" * 70, - f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", - "-" * 70, - "", - f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", - ]) + lines.extend( + [ + "", + "-" * 70, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", + "-" * 70, + f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", + "-" * 70, + "", + f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", + ] + ) # Add significance codes - lines.extend([ - "", - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 70, - ]) + lines.extend( + [ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 70, + ] + ) return "\n".join(lines) @@ -187,6 +191,7 @@ def _get_significance_stars(p_value: float) -> str: rank-deficient matrices). """ import numpy as np + if np.isnan(p_value): return "" if p_value < 0.001: @@ -259,8 +264,10 @@ class MultiPeriodDiDResults: ---------- period_effects : dict[any, PeriodEffect] Dictionary mapping period identifiers to their PeriodEffect objects. + Contains all estimated period effects (pre and post, excluding + the reference period which is normalized to zero). avg_att : float - Average Treatment effect on the Treated across all post-periods. + Average Treatment effect on the Treated across post-periods only. avg_se : float Standard error of the average ATT. avg_t_stat : float @@ -279,6 +286,13 @@ class MultiPeriodDiDResults: List of pre-treatment period identifiers. post_periods : list List of post-treatment period identifiers. + reference_period : any, optional + The reference (omitted) period. Its coefficient is zero by + construction and it is excluded from ``period_effects``. + interaction_indices : dict, optional + Mapping from period identifier to column index in the full + variance-covariance matrix. Used internally for sub-VCV + extraction (e.g., by HonestDiD and PreTrendsPower). """ period_effects: Dict[Any, PeriodEffect] @@ -298,6 +312,8 @@ class MultiPeriodDiDResults: residuals: Optional[np.ndarray] = field(default=None) fitted_values: Optional[np.ndarray] = field(default=None) r_squared: Optional[float] = field(default=None) + reference_period: Optional[Any] = field(default=None) + interaction_indices: Optional[Dict[Any, int]] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -308,6 +324,16 @@ def __repr__(self) -> str: f"n_post_periods={len(self.post_periods)})" ) + @property + def pre_period_effects(self) -> Dict[Any, PeriodEffect]: + """Pre-period effects only (for parallel trends assessment).""" + return {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods} + + @property + def post_period_effects(self) -> Dict[Any, PeriodEffect]: + """Post-period effects only.""" + return {p: pe for p, pe in self.period_effects.items() if p in self.post_periods} + def summary(self, alpha: Optional[float] = None) -> str: """ Generate a formatted summary of the estimation results. @@ -341,15 +367,49 @@ def summary(self, alpha: Optional[float] = None) -> str: if self.r_squared is not None: lines.append(f"{'R-squared:':<25} {self.r_squared:>10.4f}") - # Period-specific effects - lines.extend([ - "", - "-" * 80, - "Period-Specific Treatment Effects".center(80), - "-" * 80, - f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 80, - ]) + # Pre-period effects (parallel trends test) + pre_effects = {p: pe for p, pe in self.period_effects.items() if p in self.pre_periods} + if pre_effects: + lines.extend( + [ + "", + "-" * 80, + "Pre-Period Effects (Parallel Trends Test)".center(80), + "-" * 80, + f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 80, + ] + ) + + for period in self.pre_periods: + if period in self.period_effects: + pe = self.period_effects[period] + stars = pe.significance_stars + lines.append( + f"{str(period):<15} {pe.effect:>12.4f} {pe.se:>12.4f} " + f"{pe.t_stat:>10.3f} {pe.p_value:>10.4f} {stars:>6}" + ) + + # Show reference period + if self.reference_period is not None: + lines.append( + f"[ref: {self.reference_period}]" + f"{'0.0000':>21} {'---':>12} {'---':>10} {'---':>10} {'':>6}" + ) + + lines.append("-" * 80) + + # Post-period treatment effects + lines.extend( + [ + "", + "-" * 80, + "Post-Period Treatment Effects".center(80), + "-" * 80, + f"{'Period':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 80, + ] + ) for period in self.post_periods: pe = self.period_effects[period] @@ -360,27 +420,31 @@ def summary(self, alpha: Optional[float] = None) -> str: ) # Average effect - lines.extend([ - "-" * 80, - "", - "-" * 80, - "Average Treatment Effect (across post-periods)".center(80), - "-" * 80, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 80, - f"{'Avg ATT':<15} {self.avg_att:>12.4f} {self.avg_se:>12.4f} " - f"{self.avg_t_stat:>10.3f} {self.avg_p_value:>10.4f} {self.significance_stars:>6}", - "-" * 80, - "", - f"{conf_level}% Confidence Interval: [{self.avg_conf_int[0]:.4f}, {self.avg_conf_int[1]:.4f}]", - ]) + lines.extend( + [ + "-" * 80, + "", + "-" * 80, + "Average Treatment Effect (across post-periods)".center(80), + "-" * 80, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 80, + f"{'Avg ATT':<15} {self.avg_att:>12.4f} {self.avg_se:>12.4f} " + f"{self.avg_t_stat:>10.3f} {self.avg_p_value:>10.4f} {self.significance_stars:>6}", + "-" * 80, + "", + f"{conf_level}% Confidence Interval: [{self.avg_conf_int[0]:.4f}, {self.avg_conf_int[1]:.4f}]", + ] + ) # Add significance codes - lines.extend([ - "", - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 80, - ]) + lines.extend( + [ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 80, + ] + ) return "\n".join(lines) @@ -408,9 +472,15 @@ def get_effect(self, period) -> PeriodEffect: If the period is not found in post-treatment periods. """ if period not in self.period_effects: + if hasattr(self, "reference_period") and period == self.reference_period: + raise KeyError( + f"Period '{period}' is the reference period (coefficient " + f"normalized to zero by construction). Its effect is 0.0 with " + f"no associated uncertainty." + ) raise KeyError( f"Period '{period}' not found. " - f"Available post-periods: {list(self.period_effects.keys())}" + f"Available periods: {list(self.period_effects.keys())}" ) return self.period_effects[period] @@ -436,6 +506,7 @@ def to_dict(self) -> Dict[str, Any]: "n_pre_periods": len(self.pre_periods), "n_post_periods": len(self.post_periods), "r_squared": self.r_squared, + "reference_period": self.reference_period, } # Add period-specific effects @@ -453,20 +524,23 @@ def to_dataframe(self) -> pd.DataFrame: Returns ------- pd.DataFrame - DataFrame with one row per post-treatment period. + DataFrame with one row per estimated period (pre and post). """ rows = [] for period, pe in self.period_effects.items(): - rows.append({ - "period": period, - "effect": pe.effect, - "se": pe.se, - "t_stat": pe.t_stat, - "p_value": pe.p_value, - "conf_int_lower": pe.conf_int[0], - "conf_int_upper": pe.conf_int[1], - "is_significant": pe.is_significant, - }) + rows.append( + { + "period": period, + "effect": pe.effect, + "se": pe.se, + "t_stat": pe.t_stat, + "p_value": pe.p_value, + "conf_int_lower": pe.conf_int[0], + "conf_int_upper": pe.conf_int[1], + "is_significant": pe.is_significant, + "is_post": period in self.post_periods, + } + ) return pd.DataFrame(rows) @property @@ -587,29 +661,31 @@ def summary(self, alpha: Optional[float] = None) -> str: if self.variance_method == "bootstrap" and self.n_bootstrap is not None: lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") - lines.extend([ - "", - "-" * 75, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", - "-" * 75, - f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", - "-" * 75, - "", - f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", - ]) + lines.extend( + [ + "", + "-" * 75, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} {'t-stat':>10} {'P>|t|':>10} {'':>5}", + "-" * 75, + f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} {self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", + "-" * 75, + "", + f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", + ] + ) # Show top unit weights if self.unit_weights: - sorted_weights = sorted( - self.unit_weights.items(), key=lambda x: x[1], reverse=True - ) + sorted_weights = sorted(self.unit_weights.items(), key=lambda x: x[1], reverse=True) top_n = min(5, len(sorted_weights)) - lines.extend([ - "", - "-" * 75, - "Top Unit Weights (Synthetic Control)".center(75), - "-" * 75, - ]) + lines.extend( + [ + "", + "-" * 75, + "Top Unit Weights (Synthetic Control)".center(75), + "-" * 75, + ] + ) for unit, weight in sorted_weights[:top_n]: if weight > 0.001: # Only show meaningful weights lines.append(f" Unit {unit}: {weight:.4f}") @@ -619,11 +695,13 @@ def summary(self, alpha: Optional[float] = None) -> str: lines.append(f" ({n_nonzero} units with weight > 0.001)") # Add significance codes - lines.extend([ - "", - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 75, - ]) + lines.extend( + [ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 75, + ] + ) return "\n".join(lines) @@ -680,10 +758,9 @@ def get_unit_weights_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with unit IDs and their weights. """ - return pd.DataFrame([ - {"unit": unit, "weight": weight} - for unit, weight in self.unit_weights.items() - ]).sort_values("weight", ascending=False) + return pd.DataFrame( + [{"unit": unit, "weight": weight} for unit, weight in self.unit_weights.items()] + ).sort_values("weight", ascending=False) def get_time_weights_df(self) -> pd.DataFrame: """ @@ -694,10 +771,9 @@ def get_time_weights_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with time periods and their weights. """ - return pd.DataFrame([ - {"period": period, "weight": weight} - for period, weight in self.time_weights.items() - ]) + return pd.DataFrame( + [{"period": period, "weight": weight} for period, weight in self.time_weights.items()] + ) @property def is_significant(self) -> bool: diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index 3ba3a80..df9c301 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -167,8 +167,7 @@ def plot_event_study( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) from scipy import stats as scipy_stats @@ -181,14 +180,14 @@ def plot_event_study( extracted = _extract_plot_data( results, periods, pre_periods, post_periods, reference_period ) - effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred = extracted + effects, se, periods, pre_periods, post_periods, reference_period, reference_inferred = ( + extracted + ) # If reference was inferred from results, it was NOT explicitly provided if reference_inferred: reference_period_explicit = False elif effects is None or se is None: - raise ValueError( - "Must provide either 'results' or both 'effects' and 'se'" - ) + raise ValueError("Must provide either 'results' or both 'effects' and 'se'") # Ensure effects and se are dicts if not isinstance(effects, dict): @@ -207,8 +206,7 @@ def plot_event_study( # Auto-inferred reference periods (from CallawaySantAnna) just get hollow marker styling, # NO normalization. This prevents unintended normalization when the reference period # isn't a true identifying constraint (e.g., CallawaySantAnna with base_period="varying"). - if (reference_period is not None and reference_period in effects and - reference_period_explicit): + if reference_period is not None and reference_period in effects and reference_period_explicit: ref_effect = effects[reference_period] if np.isfinite(ref_effect): effects = {p: e - ref_effect for p, e in effects.items()} @@ -233,14 +231,16 @@ def plot_event_study( ci_lower = np.nan ci_upper = np.nan - plot_data.append({ - 'period': period, - 'effect': effect, - 'se': std_err, - 'ci_lower': ci_lower, - 'ci_upper': ci_upper, - 'is_reference': period == reference_period, - }) + plot_data.append( + { + "period": period, + "effect": effect, + "se": std_err, + "ci_lower": ci_lower, + "ci_upper": ci_upper, + "is_reference": period == reference_period, + } + ) if not plot_data: raise ValueError("No valid data to plot") @@ -254,52 +254,63 @@ def plot_event_study( fig = ax.get_figure() # Convert periods to numeric for plotting - period_to_x = {p: i for i, p in enumerate(df['period'])} - x_vals = [period_to_x[p] for p in df['period']] + period_to_x = {p: i for i, p in enumerate(df["period"])} + x_vals = [period_to_x[p] for p in df["period"]] # Shade pre-treatment region if shade_pre and pre_periods is not None: pre_x = [period_to_x[p] for p in pre_periods if p in period_to_x] if pre_x: - ax.axvspan(min(pre_x) - 0.5, max(pre_x) + 0.5, - color=shade_color, alpha=0.5, zorder=0) + ax.axvspan(min(pre_x) - 0.5, max(pre_x) + 0.5, color=shade_color, alpha=0.5, zorder=0) # Draw horizontal zero line if show_zero_line: - ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, zorder=1) + ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, zorder=1) # Draw vertical reference line if show_reference_line and reference_period is not None: if reference_period in period_to_x: ref_x = period_to_x[reference_period] - ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1) + ax.axvline(x=ref_x, color="gray", linestyle=":", linewidth=1, zorder=1) # Plot error bars (only for entries with finite CI) - has_ci = df['ci_lower'].notna() & df['ci_upper'].notna() + has_ci = df["ci_lower"].notna() & df["ci_upper"].notna() if has_ci.any(): df_with_ci = df[has_ci] - x_with_ci = [period_to_x[p] for p in df_with_ci['period']] + x_with_ci = [period_to_x[p] for p in df_with_ci["period"]] yerr = [ - df_with_ci['effect'] - df_with_ci['ci_lower'], - df_with_ci['ci_upper'] - df_with_ci['effect'] + df_with_ci["effect"] - df_with_ci["ci_lower"], + df_with_ci["ci_upper"] - df_with_ci["effect"], ] ax.errorbar( - x_with_ci, df_with_ci['effect'], yerr=yerr, - fmt='none', color=color, capsize=capsize, linewidth=linewidth, - capthick=linewidth, zorder=2 + x_with_ci, + df_with_ci["effect"], + yerr=yerr, + fmt="none", + color=color, + capsize=capsize, + linewidth=linewidth, + capthick=linewidth, + zorder=2, ) # Plot point estimates for i, row in df.iterrows(): - x = period_to_x[row['period']] - if row['is_reference']: + x = period_to_x[row["period"]] + if row["is_reference"]: # Hollow marker for reference period - ax.plot(x, row['effect'], marker=marker, markersize=markersize, - markerfacecolor='white', markeredgecolor=color, - markeredgewidth=2, zorder=3) + ax.plot( + x, + row["effect"], + marker=marker, + markersize=markersize, + markerfacecolor="white", + markeredgecolor=color, + markeredgewidth=2, + zorder=3, + ) else: - ax.plot(x, row['effect'], marker=marker, markersize=markersize, - color=color, zorder=3) + ax.plot(x, row["effect"], marker=marker, markersize=markersize, color=color, zorder=3) # Set labels and title ax.set_xlabel(xlabel) @@ -308,10 +319,10 @@ def plot_event_study( # Set x-axis ticks ax.set_xticks(x_vals) - ax.set_xticklabels([str(p) for p in df['period']]) + ax.set_xticklabels([str(p) for p in df["period"]]) # Add grid - ax.grid(True, alpha=0.3, axis='y') + ax.grid(True, alpha=0.3, axis="y") # Tight layout fig.tight_layout() @@ -342,24 +353,24 @@ def _extract_plot_data( """ # Handle DataFrame input if isinstance(results, pd.DataFrame): - if 'period' not in results.columns: + if "period" not in results.columns: raise ValueError("DataFrame must have 'period' column") - if 'effect' not in results.columns: + if "effect" not in results.columns: raise ValueError("DataFrame must have 'effect' column") - if 'se' not in results.columns: + if "se" not in results.columns: raise ValueError("DataFrame must have 'se' column") - effects = dict(zip(results['period'], results['effect'])) - se = dict(zip(results['period'], results['se'])) + effects = dict(zip(results["period"], results["effect"])) + se = dict(zip(results["period"], results["se"])) if periods is None: - periods = list(results['period']) + periods = list(results["period"]) # DataFrame input: reference_period was already set by caller, never inferred here return effects, se, periods, pre_periods, post_periods, reference_period, False # Handle MultiPeriodDiDResults - if hasattr(results, 'period_effects'): + if hasattr(results, "period_effects"): effects = {} se = {} @@ -367,26 +378,35 @@ def _extract_plot_data( effects[period] = pe.effect se[period] = pe.se - if pre_periods is None and hasattr(results, 'pre_periods'): + if pre_periods is None and hasattr(results, "pre_periods"): pre_periods = results.pre_periods - if post_periods is None and hasattr(results, 'post_periods'): + if post_periods is None and hasattr(results, "post_periods"): post_periods = results.post_periods if periods is None: - periods = post_periods + periods = sorted(results.period_effects.keys()) - # MultiPeriodDiDResults: reference_period was already set by caller, never inferred here - return effects, se, periods, pre_periods, post_periods, reference_period, False + # Auto-detect reference period from results if not explicitly provided + ref_inferred = False + if ( + reference_period is None + and hasattr(results, "reference_period") + and results.reference_period is not None + ): + reference_period = results.reference_period + ref_inferred = True + + return effects, se, periods, pre_periods, post_periods, reference_period, ref_inferred # Handle CallawaySantAnnaResults (event study aggregation) - if hasattr(results, 'event_study_effects') and results.event_study_effects is not None: + if hasattr(results, "event_study_effects") and results.event_study_effects is not None: effects = {} se = {} for rel_period, effect_data in results.event_study_effects.items(): - effects[rel_period] = effect_data['effect'] - se[rel_period] = effect_data['se'] + effects[rel_period] = effect_data["effect"] + se[rel_period] = effect_data["se"] if periods is None: periods = sorted(effects.keys()) @@ -400,7 +420,7 @@ def _extract_plot_data( # Detect reference period from n_groups=0 marker (normalization constraint) # This handles anticipation > 0 where reference is at e = -1 - anticipation for period, effect_data in results.event_study_effects.items(): - if effect_data.get('n_groups', 1) == 0: + if effect_data.get("n_groups", 1) == 0: reference_period = period break # Fallback to -1 if no marker found (backward compatibility) @@ -467,13 +487,12 @@ def plot_group_effects( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) from scipy import stats as scipy_stats - if not hasattr(results, 'group_time_effects'): + if not hasattr(results, "group_time_effects"): raise TypeError("results must be a CallawaySantAnnaResults object") # Get groups to plot @@ -494,8 +513,7 @@ def plot_group_effects( for i, group in enumerate(groups): # Get effects for this group group_effects = [ - (t, data) for (g, t), data in results.group_time_effects.items() - if g == group + (t, data) for (g, t), data in results.group_time_effects.items() if g == group ] group_effects.sort(key=lambda x: x[0]) @@ -503,26 +521,31 @@ def plot_group_effects( continue times = [t for t, _ in group_effects] - effects = [data['effect'] for _, data in group_effects] - ses = [data['se'] for _, data in group_effects] + effects = [data["effect"] for _, data in group_effects] + ses = [data["se"] for _, data in group_effects] yerr = [ [e - (e - critical_value * s) for e, s in zip(effects, ses)], - [(e + critical_value * s) - e for e, s in zip(effects, ses)] + [(e + critical_value * s) - e for e, s in zip(effects, ses)], ] ax.errorbar( - times, effects, yerr=yerr, - label=f'Cohort {group}', color=colors[i], - marker='o', capsize=3, linewidth=1.5 + times, + effects, + yerr=yerr, + label=f"Cohort {group}", + color=colors[i], + marker="o", + capsize=3, + linewidth=1.5, ) - ax.axhline(y=0, color='gray', linestyle='--', linewidth=1) + ax.axhline(y=0, color="gray", linestyle="--", linewidth=1) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) - ax.legend(loc='best') - ax.grid(True, alpha=0.3, axis='y') + ax.legend(loc="best") + ax.grid(True, alpha=0.3, axis="y") fig.tight_layout() @@ -615,8 +638,7 @@ def plot_sensitivity( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) # Create figure if needed @@ -633,52 +655,45 @@ def plot_sensitivity( ax.axhline( y=sensitivity_results.original_estimate, color=original_color, - linestyle='-', + linestyle="-", linewidth=1.5, - label='Original estimate', - alpha=0.7 + label="Original estimate", + alpha=0.7, ) # Plot zero line - ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5) + ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5) # Plot identified set bounds if show_bounds: ax.fill_between( - M, bounds_arr[:, 0], bounds_arr[:, 1], + M, + bounds_arr[:, 0], + bounds_arr[:, 1], alpha=bounds_alpha, color=bounds_color, - label='Identified set' + label="Identified set", ) # Plot confidence intervals if show_ci: - ax.plot( - M, ci_arr[:, 0], - color=ci_color, - linewidth=ci_linewidth, - label='Robust CI' - ) - ax.plot( - M, ci_arr[:, 1], - color=ci_color, - linewidth=ci_linewidth - ) + ax.plot(M, ci_arr[:, 0], color=ci_color, linewidth=ci_linewidth, label="Robust CI") + ax.plot(M, ci_arr[:, 1], color=ci_color, linewidth=ci_linewidth) # Plot breakdown line if breakdown_line and sensitivity_results.breakdown_M is not None: ax.axvline( x=sensitivity_results.breakdown_M, color=breakdown_color, - linestyle=':', + linestyle=":", linewidth=2, - label=f'Breakdown (M={sensitivity_results.breakdown_M:.2f})' + label=f"Breakdown (M={sensitivity_results.breakdown_M:.2f})", ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) - ax.legend(loc='best') + ax.legend(loc="best") ax.grid(True, alpha=0.3) fig.tight_layout() @@ -758,8 +773,7 @@ def plot_honest_event_study( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) from scipy import stats as scipy_stats @@ -767,31 +781,21 @@ def plot_honest_event_study( # Get original results for standard CIs original_results = honest_results.original_results if original_results is None: - raise ValueError( - "HonestDiDResults must have original_results to plot event study" - ) + raise ValueError("HonestDiDResults must have original_results to plot event study") # Extract data from original results - if hasattr(original_results, 'period_effects'): + if hasattr(original_results, "period_effects"): # MultiPeriodDiDResults - effects_dict = { - p: pe.effect for p, pe in original_results.period_effects.items() - } - se_dict = { - p: pe.se for p, pe in original_results.period_effects.items() - } + effects_dict = {p: pe.effect for p, pe in original_results.period_effects.items()} + se_dict = {p: pe.se for p, pe in original_results.period_effects.items()} if periods is None: periods = list(original_results.period_effects.keys()) - elif hasattr(original_results, 'event_study_effects'): + elif hasattr(original_results, "event_study_effects"): # CallawaySantAnnaResults effects_dict = { - t: data['effect'] - for t, data in original_results.event_study_effects.items() - } - se_dict = { - t: data['se'] - for t, data in original_results.event_study_effects.items() + t: data["effect"] for t, data in original_results.event_study_effects.items() } + se_dict = {t: data["se"] for t, data in original_results.event_study_effects.items()} if periods is None: periods = sorted(original_results.event_study_effects.keys()) else: @@ -815,62 +819,73 @@ def plot_honest_event_study( # Get honest bounds if available for each period if honest_results.event_study_bounds: - honest_ci_lower = [ - honest_results.event_study_bounds[p]['ci_lb'] - for p in periods - ] - honest_ci_upper = [ - honest_results.event_study_bounds[p]['ci_ub'] - for p in periods - ] + honest_ci_lower = [honest_results.event_study_bounds[p]["ci_lb"] for p in periods] + honest_ci_upper = [honest_results.event_study_bounds[p]["ci_ub"] for p in periods] else: # Use scalar bounds applied to all periods honest_ci_lower = [honest_results.ci_lb] * len(periods) honest_ci_upper = [honest_results.ci_ub] * len(periods) # Zero line - ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5) + ax.axhline(y=0, color="gray", linestyle="--", linewidth=1, alpha=0.5) # Plot original CIs (thinner, background) yerr_orig = [ [e - lower for e, lower in zip(effects, original_ci_lower)], - [u - e for e, u in zip(effects, original_ci_upper)] + [u - e for e, u in zip(effects, original_ci_upper)], ] ax.errorbar( - x_vals, effects, yerr=yerr_orig, - fmt='none', color=original_color, capsize=capsize - 1, - linewidth=1, alpha=0.6, label='Standard CI' + x_vals, + effects, + yerr=yerr_orig, + fmt="none", + color=original_color, + capsize=capsize - 1, + linewidth=1, + alpha=0.6, + label="Standard CI", ) # Plot honest CIs (thicker, foreground) yerr_honest = [ [e - lower for e, lower in zip(effects, honest_ci_lower)], - [u - e for e, u in zip(effects, honest_ci_upper)] + [u - e for e, u in zip(effects, honest_ci_upper)], ] ax.errorbar( - x_vals, effects, yerr=yerr_honest, - fmt='none', color=honest_color, capsize=capsize, - linewidth=2, label=f'Honest CI (M={honest_results.M:.2f})' + x_vals, + effects, + yerr=yerr_honest, + fmt="none", + color=honest_color, + capsize=capsize, + linewidth=2, + label=f"Honest CI (M={honest_results.M:.2f})", ) # Plot point estimates for i, (x, effect, period) in enumerate(zip(x_vals, effects, periods)): is_ref = period == reference_period if is_ref: - ax.plot(x, effect, marker=marker, markersize=markersize, - markerfacecolor='white', markeredgecolor=honest_color, - markeredgewidth=2, zorder=3) + ax.plot( + x, + effect, + marker=marker, + markersize=markersize, + markerfacecolor="white", + markeredgecolor=honest_color, + markeredgewidth=2, + zorder=3, + ) else: - ax.plot(x, effect, marker=marker, markersize=markersize, - color=honest_color, zorder=3) + ax.plot(x, effect, marker=marker, markersize=markersize, color=honest_color, zorder=3) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.set_xticks(x_vals) ax.set_xticklabels([str(p) for p in periods]) - ax.legend(loc='best') - ax.grid(True, alpha=0.3, axis='y') + ax.legend(loc="best") + ax.grid(True, alpha=0.3, axis="y") fig.tight_layout() @@ -993,16 +1008,15 @@ def plot_bacon( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) # Default colors if colors is None: colors = { - "treated_vs_never": "#22c55e", # Green - clean comparison - "earlier_vs_later": "#3b82f6", # Blue - valid comparison - "later_vs_earlier": "#ef4444", # Red - forbidden comparison + "treated_vs_never": "#22c55e", # Green - clean comparison + "earlier_vs_later": "#3b82f6", # Blue - valid comparison + "later_vs_earlier": "#ef4444", # Red - forbidden comparison } # Default titles @@ -1020,8 +1034,17 @@ def plot_bacon( if plot_type == "scatter": _plot_bacon_scatter( - ax, results, colors, marker, markersize, alpha, - show_weighted_avg, show_twfe_line, xlabel, ylabel, title + ax, + results, + colors, + marker, + markersize, + alpha, + show_weighted_avg, + show_twfe_line, + xlabel, + ylabel, + title, ) elif plot_type == "bar": _plot_bacon_bar(ax, results, colors, alpha, ylabel, title) @@ -1073,13 +1096,14 @@ def _plot_bacon_scatter( estimates = [p[0] for p in points] weights = [p[1] for p in points] ax.scatter( - estimates, weights, + estimates, + weights, c=colors[ctype], label=labels[ctype], marker=marker, s=markersize, alpha=alpha, - edgecolors='white', + edgecolors="white", linewidths=0.5, ) @@ -1091,7 +1115,7 @@ def _plot_bacon_scatter( ax.axvline( x=avg_effect, color=colors[ctype], - linestyle='--', + linestyle="--", alpha=0.5, linewidth=1.5, ) @@ -1100,20 +1124,20 @@ def _plot_bacon_scatter( if show_twfe_line: ax.axvline( x=results.twfe_estimate, - color='black', - linestyle='-', + color="black", + linestyle="-", linewidth=2, - label=f'TWFE = {results.twfe_estimate:.4f}', + label=f"TWFE = {results.twfe_estimate:.4f}", ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) - ax.legend(loc='best') + ax.legend(loc="best") ax.grid(True, alpha=0.3) # Add zero line - ax.axvline(x=0, color='gray', linestyle=':', alpha=0.5) + ax.axvline(x=0, color="gray", linestyle=":", alpha=0.5) def _plot_bacon_bar( @@ -1147,7 +1171,7 @@ def _plot_bacon_bar( bar_weights, color=bar_colors, alpha=alpha, - edgecolor='white', + edgecolor="white", linewidth=1, ) @@ -1156,14 +1180,14 @@ def _plot_bacon_bar( if weight > 0.01: # Only label if > 1% height = bar.get_height() ax.annotate( - f'{weight:.1%}', + f"{weight:.1%}", xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", - ha='center', - va='bottom', + ha="center", + va="bottom", fontsize=10, - fontweight='bold', + fontweight="bold", ) # Add weighted average effect annotations @@ -1172,13 +1196,13 @@ def _plot_bacon_bar( effect = effects[ctype] if effect is not None and weights[ctype] > 0.01: ax.annotate( - f'β = {effect:.3f}', + f"β = {effect:.3f}", xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2), - ha='center', - va='center', + ha="center", + va="center", fontsize=9, - color='white', - fontweight='bold', + color="white", + fontweight="bold", ) ax.set_ylabel(ylabel) @@ -1186,17 +1210,18 @@ def _plot_bacon_bar( ax.set_ylim(0, 1.1) # Add horizontal line at total weight = 1 - ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5) + ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5) # Add TWFE estimate as text ax.text( - 0.98, 0.98, - f'TWFE = {results.twfe_estimate:.4f}', + 0.98, + 0.98, + f"TWFE = {results.twfe_estimate:.4f}", transform=ax.transAxes, - ha='right', - va='top', + ha="right", + va="top", fontsize=10, - bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), + bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), ) @@ -1310,8 +1335,7 @@ def plot_power_curve( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) # Extract data from results if provided @@ -1325,9 +1349,7 @@ def plot_power_curve( effect_sizes = results["sample_size"].tolist() plot_type = "sample" else: - raise ValueError( - "DataFrame must have 'effect_size' or 'sample_size' column" - ) + raise ValueError("DataFrame must have 'effect_size' or 'sample_size' column") powers = results["power"].tolist() elif hasattr(results, "effect_sizes") and hasattr(results, "powers"): # SimulationPowerResults @@ -1343,13 +1365,9 @@ def plot_power_curve( "Use PowerAnalysis.power_curve() to generate curve data." ) else: - raise TypeError( - f"Cannot extract power curve data from {type(results).__name__}" - ) + raise TypeError(f"Cannot extract power curve data from {type(results).__name__}") elif effect_sizes is None or powers is None: - raise ValueError( - "Must provide either 'results' or both 'effect_sizes' and 'powers'" - ) + raise ValueError("Must provide either 'results' or both 'effect_sizes' and 'powers'") # Default titles and labels if title is None: @@ -1371,12 +1389,7 @@ def plot_power_curve( fig = ax.get_figure() # Plot power curve - ax.plot( - effect_sizes, powers, - color=color, - linewidth=linewidth, - label="Power" - ) + ax.plot(effect_sizes, powers, color=color, linewidth=linewidth, label="Power") # Add target power line if show_target_line: @@ -1386,7 +1399,7 @@ def plot_power_curve( linestyle="--", linewidth=1.5, alpha=0.7, - label=f"Target power ({target_power:.0%})" + label=f"Target power ({target_power:.0%})", ) # Add MDE line @@ -1397,7 +1410,7 @@ def plot_power_curve( linestyle=":", linewidth=1.5, alpha=0.7, - label=f"MDE = {mde:.3f}" + label=f"MDE = {mde:.3f}", ) # Mark intersection point @@ -1415,12 +1428,7 @@ def plot_power_curve( power_at_mde = None if power_at_mde is not None: - ax.scatter( - [mde], [power_at_mde], - color=mde_color, - s=50, - zorder=5 - ) + ax.scatter([mde], [power_at_mde], color=mde_color, s=50, zorder=5) # Configure axes ax.set_xlabel(xlabel) @@ -1431,7 +1439,7 @@ def plot_power_curve( ax.set_ylim(0, 1.05) # Format y-axis as percentage - ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}")) if show_grid: ax.grid(True, alpha=0.3) @@ -1566,8 +1574,7 @@ def plot_pretrends_power( import matplotlib.pyplot as plt except ImportError: raise ImportError( - "matplotlib is required for plotting. " - "Install it with: pip install matplotlib" + "matplotlib is required for plotting. " "Install it with: pip install matplotlib" ) # Extract data from results if provided @@ -1598,13 +1605,9 @@ def plot_pretrends_power( # Just show MDV marker powers = None else: - raise TypeError( - f"Cannot extract power curve data from {type(results).__name__}" - ) + raise TypeError(f"Cannot extract power curve data from {type(results).__name__}") elif M_values is None or powers is None: - raise ValueError( - "Must provide either 'results' or both 'M_values' and 'powers'" - ) + raise ValueError("Must provide either 'results' or both 'M_values' and 'powers'") # Create figure if needed if ax is None: @@ -1614,12 +1617,7 @@ def plot_pretrends_power( # Plot power curve if we have powers if powers is not None: - ax.plot( - M_values, powers, - color=color, - linewidth=linewidth, - label="Power" - ) + ax.plot(M_values, powers, color=color, linewidth=linewidth, label="Power") # Add target power line if show_target_line: @@ -1629,7 +1627,7 @@ def plot_pretrends_power( linestyle="--", linewidth=1.5, alpha=0.7, - label=f"Target power ({target_power:.0%})" + label=f"Target power ({target_power:.0%})", ) # Add MDV line @@ -1640,7 +1638,7 @@ def plot_pretrends_power( linestyle=":", linewidth=1.5, alpha=0.7, - label=f"MDV = {mdv:.3f}" + label=f"MDV = {mdv:.3f}", ) # Mark intersection point if we have powers @@ -1650,12 +1648,7 @@ def plot_pretrends_power( power_arr = np.array(powers) if M_arr.min() <= mdv <= M_arr.max(): power_at_mdv = np.interp(mdv, M_arr, power_arr) - ax.scatter( - [mdv], [power_at_mdv], - color=mdv_color, - s=50, - zorder=5 - ) + ax.scatter([mdv], [power_at_mdv], color=mdv_color, s=50, zorder=5) # Configure axes ax.set_xlabel(xlabel) @@ -1666,7 +1659,7 @@ def plot_pretrends_power( ax.set_ylim(0, 1.05) # Format y-axis as percentage - ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}')) + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.0%}")) if show_grid: ax.grid(True, alpha=0.3) diff --git a/docs/api/results.rst b/docs/api/results.rst index 7d38643..bd26d7a 100644 --- a/docs/api/results.rst +++ b/docs/api/results.rst @@ -54,6 +54,10 @@ Results from MultiPeriodDiD event study estimation. ~MultiPeriodDiDResults.att ~MultiPeriodDiDResults.pre_periods ~MultiPeriodDiDResults.post_periods + ~MultiPeriodDiDResults.reference_period + ~MultiPeriodDiDResults.interaction_indices + ~MultiPeriodDiDResults.pre_period_effects + ~MultiPeriodDiDResults.post_period_effects PeriodEffect ------------ diff --git a/docs/choosing_estimator.rst b/docs/choosing_estimator.rst index ceb48dd..45eb179 100644 --- a/docs/choosing_estimator.rst +++ b/docs/choosing_estimator.rst @@ -107,9 +107,10 @@ Multi-Period Event Study Use :class:`~diff_diff.MultiPeriodDiD` when: -- You want to visualize treatment effects over time -- You need to test for pre-trends (placebo effects before treatment) -- You want to examine treatment effect dynamics +- You want a full event-study with pre and post treatment effects +- You need pre-period coefficients to assess parallel trends +- You want to visualize treatment effect dynamics over time +- All treated units receive treatment at the same time (simultaneous adoption) .. code-block:: python diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 1bf16b9..82755d4 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -81,41 +81,111 @@ where τ is the ATT. ## MultiPeriodDiD **Primary source:** Event study methodology -- Freyaldenhoven, S., Hansen, C., Pérez, J.P., & Shapiro, J.M. (2021). Visualization, identification, and estimation in the linear panel event-study design. NBER Working Paper 29170. +- Freyaldenhoven, S., Hansen, C., Pérez, J.P., & Shapiro, J.M. (2021). Visualization, + identification, and estimation in the linear panel event-study design. NBER Working Paper 29170. +- Wooldridge, J.M. (2010). *Econometric Analysis of Cross Section and Panel Data*, 2nd ed. + MIT Press, Ch. 10, 13. +- Angrist, J.D., & Pischke, J.-S. (2009). *Mostly Harmless Econometrics*. Princeton University Press. + +**Scope:** Simultaneous adoption event study. All treated units receive treatment at the +same time. For staggered adoption (different units treated at different times), use +CallawaySantAnna or SunAbraham instead. **Key implementation requirements:** *Assumption checks / warnings:* -- Requires multiple pre and post periods -- Reference period (typically t=-1) must be specified or defaulted -- Warns if treatment timing varies across units (suggests staggered estimator) +- Treatment indicator must be binary (0/1) with variation in both groups +- Requires at least 2 pre-treatment and 1 post-treatment period + (need ≥2 pre-periods to test parallel trends) +- Reference period defaults to last pre-treatment period (e=-1 convention) +- Warns if treatment timing varies across units (suggests CallawaySantAnna) +- Treatment must be an absorbing state (once treated, always treated) + +*Estimator equation (target specification):* + +With unit and time fixed effects absorbed: + +``` +Y_it = α_i + γ_t + Σ_{e≠-1} δ_e × D_i × 1(t = E + e) + X'β + ε_it +``` + +where: +- α_i = unit fixed effects (absorbed) +- γ_t = time fixed effects (absorbed) +- E = common treatment time (same for all treated units) +- D_i = treatment group indicator (1=treated, 0=control) +- e = t - E = event time (relative periods to treatment) +- δ_e = treatment effect at event time e +- δ_{-1} = 0 (reference period, omitted for identification) + +For simultaneous treatment, this is equivalent to interacting treatment with +calendar-time indicators: -*Estimator equation (as implemented):* ``` -Y_it = α_i + γ_t + Σ_{e≠-1} δ_e × 1(t - E_i = e) + X'β + ε_it +Y_it = α_i + γ_t + Σ_{t≠t_ref} δ_t × (D_i × Period_t) + X'β + ε_it ``` -where E_i is treatment time for unit i, and δ_e are event-study coefficients. + +where interactions are included for ALL periods (pre and post), not just post-treatment. + +Pre-treatment coefficients (e < -1) test the parallel trends assumption: +under H0 of parallel trends, δ_e = 0 for all e < 0. + +Post-treatment coefficients (e ≥ 0) estimate dynamic treatment effects. + +Average ATT over post-treatment periods: + +``` +ATT_avg = (1/|post|) × Σ_{e≥0} δ_e +``` + +with SE computed from the sub-VCV matrix: + +``` +Var(ATT_avg) = 1'V1 / |post|² +``` + +where V is the VCV sub-matrix for post-treatment δ_e coefficients. *Standard errors:* -- Default: Cluster-robust at unit level -- Event-study coefficients use appropriate degrees of freedom +- Default: Cluster-robust at unit level (accounts for within-unit serial correlation) +- Alternative: HC1 heteroskedasticity-robust (for cross-sectional data) +- Optional: Wild cluster bootstrap (complex for multi-coefficient testing; + requires joint bootstrap distribution) +- Degrees of freedom adjusted for absorbed fixed effects *Edge cases:* -- Unbalanced panels: only uses observations where event-time is defined -- Never-treated units: event-time indicators are all zero -- Endpoint binning: distant event times can be binned -- Rank-deficient design matrix (collinearity): warns and sets NA for dropped coefficients (R-style, matches `lm()`) -- Average ATT (`avg_att`) is NA if any post-period effect is unidentified (R-style NA propagation) +- Reference period: omitted from design matrix; coefficient is zero by construction. + Default is last pre-treatment period (e=-1). User can override via `reference_period`. +- Never-treated units: all event-time indicators are zero; they identify the time + fixed effects and serve as comparison group. +- Endpoint binning: distant event times (e.g., e < -K or e > K) should be binned + into endpoint indicators to avoid sparse cells. This prevents imprecise estimates + at extreme leads/lags. +- Unbalanced panels: only uses observations where event-time is defined. Units + not observed at all event times contribute to the periods they are present for. +- Rank-deficient design matrix (collinearity): warns and sets NA for dropped + coefficients (R-style, matches `lm()`) +- Average ATT (`avg_att`) is NA if any post-period effect is unidentified + (R-style NA propagation) +- Pre-test of parallel trends: joint F-test on pre-treatment δ_e coefficients. + Low power in pre-test does not validate parallel trends (Roth 2022). **Reference implementation(s):** -- R: `fixest::feols()` with `i(event_time, ref=-1)` -- Stata: `eventdd` or manual indicator regression +- R: `fixest::feols(y ~ i(time, treatment, ref=ref_period) | unit + time, data, cluster=~unit)` + or equivalently `feols(y ~ i(event_time, ref=-1) | unit + time, data, cluster=~unit)` +- Stata: `reghdfe y ib(-1).event_time#1.treatment, absorb(unit time) cluster(unit)` **Requirements checklist:** -- [ ] Reference period coefficient is zero (normalized) -- [ ] Pre-period coefficients test parallel trends assumption -- [ ] Supports both balanced and unbalanced panels -- [ ] Returns PeriodEffect objects with confidence intervals + +- [x] Event-time indicators for ALL periods (pre and post), not just post-treatment +- [x] Reference period coefficient is zero (normalized by omission from design matrix) +- [x] Pre-period coefficients available for parallel trends assessment +- [ ] Default cluster-robust SE at unit level (currently HC1; cluster-robust via `cluster` param) +- [ ] Supports unit and time FE via absorption +- [ ] Endpoint binning for distant event times +- [x] Average ATT correctly accounts for covariance between period effects +- [x] Returns PeriodEffect objects with confidence intervals +- [x] Supports both balanced and unbalanced panels --- @@ -967,7 +1037,7 @@ should be a deliberate user choice. | Estimator | Default SE | Alternatives | |-----------|-----------|--------------| | DifferenceInDifferences | HC1 robust | Cluster-robust, wild bootstrap | -| MultiPeriodDiD | Cluster at unit | Wild bootstrap | +| MultiPeriodDiD | HC1 robust | Cluster-robust (via `cluster` param), wild bootstrap | | TwoWayFixedEffects | Cluster at unit | Wild bootstrap | | CallawaySantAnna | Analytical (influence fn) | Multiplier bootstrap | | SunAbraham | Cluster-robust + delta method | Pairs bootstrap | @@ -986,7 +1056,7 @@ should be a deliberate user choice. | diff-diff Estimator | R Package | Function | |---------------------|-----------|----------| | DifferenceInDifferences | fixest | `feols(y ~ treat:post, ...)` | -| MultiPeriodDiD | fixest | `feols(y ~ i(event_time), ...)` | +| MultiPeriodDiD | fixest | `feols(y ~ i(time, treat, ref=ref) \| unit + time)` | | TwoWayFixedEffects | fixest | `feols(y ~ treat \| unit + time, ...)` | | CallawaySantAnna | did | `att_gt()` | | SunAbraham | fixest | `sunab()` | diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 141bd4e..8091741 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -46,13 +46,15 @@ def simple_did_data(): # Add noise y += np.random.normal(0, 1) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": period, - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": period, + "outcome": y, + } + ) return pd.DataFrame(data) @@ -60,11 +62,13 @@ def simple_did_data(): @pytest.fixture def simple_2x2_data(): """Minimal 2x2 DiD data.""" - return pd.DataFrame({ - "outcome": [10, 11, 15, 18, 9, 10, 12, 13], - "treated": [1, 1, 1, 1, 0, 0, 0, 0], - "post": [0, 0, 1, 1, 0, 0, 1, 1], - }) + return pd.DataFrame( + { + "outcome": [10, 11, 15, 18, 9, 10, 12, 13], + "treated": [1, 1, 1, 1, 0, 0, 0, 0], + "post": [0, 0, 1, 1, 0, 0, 1, 1], + } + ) class TestDifferenceInDifferences: @@ -73,12 +77,7 @@ class TestDifferenceInDifferences: def test_basic_fit(self, simple_2x2_data): """Test basic model fitting.""" did = DifferenceInDifferences() - results = did.fit( - simple_2x2_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_2x2_data, outcome="outcome", treatment="treated", time="post") assert isinstance(results, DiDResults) assert did.is_fitted_ @@ -89,12 +88,7 @@ def test_basic_fit(self, simple_2x2_data): def test_att_direction(self, simple_did_data): """Test that ATT is estimated in correct direction.""" did = DifferenceInDifferences() - results = did.fit( - simple_did_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_did_data, outcome="outcome", treatment="treated", time="post") # True ATT is 3.0, estimate should be close assert results.att > 0 @@ -103,10 +97,7 @@ def test_att_direction(self, simple_did_data): def test_formula_interface(self, simple_2x2_data): """Test formula-based fitting.""" did = DifferenceInDifferences() - results = did.fit( - simple_2x2_data, - formula="outcome ~ treated * post" - ) + results = did.fit(simple_2x2_data, formula="outcome ~ treated * post") assert isinstance(results, DiDResults) assert did.is_fitted_ @@ -114,10 +105,7 @@ def test_formula_interface(self, simple_2x2_data): def test_formula_with_explicit_interaction(self, simple_2x2_data): """Test formula with explicit interaction syntax.""" did = DifferenceInDifferences() - results = did.fit( - simple_2x2_data, - formula="outcome ~ treated + post + treated:post" - ) + results = did.fit(simple_2x2_data, formula="outcome ~ treated + post + treated:post") assert isinstance(results, DiDResults) @@ -127,16 +115,10 @@ def test_robust_vs_classical_se(self, simple_did_data): did_classical = DifferenceInDifferences(robust=False) results_robust = did_robust.fit( - simple_did_data, - outcome="outcome", - treatment="treated", - time="post" + simple_did_data, outcome="outcome", treatment="treated", time="post" ) results_classical = did_classical.fit( - simple_did_data, - outcome="outcome", - treatment="treated", - time="post" + simple_did_data, outcome="outcome", treatment="treated", time="post" ) # The vcov matrices should differ (HC1 vs classical) @@ -149,12 +131,7 @@ def test_robust_vs_classical_se(self, simple_did_data): def test_confidence_interval(self, simple_did_data): """Test confidence interval properties.""" did = DifferenceInDifferences(alpha=0.05) - results = did.fit( - simple_did_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_did_data, outcome="outcome", treatment="treated", time="post") lower, upper = results.conf_int assert lower < results.att < upper @@ -183,11 +160,13 @@ def test_summary_output(self, simple_2x2_data): def test_invalid_treatment_values(self): """Test error on non-binary treatment.""" - data = pd.DataFrame({ - "outcome": [1, 2, 3, 4], - "treated": [0, 1, 2, 3], # Invalid: not binary - "post": [0, 0, 1, 1], - }) + data = pd.DataFrame( + { + "outcome": [1, 2, 3, 4], + "treated": [0, 1, 2, 3], # Invalid: not binary + "post": [0, 0, 1, 1], + } + ) did = DifferenceInDifferences() with pytest.raises(ValueError, match="binary"): @@ -195,10 +174,12 @@ def test_invalid_treatment_values(self): def test_missing_column_error(self): """Test error when column is missing.""" - data = pd.DataFrame({ - "outcome": [1, 2, 3, 4], - "treated": [0, 0, 1, 1], - }) + data = pd.DataFrame( + { + "outcome": [1, 2, 3, 4], + "treated": [0, 0, 1, 1], + } + ) did = DifferenceInDifferences() with pytest.raises(ValueError, match="Missing columns"): @@ -224,7 +205,7 @@ def test_rank_deficient_action_error_raises(self, simple_2x2_data): outcome="outcome", treatment="treated", time="post", - covariates=["collinear_cov"] + covariates=["collinear_cov"], ) def test_rank_deficient_action_silent_no_warning(self, simple_2x2_data): @@ -244,19 +225,23 @@ def test_rank_deficient_action_silent_no_warning(self, simple_2x2_data): outcome="outcome", treatment="treated", time="post", - covariates=["collinear_cov"] + covariates=["collinear_cov"], ) # No warnings about rank deficiency should be emitted - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "rank-deficient" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower() + ] assert len(rank_warnings) == 0, f"Expected no rank warnings, got {rank_warnings}" # Should still have NaN for dropped coefficient assert "collinear_cov" in results.coefficients # Either collinear_cov or treated will be NaN - has_nan = (np.isnan(results.coefficients.get("collinear_cov", 0)) or - np.isnan(results.coefficients.get("treated", 0))) + has_nan = np.isnan(results.coefficients.get("collinear_cov", 0)) or np.isnan( + results.coefficients.get("treated", 0) + ) assert has_nan, "Expected NaN for one of the collinear coefficients" def test_rank_deficient_action_warn_default(self, simple_2x2_data): @@ -276,12 +261,15 @@ def test_rank_deficient_action_warn_default(self, simple_2x2_data): outcome="outcome", treatment="treated", time="post", - covariates=["collinear_cov"] + covariates=["collinear_cov"], ) # Should have a warning about rank deficiency - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "rank-deficient" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower() + ] assert len(rank_warnings) > 0, "Expected warning about rank deficiency" @@ -291,12 +279,7 @@ class TestDiDResults: def test_repr(self, simple_2x2_data): """Test string representation.""" did = DifferenceInDifferences() - results = did.fit( - simple_2x2_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_2x2_data, outcome="outcome", treatment="treated", time="post") repr_str = repr(results) assert "DiDResults" in repr_str @@ -305,12 +288,7 @@ def test_repr(self, simple_2x2_data): def test_to_dict(self, simple_2x2_data): """Test conversion to dictionary.""" did = DifferenceInDifferences() - results = did.fit( - simple_2x2_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_2x2_data, outcome="outcome", treatment="treated", time="post") result_dict = results.to_dict() assert "att" in result_dict @@ -320,12 +298,7 @@ def test_to_dict(self, simple_2x2_data): def test_to_dataframe(self, simple_2x2_data): """Test conversion to DataFrame.""" did = DifferenceInDifferences() - results = did.fit( - simple_2x2_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_2x2_data, outcome="outcome", treatment="treated", time="post") df = results.to_dataframe() assert isinstance(df, pd.DataFrame) @@ -335,12 +308,7 @@ def test_to_dataframe(self, simple_2x2_data): def test_significance_stars(self, simple_did_data): """Test significance star notation.""" did = DifferenceInDifferences() - results = did.fit( - simple_did_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_did_data, outcome="outcome", treatment="treated", time="post") # With true effect of 3.0 and n=200, should be significant assert results.significance_stars in ["*", "**", "***"] @@ -348,12 +316,7 @@ def test_significance_stars(self, simple_did_data): def test_is_significant_property(self, simple_did_data): """Test is_significant property.""" did = DifferenceInDifferences(alpha=0.05) - results = did.fit( - simple_did_data, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(simple_did_data, outcome="outcome", treatment="treated", time="post") # Boolean check assert isinstance(results.is_significant, bool) @@ -388,14 +351,16 @@ def panel_data_with_fe(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "state": f"state_{state}", - "period": period, - "treated": int(is_treated), - "post": post, - "outcome": y, - }) + data.append( + { + "unit": unit, + "state": f"state_{state}", + "period": period, + "treated": int(is_treated), + "post": post, + "outcome": y, + } + ) return pd.DataFrame(data) @@ -407,7 +372,7 @@ def test_fixed_effects_dummy(self, panel_data_with_fe): outcome="outcome", treatment="treated", time="post", - fixed_effects=["state"] + fixed_effects=["state"], ) assert results is not None @@ -423,7 +388,7 @@ def test_fixed_effects_coefficients_include_dummies(self, panel_data_with_fe): outcome="outcome", treatment="treated", time="post", - fixed_effects=["state"] + fixed_effects=["state"], ) # Should have state dummy coefficients @@ -434,11 +399,7 @@ def test_absorb_fixed_effects(self, panel_data_with_fe): """Test absorbed (within-transformed) fixed effects.""" did = DifferenceInDifferences() results = did.fit( - panel_data_with_fe, - outcome="outcome", - treatment="treated", - time="post", - absorb=["unit"] + panel_data_with_fe, outcome="outcome", treatment="treated", time="post", absorb=["unit"] ) assert results is not None @@ -452,10 +413,7 @@ def test_fixed_effects_vs_no_fe(self, panel_data_with_fe): did_with_fe = DifferenceInDifferences() results_no_fe = did_no_fe.fit( - panel_data_with_fe, - outcome="outcome", - treatment="treated", - time="post" + panel_data_with_fe, outcome="outcome", treatment="treated", time="post" ) results_with_fe = did_with_fe.fit( @@ -463,7 +421,7 @@ def test_fixed_effects_vs_no_fe(self, panel_data_with_fe): outcome="outcome", treatment="treated", time="post", - fixed_effects=["state"] + fixed_effects=["state"], ) # Both should estimate positive ATT @@ -482,7 +440,7 @@ def test_invalid_fixed_effects_column(self, panel_data_with_fe): outcome="outcome", treatment="treated", time="post", - fixed_effects=["nonexistent_column"] + fixed_effects=["nonexistent_column"], ) def test_invalid_absorb_column(self, panel_data_with_fe): @@ -494,7 +452,7 @@ def test_invalid_absorb_column(self, panel_data_with_fe): outcome="outcome", treatment="treated", time="post", - absorb=["nonexistent_column"] + absorb=["nonexistent_column"], ) def test_multiple_fixed_effects(self, panel_data_with_fe): @@ -508,7 +466,7 @@ def test_multiple_fixed_effects(self, panel_data_with_fe): outcome="outcome", treatment="treated", time="post", - fixed_effects=["state", "industry"] + fixed_effects=["state", "industry"], ) assert results is not None @@ -530,7 +488,7 @@ def test_covariates_with_fixed_effects(self, panel_data_with_fe): treatment="treated", time="post", covariates=["size"], - fixed_effects=["state"] + fixed_effects=["state"], ) assert results is not None @@ -564,12 +522,14 @@ def parallel_trends_data(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) return pd.DataFrame(data) @@ -600,12 +560,14 @@ def non_parallel_trends_data(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) return pd.DataFrame(data) @@ -620,7 +582,7 @@ def test_wasserstein_parallel_trends_valid(self, parallel_trends_data): treatment_group="treated", unit="unit", pre_periods=[0, 1, 2], - seed=42 + seed=42, ) assert "wasserstein_distance" in results @@ -641,7 +603,7 @@ def test_wasserstein_parallel_trends_violated(self, non_parallel_trends_data): treatment_group="treated", unit="unit", pre_periods=[0, 1, 2], - seed=42 + seed=42, ) # When trends are not parallel, should detect it @@ -661,7 +623,7 @@ def test_wasserstein_returns_changes(self, parallel_trends_data): treatment_group="treated", unit="unit", pre_periods=[0, 1, 2], - seed=42 + seed=42, ) assert "treated_changes" in results @@ -679,7 +641,7 @@ def test_wasserstein_without_unit(self, parallel_trends_data): time="period", treatment_group="treated", pre_periods=[0, 1, 2], - seed=42 + seed=42, ) assert "wasserstein_distance" in results @@ -695,7 +657,7 @@ def test_equivalence_test_parallel(self, parallel_trends_data): time="period", treatment_group="treated", unit="unit", - pre_periods=[0, 1, 2] + pre_periods=[0, 1, 2], ) assert "tost_p_value" in results @@ -714,7 +676,7 @@ def test_equivalence_test_non_parallel(self, non_parallel_trends_data): time="period", treatment_group="treated", unit="unit", - pre_periods=[0, 1, 2] + pre_periods=[0, 1, 2], ) # When trends are not parallel, should not be equivalent @@ -731,7 +693,7 @@ def test_equivalence_test_custom_margin(self, parallel_trends_data): treatment_group="treated", unit="unit", pre_periods=[0, 1, 2], - equivalence_margin=0.1 # Very tight margin + equivalence_margin=0.1, # Very tight margin ) assert results["equivalence_margin"] == 0.1 @@ -747,7 +709,7 @@ def test_ks_test_included(self, parallel_trends_data): treatment_group="treated", unit="unit", pre_periods=[0, 1, 2], - seed=42 + seed=42, ) assert "ks_statistic" in results @@ -766,7 +728,7 @@ def test_variance_ratio(self, parallel_trends_data): treatment_group="treated", unit="unit", pre_periods=[0, 1, 2], - seed=42 + seed=42, ) assert "variance_ratio" in results @@ -781,12 +743,14 @@ def test_multicollinearity_detection(self): import warnings # Create data where a covariate is perfectly correlated with treatment - data = pd.DataFrame({ - "outcome": [10, 11, 15, 18, 9, 10, 12, 13], - "treated": [1, 1, 1, 1, 0, 0, 0, 0], - "post": [0, 0, 1, 1, 0, 0, 1, 1], - "duplicate_treated": [1, 1, 1, 1, 0, 0, 0, 0], # Same as treated - }) + data = pd.DataFrame( + { + "outcome": [10, 11, 15, 18, 9, 10, 12, 13], + "treated": [1, 1, 1, 1, 0, 0, 0, 0], + "post": [0, 0, 1, 1, 0, 0, 1, 1], + "duplicate_treated": [1, 1, 1, 1, 0, 0, 0, 0], # Same as treated + } + ) did = DifferenceInDifferences() @@ -798,7 +762,7 @@ def test_multicollinearity_detection(self): outcome="outcome", treatment="treated", time="post", - covariates=["duplicate_treated"] + covariates=["duplicate_treated"], ) # Should emit a warning about rank deficiency rank_warnings = [x for x in w if "Rank-deficient" in str(x.message)] @@ -820,12 +784,14 @@ def test_wasserstein_custom_threshold(self): is_treated = unit < n_units // 2 for period in range(n_periods): y = 10.0 + period * 1.5 + np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) df = pd.DataFrame(data) @@ -838,7 +804,7 @@ def test_wasserstein_custom_threshold(self): unit="unit", pre_periods=[0, 1], seed=42, - wasserstein_threshold=0.01 # Very strict + wasserstein_threshold=0.01, # Very strict ) # Test with high threshold (more lenient) @@ -850,7 +816,7 @@ def test_wasserstein_custom_threshold(self): unit="unit", pre_periods=[0, 1], seed=42, - wasserstein_threshold=1.0 # Very lenient + wasserstein_threshold=1.0, # Very lenient ) # Both should return valid results @@ -862,12 +828,14 @@ def test_equivalence_test_insufficient_data(self): from diff_diff.utils import equivalence_test_trends # Create minimal data with only 1 observation per group - data = pd.DataFrame({ - "outcome": [10, 15], - "period": [0, 1], - "treated": [1, 0], - "unit": [0, 1], - }) + data = pd.DataFrame( + { + "outcome": [10, 15], + "period": [0, 1], + "treated": [1, 0], + "unit": [0, 1], + } + ) results = equivalence_test_trends( data, @@ -875,7 +843,7 @@ def test_equivalence_test_insufficient_data(self): time="period", treatment_group="treated", unit="unit", - pre_periods=[0] + pre_periods=[0], ) # Should return NaN values with error message @@ -887,18 +855,16 @@ def test_parallel_trends_single_period(self): """Test that single pre-period returns NaN values.""" from diff_diff.utils import check_parallel_trends - data = pd.DataFrame({ - "outcome": [10, 11, 12, 13], - "time": [0, 0, 0, 0], # All same period - "treated": [1, 1, 0, 0], - }) + data = pd.DataFrame( + { + "outcome": [10, 11, 12, 13], + "time": [0, 0, 0, 0], # All same period + "treated": [1, 1, 0, 0], + } + ) results = check_parallel_trends( - data, - outcome="outcome", - time="time", - treatment_group="treated", - pre_periods=[0] + data, outcome="outcome", time="time", treatment_group="treated", pre_periods=[0] ) # Should handle gracefully with NaN @@ -930,13 +896,15 @@ def twfe_panel_data(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": post, - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": post, + "outcome": y, + } + ) return pd.DataFrame(data) @@ -946,11 +914,7 @@ def test_twfe_basic_fit(self, twfe_panel_data): twfe = TwoWayFixedEffects() results = twfe.fit( - twfe_panel_data, - outcome="outcome", - treatment="treated", - time="post", - unit="unit" + twfe_panel_data, outcome="outcome", treatment="treated", time="post", unit="unit" ) assert results is not None @@ -975,7 +939,7 @@ def test_twfe_with_covariates(self, twfe_panel_data): treatment="treated", time="post", unit="unit", - covariates=["size"] + covariates=["size"], ) assert results is not None @@ -992,7 +956,7 @@ def test_twfe_invalid_unit_column(self, twfe_panel_data): outcome="outcome", treatment="treated", time="post", - unit="nonexistent_unit" + unit="nonexistent_unit", ) def test_twfe_clusters_at_unit_level(self, twfe_panel_data): @@ -1001,11 +965,7 @@ def test_twfe_clusters_at_unit_level(self, twfe_panel_data): twfe = TwoWayFixedEffects() results = twfe.fit( - twfe_panel_data, - outcome="outcome", - treatment="treated", - time="post", - unit="unit" + twfe_panel_data, outcome="outcome", treatment="treated", time="post", unit="unit" ) # Cluster should NOT be mutated (remains None) - clustering is handled internally @@ -1024,13 +984,15 @@ def test_twfe_treatment_collinearity_raises_error(self): for unit in range(10): is_treated = unit < 5 for period in range(4): - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), # Same for all periods - "post": 1 if period >= 2 else 0, - "outcome": 10.0 + unit * 0.5 + period * 0.3 + np.random.normal(0, 0.1), - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), # Same for all periods + "post": 1 if period >= 2 else 0, + "outcome": 10.0 + unit * 0.5 + period * 0.3 + np.random.normal(0, 0.1), + } + ) df = pd.DataFrame(data) # Make treatment_post constant for treated units (collinear) @@ -1045,18 +1007,16 @@ def test_twfe_treatment_collinearity_raises_error(self): # The key is that it should NOT silently produce misleading results try: results = twfe.fit( - df_collinear, - outcome="outcome", - treatment="treated", - time="post", - unit="unit" + df_collinear, outcome="outcome", treatment="treated", time="post", unit="unit" ) # If we get here without error, the ATT should still be computed # (this means only covariates were dropped, not the treatment) assert results is not None except ValueError as e: # If treatment column is dropped, should get informative error - assert "collinear" in str(e).lower() or "Treatment effect cannot be identified" in str(e) + assert "collinear" in str(e).lower() or "Treatment effect cannot be identified" in str( + e + ) def test_rank_deficient_action_error_raises(self, twfe_panel_data): """Test that rank_deficient_action='error' raises ValueError on collinear data.""" @@ -1074,7 +1034,7 @@ def test_rank_deficient_action_error_raises(self, twfe_panel_data): treatment="treated", time="post", unit="unit", - covariates=["collinear_cov"] + covariates=["collinear_cov"], ) def test_rank_deficient_action_silent_no_warning(self, twfe_panel_data): @@ -1097,13 +1057,17 @@ def test_rank_deficient_action_silent_no_warning(self, twfe_panel_data): treatment="treated", time="post", unit="unit", - covariates=["size", "size_dup"] + covariates=["size", "size_dup"], ) # No warnings about rank deficiency or collinearity should be emitted - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "rank-deficient" in str(x.message).lower() - or "collinear" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) + or "rank-deficient" in str(x.message).lower() + or "collinear" in str(x.message).lower() + ] assert len(rank_warnings) == 0, f"Expected no rank warnings, got {rank_warnings}" # Should still get valid results @@ -1125,20 +1089,20 @@ def test_cluster_robust_se(self): treated = cluster < 5 post = obs >= 5 y = 10 + (3.0 if treated and post else 0) + np.random.normal(0, 1) - data.append({ - "cluster": cluster, - "outcome": y, - "treated": int(treated), - "post": int(post), - }) + data.append( + { + "cluster": cluster, + "outcome": y, + "treated": int(treated), + "post": int(post), + } + ) df = pd.DataFrame(data) # With clustering did_cluster = DifferenceInDifferences(cluster="cluster") - results_cluster = did_cluster.fit( - df, outcome="outcome", treatment="treated", time="post" - ) + results_cluster = did_cluster.fit(df, outcome="outcome", treatment="treated", time="post") # Without clustering did_no_cluster = DifferenceInDifferences(robust=True) @@ -1180,12 +1144,14 @@ def multi_period_data(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) return pd.DataFrame(data) @@ -1214,12 +1180,14 @@ def heterogeneous_effects_data(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) return pd.DataFrame(data), true_effects @@ -1231,13 +1199,15 @@ def test_basic_fit(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) assert isinstance(results, MultiPeriodDiDResults) assert did.is_fitted_ assert results.n_obs == 600 # 100 units * 6 periods - assert len(results.period_effects) == 3 # 3 post-periods + # 5 estimated periods: pre=[0,1] + post=[3,4,5] (ref=2 excluded) + assert len(results.period_effects) == 5 assert len(results.pre_periods) == 3 assert len(results.post_periods) == 3 @@ -1249,7 +1219,8 @@ def test_avg_att_close_to_true(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # True ATT is 3.0 @@ -1266,14 +1237,16 @@ def test_period_specific_effects(self, heterogeneous_effects_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # Each period-specific effect should be close to truth for period, true_effect in true_effects.items(): estimated = results.period_effects[period].effect - assert abs(estimated - true_effect) < 0.5, \ - f"Period {period}: expected ~{true_effect}, got {estimated}" + assert ( + abs(estimated - true_effect) < 0.5 + ), f"Period {period}: expected ~{true_effect}, got {estimated}" def test_period_effects_have_all_stats(self, multi_period_data): """Test that period effects contain all statistics.""" @@ -1283,16 +1256,17 @@ def test_period_effects_have_all_stats(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) for period, pe in results.period_effects.items(): assert isinstance(pe, PeriodEffect) - assert hasattr(pe, 'effect') - assert hasattr(pe, 'se') - assert hasattr(pe, 't_stat') - assert hasattr(pe, 'p_value') - assert hasattr(pe, 'conf_int') + assert hasattr(pe, "effect") + assert hasattr(pe, "se") + assert hasattr(pe, "t_stat") + assert hasattr(pe, "p_value") + assert hasattr(pe, "conf_int") assert pe.se > 0 assert len(pe.conf_int) == 2 assert pe.conf_int[0] < pe.conf_int[1] @@ -1305,17 +1279,27 @@ def test_get_effect_method(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) - # Valid period + # Valid post-period effect = results.get_effect(4) assert isinstance(effect, PeriodEffect) assert effect.period == 4 - # Invalid period + # Valid pre-period (now accessible) + pre_effect = results.get_effect(0) + assert isinstance(pre_effect, PeriodEffect) + assert pre_effect.period == 0 + + # Reference period raises with informative message + with pytest.raises(KeyError, match="reference period"): + results.get_effect(2) + + # Non-existent period raises with pytest.raises(KeyError): - results.get_effect(0) # Pre-period + results.get_effect(99) def test_auto_infer_post_periods(self, multi_period_data): """Test automatic inference of post-periods.""" @@ -1324,7 +1308,8 @@ def test_auto_infer_post_periods(self, multi_period_data): multi_period_data, outcome="outcome", treatment="treated", - time="period" + time="period", + reference_period=2, # post_periods not specified - should infer last half ) @@ -1341,14 +1326,21 @@ def test_custom_reference_period(self, multi_period_data): treatment="treated", time="period", post_periods=[3, 4, 5], - reference_period=2 # Use period 2 as reference + reference_period=1, # Use period 1 as reference (not default) ) # Should work and give reasonable results assert results is not None assert did.is_fitted_ # Reference period should not be in coefficients as a dummy - assert "period_2" not in results.coefficients + assert "period_1" not in results.coefficients + # Reference period should be stored on results + assert results.reference_period == 1 + # Reference period should not be in period_effects + assert 1 not in results.period_effects + # Other pre-periods should be in period_effects + assert 0 in results.period_effects + assert 2 in results.period_effects def test_with_covariates(self, multi_period_data): """Test multi-period DiD with covariates.""" @@ -1362,7 +1354,8 @@ def test_with_covariates(self, multi_period_data): treatment="treated", time="period", post_periods=[3, 4, 5], - covariates=["size"] + covariates=["size"], + reference_period=2, ) assert results is not None @@ -1387,13 +1380,15 @@ def test_with_fixed_effects(self): y += 3.0 y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "state": f"state_{state}", - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "state": f"state_{state}", + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) df = pd.DataFrame(data) @@ -1404,7 +1399,8 @@ def test_with_fixed_effects(self): treatment="treated", time="period", post_periods=[3, 4, 5], - fixed_effects=["state"] + reference_period=2, + fixed_effects=["state"], ) assert results is not None @@ -1421,7 +1417,8 @@ def test_with_absorbed_fe(self, multi_period_data): treatment="treated", time="period", post_periods=[3, 4, 5], - absorb=["unit"] + reference_period=2, + absorb=["unit"], ) assert results is not None @@ -1438,7 +1435,8 @@ def test_cluster_robust_se(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) results_robust = did_robust.fit( @@ -1446,7 +1444,8 @@ def test_cluster_robust_se(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # ATT should be similar @@ -1463,13 +1462,15 @@ def test_summary_output(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) summary = results.summary() assert isinstance(summary, str) assert "Multi-Period" in summary - assert "Period-Specific" in summary + assert "Post-Period Treatment Effects" in summary + assert "Pre-Period" in summary assert "Average Treatment Effect" in summary assert "Avg ATT" in summary @@ -1481,7 +1482,8 @@ def test_to_dict(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) result_dict = results.to_dict() @@ -1498,15 +1500,17 @@ def test_to_dataframe(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) df = results.to_dataframe() assert isinstance(df, pd.DataFrame) - assert len(df) == 3 # 3 post-periods + assert len(df) == 5 # 2 pre + 3 post periods assert "period" in df.columns assert "effect" in df.columns assert "p_value" in df.columns + assert "is_post" in df.columns def test_is_significant_property(self, multi_period_data): """Test is_significant property.""" @@ -1516,7 +1520,8 @@ def test_is_significant_property(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # With true effect of 3.0, should be significant @@ -1531,7 +1536,8 @@ def test_significance_stars(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # Should have significance stars @@ -1545,7 +1551,8 @@ def test_repr(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) repr_str = repr(results) @@ -1560,7 +1567,8 @@ def test_period_effect_repr(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) pe = results.period_effects[3] @@ -1578,7 +1586,7 @@ def test_invalid_post_period(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 99] # 99 doesn't exist + post_periods=[3, 4, 99], # 99 doesn't exist ) def test_no_pre_periods_error(self, multi_period_data): @@ -1590,26 +1598,22 @@ def test_no_pre_periods_error(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[0, 1, 2, 3, 4, 5] # All periods + post_periods=[0, 1, 2, 3, 4, 5], # All periods ) def test_no_post_periods_error(self): """Test error when no post-treatment periods.""" - data = pd.DataFrame({ - "outcome": [10, 11, 12, 13], - "treated": [1, 1, 0, 0], - "period": [0, 1, 0, 1], - }) + data = pd.DataFrame( + { + "outcome": [10, 11, 12, 13], + "treated": [1, 1, 0, 0], + "period": [0, 1, 0, 1], + } + ) did = MultiPeriodDiD() with pytest.raises(ValueError, match="at least one post-treatment period"): - did.fit( - data, - outcome="outcome", - treatment="treated", - time="period", - post_periods=[] - ) + did.fit(data, outcome="outcome", treatment="treated", time="period", post_periods=[]) def test_invalid_treatment_values(self, multi_period_data): """Test error on non-binary treatment.""" @@ -1622,7 +1626,7 @@ def test_invalid_treatment_values(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], ) def test_unfitted_model_error(self): @@ -1639,7 +1643,8 @@ def test_confidence_interval_contains_estimate(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # Average ATT CI @@ -1660,22 +1665,20 @@ def test_two_periods_works(self): for period in [0, 1]: y = 10.0 + (3.0 if is_treated and period == 1 else 0) y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) df = pd.DataFrame(data) did = MultiPeriodDiD() results = did.fit( - df, - outcome="outcome", - treatment="treated", - time="period", - post_periods=[1] + df, outcome="outcome", treatment="treated", time="period", post_periods=[1] ) assert len(results.period_effects) == 1 @@ -1694,12 +1697,14 @@ def test_many_periods(self): if is_treated and period >= 10: y += 2.5 y += np.random.normal(0, 0.3) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) df = pd.DataFrame(data) @@ -1709,10 +1714,11 @@ def test_many_periods(self): outcome="outcome", treatment="treated", time="period", - post_periods=list(range(10, 20)) + post_periods=list(range(10, 20)), + reference_period=9, ) - assert len(results.period_effects) == 10 + assert len(results.period_effects) == 19 # 9 pre + 10 post (ref=9 excluded) assert len(results.pre_periods) == 10 assert abs(results.avg_att - 2.5) < 0.5 @@ -1724,7 +1730,8 @@ def test_r_squared_reported(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) assert results.r_squared is not None @@ -1738,7 +1745,8 @@ def test_coefficients_dict(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # Should have treatment, period dummies, and interactions @@ -1768,18 +1776,23 @@ def test_rank_deficient_design_warns_and_sets_nan(self, multi_period_data): treatment="treated", time="period", post_periods=[3, 4, 5], - covariates=["collinear_cov"] + reference_period=2, + covariates=["collinear_cov"], ) # Should have warning about rank deficiency - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "collinear" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) or "collinear" in str(x.message).lower() + ] assert len(rank_warnings) > 0, "Expected warning about rank deficiency" # The collinear covariate should have NaN coefficient assert "collinear_cov" in results.coefficients - assert np.isnan(results.coefficients["collinear_cov"]), \ - "Collinear covariate coefficient should be NaN" + assert np.isnan( + results.coefficients["collinear_cov"] + ), "Collinear covariate coefficient should be NaN" # Treatment effects should still be identified (not NaN) for period in [3, 4, 5]: @@ -1793,7 +1806,9 @@ def test_rank_deficient_design_warns_and_sets_nan(self, multi_period_data): assert np.any(np.isnan(results.vcov)), "Vcov should have NaN for dropped column" # avg_att should still be computed because all period effects are identified - assert not np.isnan(results.avg_att), "avg_att should be valid when all period effects are identified" + assert not np.isnan( + results.avg_att + ), "avg_att should be valid when all period effects are identified" def test_avg_att_nan_when_period_effect_nan(self, multi_period_data): """Test that avg_att is NaN if any period effect is NaN (R-style NA propagation).""" @@ -1814,12 +1829,16 @@ def test_avg_att_nan_when_period_effect_nan(self, multi_period_data): outcome="outcome", treatment="treated", time="period", - post_periods=[3, 4, 5] + post_periods=[3, 4, 5], + reference_period=2, ) # Should have warning about rank deficiency (treated:period_3 is all zeros) - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "collinear" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) or "collinear" in str(x.message).lower() + ] assert len(rank_warnings) > 0, "Expected warning about rank deficiency" # The treated×period_3 interaction should have NaN coefficient (unidentified) @@ -1846,7 +1865,8 @@ def test_rank_deficient_action_error_raises(self, multi_period_data): treatment="treated", time="period", post_periods=[3, 4, 5], - covariates=["collinear_cov"] + reference_period=2, + covariates=["collinear_cov"], ) def test_rank_deficient_action_silent_no_warning(self, multi_period_data): @@ -1867,18 +1887,278 @@ def test_rank_deficient_action_silent_no_warning(self, multi_period_data): treatment="treated", time="period", post_periods=[3, 4, 5], - covariates=["collinear_cov"] + reference_period=2, + covariates=["collinear_cov"], ) # No warnings about rank deficiency should be emitted - rank_warnings = [x for x in w if "Rank-deficient" in str(x.message) - or "rank-deficient" in str(x.message).lower()] + rank_warnings = [ + x + for x in w + if "Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower() + ] assert len(rank_warnings) == 0, f"Expected no rank warnings, got {rank_warnings}" # Should still have NaN for dropped coefficient assert "collinear_cov" in results.coefficients - assert np.isnan(results.coefficients["collinear_cov"]), \ - "Collinear covariate coefficient should be NaN" + assert np.isnan( + results.coefficients["collinear_cov"] + ), "Collinear covariate coefficient should be NaN" + + +class TestMultiPeriodDiDEventStudy: + """Tests for MultiPeriodDiD full event-study specification (pre + post periods).""" + + @pytest.fixture + def panel_data(self): + """Panel data with 6 periods, treatment at period 3.""" + np.random.seed(42) + data = [] + for unit in range(100): + is_treated = unit < 50 + for period in range(6): + y = 10.0 + period * 0.5 + if is_treated and period >= 3: + y += 3.0 + y += np.random.normal(0, 0.5) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) + return pd.DataFrame(data) + + def test_default_reference_period_is_last_pre(self, panel_data): + """Verify reference_period defaults to the last pre-period.""" + import warnings + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + did = MultiPeriodDiD() + results = did.fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + ) + assert results.reference_period == 2 # last pre-period + + def test_reference_period_future_warning(self, panel_data): + """Verify FutureWarning is emitted when reference_period is None.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + ) + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning for reference_period default" + assert "reference_period" in str(future_warnings[0].message) + + def test_pre_period_effects_near_zero(self, panel_data): + """Under parallel trends DGP, pre-period effects should be ~0.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + for period in [0, 1]: + pe = results.period_effects[period] + assert ( + abs(pe.effect) < 0.5 + ), f"Pre-period {period} effect should be near zero, got {pe.effect}" + + def test_reference_period_excluded_from_effects(self, panel_data): + """Reference period should not be a key in period_effects.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + assert 2 not in results.period_effects + + def test_reference_period_stored_in_results(self, panel_data): + """Results.reference_period should match the chosen reference.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=1, + ) + assert results.reference_period == 1 + + def test_reference_period_in_post_warns(self, panel_data): + """Setting reference_period to a post-period should emit warning.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=4, + ) + post_ref_warnings = [x for x in w if "post-treatment period" in str(x.message)] + assert len(post_ref_warnings) > 0, "Expected warning about post-period reference" + + def test_staggered_treatment_warning(self): + """Staggered treatment timing with unit param should emit warning.""" + np.random.seed(42) + data = [] + for unit in range(40): + if unit < 10: + treat_start = 3 + elif unit < 20: + treat_start = 5 + else: + treat_start = None + for period in range(8): + is_treated = treat_start is not None and period >= treat_start + y = 10.0 + period * 0.5 + (2.0 if is_treated else 0) + y += np.random.normal(0, 0.3) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) + df = pd.DataFrame(data) + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MultiPeriodDiD().fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5, 6, 7], + reference_period=2, + unit="unit", + ) + staggered_warnings = [x for x in w if "staggered" in str(x.message).lower()] + assert len(staggered_warnings) > 0, "Expected staggered adoption warning" + + def test_unit_param_without_unit_column_raises(self, panel_data): + """unit='nonexistent' should raise ValueError.""" + with pytest.raises(ValueError, match="not found in data"): + MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + unit="nonexistent", + ) + + def test_avg_att_uses_only_post_periods(self, panel_data): + """avg_att should be the mean of post-period effects only.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + post_effects = [results.period_effects[p].effect for p in [3, 4, 5]] + assert abs(results.avg_att - np.mean(post_effects)) < 1e-10 + + def test_pre_period_effects_property(self, panel_data): + """results.pre_period_effects returns correct subset.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + pre = results.pre_period_effects + assert set(pre.keys()) == {0, 1} + + def test_post_period_effects_property(self, panel_data): + """results.post_period_effects returns correct subset.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + post = results.post_period_effects + assert set(post.keys()) == {3, 4, 5} + + def test_to_dataframe_has_is_post_column(self, panel_data): + """to_dataframe() should include is_post column.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + df = results.to_dataframe() + assert "is_post" in df.columns + assert df["is_post"].sum() == 3 + assert (~df["is_post"]).sum() == 2 + + def test_interaction_indices_stored(self, panel_data): + """results.interaction_indices should be populated.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + assert results.interaction_indices is not None + assert set(results.interaction_indices.keys()) == {0, 1, 3, 4, 5} + # Each value should be a valid column index + for period, idx in results.interaction_indices.items(): + assert isinstance(idx, int) + assert idx >= 0 + + def test_to_dict_has_reference_period(self, panel_data): + """to_dict() should include reference_period.""" + results = MultiPeriodDiD().fit( + panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + ) + d = results.to_dict() + assert "reference_period" in d + assert d["reference_period"] == 2 class TestSyntheticDiD: @@ -1910,12 +2190,14 @@ def sdid_panel_data(self): y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) return pd.DataFrame(data) @@ -1934,12 +2216,14 @@ def single_treated_unit_data(self): if period >= 5: y += 10.0 # True ATT = 10 y += np.random.normal(0, 1) - data.append({ - "unit": 0, - "period": period, - "treated": 1, - "outcome": y, - }) + data.append( + { + "unit": 0, + "period": period, + "treated": 1, + "outcome": y, + } + ) # Control units with various patterns for unit in range(1, n_controls + 1): @@ -1948,12 +2232,14 @@ def single_treated_unit_data(self): for period in range(n_periods): y = unit_intercept + period * unit_slope y += np.random.normal(0, 1) - data.append({ - "unit": unit, - "period": period, - "treated": 0, - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": 0, + "outcome": y, + } + ) return pd.DataFrame(data) @@ -1966,7 +2252,7 @@ def test_basic_fit(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) assert isinstance(results, SyntheticDiDResults) @@ -1984,7 +2270,7 @@ def test_att_direction(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) # True ATT is 5.0 @@ -2000,7 +2286,7 @@ def test_unit_weights_sum_to_one(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) weight_sum = sum(results.unit_weights.values()) @@ -2015,7 +2301,7 @@ def test_time_weights_sum_to_one(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) weight_sum = sum(results.time_weights.values()) @@ -2030,7 +2316,7 @@ def test_unit_weights_nonnegative(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) for w in results.unit_weights.values(): @@ -2045,7 +2331,7 @@ def test_single_treated_unit(self, single_treated_unit_data): treatment="treated", unit="unit", time="period", - post_periods=[5, 6, 7, 8, 9] + post_periods=[5, 6, 7, 8, 9], ) assert results.n_treated == 1 @@ -2065,7 +2351,7 @@ def test_regularization_effect(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) results_high_reg = sdid_high_reg.fit( @@ -2074,7 +2360,7 @@ def test_regularization_effect(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) # High regularization should give more uniform weights @@ -2093,7 +2379,7 @@ def test_placebo_inference(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) assert results.variance_method == "placebo" @@ -2110,7 +2396,7 @@ def test_bootstrap_inference(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) assert results.variance_method == "bootstrap" @@ -2132,7 +2418,7 @@ def test_get_unit_weights_df(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) weights_df = results.get_unit_weights_df() @@ -2150,7 +2436,7 @@ def test_get_time_weights_df(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) weights_df = results.get_time_weights_df() @@ -2168,7 +2454,7 @@ def test_pre_treatment_fit(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) assert results.pre_treatment_fit is not None @@ -2183,7 +2469,7 @@ def test_summary_output(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) summary = results.summary() @@ -2201,7 +2487,7 @@ def test_to_dict(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) result_dict = results.to_dict() @@ -2220,7 +2506,7 @@ def test_to_dataframe(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) df = results.to_dataframe() @@ -2237,7 +2523,7 @@ def test_repr(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) repr_str = repr(results) @@ -2253,7 +2539,7 @@ def test_is_significant_property(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) assert isinstance(results.is_significant, bool) @@ -2284,7 +2570,7 @@ def test_missing_unit_column(self, sdid_panel_data): treatment="treated", unit="nonexistent", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) def test_missing_time_column(self, sdid_panel_data): @@ -2297,17 +2583,19 @@ def test_missing_time_column(self, sdid_panel_data): treatment="treated", unit="unit", time="nonexistent", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) def test_no_treated_units_error(self): """Test error when no treated units.""" - data = pd.DataFrame({ - "unit": [1, 1, 2, 2], - "period": [0, 1, 0, 1], - "treated": [0, 0, 0, 0], - "outcome": [10, 11, 12, 13], - }) + data = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "period": [0, 1, 0, 1], + "treated": [0, 0, 0, 0], + "outcome": [10, 11, 12, 13], + } + ) sdid = SyntheticDiD() with pytest.raises(ValueError, match="No treated units"): @@ -2317,17 +2605,19 @@ def test_no_treated_units_error(self): treatment="treated", unit="unit", time="period", - post_periods=[1] + post_periods=[1], ) def test_no_control_units_error(self): """Test error when no control units.""" - data = pd.DataFrame({ - "unit": [1, 1, 2, 2], - "period": [0, 1, 0, 1], - "treated": [1, 1, 1, 1], - "outcome": [10, 11, 12, 13], - }) + data = pd.DataFrame( + { + "unit": [1, 1, 2, 2], + "period": [0, 1, 0, 1], + "treated": [1, 1, 1, 1], + "outcome": [10, 11, 12, 13], + } + ) sdid = SyntheticDiD() with pytest.raises(ValueError, match="No control units"): @@ -2337,7 +2627,7 @@ def test_no_control_units_error(self): treatment="treated", unit="unit", time="period", - post_periods=[1] + post_periods=[1], ) def test_auto_infer_post_periods(self, sdid_panel_data): @@ -2348,7 +2638,7 @@ def test_auto_infer_post_periods(self, sdid_panel_data): outcome="outcome", treatment="treated", unit="unit", - time="period" + time="period", # post_periods not specified ) @@ -2369,7 +2659,7 @@ def test_with_covariates(self, sdid_panel_data): unit="unit", time="period", post_periods=[4, 5, 6, 7], - covariates=["size"] + covariates=["size"], ) assert results is not None @@ -2384,7 +2674,7 @@ def test_confidence_interval_contains_estimate(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) lower, upper = results.conf_int @@ -2398,7 +2688,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) results2 = SyntheticDiD(n_bootstrap=50, seed=42).fit( @@ -2407,7 +2697,7 @@ def test_reproducibility_with_seed(self, sdid_panel_data): treatment="treated", unit="unit", time="period", - post_periods=[4, 5, 6, 7] + post_periods=[4, 5, 6, 7], ) assert results1.att == results2.att @@ -2428,23 +2718,27 @@ def test_insufficient_pre_periods_warning(self): y = 10.0 + t * 0.5 + np.random.normal(0, 0.3) if t in post_periods: y += 3.0 - data.append({ - "unit": 0, - "period": t, - "outcome": y, - "treated": 1, - }) + data.append( + { + "unit": 0, + "period": t, + "outcome": y, + "treated": 1, + } + ) # Control units for unit in range(1, n_control + 1): for t in range(n_periods): y = 8.0 + unit * 0.2 + t * 0.4 + np.random.normal(0, 0.3) - data.append({ - "unit": unit, - "period": t, - "outcome": y, - "treated": 0, - }) + data.append( + { + "unit": unit, + "period": t, + "outcome": y, + "treated": 0, + } + ) df = pd.DataFrame(data) @@ -2458,7 +2752,7 @@ def test_insufficient_pre_periods_warning(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods + post_periods=post_periods, ) # Results should still be valid @@ -2479,23 +2773,27 @@ def test_single_pre_period_edge_case(self): y = 10.0 + np.random.normal(0, 0.2) if t in post_periods: y += 2.0 - data.append({ - "unit": 0, - "period": t, - "outcome": y, - "treated": 1, - }) + data.append( + { + "unit": 0, + "period": t, + "outcome": y, + "treated": 1, + } + ) # Control units for unit in range(1, n_control + 1): for t in range(n_periods): y = 9.0 + np.random.normal(0, 0.2) - data.append({ - "unit": unit, - "period": t, - "outcome": y, - "treated": 0, - }) + data.append( + { + "unit": unit, + "period": t, + "outcome": y, + "treated": 0, + } + ) df = pd.DataFrame(data) @@ -2508,7 +2806,7 @@ def test_single_pre_period_edge_case(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods + post_periods=post_periods, ) # Should still produce results @@ -2530,23 +2828,27 @@ def test_more_pre_periods_than_control_units(self): y = 10.0 + t * 0.2 + np.random.normal(0, 0.3) if t in post_periods: y += 2.5 - data.append({ - "unit": 0, - "period": t, - "outcome": y, - "treated": 1, - }) + data.append( + { + "unit": 0, + "period": t, + "outcome": y, + "treated": 1, + } + ) # Control units for unit in range(1, n_control + 1): for t in range(n_periods): y = 8.0 + t * 0.15 + np.random.normal(0, 0.3) - data.append({ - "unit": unit, - "period": t, - "outcome": y, - "treated": 0, - }) + data.append( + { + "unit": unit, + "period": t, + "outcome": y, + "treated": 0, + } + ) df = pd.DataFrame(data) @@ -2559,7 +2861,7 @@ def test_more_pre_periods_than_control_units(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods + post_periods=post_periods, ) # Should produce valid results with regularization @@ -2646,9 +2948,7 @@ def test_compute_sdid_estimator(self): time_weights = np.array([0.5, 0.5]) tau = compute_sdid_estimator( - Y_pre_control, Y_post_control, - Y_pre_treated, Y_post_treated, - unit_weights, time_weights + Y_pre_control, Y_post_control, Y_pre_treated, Y_post_treated, unit_weights, time_weights ) # Treated: 15 - 10 = 5 @@ -2688,23 +2988,20 @@ def test_did_with_missing_periods(self): y += 3.0 y += np.random.normal(0, 1) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": period, - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": period, + "outcome": y, + } + ) df = pd.DataFrame(data) did = DifferenceInDifferences() - results = did.fit( - df, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(df, outcome="outcome", treatment="treated", time="post") # Should still produce valid results assert np.isfinite(results.att) @@ -2739,24 +3036,20 @@ def test_twfe_with_unbalanced_panel(self): y += 3.0 y += np.random.normal(0, 0.5) - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": post, - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": post, + "outcome": y, + } + ) df = pd.DataFrame(data) twfe = TwoWayFixedEffects() - results = twfe.fit( - df, - outcome="outcome", - treatment="post", - unit="unit", - time="period" - ) + results = twfe.fit(df, outcome="outcome", treatment="post", unit="unit", time="period") # Should produce valid results assert np.isfinite(results.att) @@ -2782,22 +3075,20 @@ def test_multiperiod_with_sparse_data(self): if is_treated and period >= 2: y += 3.0 # Treatment effect - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "outcome": y, + } + ) df = pd.DataFrame(data) mp_did = MultiPeriodDiD() results = mp_did.fit( - df, - outcome="outcome", - treatment="treated", - time="period", - reference_period=1 + df, outcome="outcome", treatment="treated", time="period", reference_period=1 ) # Should produce valid results @@ -2824,23 +3115,20 @@ def test_did_single_treated_unit(self): if is_treated and period == 1: y += 5.0 # Large effect for single unit - data.append({ - "unit": unit, - "period": period, - "treated": int(is_treated), - "post": period, - "outcome": y, - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": period, + "outcome": y, + } + ) df = pd.DataFrame(data) did = DifferenceInDifferences() - results = did.fit( - df, - outcome="outcome", - treatment="treated", - time="post" - ) + results = did.fit(df, outcome="outcome", treatment="treated", time="post") # Should produce valid results assert np.isfinite(results.att) @@ -2864,12 +3152,14 @@ def test_sdid_single_treated_unit(self): y = treated_base + treated_trend * t + np.random.normal(0, 0.3) if t in post_periods: y += 3.0 # Treatment effect - data.append({ - "unit": 0, - "period": t, - "outcome": y, - "treated": 1, - }) + data.append( + { + "unit": 0, + "period": t, + "outcome": y, + "treated": 1, + } + ) # Generate control units for unit in range(1, n_control + 1): @@ -2877,12 +3167,14 @@ def test_sdid_single_treated_unit(self): unit_trend = 0.4 + np.random.normal(0, 0.1) for t in range(n_periods): y = unit_base + unit_trend * t + np.random.normal(0, 0.3) - data.append({ - "unit": unit, - "period": t, - "outcome": y, - "treated": 0, - }) + data.append( + { + "unit": unit, + "period": t, + "outcome": y, + "treated": 0, + } + ) df = pd.DataFrame(data) @@ -2893,7 +3185,7 @@ def test_sdid_single_treated_unit(self): treatment="treated", unit="unit", time="period", - post_periods=post_periods + post_periods=post_periods, ) # SDID is designed for single/few treated units @@ -2916,12 +3208,14 @@ def test_did_with_redundant_covariate_emits_warning(self): import warnings np.random.seed(42) - data = pd.DataFrame({ - "outcome": np.random.normal(10, 1, 100), - "treated": np.repeat([0, 1], 50), - "post": np.tile([0, 1], 50), - "x1": np.random.normal(0, 1, 100), - }) + data = pd.DataFrame( + { + "outcome": np.random.normal(10, 1, 100), + "treated": np.repeat([0, 1], 50), + "post": np.tile([0, 1], 50), + "x1": np.random.normal(0, 1, 100), + } + ) # Add perfectly collinear covariate data["x2"] = data["x1"] * 2 + 3 @@ -2931,11 +3225,7 @@ def test_did_with_redundant_covariate_emits_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = did.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - covariates=["x1", "x2"] + data, outcome="outcome", treatment="treated", time="post", covariates=["x1", "x2"] ) # Should emit a warning about rank deficiency rank_warnings = [x for x in w if "Rank-deficient" in str(x.message)] @@ -2954,12 +3244,14 @@ def test_did_with_constant_covariate_emits_warning(self): import warnings np.random.seed(42) - data = pd.DataFrame({ - "outcome": np.random.normal(10, 1, 100), - "treated": np.repeat([0, 1], 50), - "post": np.tile([0, 1], 50), - "constant_x": np.ones(100), # Constant covariate - }) + data = pd.DataFrame( + { + "outcome": np.random.normal(10, 1, 100), + "treated": np.repeat([0, 1], 50), + "post": np.tile([0, 1], 50), + "constant_x": np.ones(100), # Constant covariate + } + ) did = DifferenceInDifferences() @@ -2968,11 +3260,7 @@ def test_did_with_constant_covariate_emits_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = did.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - covariates=["constant_x"] + data, outcome="outcome", treatment="treated", time="post", covariates=["constant_x"] ) # Should emit a warning about rank deficiency rank_warnings = [x for x in w if "Rank-deficient" in str(x.message)] @@ -2984,12 +3272,14 @@ def test_did_with_constant_covariate_emits_warning(self): def test_did_with_near_collinear_covariates(self): """Test DiD handles near-collinear covariates (not perfectly collinear).""" np.random.seed(42) - data = pd.DataFrame({ - "outcome": np.random.normal(10, 1, 100), - "treated": np.repeat([0, 1], 50), - "post": np.tile([0, 1], 50), - "x1": np.random.normal(0, 1, 100), - }) + data = pd.DataFrame( + { + "outcome": np.random.normal(10, 1, 100), + "treated": np.repeat([0, 1], 50), + "post": np.tile([0, 1], 50), + "x1": np.random.normal(0, 1, 100), + } + ) # Add near-collinear covariate (small noise breaks perfect collinearity) data["x2"] = data["x1"] * 2 + 3 + np.random.normal(0, 0.1, 100) @@ -2997,11 +3287,7 @@ def test_did_with_near_collinear_covariates(self): # Near-collinear should work (not perfectly rank-deficient) results = did.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - covariates=["x1", "x2"] + data, outcome="outcome", treatment="treated", time="post", covariates=["x1", "x2"] ) assert np.isfinite(results.att) @@ -3025,26 +3311,22 @@ def test_twfe_with_absorbed_covariate(self): if unit < n_units // 2 and post: y += 2.0 - data.append({ - "unit": unit, - "period": period, - "outcome": y, - "treated": int(unit < n_units // 2), - "post": post, - "unit_covariate": unit_x, # Same for all periods within unit - }) + data.append( + { + "unit": unit, + "period": period, + "outcome": y, + "treated": int(unit < n_units // 2), + "post": post, + "unit_covariate": unit_x, # Same for all periods within unit + } + ) df = pd.DataFrame(data) twfe = TwoWayFixedEffects() # unit_covariate is absorbed by unit fixed effects - results = twfe.fit( - df, - outcome="outcome", - treatment="post", - unit="unit", - time="period" - ) + results = twfe.fit(df, outcome="outcome", treatment="post", unit="unit", time="period") assert np.isfinite(results.att) assert results.se > 0 diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index 5172ac7..a60a0bf 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -56,13 +56,15 @@ def simple_panel_data(): y += np.random.normal(0, 0.5) - data.append({ - 'unit': unit, - 'period': period, - 'treated': int(is_treated), - 'post': int(post), - 'outcome': y - }) + data.append( + { + "unit": unit, + "period": period, + "treated": int(is_treated), + "post": int(post), + "outcome": y, + } + ) return pd.DataFrame(data) @@ -73,42 +75,61 @@ def multiperiod_results(simple_panel_data): mp_did = MultiPeriodDiD() results = mp_did.fit( simple_panel_data, - outcome='outcome', - treatment='treated', - time='period', - post_periods=[4, 5, 6, 7] + outcome="outcome", + treatment="treated", + time="period", + post_periods=[4, 5, 6, 7], + reference_period=3, ) return results @pytest.fixture def mock_multiperiod_results(): - """Create mock MultiPeriodDiDResults for unit testing.""" + """Create mock MultiPeriodDiDResults for unit testing. + + Simulates a full event-study with pre-period and post-period effects. + Reference period is 3 (last pre-period), so period_effects has + periods 0, 1, 2 (pre) and 4, 5, 6, 7 (post). + """ period_effects = { + # Pre-period effects (should be ~0 under parallel trends) + 0: PeriodEffect( + period=0, effect=0.05, se=0.4, t_stat=0.125, p_value=0.90, conf_int=(-0.73, 0.83) + ), + 1: PeriodEffect( + period=1, effect=-0.02, se=0.35, t_stat=-0.057, p_value=0.95, conf_int=(-0.71, 0.67) + ), + 2: PeriodEffect( + period=2, effect=0.08, se=0.3, t_stat=0.267, p_value=0.79, conf_int=(-0.51, 0.67) + ), + # Post-period effects 4: PeriodEffect( - period=4, effect=5.0, se=0.5, - t_stat=10.0, p_value=0.0001, - conf_int=(4.02, 5.98) + period=4, effect=5.0, se=0.5, t_stat=10.0, p_value=0.0001, conf_int=(4.02, 5.98) ), 5: PeriodEffect( - period=5, effect=5.2, se=0.5, - t_stat=10.4, p_value=0.0001, - conf_int=(4.22, 6.18) + period=5, effect=5.2, se=0.5, t_stat=10.4, p_value=0.0001, conf_int=(4.22, 6.18) ), 6: PeriodEffect( - period=6, effect=4.8, se=0.5, - t_stat=9.6, p_value=0.0001, - conf_int=(3.82, 5.78) + period=6, effect=4.8, se=0.5, t_stat=9.6, p_value=0.0001, conf_int=(3.82, 5.78) ), 7: PeriodEffect( - period=7, effect=5.0, se=0.5, - t_stat=10.0, p_value=0.0001, - conf_int=(4.02, 5.98) + period=7, effect=5.0, se=0.5, t_stat=10.0, p_value=0.0001, conf_int=(4.02, 5.98) ), } - # Create vcov matrix (diagonal for simplicity) - vcov = np.diag([0.25] * 4) + # SE^2 for all 7 interaction terms (periods 0,1,2,4,5,6,7) + vcov_diag = [0.4**2, 0.35**2, 0.3**2, 0.5**2, 0.5**2, 0.5**2, 0.5**2] + + # interaction_indices maps period -> column index in the full regression VCV + # (in a real fit, these would be the actual column positions) + interaction_indices = {0: 10, 1: 11, 2: 12, 4: 13, 5: 14, 6: 15, 7: 16} + + # Build a larger "full" VCV that the sub-extraction will index into + full_vcov = np.zeros((20, 20)) + for i, period in enumerate(sorted(interaction_indices.keys())): + col = interaction_indices[period] + full_vcov[col, col] = vcov_diag[i] return MultiPeriodDiDResults( period_effects=period_effects, @@ -122,7 +143,9 @@ def mock_multiperiod_results(): n_control=400, pre_periods=[0, 1, 2, 3], post_periods=[4, 5, 6, 7], - vcov=vcov, + vcov=full_vcov, + reference_period=3, + interaction_indices=interaction_indices, ) @@ -198,11 +221,7 @@ def test_construct_A_sd_small(self): def test_construct_constraints_sd(self): """Test smoothness constraints.""" - A_ineq, b_ineq = _construct_constraints_sd( - num_pre_periods=3, - num_post_periods=4, - M=0.5 - ) + A_ineq, b_ineq = _construct_constraints_sd(num_pre_periods=3, num_post_periods=4, M=0.5) # Should have 2 * (7 - 2) = 10 constraints assert A_ineq.shape[0] == 10 @@ -212,10 +231,7 @@ def test_construct_constraints_sd(self): def test_construct_constraints_rm(self): """Test relative magnitudes constraints.""" A_ineq, b_ineq = _construct_constraints_rm( - num_pre_periods=3, - num_post_periods=4, - Mbar=1.5, - max_pre_violation=0.2 + num_pre_periods=3, num_post_periods=4, Mbar=1.5, max_pre_violation=0.2 ) # Should have 2 * 4 = 8 constraints (upper and lower for each post period) @@ -242,6 +258,7 @@ def test_flci_symmetric(self): # FLCI extends each side by z * se from scipy import stats + z = stats.norm.ppf(1 - alpha / 2) expected_ci_lb = lb - z * se expected_ci_ub = ub + z * se @@ -259,6 +276,7 @@ def test_flci_point_identified(self): # Should be standard CI from scipy import stats + z = stats.norm.ppf(1 - alpha / 2) assert ci_lb == pytest.approx(point - z * se) assert ci_ub == pytest.approx(point + z * se) @@ -274,15 +292,22 @@ class TestParameterExtraction: def test_extract_from_multiperiod(self, mock_multiperiod_results): """Test extraction from MultiPeriodDiDResults.""" - (beta_hat, sigma, num_pre, num_post, - pre_periods, post_periods) = _extract_event_study_params(mock_multiperiod_results) + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + _extract_event_study_params(mock_multiperiod_results) + ) - assert len(beta_hat) == 4 - assert sigma.shape == (4, 4) - assert num_pre == 4 + # 7 estimated effects: 3 pre (0,1,2) + 4 post (4,5,6,7), ref=3 excluded + assert len(beta_hat) == 7 + assert sigma.shape == (7, 7) + assert num_pre == 3 assert num_post == 4 assert post_periods == [4, 5, 6, 7] + # Verify sub-VCV diagonal matches squared SEs + for i, period in enumerate(sorted(mock_multiperiod_results.period_effects.keys())): + pe = mock_multiperiod_results.period_effects[period] + assert sigma[i, i] == pytest.approx(pe.se**2, abs=1e-10) + def test_extract_unsupported_type_raises(self): """Test that unsupported types raise TypeError.""" with pytest.raises(TypeError, match="Unsupported results type"): @@ -389,7 +414,9 @@ def test_bounds_widen_with_M(self, mock_multiperiod_results): results_large = honest.fit(mock_multiperiod_results, M=2.0) # Larger M should give wider bounds - assert results_large.ci_ub - results_large.ci_lb >= results_small.ci_ub - results_small.ci_lb + assert ( + results_large.ci_ub - results_large.ci_lb >= results_small.ci_ub - results_small.ci_lb + ) class TestSensitivityAnalysis: @@ -459,13 +486,13 @@ def test_results_properties(self, mock_multiperiod_results): honest = HonestDiD(method="relative_magnitude", M=1.0) results = honest.fit(mock_multiperiod_results) - assert hasattr(results, 'lb') - assert hasattr(results, 'ub') - assert hasattr(results, 'ci_lb') - assert hasattr(results, 'ci_ub') - assert hasattr(results, 'is_significant') - assert hasattr(results, 'identified_set_width') - assert hasattr(results, 'ci_width') + assert hasattr(results, "lb") + assert hasattr(results, "ub") + assert hasattr(results, "ci_lb") + assert hasattr(results, "ci_ub") + assert hasattr(results, "is_significant") + assert hasattr(results, "identified_set_width") + assert hasattr(results, "ci_width") def test_results_is_significant(self, mock_multiperiod_results): """Test is_significant property.""" @@ -501,10 +528,10 @@ def test_results_to_dict(self, mock_multiperiod_results): d = results.to_dict() assert isinstance(d, dict) - assert 'lb' in d - assert 'ub' in d - assert 'M' in d - assert 'method' in d + assert "lb" in d + assert "ub" in d + assert "M" in d + assert "method" in d def test_results_to_dataframe(self, mock_multiperiod_results): """Test to_dataframe method.""" @@ -526,9 +553,9 @@ def test_sensitivity_results_to_dataframe(self, mock_multiperiod_results): df = sensitivity.to_dataframe() assert isinstance(df, pd.DataFrame) - assert 'M' in df.columns - assert 'lb' in df.columns - assert 'ci_lb' in df.columns + assert "M" in df.columns + assert "lb" in df.columns + assert "ci_lb" in df.columns def test_sensitivity_results_summary(self, mock_multiperiod_results): """Test summary method.""" @@ -549,11 +576,7 @@ class TestConvenienceFunctions: def test_compute_honest_did(self, mock_multiperiod_results): """Test compute_honest_did function.""" - results = compute_honest_did( - mock_multiperiod_results, - method='relative_magnitude', - M=1.0 - ) + results = compute_honest_did(mock_multiperiod_results, method="relative_magnitude", M=1.0) assert isinstance(results, HonestDiDResults) assert results.M == 1.0 @@ -573,14 +596,15 @@ def test_with_multiperiod_did(self, simple_panel_data): mp_did = MultiPeriodDiD() event_results = mp_did.fit( simple_panel_data, - outcome='outcome', - treatment='treated', - time='period', - post_periods=[4, 5, 6, 7] + outcome="outcome", + treatment="treated", + time="period", + post_periods=[4, 5, 6, 7], + reference_period=3, ) # Run Honest DiD - honest = HonestDiD(method='relative_magnitude', M=1.0) + honest = HonestDiD(method="relative_magnitude", M=1.0) bounds = honest.fit(event_results) # Check results are reasonable @@ -593,13 +617,14 @@ def test_sensitivity_analysis_integration(self, simple_panel_data): mp_did = MultiPeriodDiD() event_results = mp_did.fit( simple_panel_data, - outcome='outcome', - treatment='treated', - time='period', - post_periods=[4, 5, 6, 7] + outcome="outcome", + treatment="treated", + time="period", + post_periods=[4, 5, 6, 7], + reference_period=3, ) - honest = HonestDiD(method='relative_magnitude') + honest = HonestDiD(method="relative_magnitude") sensitivity = honest.sensitivity_analysis(event_results, M_grid=[0, 0.5, 1.0, 2.0]) # Bounds should widen as M increases @@ -611,17 +636,53 @@ def test_smoothness_method_integration(self, simple_panel_data): mp_did = MultiPeriodDiD() event_results = mp_did.fit( simple_panel_data, - outcome='outcome', - treatment='treated', - time='period', - post_periods=[4, 5, 6, 7] + outcome="outcome", + treatment="treated", + time="period", + post_periods=[4, 5, 6, 7], + reference_period=3, ) - honest = HonestDiD(method='smoothness', M=0.5) + honest = HonestDiD(method="smoothness", M=0.5) bounds = honest.fit(event_results) assert isinstance(bounds, HonestDiDResults) - assert bounds.method == 'smoothness' + assert bounds.method == "smoothness" + + def test_multiperiod_sub_vcov_extraction(self, simple_panel_data): + """Test that interaction_indices enables correct sub-VCV extraction. + + Fit MultiPeriodDiD, pass to _extract_event_study_params, and verify: + - sigma shape matches len(period_effects) x len(period_effects) + - Diagonal of sigma matches squared SEs from period_effects + - beta_hat length equals num_pre + num_post + """ + mp_did = MultiPeriodDiD() + results = mp_did.fit( + simple_panel_data, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[4, 5, 6, 7], + reference_period=3, + ) + + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + _extract_event_study_params(results) + ) + + n_effects = len(results.period_effects) + assert len(beta_hat) == n_effects + assert sigma.shape == (n_effects, n_effects) + assert num_pre + num_post == n_effects + + # Verify sub-VCV diagonal matches squared SEs from period_effects + sorted_periods = sorted(results.period_effects.keys()) + for i, period in enumerate(sorted_periods): + pe = results.period_effects[period] + assert sigma[i, i] == pytest.approx( + pe.se**2, rel=1e-6 + ), f"sigma[{i},{i}] = {sigma[i, i]} != se^2 = {pe.se**2} for period {period}" # ============================================================================= @@ -636,9 +697,7 @@ def test_single_post_period(self): """Test with single post-period.""" period_effects = { 4: PeriodEffect( - period=4, effect=5.0, se=0.5, - t_stat=10.0, p_value=0.0001, - conf_int=(4.02, 5.98) + period=4, effect=5.0, se=0.5, t_stat=10.0, p_value=0.0001, conf_int=(4.02, 5.98) ), } @@ -657,14 +716,14 @@ def test_single_post_period(self): vcov=np.array([[0.25]]), ) - honest = HonestDiD(method='relative_magnitude', M=1.0) + honest = HonestDiD(method="relative_magnitude", M=1.0) bounds = honest.fit(results) assert isinstance(bounds, HonestDiDResults) def test_m_zero_recovers_standard(self, mock_multiperiod_results): """Test that M=0 gives tighter bounds.""" - honest = HonestDiD(method='relative_magnitude') + honest = HonestDiD(method="relative_magnitude") results_0 = honest.fit(mock_multiperiod_results, M=0) results_1 = honest.fit(mock_multiperiod_results, M=1) @@ -674,7 +733,7 @@ def test_m_zero_recovers_standard(self, mock_multiperiod_results): def test_very_large_M(self, mock_multiperiod_results): """Test with very large M value.""" - honest = HonestDiD(method='relative_magnitude', M=100) + honest = HonestDiD(method="relative_magnitude", M=100) results = honest.fit(mock_multiperiod_results) # Should still return valid results @@ -694,19 +753,19 @@ def test_callaway_santanna_universal_base_period(self): cs = CallawaySantAnna(base_period="universal") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='period', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", ) # Verify reference period exists with NaN SE assert -1 in results.event_study_effects - assert np.isnan(results.event_study_effects[-1]['se']) + assert np.isnan(results.event_study_effects[-1]["se"]) # HonestDiD should work without errors (reference period filtered out) - honest = HonestDiD(method='relative_magnitude', M=1.0) + honest = HonestDiD(method="relative_magnitude", M=1.0) bounds = honest.fit(results) # Should have valid (non-NaN) results @@ -728,24 +787,25 @@ def test_max_pre_violation_excludes_reference_period(self): cs = CallawaySantAnna(base_period="universal") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='period', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", ) # Verify reference period exists with n_groups=0 assert -1 in results.event_study_effects - assert results.event_study_effects[-1]['n_groups'] == 0 + assert results.event_study_effects[-1]["n_groups"] == 0 # The max pre-violation calculation should exclude the reference period - honest = HonestDiD(method='relative_magnitude', M=1.0) + honest = HonestDiD(method="relative_magnitude", M=1.0) # Get pre_periods excluding reference (n_groups=0) real_pre_periods = [ - t for t in results.event_study_effects - if t < 0 and results.event_study_effects[t].get('n_groups', 1) > 0 + t + for t in results.event_study_effects + if t < 0 and results.event_study_effects[t].get("n_groups", 1) > 0 ] # If there are real pre-periods, max_violation should be > 0 @@ -754,9 +814,7 @@ def test_max_pre_violation_excludes_reference_period(self): max_violation = honest._estimate_max_pre_violation(results, real_pre_periods) # Max violation should reflect actual pre-period coefficients, not 0 # The actual effects are non-zero due to sampling variation - assert max_violation > 0, ( - "max_pre_violation should be > 0 when real pre-periods exist" - ) + assert max_violation > 0, "max_pre_violation should be > 0 when real pre-periods exist" # ============================================================================= @@ -769,8 +827,8 @@ class TestVisualizationNoMatplotlib: def test_sensitivity_results_has_plot_method(self, mock_multiperiod_results): """Test that SensitivityResults has plot method.""" - honest = HonestDiD(method='relative_magnitude') + honest = HonestDiD(method="relative_magnitude") sensitivity = honest.sensitivity_analysis(mock_multiperiod_results) - assert hasattr(sensitivity, 'plot') + assert hasattr(sensitivity, "plot") assert callable(sensitivity.plot) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 4a3ae3b..55edfb6 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -32,11 +32,13 @@ def generate_multi_period_data(n_obs: int = 200, seed: int = 42) -> pd.DataFrame y += np.random.randn() * 0.5 - data.append({ - 'outcome': y, - 'treated': treated, - 'period': period, - }) + data.append( + { + "outcome": y, + "treated": treated, + "period": period, + } + ) return pd.DataFrame(data) @@ -67,12 +69,14 @@ def generate_staggered_data( outcomes = unit_fe_expanded + 0.5 * times + 2.0 * post + np.random.randn(len(units)) * 0.3 - return pd.DataFrame({ - 'unit': units, - 'time': times, - 'outcome': outcomes, - 'first_treat': first_treat_expanded.astype(int), - }) + return pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded.astype(int), + } + ) class TestPlotEventStudy: @@ -85,10 +89,11 @@ def multi_period_results(self): did = MultiPeriodDiD() return did.fit( data, - outcome='outcome', - treatment='treated', - time='period', - post_periods=[2, 3] + outcome="outcome", + treatment="treated", + time="period", + post_periods=[2, 3], + reference_period=1, ) @pytest.fixture @@ -98,11 +103,11 @@ def cs_results(self): cs = CallawaySantAnna() return cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) def test_plot_from_multi_period_results(self, multi_period_results): @@ -123,11 +128,13 @@ def test_plot_from_dataframe(self): """Test plotting from DataFrame.""" pytest.importorskip("matplotlib") - df = pd.DataFrame({ - 'period': [-2, -1, 0, 1, 2], - 'effect': [0.1, 0.05, 0.0, 0.5, 0.6], - 'se': [0.1, 0.1, 0.0, 0.15, 0.15] - }) + df = pd.DataFrame( + { + "period": [-2, -1, 0, 1, 2], + "effect": [0.1, 0.05, 0.0, 0.5, 0.6], + "se": [0.1, 0.1, 0.0, 0.15, 0.15], + } + ) ax = plot_event_study(df, reference_period=0, show=False) assert ax is not None @@ -139,12 +146,7 @@ def test_plot_from_dict(self): effects = {-2: 0.1, -1: 0.05, 0: 0.0, 1: 0.5, 2: 0.6} se = {-2: 0.1, -1: 0.1, 0: 0.0, 1: 0.15, 2: 0.15} - ax = plot_event_study( - effects=effects, - se=se, - reference_period=0, - show=False - ) + ax = plot_event_study(effects=effects, se=se, reference_period=0, show=False) assert ax is not None def test_plot_customization(self, multi_period_results): @@ -159,7 +161,7 @@ def test_plot_customization(self, multi_period_results): color="red", marker="s", markersize=10, - show=False + show=False, ) assert ax.get_title() == "Custom Title" @@ -170,11 +172,7 @@ def test_plot_no_zero_line(self, multi_period_results): """Test disabling zero line.""" pytest.importorskip("matplotlib") - ax = plot_event_study( - multi_period_results, - show_zero_line=False, - show=False - ) + ax = plot_event_study(multi_period_results, show_zero_line=False, show=False) assert ax is not None def test_plot_with_existing_axes(self, multi_period_results): @@ -208,10 +206,7 @@ def test_error_invalid_se_type(self): def test_error_missing_dataframe_columns(self): """Test error with missing DataFrame columns.""" pytest.importorskip("matplotlib") - df = pd.DataFrame({ - 'x': [1, 2, 3], - 'y': [0.1, 0.2, 0.3] - }) + df = pd.DataFrame({"x": [1, 2, 3], "y": [0.1, 0.2, 0.3]}) with pytest.raises(ValueError, match="must have 'period' column"): plot_event_study(df) @@ -236,12 +231,7 @@ def test_plot_with_nan_se_reference_period(self): effects = {-2: 0.1, -1: 0.0, 0: 0.5, 1: 0.6} se = {-2: 0.1, -1: np.nan, 0: 0.15, 1: 0.15} # NaN SE at reference period - ax = plot_event_study( - effects=effects, - se=se, - reference_period=-1, - show=False - ) + ax = plot_event_study(effects=effects, se=se, reference_period=-1, show=False) # Verify the plot was created successfully assert ax is not None @@ -250,7 +240,7 @@ def test_plot_with_nan_se_reference_period(self): # The x-axis should have 4 tick labels xtick_labels = [t.get_text() for t in ax.get_xticklabels()] assert len(xtick_labels) == 4 - assert '-1' in xtick_labels + assert "-1" in xtick_labels plt.close() @@ -268,11 +258,11 @@ def test_plot_cs_universal_base_period(self): cs = CallawaySantAnna(base_period="universal") results = cs.fit( data, - outcome='outcome', - unit='unit', - time='period', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", ) # Should not raise even with NaN SE in reference period @@ -281,7 +271,7 @@ def test_plot_cs_universal_base_period(self): # Verify reference period (-1) is in the plot xtick_labels = [t.get_text() for t in ax.get_xticklabels()] - assert '-1' in xtick_labels + assert "-1" in xtick_labels plt.close() @@ -298,23 +288,23 @@ def test_plot_cs_with_anticipation(self): cs = CallawaySantAnna(base_period="universal", anticipation=1) results = cs.fit( data, - outcome='outcome', - unit='unit', - time='period', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="event_study", ) # Reference period should be at e=-2 (not e=-1) with anticipation=1 assert -2 in results.event_study_effects - assert results.event_study_effects[-2]['n_groups'] == 0 + assert results.event_study_effects[-2]["n_groups"] == 0 ax = plot_event_study(results, show=False) assert ax is not None # Verify -2 is in the plot (the true reference period) xtick_labels = [t.get_text() for t in ax.get_xticklabels()] - assert '-2' in xtick_labels + assert "-2" in xtick_labels plt.close() @@ -333,11 +323,13 @@ def test_plot_event_study_reference_period_normalization(self): import matplotlib.pyplot as plt # Create data where reference period (period=0) has effect=0.3 - df = pd.DataFrame({ - 'period': [-2, -1, 0, 1, 2], - 'effect': [0.1, 0.2, 0.3, 0.5, 0.6], # ref at 0 has effect 0.3 - 'se': [0.1, 0.1, 0.1, 0.1, 0.1] - }) + df = pd.DataFrame( + { + "period": [-2, -1, 0, 1, 2], + "effect": [0.1, 0.2, 0.3, 0.5, 0.6], # ref at 0 has effect 0.3 + "se": [0.1, 0.1, 0.1, 0.1, 0.1], + } + ) ax = plot_event_study(df, reference_period=0, show=False) @@ -346,7 +338,7 @@ def test_plot_event_study_reference_period_normalization(self): y_values = [] for child in ax.get_children(): # Line2D objects with single points are our markers - if hasattr(child, 'get_ydata'): + if hasattr(child, "get_ydata"): ydata = child.get_ydata() if len(ydata) == 1: y_values.append(float(ydata[0])) @@ -358,13 +350,15 @@ def test_plot_event_study_reference_period_normalization(self): expected_normalized = [-0.2, -0.1, 0.0, 0.2, 0.3] # Check that reference period (0) is at y=0 - assert 0.0 in y_values or any(abs(y) < 0.01 for y in y_values), \ - f"Reference period should be at y=0, got y_values={y_values}" + assert 0.0 in y_values or any( + abs(y) < 0.01 for y in y_values + ), f"Reference period should be at y=0, got y_values={y_values}" # Verify all expected normalized values are present for expected in expected_normalized: - assert any(abs(y - expected) < 0.01 for y in y_values), \ - f"Expected normalized value {expected} not found in {y_values}" + assert any( + abs(y - expected) < 0.01 for y in y_values + ), f"Expected normalized value {expected} not found in {y_values}" # Verify error bars: reference period (y=0) should have NO error bars # while other periods should have error bars @@ -375,7 +369,7 @@ def test_plot_event_study_reference_period_normalization(self): errorbar_x_coords = set() for child in ax.get_children(): # ErrorbarContainer's children include LineCollection for the caps/stems - if hasattr(child, 'get_segments'): + if hasattr(child, "get_segments"): segments = child.get_segments() for seg in segments: # Each segment is [[x1, y1], [x2, y2]] @@ -388,14 +382,16 @@ def test_plot_event_study_reference_period_normalization(self): reference_x = 2 # period 0 is at x-coordinate 2 # Reference period should NOT have error bars (x=2 should not be in errorbar_x_coords) - assert reference_x not in errorbar_x_coords, \ - f"Reference period should have no error bars but found error bar at x={reference_x}" + assert ( + reference_x not in errorbar_x_coords + ), f"Reference period should have no error bars but found error bar at x={reference_x}" # Other periods SHOULD have error bars # At least some of x=0, x=1, x=3, x=4 should have error bars non_ref_x_coords = {0, 1, 3, 4} - assert len(errorbar_x_coords & non_ref_x_coords) >= 2, \ - f"Non-reference periods should have error bars, found: {errorbar_x_coords}" + assert ( + len(errorbar_x_coords & non_ref_x_coords) >= 2 + ), f"Non-reference periods should have error bars, found: {errorbar_x_coords}" plt.close() @@ -404,26 +400,23 @@ def test_plot_event_study_no_normalization_without_reference(self): pytest.importorskip("matplotlib") import matplotlib.pyplot as plt - df = pd.DataFrame({ - 'period': [-1, 0, 1], - 'effect': [0.1, 0.3, 0.5], - 'se': [0.1, 0.1, 0.1] - }) + df = pd.DataFrame({"period": [-1, 0, 1], "effect": [0.1, 0.3, 0.5], "se": [0.1, 0.1, 0.1]}) ax = plot_event_study(df, reference_period=None, show=False) # Extract y-values y_values = [] for child in ax.get_children(): - if hasattr(child, 'get_ydata'): + if hasattr(child, "get_ydata"): ydata = child.get_ydata() if len(ydata) == 1: y_values.append(float(ydata[0])) # Without normalization, original values should be preserved for expected in [0.1, 0.3, 0.5]: - assert any(abs(y - expected) < 0.01 for y in y_values), \ - f"Original value {expected} not found in {y_values}" + assert any( + abs(y - expected) < 0.01 for y in y_values + ), f"Original value {expected} not found in {y_values}" plt.close() @@ -432,11 +425,13 @@ def test_plot_event_study_normalization_with_nan_reference(self): pytest.importorskip("matplotlib") import matplotlib.pyplot as plt - df = pd.DataFrame({ - 'period': [-1, 0, 1], - 'effect': [0.1, np.nan, 0.5], # Reference period has NaN effect - 'se': [0.1, 0.1, 0.1] - }) + df = pd.DataFrame( + { + "period": [-1, 0, 1], + "effect": [0.1, np.nan, 0.5], # Reference period has NaN effect + "se": [0.1, 0.1, 0.1], + } + ) # This should not raise and should skip normalization ax = plot_event_study(df, reference_period=0, show=False) @@ -444,15 +439,16 @@ def test_plot_event_study_normalization_with_nan_reference(self): # Extract y-values (NaN effect is skipped in plotting) y_values = [] for child in ax.get_children(): - if hasattr(child, 'get_ydata'): + if hasattr(child, "get_ydata"): ydata = child.get_ydata() if len(ydata) == 1: y_values.append(float(ydata[0])) # Original non-NaN values should be preserved (not normalized) for expected in [0.1, 0.5]: - assert any(abs(y - expected) < 0.01 for y in y_values), \ - f"Original value {expected} not found in {y_values}" + assert any( + abs(y - expected) < 0.01 for y in y_values + ), f"Original value {expected} not found in {y_values}" plt.close() @@ -471,7 +467,7 @@ def test_plot_cs_results_no_auto_normalization(self, cs_results): # Get original effects from results (before any normalization) original_effects = { - period: effect_data['effect'] + period: effect_data["effect"] for period, effect_data in results.event_study_effects.items() } @@ -482,7 +478,7 @@ def test_plot_cs_results_no_auto_normalization(self, cs_results): # Extract plotted y-values y_values = [] for child in ax.get_children(): - if hasattr(child, 'get_ydata'): + if hasattr(child, "get_ydata"): ydata = child.get_ydata() if len(ydata) == 1: y_values.append(float(ydata[0])) @@ -498,9 +494,10 @@ def test_plot_cs_results_no_auto_normalization(self, cs_results): for period, orig_effect in original_effects.items(): if np.isfinite(orig_effect): # Check that original value is present (not normalized) - assert any(abs(y - orig_effect) < 0.05 for y in y_values), \ - f"Original effect {orig_effect:.3f} for period {period} " \ + assert any(abs(y - orig_effect) < 0.05 for y in y_values), ( + f"Original effect {orig_effect:.3f} for period {period} " f"should be plotted without normalization. Found y_values: {y_values}" + ) plt.close() @@ -518,7 +515,7 @@ def test_plot_cs_results_explicit_reference_normalizes(self, cs_results): # Get original effects from results original_effects = { - period: effect_data['effect'] + period: effect_data["effect"] for period, effect_data in results.event_study_effects.items() } @@ -528,8 +525,7 @@ def test_plot_cs_results_explicit_reference_normalizes(self, cs_results): # Compute expected normalized effects expected_normalized = { - period: effect - ref_effect - for period, effect in original_effects.items() + period: effect - ref_effect for period, effect in original_effects.items() } # Plot WITH explicit reference_period - this SHOULD normalize @@ -538,21 +534,23 @@ def test_plot_cs_results_explicit_reference_normalizes(self, cs_results): # Extract plotted y-values y_values = [] for child in ax.get_children(): - if hasattr(child, 'get_ydata'): + if hasattr(child, "get_ydata"): ydata = child.get_ydata() if len(ydata) == 1: y_values.append(float(ydata[0])) # The reference period should now be at y=0 (normalized) - assert any(abs(y) < 0.01 for y in y_values), \ - f"Reference period should be normalized to y=0, got y_values={y_values}" + assert any( + abs(y) < 0.01 for y in y_values + ), f"Reference period should be normalized to y=0, got y_values={y_values}" # Verify normalized values are present for period, norm_effect in expected_normalized.items(): if np.isfinite(norm_effect): - assert any(abs(y - norm_effect) < 0.05 for y in y_values), \ - f"Normalized effect {norm_effect:.3f} for period {period} " \ + assert any(abs(y - norm_effect) < 0.05 for y in y_values), ( + f"Normalized effect {norm_effect:.3f} for period {period} " f"not found in {y_values}" + ) # Verify reference period has no error bars (SE was set to NaN) # Find error bar x-coordinates @@ -562,15 +560,16 @@ def test_plot_cs_results_explicit_reference_normalizes(self, cs_results): if ref_x_idx is not None: errorbar_x_coords = set() for child in ax.get_children(): - if hasattr(child, 'get_segments'): + if hasattr(child, "get_segments"): segments = child.get_segments() for seg in segments: if len(seg) >= 2: errorbar_x_coords.add(round(seg[0][0], 1)) # Reference period should NOT have error bars - assert ref_x_idx not in errorbar_x_coords, \ - f"Reference period at x={ref_x_idx} should have no error bars" + assert ( + ref_x_idx not in errorbar_x_coords + ), f"Reference period at x={ref_x_idx} should have no error bars" plt.close() @@ -590,18 +589,15 @@ def test_full_workflow_multi_period(self): did = MultiPeriodDiD() results = did.fit( data, - outcome='outcome', - treatment='treated', - time='period', - post_periods=[2, 3] + outcome="outcome", + treatment="treated", + time="period", + post_periods=[2, 3], + reference_period=1, ) # Plot - ax = plot_event_study( - results, - title="Treatment Effects Over Time", - show=False - ) + ax = plot_event_study(results, title="Treatment Effects Over Time", show=False) assert ax is not None plt.close() @@ -618,19 +614,15 @@ def test_full_workflow_callaway_santanna(self): cs = CallawaySantAnna() results = cs.fit( data, - outcome='outcome', - unit='unit', - time='time', - first_treat='first_treat', - aggregate='event_study' + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", ) # Plot - ax = plot_event_study( - results, - title="Staggered DiD Event Study", - show=False - ) + ax = plot_event_study(results, title="Staggered DiD Event Study", show=False) assert ax is not None plt.close() From a2931e67fba9911aa322198ae83f9055db632afd Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 13:25:31 -0500 Subject: [PATCH 2/8] Address PR #125 review feedback: validation hardening and docs fix - Upgrade post-period reference_period from warning to ValueError (P1) - Add warning when <2 pre-periods available for parallel trends (P1) - Add absorbing treatment validation when unit param provided (P2) - Fix staggered detection false positives on unbalanced panels (P2) - Fix REGISTRY.md SE default documentation (HC1, not cluster-robust) Co-Authored-By: Claude Opus 4.5 --- diff_diff/estimators.py | 55 ++++++++++---- docs/methodology/REGISTRY.md | 4 +- tests/test_estimators.py | 136 ++++++++++++++++++++++++++++++++--- 3 files changed, 173 insertions(+), 22 deletions(-) diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 36cfa00..635c91a 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -796,11 +796,32 @@ def fit( # type: ignore[override] if unit not in data.columns: raise ValueError(f"Unit column '{unit}' not found in data") - # Check for staggered treatment timing - treated_mask = data[treatment] == 1 - if treated_mask.any(): - treatment_timing = data.loc[treated_mask].groupby(unit)[time].min() - if treatment_timing.nunique() > 1: + # Check for staggered treatment timing and absorbing treatment + unit_time_sorted = data.sort_values([unit, time]) + adoption_times = {} + has_reversal = False + for u, group in unit_time_sorted.groupby(unit): + d_vals = group[treatment].values + # Check for treatment reversal (non-absorbing treatment) + if not has_reversal and len(d_vals) > 1 and np.any(np.diff(d_vals) < 0): + warnings.warn( + f"Treatment reversal detected (unit '{u}' transitions from " + f"treated to untreated). MultiPeriodDiD assumes treatment is " + f"an absorbing state (once treated, always treated). " + f"Treatment reversals violate this assumption and may " + f"produce unreliable estimates.", + UserWarning, + stacklevel=2, + ) + has_reversal = True + # Only use units with observed 0→1 transition for adoption timing + # (skip units that are always treated — can't determine adoption time) + if 0 in d_vals and 1 in d_vals: + adoption_times[u] = group.loc[group[treatment] == 1, time].iloc[0] + + if len(adoption_times) > 0: + unique_adoption = len(set(adoption_times.values())) + if unique_adoption > 1: warnings.warn( "Treatment timing varies across units (staggered adoption " "detected). MultiPeriodDiD assumes simultaneous adoption " @@ -832,6 +853,16 @@ def fit( # type: ignore[override] if len(pre_periods) == 0: raise ValueError("Must have at least one pre-treatment period") + if len(pre_periods) < 2: + warnings.warn( + "Only one pre-treatment period available. At least 2 pre-periods " + "are needed to assess parallel trends. The treatment effect estimate " + "is still valid, but pre-period coefficients for parallel trends " + "testing are not available.", + UserWarning, + stacklevel=2, + ) + # Validate post_periods are in the data for p in post_periods: if p not in all_periods: @@ -856,15 +887,15 @@ def fit( # type: ignore[override] elif reference_period not in all_periods: raise ValueError(f"Reference period '{reference_period}' not found in time column") - # Warn if reference period is a post-treatment period + # Disallow post-period reference (downstream logic assumes reference is pre-period) if reference_period in post_periods: - warnings.warn( + raise ValueError( f"reference_period={reference_period} is a post-treatment period. " - f"The reference period should typically be a pre-treatment period " - f"(e.g., the last pre-period). Post-period references alter the " - f"interpretation of all coefficients.", - UserWarning, - stacklevel=2, + f"The reference period must be a pre-treatment period " + f"(e.g., the last pre-period {pre_periods[-1]}). " + f"Post-period references are not supported because the reference " + f"period is excluded from estimation, which would bias avg_att " + f"and break downstream inference." ) # Validate fixed effects and absorb columns diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 82755d4..c2d1b22 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -147,8 +147,8 @@ Var(ATT_avg) = 1'V1 / |post|² where V is the VCV sub-matrix for post-treatment δ_e coefficients. *Standard errors:* -- Default: Cluster-robust at unit level (accounts for within-unit serial correlation) -- Alternative: HC1 heteroskedasticity-robust (for cross-sectional data) +- Default: HC1 heteroskedasticity-robust (same as DifferenceInDifferences base class) +- Alternative: Cluster-robust at unit level via `cluster` parameter (recommended for panel data) - Optional: Wild cluster bootstrap (complex for multi-coefficient testing; requires joint bootstrap distribution) - Degrees of freedom adjusted for absorbed fixed effects diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 8091741..21d1a56 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -2004,12 +2004,9 @@ def test_reference_period_stored_in_results(self, panel_data): ) assert results.reference_period == 1 - def test_reference_period_in_post_warns(self, panel_data): - """Setting reference_period to a post-period should emit warning.""" - import warnings - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + def test_reference_period_in_post_raises(self, panel_data): + """Setting reference_period to a post-period should raise ValueError.""" + with pytest.raises(ValueError, match="post-treatment period"): MultiPeriodDiD().fit( panel_data, outcome="outcome", @@ -2018,8 +2015,6 @@ def test_reference_period_in_post_warns(self, panel_data): post_periods=[3, 4, 5], reference_period=4, ) - post_ref_warnings = [x for x in w if "post-treatment period" in str(x.message)] - assert len(post_ref_warnings) > 0, "Expected warning about post-period reference" def test_staggered_treatment_warning(self): """Staggered treatment timing with unit param should emit warning.""" @@ -2160,6 +2155,131 @@ def test_to_dict_has_reference_period(self, panel_data): assert "reference_period" in d assert d["reference_period"] == 2 + def test_single_pre_period_warns(self): + """Single pre-period should warn but still produce valid results.""" + np.random.seed(42) + data = [] + for unit_id in range(40): + is_treated = unit_id < 20 + for period in range(3): # period 0 = pre, periods 1,2 = post + y = 10.0 + period * 0.5 + (2.0 if is_treated and period >= 1 else 0) + y += np.random.normal(0, 0.3) + data.append( + { + "unit": unit_id, + "period": period, + "treated": int(is_treated and period >= 1), + "outcome": y, + } + ) + df = pd.DataFrame(data) + # Make treatment a proper absorbing indicator + for uid in range(20): + df.loc[(df["unit"] == uid), "treated"] = 1 + for uid in range(20, 40): + df.loc[(df["unit"] == uid), "treated"] = 0 + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = MultiPeriodDiD().fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[1, 2], + reference_period=0, + ) + pre_period_warnings = [x for x in w if "Only one pre-treatment period" in str(x.message)] + assert len(pre_period_warnings) > 0, "Expected warning about single pre-period" + # Results should still be valid + assert np.isfinite(results.avg_att) + + def test_treatment_reversal_warns(self): + """Treatment reversal (D goes 1→0) should emit warning when unit provided.""" + np.random.seed(42) + data = [] + for unit_id in range(40): + for period in range(6): + is_treated = unit_id < 20 + # Unit 0 has treatment reversal: treated in periods 2-3, untreated in 4-5 + if unit_id == 0: + d = 1 if 2 <= period <= 3 else 0 + elif is_treated: + d = 1 if period >= 3 else 0 + else: + d = 0 + y = 10.0 + period * 0.5 + (2.0 if d else 0) + np.random.normal(0, 0.3) + data.append( + { + "unit": unit_id, + "period": period, + "treated": d, + "outcome": y, + } + ) + df = pd.DataFrame(data) + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MultiPeriodDiD().fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + unit="unit", + ) + reversal_warnings = [x for x in w if "Treatment reversal" in str(x.message)] + assert len(reversal_warnings) > 0, "Expected warning about treatment reversal" + + def test_staggered_no_false_positive_unbalanced(self): + """Unbalanced panel with simultaneous treatment should not trigger staggered warning.""" + np.random.seed(42) + data = [] + for unit_id in range(40): + is_treated = unit_id < 20 + # Some treated units enter the panel late (already treated) + if is_treated and unit_id < 5: + start_period = 4 # Enter after treatment starts at period 3 + else: + start_period = 0 + for period in range(start_period, 8): + d = 1 if is_treated and period >= 3 else 0 + y = 10.0 + period * 0.5 + (2.0 if d else 0) + np.random.normal(0, 0.3) + data.append( + { + "unit": unit_id, + "period": period, + "treated": d, + "outcome": y, + } + ) + df = pd.DataFrame(data) + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + MultiPeriodDiD().fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5, 6, 7], + reference_period=2, + unit="unit", + ) + staggered_warnings = [x for x in w if "staggered" in str(x.message).lower()] + assert len(staggered_warnings) == 0, ( + "Should NOT warn about staggered adoption when all units adopt simultaneously " + "(some just enter the panel late)" + ) + class TestSyntheticDiD: """Tests for SyntheticDiD estimator.""" From 9440e094e52b66dedf481d695fde2b91a6fb5361 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 14:27:07 -0500 Subject: [PATCH 3/8] Address PR #125 review round 2: NaN propagation and docs fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Guard period-level p_value and conf_int when SE is non-finite/zero (previously conf_int returned misleading zero-width interval) - Add test assertions for full NaN propagation on unidentified periods - Fix REGISTRY.md: "requires ≥2 pre-periods" → "warns when <2" - Fix CHANGELOG.md: "Warning" → "ValueError" for post-period reference Co-Authored-By: Claude Opus 4.5 --- CHANGELOG.md | 2 +- diff_diff/estimators.py | 11 ++++++++--- docs/methodology/REGISTRY.md | 5 +++-- tests/test_estimators.py | 7 +++++++ 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2685221..ee6a939 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `reference_period` and `interaction_indices` attributes on `MultiPeriodDiDResults` - `pre_period_effects` and `post_period_effects` convenience properties on results - Pre-period section in `summary()` output with reference period indicator -- Warning when `reference_period` is set to a post-treatment period +- `ValueError` when `reference_period` is set to a post-treatment period - Staggered adoption warning when treatment timing varies across units (with `unit` param) - Informative KeyError when accessing reference period via `get_effect()` diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 635c91a..c4da46f 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -1016,9 +1016,14 @@ def fit( # type: ignore[override] idx = interaction_indices[period] effect = coefficients[idx] se = np.sqrt(vcov[idx, idx]) - t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan - p_value = compute_p_value(t_stat, df=df) - conf_int = compute_confidence_interval(effect, se, self.alpha, df=df) + if np.isfinite(se) and se > 0: + t_stat = effect / se + p_value = compute_p_value(t_stat, df=df) + conf_int = compute_confidence_interval(effect, se, self.alpha, df=df) + else: + t_stat = np.nan + p_value = np.nan + conf_int = (np.nan, np.nan) period_effects[period] = PeriodEffect( period=period, diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index c2d1b22..991e57b 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -95,8 +95,9 @@ CallawaySantAnna or SunAbraham instead. *Assumption checks / warnings:* - Treatment indicator must be binary (0/1) with variation in both groups -- Requires at least 2 pre-treatment and 1 post-treatment period - (need ≥2 pre-periods to test parallel trends) +- Requires at least 1 pre-treatment and 1 post-treatment period +- Warns when only 1 pre-period available (≥2 needed to test parallel trends; + ATT is still valid but pre-trends assessment is not possible) - Reference period defaults to last pre-treatment period (e=-1 convention) - Warns if treatment timing varies across units (suggests CallawaySantAnna) - Treatment must be an absorbing state (once treated, always treated) diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 21d1a56..7d56b86 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -1845,6 +1845,13 @@ def test_avg_att_nan_when_period_effect_nan(self, multi_period_data): pe_3 = results.period_effects[3] assert np.isnan(pe_3.effect), "Period 3 effect should be NaN (unidentified)" + # All inference fields for the unidentified period should be NaN + assert np.isnan(pe_3.se), "Period 3 SE should be NaN (unidentified)" + assert np.isnan(pe_3.t_stat), "Period 3 t_stat should be NaN (unidentified)" + assert np.isnan(pe_3.p_value), "Period 3 p_value should be NaN (unidentified)" + assert np.isnan(pe_3.conf_int[0]), "Period 3 CI lower should be NaN (unidentified)" + assert np.isnan(pe_3.conf_int[1]), "Period 3 CI upper should be NaN (unidentified)" + # avg_att should be NaN because one period effect is NaN (R-style NA propagation) assert np.isnan(results.avg_att), "avg_att should be NaN when any period effect is NaN" assert np.isnan(results.avg_se), "avg_se should be NaN when avg_att is NaN" From 8207057580c6d184fad01b81df1fbd4837792137 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 15:09:48 -0500 Subject: [PATCH 4/8] Address PR #125 review round 3: D_it warning and HonestDiD NaN filtering - Add time-varying treatment warning when `unit` is provided and treatment varies within units (guides users toward ever-treated indicator D_i) - Filter non-finite pre-period effects in HonestDiD _extract_event_study_params and _estimate_max_pre_violation (prevents NaN propagation into bounds) - Update REGISTRY.md: D_i requirement and staggered check conditional on unit - Add tests for D_it warning, NaN filtering, and all-NaN-pre-periods error Co-Authored-By: Claude Opus 4.5 --- CHANGELOG.md | 4 + diff_diff/estimators.py | 21 +++++- diff_diff/honest_did.py | 25 ++++++- docs/methodology/REGISTRY.md | 5 +- tests/test_estimators.py | 40 ++++++++++ tests/test_honest_did.py | 139 +++++++++++++++++++++++++++++++++-- 6 files changed, 224 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee6a939..69dd68d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - t_stat uses `np.isfinite(se) and se > 0` guard (consistent with other estimators) ### Added +- Time-varying treatment warning when `unit` is provided and treatment varies + within units (guides users toward ever-treated indicator D_i) - `unit` parameter to `MultiPeriodDiD.fit()` for staggered adoption detection - `reference_period` and `interaction_indices` attributes on `MultiPeriodDiDResults` - `pre_period_effects` and `post_period_effects` convenience properties on results @@ -34,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 The `variance_method` field has also been removed from `TROPResults`. ### Fixed +- HonestDiD: filter non-finite pre-period effects from MultiPeriodDiD results + (prevents NaN propagation into sensitivity bounds) - HonestDiD VCV extraction: now uses interaction sub-VCV instead of full regression VCV (via `interaction_indices` period → column index mapping) diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index c4da46f..f45d589 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -741,7 +741,10 @@ def fit( # type: ignore[override] outcome : str Name of the outcome variable column. treatment : str - Name of the treatment group indicator column (0/1). + Name of the treatment group indicator column (0/1). Should be a + time-invariant ever-treated indicator (D_i = 1 for all periods of + treated units). If treatment is time-varying (D_it), pre-period + interaction coefficients will be unidentified. time : str Name of the time period column (can have multiple values). post_periods : list @@ -831,6 +834,22 @@ def fit( # type: ignore[override] stacklevel=2, ) + # Check for time-varying treatment (D_it instead of D_i) + # If any unit has a 0→1 transition, the treatment column is D_it. + # MultiPeriodDiD expects a time-invariant ever-treated indicator. + warnings.warn( + "Treatment indicator varies within units (time-varying " + "treatment detected). MultiPeriodDiD's event-study " + "specification expects a time-invariant ever-treated " + "indicator (D_i = 1 for all periods of eventually-treated " + "units). With time-varying treatment, pre-period " + "interaction coefficients will be unidentified. Consider: " + f"df['ever_treated'] = df.groupby('{unit}')['{treatment}']" + ".transform('max')", + UserWarning, + stacklevel=2, + ) + # Get all unique time periods all_periods = sorted(data[time].unique()) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index ddae7ff..577a2f5 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -563,8 +563,19 @@ def _extract_event_study_params( pre_periods = results.pre_periods post_periods = results.post_periods - # Extract all estimated effects in chronological order - all_estimated = sorted(results.period_effects.keys()) + # Filter out periods with non-finite effects/SEs (e.g. rank-deficient) + all_estimated = sorted( + p + for p in results.period_effects.keys() + if np.isfinite(results.period_effects[p].effect) + and np.isfinite(results.period_effects[p].se) + ) + + if not all_estimated: + raise ValueError( + "No period effects with finite estimates found. " "Cannot compute HonestDiD bounds." + ) + effects = [results.period_effects[p].effect for p in all_estimated] ses = [results.period_effects[p].se for p in all_estimated] @@ -572,6 +583,13 @@ def _extract_event_study_params( num_pre_periods = sum(1 for p in all_estimated if p in pre_periods) num_post_periods = sum(1 for p in all_estimated if p in post_periods) + if num_pre_periods == 0: + raise ValueError( + "No pre-period effects with finite estimates found. " + "HonestDiD requires at least one identified pre-period " + "coefficient." + ) + # Extract proper sub-VCV for interaction terms if ( results.vcov is not None @@ -1236,10 +1254,11 @@ def _estimate_max_pre_violation(self, results: Any, pre_periods: List) -> float: """ if isinstance(results, MultiPeriodDiDResults): # Pre-period effects are now in period_effects directly + # Filter out non-finite effects (e.g. from rank-deficient designs) pre_effects = [ abs(results.period_effects[p].effect) for p in pre_periods - if p in results.period_effects + if p in results.period_effects and np.isfinite(results.period_effects[p].effect) ] if pre_effects: return max(pre_effects) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 991e57b..fcf0a7e 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -99,7 +99,10 @@ CallawaySantAnna or SunAbraham instead. - Warns when only 1 pre-period available (≥2 needed to test parallel trends; ATT is still valid but pre-trends assessment is not possible) - Reference period defaults to last pre-treatment period (e=-1 convention) -- Warns if treatment timing varies across units (suggests CallawaySantAnna) +- Treatment indicator should be time-invariant ever-treated (D_i); + warns when time-varying D_it detected (requires `unit` parameter) +- Warns if treatment timing varies across units when `unit` is provided + (suggests CallawaySantAnna or SunAbraham instead) - Treatment must be an absorbing state (once treated, always treated) *Estimator equation (target specification):* diff --git a/tests/test_estimators.py b/tests/test_estimators.py index 7d56b86..e99444c 100644 --- a/tests/test_estimators.py +++ b/tests/test_estimators.py @@ -2244,6 +2244,46 @@ def test_treatment_reversal_warns(self): reversal_warnings = [x for x in w if "Treatment reversal" in str(x.message)] assert len(reversal_warnings) > 0, "Expected warning about treatment reversal" + def test_time_varying_treatment_warning(self): + """Time-varying D_it (0 pre, 1 post) should emit warning about ever-treated indicator.""" + np.random.seed(42) + data = [] + for unit_id in range(40): + is_treated = unit_id < 20 + for period in range(6): + # D_it: 0 in pre-periods, 1 in post-periods for treated units + d = 1 if is_treated and period >= 3 else 0 + y = 10.0 + period * 0.5 + (2.0 if d else 0) + np.random.normal(0, 0.3) + data.append( + { + "unit": unit_id, + "period": period, + "treated": d, + "outcome": y, + } + ) + df = pd.DataFrame(data) + + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = MultiPeriodDiD().fit( + df, + outcome="outcome", + treatment="treated", + time="period", + post_periods=[3, 4, 5], + reference_period=2, + unit="unit", + ) + dit_warnings = [x for x in w if "time-varying" in str(x.message).lower()] + assert len(dit_warnings) > 0, "Expected warning about time-varying treatment" + assert "ever_treated" in str(dit_warnings[0].message) + # Results should still be produced (but may have NaN due to rank deficiency) + assert results is not None + assert len(results.period_effects) > 0 + def test_staggered_no_false_positive_unbalanced(self): """Unbalanced panel with simultaneous treatment should not trigger staggered warning.""" np.random.seed(42) diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index a60a0bf..ffa49b1 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -693,8 +693,12 @@ def test_multiperiod_sub_vcov_extraction(self, simple_panel_data): class TestEdgeCases: """Tests for edge cases and error handling.""" - def test_single_post_period(self): - """Test with single post-period.""" + def test_single_post_period_no_pre_effects_raises(self): + """Test with single post-period and no pre-period effects raises ValueError. + + HonestDiD requires pre-period coefficients for sensitivity analysis. + A results object with only post-period effects is not usable. + """ period_effects = { 4: PeriodEffect( period=4, effect=5.0, se=0.5, t_stat=10.0, p_value=0.0001, conf_int=(4.02, 5.98) @@ -717,9 +721,8 @@ def test_single_post_period(self): ) honest = HonestDiD(method="relative_magnitude", M=1.0) - bounds = honest.fit(results) - - assert isinstance(bounds, HonestDiDResults) + with pytest.raises(ValueError, match="No pre-period effects with finite"): + honest.fit(results) def test_m_zero_recovers_standard(self, mock_multiperiod_results): """Test that M=0 gives tighter bounds.""" @@ -816,6 +819,132 @@ def test_max_pre_violation_excludes_reference_period(self): # The actual effects are non-zero due to sampling variation assert max_violation > 0, "max_pre_violation should be > 0 when real pre-periods exist" + def test_honest_did_filters_nan_pre_period_effects(self): + """HonestDiD should filter NaN pre-period effects from MultiPeriodDiDResults. + + When MultiPeriodDiD produces NaN effects (e.g. from rank-deficient designs + with time-varying treatment), HonestDiD should skip those periods rather + than propagating NaN into sensitivity bounds. + """ + # Create results with one NaN pre-period (simulating rank deficiency) + period_effects = { + 0: PeriodEffect( + period=0, + effect=np.nan, + se=np.nan, + t_stat=np.nan, + p_value=np.nan, + conf_int=(np.nan, np.nan), + ), + 1: PeriodEffect( + period=1, + effect=0.1, + se=0.3, + t_stat=0.33, + p_value=0.74, + conf_int=(-0.49, 0.69), + ), + # Reference period (2) omitted + 3: PeriodEffect( + period=3, + effect=2.5, + se=0.4, + t_stat=6.25, + p_value=0.0001, + conf_int=(1.72, 3.28), + ), + 4: PeriodEffect( + period=4, + effect=2.8, + se=0.4, + t_stat=7.0, + p_value=0.0001, + conf_int=(2.02, 3.58), + ), + } + + # Build VCV with NaN row/col for period 0 (rank-deficient) + interaction_indices = {0: 0, 1: 1, 3: 2, 4: 3} + vcov_with_nan = np.full((4, 4), 0.0) + vcov_with_nan[0, :] = np.nan + vcov_with_nan[:, 0] = np.nan + vcov_with_nan[1, 1] = 0.09 + vcov_with_nan[2, 2] = 0.16 + vcov_with_nan[3, 3] = 0.16 + + results = MultiPeriodDiDResults( + period_effects=period_effects, + avg_att=2.65, + avg_se=0.4, + avg_t_stat=6.625, + avg_p_value=0.0001, + avg_conf_int=(1.87, 3.43), + n_obs=400, + n_treated=200, + n_control=200, + pre_periods=[0, 1, 2], + post_periods=[3, 4], + vcov=vcov_with_nan, + reference_period=2, + interaction_indices=interaction_indices, + ) + + # _extract_event_study_params should filter out period 0 (NaN) + beta_hat, sigma, num_pre, num_post, pre_p, post_p = _extract_event_study_params(results) + assert len(beta_hat) == 3 # periods 1, 3, 4 (period 0 filtered) + assert num_pre == 1 # only period 1 + assert num_post == 2 # periods 3, 4 + assert np.all(np.isfinite(beta_hat)) + assert np.all(np.isfinite(sigma)) + + # _estimate_max_pre_violation should ignore the NaN period + honest = HonestDiD(method="relative_magnitude", M=1.0) + max_viol = honest._estimate_max_pre_violation(results, [0, 1]) + assert np.isfinite(max_viol) + assert max_viol == pytest.approx(0.1, abs=1e-10) # only period 1's |effect| + + def test_honest_did_all_pre_nan_raises(self): + """HonestDiD should raise ValueError when all pre-period effects are NaN.""" + period_effects = { + 0: PeriodEffect( + period=0, + effect=np.nan, + se=np.nan, + t_stat=np.nan, + p_value=np.nan, + conf_int=(np.nan, np.nan), + ), + # Reference period (1) omitted + 2: PeriodEffect( + period=2, + effect=2.5, + se=0.4, + t_stat=6.25, + p_value=0.0001, + conf_int=(1.72, 3.28), + ), + } + + results = MultiPeriodDiDResults( + period_effects=period_effects, + avg_att=2.5, + avg_se=0.4, + avg_t_stat=6.25, + avg_p_value=0.0001, + avg_conf_int=(1.72, 3.28), + n_obs=200, + n_treated=100, + n_control=100, + pre_periods=[0, 1], + post_periods=[2], + vcov=np.diag([0.16]), + reference_period=1, + interaction_indices={0: 0, 2: 1}, + ) + + with pytest.raises(ValueError, match="No pre-period effects with finite"): + _extract_event_study_params(results) + # ============================================================================= # Tests for Visualization (without matplotlib) From d1de9cb68ee8a4ad5e412d23badc39e9040cc6ed Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 15:32:39 -0500 Subject: [PATCH 5/8] Address PR #125 review round 4: HonestDiD post-period NaN guard Add num_post == 0 guard in HonestDiD.fit() to raise ValueError when all post-period effects are non-finite, preventing silent computation with empty arrays. Covers both MultiPeriodDiD and CallawaySantAnna paths. Co-Authored-By: Claude Opus 4.5 --- CHANGELOG.md | 5 +- diff_diff/honest_did.py | 7 +++ tests/test_honest_did.py | 98 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69dd68d..b5e8a1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,8 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 The `variance_method` field has also been removed from `TROPResults`. ### Fixed -- HonestDiD: filter non-finite pre-period effects from MultiPeriodDiD results - (prevents NaN propagation into sensitivity bounds) +- HonestDiD: filter non-finite period effects from MultiPeriodDiD results + (prevents NaN propagation into sensitivity bounds; raises ValueError + when no finite pre- or post-period effects remain) - HonestDiD VCV extraction: now uses interaction sub-VCV instead of full regression VCV (via `interaction_indices` period → column index mapping) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 577a2f5..76ef713 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -1111,6 +1111,13 @@ def fit( # Update num_post to match actual data num_post = len(beta_post) + if num_post == 0: + raise ValueError( + "No post-period effects with finite estimates found. " + "HonestDiD requires at least one identified post-period " + "coefficient to compute bounds." + ) + # Set up weighting vector if self.l_vec is None: l_vec = np.ones(num_post) / num_post # Uniform weights diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index ffa49b1..ea52e18 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -945,6 +945,104 @@ def test_honest_did_all_pre_nan_raises(self): with pytest.raises(ValueError, match="No pre-period effects with finite"): _extract_event_study_params(results) + def test_honest_did_all_post_nan_raises(self): + """HonestDiD should raise ValueError when all post-period effects are NaN. + + When MultiPeriodDiD produces NaN for all post-period effects (e.g. from + severe rank deficiency), HonestDiD.fit() should raise rather than + silently computing with an empty weight vector. + """ + period_effects = { + 0: PeriodEffect( + period=0, + effect=0.1, + se=0.3, + t_stat=0.33, + p_value=0.74, + conf_int=(-0.49, 0.69), + ), + # Reference period (1) omitted + 2: PeriodEffect( + period=2, + effect=np.nan, + se=np.nan, + t_stat=np.nan, + p_value=np.nan, + conf_int=(np.nan, np.nan), + ), + 3: PeriodEffect( + period=3, + effect=np.nan, + se=np.nan, + t_stat=np.nan, + p_value=np.nan, + conf_int=(np.nan, np.nan), + ), + } + + interaction_indices = {0: 0, 2: 1, 3: 2} + vcov = np.full((3, 3), 0.0) + vcov[0, 0] = 0.09 + vcov[1, :] = np.nan + vcov[:, 1] = np.nan + vcov[2, :] = np.nan + vcov[:, 2] = np.nan + + results = MultiPeriodDiDResults( + period_effects=period_effects, + avg_att=np.nan, + avg_se=np.nan, + avg_t_stat=np.nan, + avg_p_value=np.nan, + avg_conf_int=(np.nan, np.nan), + n_obs=300, + n_treated=150, + n_control=150, + pre_periods=[0, 1], + post_periods=[2, 3], + vcov=vcov, + reference_period=1, + interaction_indices=interaction_indices, + ) + + honest = HonestDiD(method="relative_magnitude", M=1.0) + with pytest.raises(ValueError, match="No post-period effects with finite"): + honest.fit(results) + + def test_honest_did_cs_all_post_nan_raises(self): + """HonestDiD should raise ValueError when all CS post-period effects have NaN SEs. + + When CallawaySantAnnaResults has non-finite SEs for all t>=0 event-study + effects, HonestDiD.fit() should raise rather than computing with empty + post-period arrays. + """ + from diff_diff.staggered_results import CallawaySantAnnaResults + + # Create CS results with valid pre-periods but NaN SEs in post-periods + cs_results = CallawaySantAnnaResults( + group_time_effects={}, + overall_att=np.nan, + overall_se=np.nan, + overall_t_stat=np.nan, + overall_p_value=np.nan, + overall_conf_int=(np.nan, np.nan), + groups=[2004], + time_periods=[2000, 2001, 2002, 2003], + n_obs=400, + n_treated_units=200, + n_control_units=200, + ) + cs_results.event_study_effects = { + -2: {"effect": 0.1, "se": 0.3, "n_groups": 2}, + -1: {"effect": 0.05, "se": 0.25, "n_groups": 2}, + 0: {"effect": 2.0, "se": np.nan, "n_groups": 0}, + 1: {"effect": 2.5, "se": np.nan, "n_groups": 0}, + } + + honest = HonestDiD(method="relative_magnitude", M=1.0) + with pytest.raises(ValueError, match="No post-period effects with finite"): + honest.fit(cs_results) + # ============================================================================= # Tests for Visualization (without matplotlib) From 66b27546da306fc4b3e89a3e0423633fbdc09e9d Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 16:01:30 -0500 Subject: [PATCH 6/8] Address PR #125 review round 5: avg_se Inf guard and HonestDiD ordering - Align avg_se guard with per-period pattern (np.isfinite check prevents avg_t_stat=0 / avg_p_value=1 when variance is infinite) - Use explicit pre-then-post ordering in HonestDiD extraction instead of sorted() (prevents misclassification when period labels don't sort chronologically) - Add test with non-monotone period labels (pre=[5,6,7], post=[1,2]) Co-Authored-By: Claude Opus 4.5 --- CHANGELOG.md | 4 +++ diff_diff/estimators.py | 2 +- diff_diff/honest_did.py | 11 +++++-- tests/test_honest_did.py | 66 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5e8a1d..c013c6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 when no finite pre- or post-period effects remain) - HonestDiD VCV extraction: now uses interaction sub-VCV instead of full regression VCV (via `interaction_indices` period → column index mapping) +- MultiPeriodDiD: `avg_se` guard now checks `np.isfinite()` (matches per-period pattern; + prevents `avg_t_stat=0` / `avg_p_value=1` when variance is infinite) +- HonestDiD: extraction now uses explicit pre-then-post ordering instead of sorted period + labels (prevents misclassification when period labels don't sort chronologically) ## [2.2.0] - 2026-01-27 diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index f45d589..5f7f797 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -1086,7 +1086,7 @@ def fit( # type: ignore[override] avg_conf_int = (np.nan, np.nan) else: avg_se = float(np.sqrt(avg_var)) - if avg_se > 0: + if np.isfinite(avg_se) and avg_se > 0: avg_t_stat = avg_att / avg_se avg_p_value = compute_p_value(avg_t_stat, df=df) avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 76ef713..cc792bf 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -563,13 +563,18 @@ def _extract_event_study_params( pre_periods = results.pre_periods post_periods = results.post_periods - # Filter out periods with non-finite effects/SEs (e.g. rank-deficient) - all_estimated = sorted( + # Filter periods with finite effects/SEs, maintaining pre-then-post order + # (sorted() alone would fail if period labels don't sort chronologically) + finite_periods = { p for p in results.period_effects.keys() if np.isfinite(results.period_effects[p].effect) and np.isfinite(results.period_effects[p].se) - ) + } + + pre_estimated = sorted(p for p in finite_periods if p in pre_periods) + post_estimated = sorted(p for p in finite_periods if p in post_periods) + all_estimated = pre_estimated + post_estimated if not all_estimated: raise ValueError( diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index ea52e18..050ae27 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -1043,6 +1043,72 @@ def test_honest_did_cs_all_post_nan_raises(self): with pytest.raises(ValueError, match="No post-period effects with finite"): honest.fit(cs_results) + def test_honest_did_nonmonotone_period_labels(self): + """HonestDiD extraction should handle period labels where sorted order + doesn't separate pre/post (e.g. pre=[5,6], post=[1,2]). + + The extraction must place pre-period effects before post-period effects + in beta_hat regardless of label values. + """ + # Pre-periods 5, 6, 7 (reference=7 omitted), post-periods 1, 2 + # sorted() would give [1, 2, 5, 6] — post before pre — which is wrong + period_effects = { + 5: PeriodEffect( + period=5, effect=0.1, se=0.3, t_stat=0.33, p_value=0.74, conf_int=(-0.49, 0.69) + ), + 6: PeriodEffect( + period=6, effect=0.2, se=0.35, t_stat=0.57, p_value=0.57, conf_int=(-0.49, 0.89) + ), + 1: PeriodEffect( + period=1, effect=2.5, se=0.4, t_stat=6.25, p_value=0.0001, conf_int=(1.72, 3.28) + ), + 2: PeriodEffect( + period=2, effect=2.8, se=0.45, t_stat=6.22, p_value=0.0001, conf_int=(1.92, 3.68) + ), + } + + # VCV column mapping: period -> index in regression VCV + interaction_indices = {5: 0, 6: 1, 1: 2, 2: 3} + + # Distinct diagonal entries so we can verify VCV block extraction + vcov = np.diag([0.09, 0.1225, 0.16, 0.2025]) + + results = MultiPeriodDiDResults( + period_effects=period_effects, + avg_att=2.65, + avg_se=0.42, + avg_t_stat=6.31, + avg_p_value=0.0001, + avg_conf_int=(1.83, 3.47), + n_obs=400, + n_treated=200, + n_control=200, + pre_periods=[5, 6, 7], + post_periods=[1, 2], + vcov=vcov, + reference_period=7, + interaction_indices=interaction_indices, + ) + + beta_hat, sigma, num_pre, num_post, pre_p, post_p = _extract_event_study_params(results) + + # Pre-periods: 5, 6 (7 is reference, omitted) + assert num_pre == 2 + # Post-periods: 1, 2 + assert num_post == 2 + + # beta_hat must be [pre_5, pre_6, post_1, post_2] + assert beta_hat[0] == pytest.approx(0.1) # period 5 + assert beta_hat[1] == pytest.approx(0.2) # period 6 + assert beta_hat[2] == pytest.approx(2.5) # period 1 + assert beta_hat[3] == pytest.approx(2.8) # period 2 + + # sigma blocks must match: pre block = diag(0.09, 0.1225), post block = diag(0.16, 0.2025) + assert sigma[0, 0] == pytest.approx(0.09) # period 5 variance + assert sigma[1, 1] == pytest.approx(0.1225) # period 6 variance + assert sigma[2, 2] == pytest.approx(0.16) # period 1 variance + assert sigma[3, 3] == pytest.approx(0.2025) # period 2 variance + # ============================================================================= # Tests for Visualization (without matplotlib) From 3cb1c857280069541281f8482a1a328e04478628 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 16:36:57 -0500 Subject: [PATCH 7/8] Address PR #125 review round 6: stale comment and FutureWarning text Co-Authored-By: Claude Opus 4.5 --- diff_diff/estimators.py | 9 ++++----- diff_diff/honest_did.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 5f7f797..367e660 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -892,13 +892,12 @@ def fit( # type: ignore[override] # Default: last pre-period (e=-1 convention, matches fixest) if len(pre_periods) > 1: warnings.warn( - f"The default reference_period is changing from the first " + f"The default reference_period has changed from the first " f"pre-period ({pre_periods[0]}) to the last pre-period " - f"({pre_periods[-1]}) to match the standard e=-1 convention. " + f"({pre_periods[-1]}) to match the standard e=-1 convention " + f"(as used by fixest, did, etc.). " f"To silence this warning, pass " - f"reference_period={pre_periods[-1]} explicitly. " - f"In a future version, the default will be the last " - f"pre-period.", + f"reference_period={pre_periods[-1]} explicitly.", FutureWarning, stacklevel=2, ) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index cc792bf..da6a136 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -1091,8 +1091,8 @@ def fit( _extract_event_study_params(results) ) - # beta_hat from MultiPeriodDiDResults already contains only post-periods - # Check if we have the right number of coefficients + # beta_hat contains [pre-period effects, post-period effects] in order. + # Extract just the post-period effects for HonestDiD bounds. if len(beta_hat) == num_post: # Already just post-period effects beta_post = beta_hat From ecddfb248a6585104db3d5850d47e41807d6bfca Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 1 Feb 2026 17:06:55 -0500 Subject: [PATCH 8/8] Address PR #125 review round 7: fix sorted() ordering for HonestDiD Replace sorted() with order-preserving list comprehensions in _extract_event_study_params to maintain chronological period ordering. sorted() could break smoothness-based bounds (DeltaSD, DeltaSDRM) when period labels don't sort chronologically by default (e.g., string labels). Co-Authored-By: Claude Opus 4.5 --- diff_diff/honest_did.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index da6a136..d2a5417 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -564,7 +564,6 @@ def _extract_event_study_params( post_periods = results.post_periods # Filter periods with finite effects/SEs, maintaining pre-then-post order - # (sorted() alone would fail if period labels don't sort chronologically) finite_periods = { p for p in results.period_effects.keys() @@ -572,8 +571,8 @@ def _extract_event_study_params( and np.isfinite(results.period_effects[p].se) } - pre_estimated = sorted(p for p in finite_periods if p in pre_periods) - post_estimated = sorted(p for p in finite_periods if p in post_periods) + pre_estimated = [p for p in pre_periods if p in finite_periods] + post_estimated = [p for p in post_periods if p in finite_periods] all_estimated = pre_estimated + post_estimated if not all_estimated: