Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **TROP `variance_method` parameter** — Jackknife variance estimation removed.
Bootstrap (the only method specified in Athey et al. 2025) is now always used.
The `variance_method` field has also been removed from `TROPResults`.
- **TROP `max_loocv_samples` parameter** — Control observation subsampling removed
from LOOCV tuning parameter selection. Equation 5 of Athey et al. (2025) explicitly
sums over ALL control observations where D=0; the previous subsampling (default 100)
was not specified in the paper. LOOCV now uses all control observations, making
tuning fully deterministic. Inner LOOCV loops in the Rust backend are parallelized
to compensate for the increased observation count.

## [2.2.0] - 2026-01-27

Expand Down
37 changes: 2 additions & 35 deletions diff_diff/trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,6 @@ class TROP:
Significance level for confidence intervals.
n_bootstrap : int, default=200
Number of bootstrap replications for variance estimation.
max_loocv_samples : int, default=100
Maximum control observations to use in LOOCV for tuning parameter
selection. Subsampling is used for computational tractability as
noted in the paper. Increase for more precise tuning at the cost
of computational time.
seed : int, optional
Random seed for reproducibility.

Expand Down Expand Up @@ -429,15 +424,6 @@ class TROP:
"""

# Class constants
DEFAULT_LOOCV_MAX_SAMPLES: int = 100
"""Maximum control observations to use in LOOCV (for computational tractability).

As noted in the paper's footnote, LOOCV is subsampled for computational
tractability. This constant controls the maximum number of control observations
used in each LOOCV evaluation. Increase for more precise tuning at the cost
of computational time.
"""

CONVERGENCE_TOL_SVD: float = 1e-10
"""Tolerance for singular value truncation in soft-thresholding.

Expand All @@ -455,7 +441,6 @@ def __init__(
tol: float = 1e-6,
alpha: float = 0.05,
n_bootstrap: int = 200,
max_loocv_samples: int = 100,
seed: Optional[int] = None,
):
# Validate method parameter
Expand All @@ -475,7 +460,6 @@ def __init__(
self.tol = tol
self.alpha = alpha
self.n_bootstrap = n_bootstrap
self.max_loocv_samples = max_loocv_samples
self.seed = seed

# Validate that time/unit grids do not contain inf.
Expand Down Expand Up @@ -1359,8 +1343,7 @@ def _fit_joint(
result = _rust_loocv_grid_search_joint(
Y, D.astype(np.float64), control_mask_u8,
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
self.max_loocv_samples, self.max_iter, self.tol,
self.seed if self.seed is not None else 0
self.max_iter, self.tol,
)
# Unpack result - 7 values including optional first_failed_obs
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
Expand Down Expand Up @@ -1407,13 +1390,6 @@ def _fit_joint(
if control_mask[t, i] and not np.isnan(Y[t, i])
]

# Subsample if needed (sample indices to avoid ValueError on list of tuples)
rng = np.random.default_rng(self.seed)
max_loocv = min(self.max_loocv_samples, len(control_obs))
if len(control_obs) > max_loocv:
indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
control_obs = [control_obs[idx] for idx in indices]

# Grid search with true LOOCV
for lambda_time_val in self.lambda_time_grid:
for lambda_unit_val in self.lambda_unit_grid:
Expand Down Expand Up @@ -1898,8 +1874,7 @@ def fit(
Y, D.astype(np.float64), control_mask_u8,
time_dist_matrix,
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
self.max_loocv_samples, self.max_iter, self.tol,
self.seed if self.seed is not None else 0
self.max_iter, self.tol,
)
# Unpack result - 7 values including optional first_failed_obs
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
Expand Down Expand Up @@ -2579,13 +2554,6 @@ def _loocv_score_obs_specific(
control_obs = [(t, i) for t in range(n_periods) for i in range(n_units)
if control_mask[t, i] and not np.isnan(Y[t, i])]

# Subsample for computational tractability (as noted in paper's footnote)
rng = np.random.default_rng(self.seed)
max_loocv = min(self.max_loocv_samples, len(control_obs))
if len(control_obs) > max_loocv:
indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
control_obs = [control_obs[idx] for idx in indices]

# Empty control set check: if no control observations, return infinity
# A score of 0.0 would incorrectly "win" over legitimate parameters
if len(control_obs) == 0:
Expand Down Expand Up @@ -2877,7 +2845,6 @@ def get_params(self) -> Dict[str, Any]:
"tol": self.tol,
"alpha": self.alpha,
"n_bootstrap": self.n_bootstrap,
"max_loocv_samples": self.max_loocv_samples,
"seed": self.seed,
}

Expand Down
7 changes: 2 additions & 5 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,6 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
- `λ_nn=∞`: Factor model disabled (L=0), because infinite penalty; converted to `1e10` internally
- **Note**: `λ_nn=0` means NO regularization (full-rank L), which is the OPPOSITE of "disabled"
- **Validation**: `lambda_time_grid` and `lambda_unit_grid` must not contain inf. A `ValueError` is raised if they do, guiding users to use 0.0 for uniform weights per Eq. 3.
- **Subsampling**: max_loocv_samples (default 100) for computational tractability
- This subsamples control observations, NOT parameter combinations
- Increases precision at cost of computation; increase for more precise tuning
- **LOOCV failure handling** (Equation 5 compliance):
- If ANY LOOCV fit fails for a parameter combination, Q(λ) = ∞
- A warning is emitted on the first failure with the observation (t, i) and λ values
Expand All @@ -556,14 +553,14 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
- Rank selection: automatic via cross-validation, information criterion, or elbow
- Zero singular values: handled by soft-thresholding
- Extreme distances: weights regularized to prevent degeneracy
- LOOCV fit failures: returns Q(λ) = ∞ on first failure (per Equation 5 requirement that Q sums over ALL D==0 cells); if all parameter combinations fail, falls back to defaults (1.0, 1.0, 0.1)
- LOOCV fit failures: returns Q(λ) = ∞ on first failure (per Equation 5 requirement that Q sums over ALL control observations where D==0); if all parameter combinations fail, falls back to defaults (1.0, 1.0, 0.1)
- **λ_nn=∞ implementation**: Only λ_nn uses infinity (converted to 1e10 for computation):
- λ_nn=∞ → 1e10 (large penalty → L≈0, factor model disabled)
- Conversion applied to grid values during LOOCV (including Rust backend)
- Conversion applied to selected values for point estimation
- Conversion applied to selected values for variance estimation (ensures SE matches ATT)
- **Results storage**: `TROPResults` stores *original* λ_nn value (inf), while computations use 1e10. λ_time and λ_unit store their selected values directly (0.0 = uniform).
- **Empty control observations**: If LOOCV control observations become empty (edge case during subsampling), returns Q(λ) = ∞ with warning. A score of 0.0 would incorrectly "win" over legitimate parameters.
- **Empty control observations**: If no valid control observations exist, returns Q(λ) = ∞ with warning. A score of 0.0 would incorrectly "win" over legitimate parameters.
- **Infinite LOOCV score handling**: If best LOOCV score is infinite, `best_lambda` is set to None, triggering defaults fallback
- Validation: requires at least 2 periods before first treatment
- **D matrix validation**: Treatment indicator must be an absorbing state (monotonic non-decreasing per unit)
Expand Down
Loading