diff --git a/CHANGELOG.md b/CHANGELOG.md index 73e2e81..fdb93e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **TROP `variance_method` parameter** — Jackknife variance estimation removed. Bootstrap (the only method specified in Athey et al. 2025) is now always used. The `variance_method` field has also been removed from `TROPResults`. +- **TROP `max_loocv_samples` parameter** — Control observation subsampling removed + from LOOCV tuning parameter selection. Equation 5 of Athey et al. (2025) explicitly + sums over ALL control observations where D=0; the previous subsampling (default 100) + was not specified in the paper. LOOCV now uses all control observations, making + tuning fully deterministic. Inner LOOCV loops in the Rust backend are parallelized + to compensate for the increased observation count. ## [2.2.0] - 2026-01-27 diff --git a/diff_diff/trop.py b/diff_diff/trop.py index df07bff..179412c 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -394,11 +394,6 @@ class TROP: Significance level for confidence intervals. n_bootstrap : int, default=200 Number of bootstrap replications for variance estimation. - max_loocv_samples : int, default=100 - Maximum control observations to use in LOOCV for tuning parameter - selection. Subsampling is used for computational tractability as - noted in the paper. Increase for more precise tuning at the cost - of computational time. seed : int, optional Random seed for reproducibility. @@ -429,15 +424,6 @@ class TROP: """ # Class constants - DEFAULT_LOOCV_MAX_SAMPLES: int = 100 - """Maximum control observations to use in LOOCV (for computational tractability). - - As noted in the paper's footnote, LOOCV is subsampled for computational - tractability. This constant controls the maximum number of control observations - used in each LOOCV evaluation. Increase for more precise tuning at the cost - of computational time. - """ - CONVERGENCE_TOL_SVD: float = 1e-10 """Tolerance for singular value truncation in soft-thresholding. @@ -455,7 +441,6 @@ def __init__( tol: float = 1e-6, alpha: float = 0.05, n_bootstrap: int = 200, - max_loocv_samples: int = 100, seed: Optional[int] = None, ): # Validate method parameter @@ -475,7 +460,6 @@ def __init__( self.tol = tol self.alpha = alpha self.n_bootstrap = n_bootstrap - self.max_loocv_samples = max_loocv_samples self.seed = seed # Validate that time/unit grids do not contain inf. @@ -1359,8 +1343,7 @@ def _fit_joint( result = _rust_loocv_grid_search_joint( Y, D.astype(np.float64), control_mask_u8, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, - self.max_loocv_samples, self.max_iter, self.tol, - self.seed if self.seed is not None else 0 + self.max_iter, self.tol, ) # Unpack result - 7 values including optional first_failed_obs best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result @@ -1407,13 +1390,6 @@ def _fit_joint( if control_mask[t, i] and not np.isnan(Y[t, i]) ] - # Subsample if needed (sample indices to avoid ValueError on list of tuples) - rng = np.random.default_rng(self.seed) - max_loocv = min(self.max_loocv_samples, len(control_obs)) - if len(control_obs) > max_loocv: - indices = rng.choice(len(control_obs), size=max_loocv, replace=False) - control_obs = [control_obs[idx] for idx in indices] - # Grid search with true LOOCV for lambda_time_val in self.lambda_time_grid: for lambda_unit_val in self.lambda_unit_grid: @@ -1898,8 +1874,7 @@ def fit( Y, D.astype(np.float64), control_mask_u8, time_dist_matrix, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, - self.max_loocv_samples, self.max_iter, self.tol, - self.seed if self.seed is not None else 0 + self.max_iter, self.tol, ) # Unpack result - 7 values including optional first_failed_obs best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result @@ -2579,13 +2554,6 @@ def _loocv_score_obs_specific( control_obs = [(t, i) for t in range(n_periods) for i in range(n_units) if control_mask[t, i] and not np.isnan(Y[t, i])] - # Subsample for computational tractability (as noted in paper's footnote) - rng = np.random.default_rng(self.seed) - max_loocv = min(self.max_loocv_samples, len(control_obs)) - if len(control_obs) > max_loocv: - indices = rng.choice(len(control_obs), size=max_loocv, replace=False) - control_obs = [control_obs[idx] for idx in indices] - # Empty control set check: if no control observations, return infinity # A score of 0.0 would incorrectly "win" over legitimate parameters if len(control_obs) == 0: @@ -2877,7 +2845,6 @@ def get_params(self) -> Dict[str, Any]: "tol": self.tol, "alpha": self.alpha, "n_bootstrap": self.n_bootstrap, - "max_loocv_samples": self.max_loocv_samples, "seed": self.seed, } diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 1bf16b9..8416aaa 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -540,9 +540,6 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - `λ_nn=∞`: Factor model disabled (L=0), because infinite penalty; converted to `1e10` internally - **Note**: `λ_nn=0` means NO regularization (full-rank L), which is the OPPOSITE of "disabled" - **Validation**: `lambda_time_grid` and `lambda_unit_grid` must not contain inf. A `ValueError` is raised if they do, guiding users to use 0.0 for uniform weights per Eq. 3. -- **Subsampling**: max_loocv_samples (default 100) for computational tractability - - This subsamples control observations, NOT parameter combinations - - Increases precision at cost of computation; increase for more precise tuning - **LOOCV failure handling** (Equation 5 compliance): - If ANY LOOCV fit fails for a parameter combination, Q(λ) = ∞ - A warning is emitted on the first failure with the observation (t, i) and λ values @@ -556,14 +553,14 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - Rank selection: automatic via cross-validation, information criterion, or elbow - Zero singular values: handled by soft-thresholding - Extreme distances: weights regularized to prevent degeneracy -- LOOCV fit failures: returns Q(λ) = ∞ on first failure (per Equation 5 requirement that Q sums over ALL D==0 cells); if all parameter combinations fail, falls back to defaults (1.0, 1.0, 0.1) +- LOOCV fit failures: returns Q(λ) = ∞ on first failure (per Equation 5 requirement that Q sums over ALL control observations where D==0); if all parameter combinations fail, falls back to defaults (1.0, 1.0, 0.1) - **λ_nn=∞ implementation**: Only λ_nn uses infinity (converted to 1e10 for computation): - λ_nn=∞ → 1e10 (large penalty → L≈0, factor model disabled) - Conversion applied to grid values during LOOCV (including Rust backend) - Conversion applied to selected values for point estimation - Conversion applied to selected values for variance estimation (ensures SE matches ATT) - **Results storage**: `TROPResults` stores *original* λ_nn value (inf), while computations use 1e10. λ_time and λ_unit store their selected values directly (0.0 = uniform). -- **Empty control observations**: If LOOCV control observations become empty (edge case during subsampling), returns Q(λ) = ∞ with warning. A score of 0.0 would incorrectly "win" over legitimate parameters. +- **Empty control observations**: If no valid control observations exist, returns Q(λ) = ∞ with warning. A score of 0.0 would incorrectly "win" over legitimate parameters. - **Infinite LOOCV score handling**: If best LOOCV score is infinite, `best_lambda` is set to None, triggering defaults fallback - Validation: requires at least 2 periods before first treatment - **D matrix validation**: Treatment indicator must be an absorbing state (monotonic non-decreasing per unit) diff --git a/rust/src/trop.rs b/rust/src/trop.rs index b7be5f1..9e275d5 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -304,10 +304,8 @@ fn cycling_parameter_search( /// * `lambda_time_grid` - Grid of time decay parameters /// * `lambda_unit_grid` - Grid of unit distance parameters /// * `lambda_nn_grid` - Grid of nuclear norm parameters -/// * `max_loocv_samples` - Maximum control observations to evaluate /// * `max_iter` - Maximum iterations for model estimation /// * `tol` - Convergence tolerance -/// * `seed` - Random seed for subsampling /// /// # Returns /// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs) @@ -315,7 +313,7 @@ fn cycling_parameter_search( /// allowing Python to emit warnings when >10% of fits fail. /// first_failed_obs is Some((t, i)) if a fit failed during final score computation, None otherwise. #[pyfunction] -#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))] +#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_iter, tol))] #[allow(clippy::too_many_arguments)] pub fn loocv_grid_search<'py>( _py: Python<'py>, @@ -326,10 +324,8 @@ pub fn loocv_grid_search<'py>( lambda_time_grid: PyReadonlyArray1<'py, f64>, lambda_unit_grid: PyReadonlyArray1<'py, f64>, lambda_nn_grid: PyReadonlyArray1<'py, f64>, - max_loocv_samples: usize, max_iter: usize, tol: f64, - seed: u64, ) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> { let y_arr = y.as_array(); let d_arr = d.as_array(); @@ -360,8 +356,6 @@ pub fn loocv_grid_search<'py>( let control_obs = get_control_observations( &y_arr, &control_mask_arr, - max_loocv_samples, - seed, ); let n_attempted = control_obs.len(); @@ -409,16 +403,11 @@ pub fn loocv_grid_search<'py>( Ok((best_time, best_unit, best_nn, best_score, n_valid, n_attempted, first_failed)) } -/// Get sampled control observations for LOOCV. +/// Get all valid control observations for LOOCV. fn get_control_observations( y: &ArrayView2, control_mask: &ArrayView2, - max_samples: usize, - seed: u64, ) -> Vec<(usize, usize)> { - use rand::prelude::*; - use rand_xoshiro::Xoshiro256PlusPlus; - let n_periods = y.nrows(); let n_units = y.ncols(); @@ -432,13 +421,6 @@ fn get_control_observations( } } - // Subsample if needed - if obs.len() > max_samples { - let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed); - obs.shuffle(&mut rng); - obs.truncate(max_samples); - } - obs } @@ -463,48 +445,54 @@ fn loocv_score_for_params( let n_periods = y.nrows(); let n_units = y.ncols(); - let mut tau_sq_sum = 0.0; - let mut n_valid = 0usize; + // Parallelize over control observations — each per-observation computation + // is independent (compute weight matrix, fit model, extract τ²). + // with_min_len(64) prevents scheduling overhead from dominating on small panels. + let (tau_sq_sum, n_valid, first_failed) = control_obs + .par_iter() + .with_min_len(64) + .fold( + || (0.0f64, 0usize, None::<(usize, usize)>), + |(sum, valid, first_fail), &(t, i)| { + let weight_matrix = compute_weight_matrix( + y, + d, + n_periods, + n_units, + i, + t, + lambda_time, + lambda_unit, + time_dist, + ); - for &(t, i) in control_obs { - // Compute observation-specific weight matrix - let weight_matrix = compute_weight_matrix( - y, - d, - n_periods, - n_units, - i, - t, - lambda_time, - lambda_unit, - time_dist, + match estimate_model( + y, + control_mask, + &weight_matrix.view(), + lambda_nn, + n_periods, + n_units, + max_iter, + tol, + Some((t, i)), + ) { + Some((alpha, beta, l)) => { + let tau = y[[t, i]] - alpha[i] - beta[t] - l[[t, i]]; + (sum + tau * tau, valid + 1, first_fail) + } + None => (sum, valid, first_fail.or(Some((t, i)))), + } + }, + ) + .reduce( + || (0.0, 0, None), + |(s1, v1, f1), (s2, v2, f2)| (s1 + s2, v1 + v2, f1.or(f2)), ); - // Estimate model excluding this observation - match estimate_model( - y, - control_mask, - &weight_matrix.view(), - lambda_nn, - n_periods, - n_units, - max_iter, - tol, - Some((t, i)), - ) { - Some((alpha, beta, l)) => { - // Pseudo treatment effect: τ = Y - α - β - L - let tau = y[[t, i]] - alpha[i] - beta[t] - l[[t, i]]; - tau_sq_sum += tau * tau; - n_valid += 1; - } - None => { - // Per Equation 5: Q(λ) must sum over ALL D==0 cells - // Any failure means this λ cannot produce valid estimates for all cells - // Return the failed observation (t, i) for warning metadata - return (f64::INFINITY, n_valid, Some((t, i))); - } - } + // Per Equation 5: if ANY fit fails, this λ combination is invalid + if first_failed.is_some() { + return (f64::INFINITY, n_valid, first_failed); } if n_valid == 0 { @@ -1404,45 +1392,52 @@ fn loocv_score_joint( let n_periods = y.nrows(); let n_units = y.ncols(); - let mut tau_sq_sum = 0.0; - let mut n_valid = 0usize; - // Compute global weights (same for all LOOCV iterations) let delta = compute_joint_weights(y, d, lambda_time, lambda_unit, treated_periods); - for &(t_ex, i_ex) in control_obs { - // Create modified delta with excluded observation zeroed out - let mut delta_ex = delta.clone(); - delta_ex[[t_ex, i_ex]] = 0.0; - - // Fit joint model excluding this observation - let result = if lambda_nn >= 1e10 { - solve_joint_no_lowrank(y, d, &delta_ex.view()) - .map(|(mu, alpha, beta, tau)| { - let l = Array2::::zeros((n_periods, n_units)); - (mu, alpha, beta, l, tau) - }) - } else { - solve_joint_with_lowrank(y, d, &delta_ex.view(), lambda_nn, max_iter, tol) - }; - - match result { - Some((mu, alpha, beta, l, _tau)) => { - // Pseudo treatment effect: τ = Y - μ - α - β - L - let y_ti = if y[[t_ex, i_ex]].is_finite() { - y[[t_ex, i_ex]] + // Parallelize over control observations — each per-observation computation + // is independent (clone delta, zero one entry, fit model, extract τ²). + // with_min_len(64) prevents scheduling overhead from dominating on small panels. + let (tau_sq_sum, n_valid, first_failed) = control_obs + .par_iter() + .with_min_len(64) + .fold( + || (0.0f64, 0usize, None::<(usize, usize)>), + |(sum, valid, first_fail), &(t_ex, i_ex)| { + let mut delta_ex = delta.clone(); + delta_ex[[t_ex, i_ex]] = 0.0; + + let result = if lambda_nn >= 1e10 { + solve_joint_no_lowrank(y, d, &delta_ex.view()) + .map(|(mu, alpha, beta, tau)| { + let l = Array2::::zeros((n_periods, n_units)); + (mu, alpha, beta, l, tau) + }) } else { - continue; + solve_joint_with_lowrank(y, d, &delta_ex.view(), lambda_nn, max_iter, tol) }; - let tau_loocv = y_ti - mu - alpha[i_ex] - beta[t_ex] - l[[t_ex, i_ex]]; - tau_sq_sum += tau_loocv * tau_loocv; - n_valid += 1; - } - None => { - // Any failure means this λ combination is invalid per Equation 5 - return (f64::INFINITY, n_valid, Some((t_ex, i_ex))); - } - } + + match result { + Some((mu, alpha, beta, l, _tau)) => { + if y[[t_ex, i_ex]].is_finite() { + let tau_loocv = y[[t_ex, i_ex]] - mu - alpha[i_ex] - beta[t_ex] - l[[t_ex, i_ex]]; + (sum + tau_loocv * tau_loocv, valid + 1, first_fail) + } else { + (sum, valid, first_fail) + } + } + None => (sum, valid, first_fail.or(Some((t_ex, i_ex)))), + } + }, + ) + .reduce( + || (0.0, 0, None), + |(s1, v1, f1), (s2, v2, f2)| (s1 + s2, v1 + v2, f1.or(f2)), + ); + + // Per Equation 5: if ANY fit fails, this λ combination is invalid + if first_failed.is_some() { + return (f64::INFINITY, n_valid, first_failed); } if n_valid == 0 { @@ -1464,15 +1459,13 @@ fn loocv_score_joint( /// * `lambda_time_grid` - Grid of time decay parameters /// * `lambda_unit_grid` - Grid of unit distance parameters /// * `lambda_nn_grid` - Grid of nuclear norm parameters -/// * `max_loocv_samples` - Maximum control observations to evaluate /// * `max_iter` - Maximum iterations for model estimation /// * `tol` - Convergence tolerance -/// * `seed` - Random seed for subsampling /// /// # Returns /// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs) #[pyfunction] -#[pyo3(signature = (y, d, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))] +#[pyo3(signature = (y, d, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_iter, tol))] #[allow(clippy::too_many_arguments)] pub fn loocv_grid_search_joint<'py>( _py: Python<'py>, @@ -1482,10 +1475,8 @@ pub fn loocv_grid_search_joint<'py>( lambda_time_grid: PyReadonlyArray1<'py, f64>, lambda_unit_grid: PyReadonlyArray1<'py, f64>, lambda_nn_grid: PyReadonlyArray1<'py, f64>, - max_loocv_samples: usize, max_iter: usize, tol: f64, - seed: u64, ) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> { let y_arr = y.as_array(); let d_arr = d.as_array(); @@ -1527,7 +1518,7 @@ pub fn loocv_grid_search_joint<'py>( let treated_periods = n_periods.saturating_sub(first_treat_period); // Get control observations for LOOCV - let control_obs = get_control_observations(&y_arr, &control_mask_arr, max_loocv_samples, seed); + let control_obs = get_control_observations(&y_arr, &control_mask_arr); let n_attempted = control_obs.len(); // Build grid combinations diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 76edf80..8269b5e 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -973,7 +973,7 @@ def test_loocv_grid_search_returns_valid_params(self): best_lt, best_lu, best_ln, score, n_valid, n_attempted, first_failed = loocv_grid_search( Y, D, control_mask, time_dist, lambda_time, lambda_unit, lambda_nn, - 50, 100, 1e-6, 42 + 100, 1e-6, ) # Check returned parameters are from the grid @@ -1108,7 +1108,6 @@ def test_trop_produces_valid_results(self): lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=20, - max_loocv_samples=30, seed=42 ) results = trop.fit(df, 'outcome', 'treated', 'unit', 'time') @@ -1160,7 +1159,7 @@ def test_loocv_grid_search_joint_returns_valid_result(self): result = loocv_grid_search_joint( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, - 50, 100, 1e-6, 42 + 100, 1e-6, ) best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, _ = result @@ -1177,7 +1176,7 @@ def test_loocv_grid_search_joint_returns_valid_result(self): assert best_score >= 0 or np.isinf(best_score) def test_loocv_grid_search_joint_reproducible(self): - """Test loocv_grid_search_joint is reproducible with same seed.""" + """Test loocv_grid_search_joint is deterministic (no subsampling).""" from diff_diff._rust_backend import loocv_grid_search_joint np.random.seed(42) @@ -1197,15 +1196,15 @@ def test_loocv_grid_search_joint_reproducible(self): result1 = loocv_grid_search_joint( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, - 30, 50, 1e-6, 42 + 50, 1e-6, ) result2 = loocv_grid_search_joint( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, - 30, 50, 1e-6, 42 + 50, 1e-6, ) - # Same seed should produce same results + # Without subsampling, results should be deterministic assert result1[:4] == result2[:4] def test_bootstrap_trop_variance_joint_shape(self): diff --git a/tests/test_trop.py b/tests/test_trop.py index 196476c..985fe5d 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2231,7 +2231,6 @@ def test_empty_control_obs_returns_infinity(self, simple_panel_data): lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[1.0], - max_loocv_samples=100, seed=42 ) @@ -3168,59 +3167,3 @@ def test_joint_rejects_staggered_adoption(self): with pytest.raises(ValueError, match="staggered adoption"): trop.fit(df, 'outcome', 'treated', 'unit', 'time') - def test_joint_python_loocv_subsampling(self): - """Test that joint method works with Python-only LOOCV when control_obs > max_loocv_samples. - - This tests the fix for PR #113 Round 7 feedback (P1): Python fallback - LOOCV sampling could raise ValueError when control_obs is a list of tuples. - """ - from unittest.mock import patch - import sys - - np.random.seed(42) - # Create data with many control observations (> default max_loocv_samples=500) - n_units, n_periods = 30, 25 # 30*25 = 750 observations, most are control - n_treated = 3 - n_post = 3 - - data = [] - for i in range(n_units): - is_treated = i < n_treated - for t in range(n_periods): - post = t >= (n_periods - n_post) - y = 10.0 + i * 0.1 + t * 0.1 + np.random.randn() * 0.5 - treatment_indicator = 1 if (is_treated and post) else 0 - if treatment_indicator: - y += 2.0 - data.append({ - 'unit': i, - 'time': t, - 'outcome': y, - 'treated': treatment_indicator, - }) - - df = pd.DataFrame(data) - - # Patch to force Python backend and set small max_loocv_samples - trop_module = sys.modules['diff_diff.trop'] - - with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): - - # Use small max_loocv_samples to trigger subsampling - trop_est = TROP( - method="joint", - lambda_time_grid=[1.0], - lambda_unit_grid=[1.0], - lambda_nn_grid=[0.0], - max_loocv_samples=100, # Force subsampling (control_obs > 100) - n_bootstrap=0, - seed=42 - ) - - # This should not raise ValueError - results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time') - - assert isinstance(results, TROPResults) - assert np.isfinite(results.att)