diff --git a/R/RcppExports.R b/R/RcppExports.R index 25253d0..806fcb8 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -25,6 +25,14 @@ run_simulation_parallel <- function(pairwise_samples, main_samples, draw_indices .Call(`_bgms_run_simulation_parallel`, pairwise_samples, main_samples, draw_indices, no_states, no_variables, no_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda) +} + +sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, edge_selection_start = -1L) { + .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start) +} + compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) } diff --git a/R/bgm.R b/R/bgm.R index e0cca53..c89f3fb 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -400,6 +400,7 @@ bgm = function( chains = 4, cores = parallel::detectCores(), display_progress = c("per-chain", "total", "none"), + backend = c("legacy", "new"), seed = NULL, standardize = FALSE, interaction_scale, @@ -508,6 +509,23 @@ bgm = function( edge_selection = model$edge_selection edge_prior = model$edge_prior inclusion_probability = model$inclusion_probability + is_continuous = model$is_continuous + + # Block NUTS/HMC for the Gaussian model -------------------------------------- + if(is_continuous) { + user_chose_method = length(update_method_input) == 1 + if(user_chose_method && update_method %in% c("nuts", "hamiltonian-mc")) { + stop(paste0( + "The Gaussian model (variable_type = 'continuous') only supports ", + "update_method = 'adaptive-metropolis'. ", + "Got '", update_method, "'." + )) + } + update_method = "adaptive-metropolis" + if(!hasArg(target_accept)) { + target_accept = 0.44 + } + } # Check Gibbs input ----------------------------------------------------------- check_positive_integer(iter, "iter") @@ -551,9 +569,105 @@ bgm = function( )) } + # Check backend --------------------------------------------------------------- + backend = match.arg(backend) + # Check display_progress ------------------------------------------------------ progress_type = progress_type_from_display_progress(display_progress) +# Setting the seed + if(missing(seed) || is.null(seed)) { + seed = sample.int(.Machine$integer.max, 1) + } + + if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) { + stop("Argument 'seed' must be a single non-negative integer.") + } + + seed <- as.integer(seed) + + data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x) + + # ========================================================================== + # Gaussian (continuous) path + # ========================================================================== + if (is_continuous) { + num_variables = ncol(x) + + # Handle missing data for continuous variables + if (na_action == "listwise") { + missing_rows = apply(x, 1, anyNA) + if (all(missing_rows)) { + stop("All rows in x contain at least one missing response.\n", + "You could try option na_action = 'impute'.") + } + if (sum(missing_rows) > 0) { + warning(sum(missing_rows), " row(s) with missing observations removed (na_action = 'listwise').", + call. = FALSE) + } + x = x[!missing_rows, , drop = FALSE] + na_impute = FALSE + } else { + stop("Imputation is not yet supported for the Gaussian model. ", + "Use na_action = 'listwise'.") + } + + indicator = matrix(1L, nrow = num_variables, ncol = num_variables) + + out_raw = sample_ggm( + inputFromR = list(X = x), + prior_inclusion_prob = matrix(inclusion_probability, + nrow = num_variables, ncol = num_variables), + initial_edge_indicators = indicator, + no_iter = iter, + no_warmup = warmup, + no_chains = chains, + edge_selection = edge_selection, + seed = seed, + no_threads = cores, + progress_type = progress_type, + edge_prior = edge_prior, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda + ) + + out = transform_ggm_backend_output(out_raw, num_variables) + + userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) + if (userInterrupt) { + warning("Stopped sampling after user interrupt, results are likely uninterpretable.") + } + + output = prepare_output_ggm( + out = out, x = x, iter = iter, + data_columnnames = data_columnnames, + warmup = warmup, + edge_selection = edge_selection, edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda, + na_action = na_action, na_impute = na_impute, + variable_type = variable_type, + update_method = update_method, + target_accept = target_accept, + num_chains = chains + ) + + return(output) + } + + # ========================================================================== + # Ordinal / Blume-Capel path + # ========================================================================== + # Format the data input ------------------------------------------------------- data = reformat_data( x = x, @@ -592,6 +706,9 @@ bgm = function( } } + # Save data before Blume-Capel centering (needed by the new backend) + x_raw = x + # Precompute the sufficient statistics for the two Blume-Capel parameters ----- blume_capel_stats = matrix(0, nrow = 2, ncol = num_variables) if(any(!variable_bool)) { @@ -691,6 +808,47 @@ bgm = function( seed <- as.integer(seed) + if (backend == "new") { + input_list = list( + observations = x_raw, + num_categories = num_categories, + is_ordinal_variable = variable_bool, + baseline_category = baseline_category, + main_alpha = main_alpha, + main_beta = main_beta, + pairwise_scale = pairwise_scale + ) + + out_raw = sample_omrf( + inputFromR = input_list, + prior_inclusion_prob = matrix(inclusion_probability, + nrow = num_variables, ncol = num_variables), + initial_edge_indicators = indicator, + no_iter = iter, + no_warmup = warmup, + no_chains = chains, + no_threads = cores, + progress_type = progress_type, + edge_selection = edge_selection, + sampler_type = update_method, + seed = seed, + edge_prior = edge_prior, + na_impute = na_impute, + missing_index = missing_index, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda, + target_acceptance = target_accept, + max_tree_depth = nuts_max_depth, + num_leapfrogs = hmc_num_leapfrogs + ) + + out = transform_new_backend_output(out_raw, num_thresholds) + } else { + out = run_bgm_parallel( observations = x, num_categories = num_categories, pairwise_scale = pairwise_scale, edge_prior = edge_prior, @@ -716,6 +874,8 @@ bgm = function( pairwise_scaling_factors = pairwise_scaling_factors ) + } # end backend branch + userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) if(userInterrupt) { warning("Stopped sampling after user interrupt, results are likely uninterpretable.") @@ -723,7 +883,7 @@ bgm = function( output <- tryCatch( prepare_output_bgm( out = out, x = x, num_categories = num_categories, iter = iter, - data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = data_columnnames, is_ordinal_variable = variable_bool, warmup = warmup, pairwise_scale = pairwise_scale, standardize = standardize, pairwise_scaling_factors = pairwise_scaling_factors, @@ -756,7 +916,7 @@ bgm = function( # Main output handler in the wrapper function output = prepare_output_bgm( out = out, x = x, num_categories = num_categories, iter = iter, - data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = data_columnnames, is_ordinal_variable = variable_bool, warmup = warmup, pairwise_scale = pairwise_scale, standardize = standardize, pairwise_scaling_factors = pairwise_scaling_factors, diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 4dde1a1..4c6aa3a 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -36,24 +36,30 @@ check_model = function(x, dirichlet_alpha = dirichlet_alpha, lambda = lambda) { # Check variable type input --------------------------------------------------- + is_continuous = FALSE if(length(variable_type) == 1) { variable_input = variable_type variable_type = try( match.arg( arg = variable_type, - choices = c("ordinal", "blume-capel") + choices = c("ordinal", "blume-capel", "continuous") ), silent = TRUE ) if(inherits(variable_type, what = "try-error")) { stop(paste0( - "The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", + "The bgm function supports variables of type ordinal, blume-capel, ", + "and continuous, but not of type ", variable_input, "." )) } - variable_bool = (variable_type == "ordinal") - variable_bool = rep(variable_bool, ncol(x)) + if(variable_type == "continuous") { + is_continuous = TRUE + variable_bool = rep(TRUE, ncol(x)) + } else { + variable_bool = (variable_type == "ordinal") + variable_bool = rep(variable_bool, ncol(x)) + } } else { if(length(variable_type) != ncol(x)) { stop(paste0( @@ -62,41 +68,54 @@ check_model = function(x, )) } - variable_input = unique(variable_type) - variable_type = try(match.arg( - arg = variable_type, - choices = c("ordinal", "blume-capel"), - several.ok = TRUE - ), silent = TRUE) - - if(inherits(variable_type, what = "try-error")) { + has_continuous = any(variable_type == "continuous") + if(has_continuous && !all(variable_type == "continuous")) { stop(paste0( - "The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input, collapse = ", "), "." + "When using continuous variables, all variables must be of type ", + "'continuous'. Mixtures of continuous and ordinal/blume-capel ", + "variables are not supported." )) } + if(has_continuous) { + is_continuous = TRUE + variable_bool = rep(TRUE, ncol(x)) + } else { + variable_input = unique(variable_type) + variable_type = try(match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel"), + several.ok = TRUE + ), silent = TRUE) - num_types = sapply(variable_input, function(type) { - tmp = try( - match.arg( - arg = type, - choices = c("ordinal", "blume-capel") - ), - silent = TRUE - ) - inherits(tmp, what = "try-error") - }) + if(inherits(variable_type, what = "try-error")) { + stop(paste0( + "The bgm function supports variables of type ordinal, blume-capel, ", + "and continuous, but not of type ", + paste0(variable_input, collapse = ", "), "." + )) + } - if(length(variable_type) != ncol(x)) { - stop(paste0( - "The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input[num_types], collapse = ", "), "." - )) - } + num_types = sapply(variable_input, function(type) { + tmp = try( + match.arg( + arg = type, + choices = c("ordinal", "blume-capel") + ), + silent = TRUE + ) + inherits(tmp, what = "try-error") + }) - variable_bool = (variable_type == "ordinal") + if(length(variable_type) != ncol(x)) { + stop(paste0( + "The bgm function supports variables of type ordinal, blume-capel, ", + "and continuous, but not of type ", + paste0(variable_input[num_types], collapse = ", "), "." + )) + } + + variable_bool = (variable_type == "ordinal") + } } # Check Blume-Capel variable input -------------------------------------------- @@ -316,7 +335,8 @@ check_model = function(x, baseline_category = baseline_category, edge_selection = edge_selection, edge_prior = edge_prior, - inclusion_probability = theta + inclusion_probability = theta, + is_continuous = is_continuous )) } diff --git a/R/output_utils.R b/R/output_utils.R index 1ae67cf..43e32c5 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -178,6 +178,226 @@ prepare_output_bgm = function( } +prepare_output_ggm = function( + out, x, iter, data_columnnames, + warmup, edge_selection, edge_prior, inclusion_probability, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda, + na_action, na_impute, + variable_type, update_method, target_accept, + num_chains +) { + num_variables = ncol(x) + + arguments = list( + num_variables = num_variables, + num_cases = nrow(x), + na_impute = na_impute, + variable_type = variable_type, + iter = iter, + warmup = warmup, + edge_selection = edge_selection, + edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda, + na_action = na_action, + version = packageVersion("bgms"), + update_method = update_method, + target_accept = target_accept, + num_chains = num_chains, + data_columnnames = data_columnnames, + no_variables = num_variables, + is_continuous = TRUE + ) + + results = list() + + # Parameter names: diagonal = "Var (precision)", off-diagonal = "Var1-Var2" + diag_names = paste0(data_columnnames, " (precision)") + + edge_names = character() + for (i in 1:(num_variables - 1)) { + for (j in (i + 1):num_variables) { + edge_names = c(edge_names, paste0(data_columnnames[i], "-", data_columnnames[j])) + } + } + + # Summarize MCMC chains + summary_list = summarize_fit(out, edge_selection = edge_selection) + main_summary = summary_list$main[, -1] + pairwise_summary = summary_list$pairwise[, -1] + + rownames(main_summary) = diag_names + rownames(pairwise_summary) = edge_names + + results$posterior_summary_main = main_summary + results$posterior_summary_pairwise = pairwise_summary + + if (edge_selection) { + indicator_summary = summarize_indicator(out, param_names = edge_names)[, -1] + rownames(indicator_summary) = edge_names + results$posterior_summary_indicator = indicator_summary + + has_sbm = identical(edge_prior, "Stochastic-Block") && + "allocations" %in% names(out[[1]]) + + if (has_sbm) { + sbm_convergence = summarize_alloc_pairs( + allocations = lapply(out, `[[`, "allocations"), + node_names = data_columnnames + ) + results$posterior_summary_pairwise_allocations = sbm_convergence$sbm_summary + } + } + + # Posterior mean matrices + results$posterior_mean_main = matrix(main_summary$mean, + nrow = num_variables, ncol = 1, + dimnames = list(data_columnnames, "precision_diag")) + + results$posterior_mean_pairwise = matrix(0, + nrow = num_variables, ncol = num_variables, + dimnames = list(data_columnnames, data_columnnames)) + results$posterior_mean_pairwise[lower.tri(results$posterior_mean_pairwise)] = pairwise_summary$mean + results$posterior_mean_pairwise = results$posterior_mean_pairwise + t(results$posterior_mean_pairwise) + + if (edge_selection) { + indicator_means = indicator_summary$mean + results$posterior_mean_indicator = matrix(0, + nrow = num_variables, ncol = num_variables, + dimnames = list(data_columnnames, data_columnnames)) + results$posterior_mean_indicator[upper.tri(results$posterior_mean_indicator)] = indicator_means + results$posterior_mean_indicator[lower.tri(results$posterior_mean_indicator)] = + t(results$posterior_mean_indicator)[lower.tri(results$posterior_mean_indicator)] + + if (has_sbm) { + sbm_convergence = summarize_alloc_pairs( + allocations = lapply(out, `[[`, "allocations"), + node_names = data_columnnames + ) + results$posterior_mean_coclustering_matrix = sbm_convergence$co_occur_matrix + + sbm_summary = posterior_summary_SBM( + allocations = lapply(out, `[[`, "allocations"), + arguments = arguments + ) + results$posterior_mean_allocations = sbm_summary$allocations_mean + results$posterior_mode_allocations = sbm_summary$allocations_mode + results$posterior_num_blocks = sbm_summary$blocks + } + } + + results$arguments = arguments + class(results) = "bgms" + + results$raw_samples = list( + main = lapply(out, function(chain) chain$main_samples), + pairwise = lapply(out, function(chain) chain$pairwise_samples), + indicator = if (edge_selection) lapply(out, function(chain) chain$indicator_samples) else NULL, + allocations = if (edge_selection && identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) lapply(out, `[[`, "allocations") else NULL, + nchains = length(out), + niter = nrow(out[[1]]$main_samples), + parameter_names = list( + main = diag_names, + pairwise = edge_names, + indicator = if (edge_selection) edge_names else NULL, + allocations = if (identical(edge_prior, "Stochastic-Block")) data_columnnames else NULL + ) + ) + + return(results) +} + + +# Transform sample_ggm output to match the old backend format. +# +# The GGM backend returns a `samples` matrix (params x iters) where params +# are the upper triangle of the precision matrix stored column-by-column: +# (0,0), (0,1), (1,1), (0,2), (1,2), (2,2), ... +# We split these into diagonal elements ("main") and off-diagonal ("pairwise"). +transform_ggm_backend_output = function(out, p) { + # Build index maps for upper triangle (column-major) + diag_idx = integer(p) + offdiag_idx = integer(p * (p - 1) / 2) + pos = 0L + d = 0L + od = 0L + for (j in seq_len(p)) { + for (i in seq_len(j)) { + pos = pos + 1L + if (i == j) { + d = d + 1L + diag_idx[d] = pos + } else { + od = od + 1L + offdiag_idx[od] = pos + } + } + } + + lapply(out, function(chain) { + samples_t = t(chain$samples) # (params x iters) -> (iters x params) + + res = list( + main_samples = samples_t[, diag_idx, drop = FALSE], + pairwise_samples = samples_t[, offdiag_idx, drop = FALSE], + userInterrupt = isTRUE(chain$userInterrupt), + chain_id = chain$chain_id + ) + + if (!is.null(chain$indicator_samples)) { + indic_t = t(chain$indicator_samples) # (params x iters) -> (iters x params) + res$indicator_samples = indic_t[, offdiag_idx, drop = FALSE] + } + + if (!is.null(chain$allocation_samples)) { + res$allocations = t(chain$allocation_samples) # (variables x iters) -> (iters x variables) + } + + res + }) +} + + +# Transform sample_omrf output to match the old backend format. +# +# The new backend returns a flat `samples` matrix (params x iters) containing +# all main + pairwise parameters concatenated. The old backend stores separate +# `main_samples` and `pairwise_samples` matrices (iters x params). NUTS fields +# also differ in naming. This function bridges the gap so that +# `prepare_output_bgm()` can process both backends identically. +transform_new_backend_output = function(out, num_thresholds) { + lapply(out, function(chain) { + samples_t = t(chain$samples) # (params x iters) -> (iters x params) + n_params = ncol(samples_t) + + res = list( + main_samples = samples_t[, seq_len(num_thresholds), drop = FALSE], + pairwise_samples = samples_t[, seq(num_thresholds + 1, n_params), drop = FALSE], + userInterrupt = isTRUE(chain$userInterrupt), + chain_id = chain$chain_id + ) + + if (!is.null(chain$indicator_samples)) { + res$indicator_samples = t(chain$indicator_samples) + } + + # Rename NUTS diagnostics to match old backend convention (trailing __) + if (!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth + if (!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if (!is.null(chain$energy)) res[["energy__"]] = chain$energy + + res + }) +} + + # Generate names for bgmCompare parameters generate_param_names_bgmCompare = function( data_columnnames, diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index ff1811b..b82898c 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -179,6 +179,67 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// sample_ggm +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); + Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type dirichlet_alpha(dirichlet_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda)); + return rcpp_result_gen; +END_RCPP +} +// sample_omrf +Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const int edge_selection_start); +RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP edge_selection_startSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type sampler_type(sampler_typeSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); + Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); + Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); + Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_nullable(missing_index_nullableSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type dirichlet_alpha(dirichlet_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); + Rcpp::traits::input_parameter< const double >::type target_acceptance(target_acceptanceSEXP); + Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); + Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); + Rcpp::traits::input_parameter< const int >::type edge_selection_start(edge_selection_startSEXP); + rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start)); + return rcpp_result_gen; +END_RCPP +} // compute_Vn_mfm_sbm arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda); RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) { @@ -201,6 +262,8 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 7}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 17}, + {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/mcmc/adaptive_gradient_sampler.h b/src/mcmc/adaptive_gradient_sampler.h new file mode 100644 index 0000000..02a05ef --- /dev/null +++ b/src/mcmc/adaptive_gradient_sampler.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include +#include +#include +#include "mcmc/base_sampler.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/mcmc_adaptation.h" +#include "mcmc/sampler_config.h" +#include "models/base_model.h" + +/** + * AdaptiveGradientSampler - Base for gradient-based MCMC with warmup adaptation + * + * Shared warmup logic for NUTS and HMC: + * Stage 1: step-size adaptation only + * Stage 2: mass matrix estimation in doubling windows + step-size re-tuning + * Stage 3: final step-size averaging + */ +class AdaptiveGradientSampler : public BaseSampler { +public: + AdaptiveGradientSampler(double step_size, double target_acceptance, int n_warmup) + : step_size_(step_size), + target_acceptance_(target_acceptance), + n_warmup_(n_warmup), + warmup_iteration_(0), + initialized_(false), + step_adapter_(step_size) + { + build_warmup_schedule(n_warmup); + } + + SamplerResult warmup_step(BaseModel& model) override { + if (!initialized_) { + initialize(model); + initialized_ = true; + } + + SamplerResult result = do_gradient_step(model); + + step_adapter_.update(result.accept_prob, target_acceptance_); + step_size_ = step_adapter_.current(); + + if (in_stage2()) { + arma::vec full_params = model.get_full_vectorized_parameters(); + mass_accumulator_->update(full_params); + + if (at_window_end()) { + inv_mass_ = mass_accumulator_->variance(); + mass_accumulator_->reset(); + model.set_inv_mass(inv_mass_); + + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + arma::vec active_inv_mass = model.get_active_inv_mass(); + + double new_eps = heuristic_initial_step_size( + theta, log_post, grad_fn, active_inv_mass, rng, + 0.625, step_size_); + step_size_ = new_eps; + step_adapter_.restart(new_eps); + } + } + + warmup_iteration_++; + return result; + } + + void finalize_warmup() override { + step_size_ = step_adapter_.averaged(); + } + + SamplerResult sample_step(BaseModel& model) override { + return do_gradient_step(model); + } + + double get_step_size() const { return step_size_; } + double get_averaged_step_size() const { return step_adapter_.averaged(); } + const arma::vec& get_inv_mass() const { return inv_mass_; } + +protected: + virtual SamplerResult do_gradient_step(BaseModel& model) = 0; + + double step_size_; + double target_acceptance_; + +private: + void build_warmup_schedule(int n_warmup) { + stage1_end_ = static_cast(0.075 * n_warmup); + stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); + + window_ends_.clear(); + int cur = stage1_end_; + int wsize = 25; + + while (cur < stage3_start_) { + int win = std::min(wsize, stage3_start_ - cur); + window_ends_.push_back(cur + win); + cur += win; + wsize = std::min(wsize * 2, stage3_start_ - cur); + } + } + + bool in_stage2() const { + return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; + } + + bool at_window_end() const { + for (int end : window_ends_) { + if (warmup_iteration_ + 1 == end) return true; + } + return false; + } + + void initialize(BaseModel& model) { + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + inv_mass_ = arma::ones(model.full_parameter_dimension()); + model.set_inv_mass(inv_mass_); + + mass_accumulator_ = std::make_unique( + static_cast(model.full_parameter_dimension())); + + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + step_size_ = heuristic_initial_step_size( + theta, log_post, grad_fn, rng, target_acceptance_); + + step_adapter_.restart(step_size_); + } + + int n_warmup_; + int warmup_iteration_; + bool initialized_; + + DualAveraging step_adapter_; + arma::vec inv_mass_; + std::unique_ptr mass_accumulator_; + + int stage1_end_; + int stage3_start_; + std::vector window_ends_; +}; diff --git a/src/mcmc/base_sampler.h b/src/mcmc/base_sampler.h new file mode 100644 index 0000000..576840c --- /dev/null +++ b/src/mcmc/base_sampler.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include "mcmc/mcmc_utils.h" +#include "mcmc/sampler_config.h" +#include "models/base_model.h" + +/** + * BaseSampler - Abstract base class for MCMC samplers + * + * Provides a unified interface for all MCMC sampling algorithms: + * - Random-walk Metropolis-Hastings + * - Hamiltonian Monte Carlo (HMC) + * - No-U-Turn Sampler (NUTS) + * + * All samplers follow the same workflow: + * 1. warmup_step() during warmup (may adapt parameters) + * 2. finalize_warmup() after warmup completes + * 3. sample_step() during sampling (fixed parameters) + * + * The sampler asks the model for logp/gradient evaluations but owns + * the sampling algorithm logic. + */ +class BaseSampler { +public: + virtual ~BaseSampler() = default; + + /** + * Perform one step during warmup phase + * + * During warmup, samplers may adapt their parameters (step size, + * proposal covariance, mass matrix, etc.) + * + * @param model The model to sample from + * @return SamplerResult with new state and diagnostics + */ + virtual SamplerResult warmup_step(BaseModel& model) = 0; + + /** + * Finalize warmup phase + * + * Called after all warmup iterations complete. Samplers should + * fix their adapted parameters for the sampling phase. + */ + virtual void finalize_warmup() {} + + /** + * Perform one step during sampling phase + * + * Sampling steps use fixed parameters (no adaptation). + * + * @param model The model to sample from + * @return SamplerResult with new state and diagnostics + */ + virtual SamplerResult sample_step(BaseModel& model) = 0; + + /** + * Check if this sampler produces NUTS-style diagnostics + * (tree depth, divergences, energy) + */ + virtual bool has_nuts_diagnostics() const { return false; } +}; diff --git a/src/mcmc/chain_result.h b/src/mcmc/chain_result.h new file mode 100644 index 0000000..d43f756 --- /dev/null +++ b/src/mcmc/chain_result.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include + +/** + * ChainResultNew - Storage for a single MCMC chain's output + * + * Holds samples, diagnostics, and error state for one chain. + * Designed for use with both MH and NUTS/HMC samplers. + */ +class ChainResultNew { + +public: + ChainResultNew() = default; + + // Error handling + bool error = false; + bool userInterrupt = false; + std::string error_msg; + + // Chain identifier + int chain_id = 0; + + // Parameter samples (param_dim × n_iter) + arma::mat samples; + + // Edge indicator samples (n_edges × n_iter), only if edge_selection = true + arma::imat indicator_samples; + bool has_indicators = false; + + // SBM allocation samples (n_variables × n_iter), only if SBM edge prior + arma::imat allocation_samples; + bool has_allocations = false; + + // NUTS/HMC diagnostics (n_iter), only if using NUTS/HMC + arma::ivec treedepth_samples; + arma::ivec divergent_samples; + arma::vec energy_samples; + bool has_nuts_diagnostics = false; + + /** + * Reserve storage for samples + * @param param_dim Number of parameters per sample + * @param n_iter Number of sampling iterations + */ + void reserve(const size_t param_dim, const size_t n_iter) { + samples.set_size(param_dim, n_iter); + } + + /** + * Reserve storage for edge indicator samples + * @param n_edges Number of edges (p * (p - 1) / 2) + * @param n_iter Number of sampling iterations + */ + void reserve_indicators(const size_t n_edges, const size_t n_iter) { + indicator_samples.set_size(n_edges, n_iter); + has_indicators = true; + } + + /** + * Reserve storage for SBM allocation samples + * @param n_variables Number of variables + * @param n_iter Number of sampling iterations + */ + void reserve_allocations(const size_t n_variables, const size_t n_iter) { + allocation_samples.set_size(n_variables, n_iter); + has_allocations = true; + } + + /** + * Reserve storage for NUTS diagnostics + * @param n_iter Number of sampling iterations + */ + void reserve_nuts_diagnostics(const size_t n_iter) { + treedepth_samples.set_size(n_iter); + divergent_samples.set_size(n_iter); + energy_samples.set_size(n_iter); + has_nuts_diagnostics = true; + } + + /** + * Store a parameter sample + * @param iter Iteration index (0-based) + * @param sample Parameter vector + */ + void store_sample(const size_t iter, const arma::vec& sample) { + samples.col(iter) = sample; + } + + /** + * Store edge indicator sample + * @param iter Iteration index (0-based) + * @param indicators Edge indicator vector + */ + void store_indicators(const size_t iter, const arma::ivec& indicators) { + indicator_samples.col(iter) = indicators; + } + + /** + * Store SBM allocation sample + * @param iter Iteration index (0-based) + * @param allocations Allocation vector (1-based cluster labels) + */ + void store_allocations(const size_t iter, const arma::ivec& allocations) { + allocation_samples.col(iter) = allocations; + } + + /** + * Store NUTS diagnostics for one iteration + * @param iter Iteration index (0-based) + * @param tree_depth Tree depth from NUTS + * @param divergent Whether a divergence occurred + * @param energy Final Hamiltonian energy + */ + void store_nuts_diagnostics(const size_t iter, int tree_depth, bool divergent, double energy) { + treedepth_samples(iter) = tree_depth; + divergent_samples(iter) = divergent ? 1 : 0; + energy_samples(iter) = energy; + } +}; + diff --git a/src/mcmc/hmc_sampler.h b/src/mcmc/hmc_sampler.h new file mode 100644 index 0000000..458dcb1 --- /dev/null +++ b/src/mcmc/hmc_sampler.h @@ -0,0 +1,42 @@ +#pragma once + +#include "mcmc/adaptive_gradient_sampler.h" +#include "mcmc/mcmc_hmc.h" + +/** + * HMCSampler - Hamiltonian Monte Carlo + * + * Fixed-length leapfrog integration. Inherits warmup adaptation + * (step size + diagonal mass matrix) from AdaptiveGradientSampler. + */ +class HMCSampler : public AdaptiveGradientSampler { +public: + explicit HMCSampler(const SamplerConfig& config) + : AdaptiveGradientSampler(config.initial_step_size, config.target_acceptance, config.no_warmup), + num_leapfrogs_(config.num_leapfrogs) + {} + +protected: + SamplerResult do_gradient_step(BaseModel& model) override { + arma::vec theta = model.get_vectorized_parameters(); + arma::vec inv_mass = model.get_active_inv_mass(); + SafeRNG& rng = model.get_rng(); + + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + SamplerResult result = hmc_sampler( + theta, step_size_, log_post, grad_fn, + num_leapfrogs_, inv_mass, rng); + + model.set_vectorized_parameters(result.state); + return result; + } + +private: + int num_leapfrogs_; +}; diff --git a/src/mcmc/mcmc_adaptation.h b/src/mcmc/mcmc_adaptation.h index 9cefb21..fec3441 100644 --- a/src/mcmc/mcmc_adaptation.h +++ b/src/mcmc/mcmc_adaptation.h @@ -94,7 +94,7 @@ class DiagMassMatrixAccumulator { // === Dynamic Warmup Schedule with Adaptive Windows === -// +// // For edge_selection = FALSE: // Stage 1 (init), Stage 2 (doubling windows), Stage 3a (terminal) // total_warmup = user-specified warmup @@ -108,7 +108,7 @@ class DiagMassMatrixAccumulator { // Warning types: // 0 = none // 1 = warmup extremely short (< 50) -// 2 = core stages using proportional fallback +// 2 = core stages using proportional fallback // 3 = limited proposal SD tuning (edge_selection && warmup < 300) // 4 = Stage 3b skipped (would have < 20 iterations) // @@ -228,7 +228,7 @@ struct WarmupSchedule { bool in_stage3b(int i) const { return !stage3b_skipped && i >= stage3b_start && i < stage3c_start; } bool in_stage3c(int i) const { return enable_selection && !stage3b_skipped && i >= stage3c_start && i < total_warmup; } bool sampling (int i) const { return i >= total_warmup; } - + bool has_warning() const { return warning_type > 0; } bool warmup_extremely_short() const { return warning_type == 1; } bool using_proportional_fallback() const { return warning_type == 2; } diff --git a/src/mcmc/mcmc_leapfrog.cpp b/src/mcmc/mcmc_leapfrog.cpp index ceadcfe..cb7a889 100644 --- a/src/mcmc/mcmc_leapfrog.cpp +++ b/src/mcmc/mcmc_leapfrog.cpp @@ -81,10 +81,11 @@ std::pair leapfrog_memo( arma::vec r_half = r; arma::vec theta_new = theta; - auto grad1 = memo.cached_grad(theta_new); + const arma::vec& grad1 = memo.cached_grad(theta_new); r_half += 0.5 * eps * grad1; + theta_new += eps * (inv_mass_diag % r_half); - auto grad2 = memo.cached_grad(theta_new); + const arma::vec& grad2 = memo.cached_grad(theta_new); r_half += 0.5 * eps * grad2; return {theta_new, r_half}; diff --git a/src/mcmc/mcmc_memoization.h b/src/mcmc/mcmc_memoization.h index c28574e..a5ad694 100644 --- a/src/mcmc/mcmc_memoization.h +++ b/src/mcmc/mcmc_memoization.h @@ -1,143 +1,77 @@ #pragma once #include -#include // for std::unordered_map -#include // for std::function - - - -/** - * Struct: VecHash - * - * A hash function is a way of converting complex input data (like a vector of numbers) - * into a single fixed-size integer value. This hash acts like a digital fingerprint — - * it helps uniquely identify the input and allows fast lookup in data structures - * like hash tables. - * - * In C++, hash tables are implemented using containers like std::unordered_map, - * which require a hash function and an equality comparator for the key type. - * Since arma::vec (a vector of doubles) isn't natively supported as a key, - * this struct defines how to compute a hash for it. - * - * Why use hashing here? - * - We cache (memoize) values like log-posterior or gradient at specific points. - * - To retrieve them quickly later, we need a fast way to index by vector. - * - Hashing gives constant-time access (on average) by turning the vector into a hash code. - * - * This implementation: - * - Initializes a seed based on the vector size. - * - Iteratively incorporates each element's hash into the seed using bitwise operations - * and a constant based on the golden ratio to reduce collisions. - * - * The result is a unique and repeatable hash value that lets us efficiently - * store and retrieve cached evaluations. - * - * Provides a custom hash function for arma::vec so that it can be used as a key - * in standard hash-based containers like std::unordered_map. - * - * Since arma::vec is not hashable by default in C++, this functor allows you - * to compute a combined hash from the individual elements of the vector. - * - * Hashing Strategy: - * - Starts with an initial seed equal to the vector's length. - * - Iteratively combines the hash of each element into the seed using - * a standard hashing recipe (XOR, bit-shift mix, and golden ratio). - * - * This ensures that vectors with similar but not identical elements - * produce distinct hash codes, reducing the risk of collisions. - */ -struct VecHash { - std::size_t operator()(const arma::vec& v) const { - std::size_t seed = v.n_elem; - for (size_t i = 0; i < v.n_elem; i++) { - seed ^= std::hash{}(v[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - - - -/** - * Struct: VecEqual - * - * Provides a custom equality comparator for arma::vec, to be used in conjunction - * with VecHash in unordered_map and similar containers. - * - * By default, arma::vec does not implement equality suitable for use as a key - * in hashed containers. This comparator solves that by relying on Armadillo's - * approx_equal function. - * - * Comparison Strategy: - * - Uses "absdiff" mode to check if the absolute difference between vectors - * is below a specified tolerance. - * - The tolerance is hardcoded as 1e-12, assuming numerical precision errors - * should not cause false inequality. - * - * This is crucial for numerical stability, especially when the same vector - * might be computed multiple times with slight rounding differences. - */ -struct VecEqual { - bool operator()(const arma::vec& a, const arma::vec& b) const { - return arma::approx_equal(a, b, "absdiff", 1e-12); - } -}; - +#include /** * Class: Memoizer * - * A utility class that caches (memoizes) evaluations of the log-posterior and its gradient - * at specific parameter values (theta) to avoid redundant and costly recomputation. + * Single-entry cache for joint log-posterior and gradient evaluations. * - * Usage: - * - Constructed with a log-posterior function and its gradient. - * - Uses hash-based caching (via unordered_map) keyed by arma::vec inputs. - * That means: - * - Each parameter vector (theta) is used as a key. - * - Results are stored in a lookup table so that repeated evaluations can be skipped. - * - A custom VecHash function generates a hash code for arma::vec. - * - A custom VecEqual comparator defines when two vectors are considered equal (with a small tolerance). - * This enables fast retrieval (average constant time) of previously computed values for log_post and grad. + * In NUTS, the typical access pattern within a leapfrog step is: + * 1. cached_grad(theta) — compute gradient (and cache logp as side-effect) + * 2. cached_log_post(theta) — retrieve the already-cached logp * - * Methods: - * - cached_log_post(const arma::vec& theta): - * Returns the cached log-posterior if available, otherwise computes, stores, and returns it. + * A single-entry cache is optimal here because each leapfrog step produces + * a new unique theta: hash-map lookups would almost never hit, and hashing + * an arma::vec element-by-element is expensive. * - * - cached_grad(const arma::vec& theta): - * Returns the cached gradient if available, otherwise computes, stores, and returns it. - * - * Internal Details: - * - Relies on VecHash and VecEqual to use arma::vec as map keys. - * - Assumes high-precision comparison for theta keys (absdiff tolerance 1e-12). + * The joint evaluation function computes both logp and gradient together + * (since models often share most of the computation between the two). */ class Memoizer { public: - std::function log_post; - std::function grad; + using JointFn = std::function(const arma::vec&)>; + + JointFn joint_fn; - std::unordered_map logp_cache; - std::unordered_map grad_cache; + // Single-entry cache + arma::vec cached_theta; + double cached_logp_val; + arma::vec cached_grad_val; + bool has_cache = false; + /** + * Construct from separate log_post and grad functions. + * Calls them independently (backward-compatible). + */ Memoizer( const std::function& lp, const std::function& gr - ) : log_post(lp), grad(gr) {} + ) : joint_fn([lp, gr](const arma::vec& theta) -> std::pair { + arma::vec g = gr(theta); + double v = lp(theta); + return {v, std::move(g)}; + }) {} + + /** + * Construct from a joint function that computes both at once. + */ + explicit Memoizer(JointFn jf) : joint_fn(std::move(jf)) {} double cached_log_post(const arma::vec& theta) { - auto it = logp_cache.find(theta); - if (it != logp_cache.end()) return it->second; - double val = log_post(theta); - logp_cache[theta] = val; - return val; + ensure_cached(theta); + return cached_logp_val; + } + + const arma::vec& cached_grad(const arma::vec& theta) { + ensure_cached(theta); + return cached_grad_val; } - arma::vec cached_grad(const arma::vec& theta) { - auto it = grad_cache.find(theta); - if (it != grad_cache.end()) return it->second; - arma::vec val = grad(theta); - grad_cache[theta] = val; - return val; +private: + void ensure_cached(const arma::vec& theta) { + if (has_cache && + theta.n_elem == cached_theta.n_elem && + std::memcmp(theta.memptr(), cached_theta.memptr(), + theta.n_elem * sizeof(double)) == 0) { + return; + } + auto [lp, gr] = joint_fn(theta); + cached_theta = theta; + cached_logp_val = lp; + cached_grad_val = std::move(gr); + has_cache = true; } }; \ No newline at end of file diff --git a/src/mcmc/mcmc_nuts.cpp b/src/mcmc/mcmc_nuts.cpp index e47c770..070b6de 100644 --- a/src/mcmc/mcmc_nuts.cpp +++ b/src/mcmc/mcmc_nuts.cpp @@ -260,5 +260,83 @@ SamplerResult nuts_sampler( diag->divergent = any_divergence; diag->energy = energy; + return {theta, accept_prob, diag}; +} + + +SamplerResult nuts_sampler_joint( + const arma::vec& init_theta, + double step_size, + const std::function(const arma::vec&)>& joint_fn, + const arma::vec& inv_mass_diag, + SafeRNG& rng, + int max_depth +) { + Memoizer memo(joint_fn); + bool any_divergence = false; + + arma::vec r0 = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, init_theta.n_elem); + auto logp0 = memo.cached_log_post(init_theta); + double kin0 = kinetic_energy(r0, inv_mass_diag); + double joint0 = logp0 - kin0; + double log_u = log(runif(rng)) + joint0; + arma::vec theta_min = init_theta, r_min = r0; + arma::vec theta_plus = init_theta, r_plus = r0; + arma::vec theta = init_theta; + arma::vec r = r0; + int j = 0; + int n = 1, s = 1; + + double alpha = 0.5; + int n_alpha = 1; + + while (s == 1 && j < max_depth) { + int v = runif(rng) < 0.5 ? -1 : 1; + + BuildTreeResult result; + if (v == -1) { + result = build_tree( + theta_min, r_min, log_u, v, j, step_size, init_theta, r0, logp0, kin0, + memo, inv_mass_diag, rng + ); + theta_min = result.theta_min; + r_min = result.r_min; + } else { + result = build_tree( + theta_plus, r_plus, log_u, v, j, step_size, init_theta, r0, logp0, kin0, + memo, inv_mass_diag, rng + ); + theta_plus = result.theta_plus; + r_plus = result.r_plus; + } + + any_divergence = any_divergence || result.divergent; + alpha = result.alpha; + n_alpha = result.n_alpha; + + if (result.s_prime == 1) { + double prob = static_cast(result.n_prime) / static_cast(n); + if (runif(rng) < prob) { + theta = result.theta_prime; + r = result.r_prime; + } + } + bool no_uturn = !is_uturn(theta_min, theta_plus, r_min, r_plus, inv_mass_diag); + s = result.s_prime * no_uturn; + n += result.n_prime; + j++; + } + + double accept_prob = alpha / static_cast(n_alpha); + + auto logp_final = memo.cached_log_post(theta); + double kin_final = kinetic_energy(r, inv_mass_diag); + double energy = -logp_final + kin_final; + + auto diag = std::make_shared(); + diag->tree_depth = j; + diag->divergent = any_divergence; + diag->energy = energy; + return {theta, accept_prob, diag}; } \ No newline at end of file diff --git a/src/mcmc/mcmc_nuts.h b/src/mcmc/mcmc_nuts.h index 099e1c4..a730684 100644 --- a/src/mcmc/mcmc_nuts.h +++ b/src/mcmc/mcmc_nuts.h @@ -56,4 +56,12 @@ SamplerResult nuts_sampler(const arma::vec& init_theta, const std::function& grad, const arma::vec& inv_mass_diag, SafeRNG& rng, - int max_depth = 10); \ No newline at end of file + int max_depth = 10); + +SamplerResult nuts_sampler_joint( + const arma::vec& init_theta, + double step_size, + const std::function(const arma::vec&)>& joint_fn, + const arma::vec& inv_mass_diag, + SafeRNG& rng, + int max_depth = 10); \ No newline at end of file diff --git a/src/mcmc/mcmc_runner.cpp b/src/mcmc/mcmc_runner.cpp new file mode 100644 index 0000000..a4a4c14 --- /dev/null +++ b/src/mcmc/mcmc_runner.cpp @@ -0,0 +1,237 @@ +#include "mcmc/mcmc_runner.h" + +#include "mcmc/nuts_sampler.h" +#include "mcmc/hmc_sampler.h" +#include "mcmc/mh_sampler.h" + + +std::unique_ptr create_sampler(const SamplerConfig& config) { + if (config.sampler_type == "nuts") { + return std::make_unique(config, config.no_warmup); + } else if (config.sampler_type == "hmc" || config.sampler_type == "hamiltonian-mc") { + return std::make_unique(config); + } else if (config.sampler_type == "mh" || config.sampler_type == "adaptive-metropolis") { + return std::make_unique(config); + } else { + Rcpp::stop("Unknown sampler_type: '%s'", config.sampler_type.c_str()); + } +} + + +void run_mcmc_chain( + ChainResultNew& chain_result, + BaseModel& model, + BaseEdgePrior& edge_prior, + const SamplerConfig& config, + const int chain_id, + ProgressManager& pm +) { + chain_result.chain_id = chain_id + 1; + + const int edge_start = config.get_edge_selection_start(); + + auto sampler = create_sampler(config); + + // Warmup phase + for (int iter = 0; iter < config.no_warmup; ++iter) { + + if (config.na_impute && model.has_missing_data()) { + model.impute_missing(); + } + + if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { + model.update_edge_indicators(); + } + + sampler->warmup_step(model); + + if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { + edge_prior.update( + model.get_edge_indicators(), + model.get_inclusion_probability(), + model.get_num_variables(), + model.get_num_pairwise(), + model.get_rng() + ); + } + + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + return; + } + } + + sampler->finalize_warmup(); + + // Activate edge selection mode + if (config.edge_selection && model.has_edge_selection()) { + model.set_edge_selection_active(true); + model.initialize_graph(); + } + + // Sampling phase + for (int iter = 0; iter < config.no_iter; ++iter) { + + if (config.na_impute && model.has_missing_data()) { + model.impute_missing(); + } + + if (config.edge_selection && model.has_edge_selection()) { + model.update_edge_indicators(); + } + + SamplerResult result = sampler->sample_step(model); + + if (config.edge_selection && model.has_edge_selection()) { + edge_prior.update( + model.get_edge_indicators(), + model.get_inclusion_probability(), + model.get_num_variables(), + model.get_num_pairwise(), + model.get_rng() + ); + } + + if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { + auto* diag = dynamic_cast(result.diagnostics.get()); + if (diag) { + chain_result.store_nuts_diagnostics(iter, diag->tree_depth, diag->divergent, diag->energy); + } + } + + chain_result.store_sample(iter, model.get_full_vectorized_parameters()); + + if (chain_result.has_indicators) { + chain_result.store_indicators(iter, model.get_vectorized_indicator_parameters()); + } + + if (chain_result.has_allocations && edge_prior.has_allocations()) { + chain_result.store_allocations(iter, edge_prior.get_allocations()); + } + + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + return; + } + } +} + + +void MCMCChainRunner::operator()(std::size_t begin, std::size_t end) { + for (std::size_t i = begin; i < end; ++i) { + ChainResultNew& chain_result = results_[i]; + BaseModel& model = *models_[i]; + BaseEdgePrior& edge_prior = *edge_priors_[i]; + model.set_seed(config_.seed + static_cast(i)); + + try { + run_mcmc_chain(chain_result, model, edge_prior, config_, static_cast(i), pm_); + } catch (std::exception& e) { + chain_result.error = true; + chain_result.error_msg = e.what(); + } catch (...) { + chain_result.error = true; + chain_result.error_msg = "Unknown error"; + } + } +} + + +std::vector run_mcmc_sampler( + BaseModel& model, + BaseEdgePrior& edge_prior, + const SamplerConfig& config, + const int no_chains, + const int no_threads, + ProgressManager& pm +) { + const bool has_nuts_diag = (config.sampler_type == "nuts"); + const bool has_sbm_alloc = edge_prior.has_allocations() || + (config.edge_selection && dynamic_cast(&edge_prior) != nullptr); + + std::vector results(no_chains); + for (int c = 0; c < no_chains; ++c) { + results[c].reserve(model.full_parameter_dimension(), config.no_iter); + + if (config.edge_selection) { + size_t n_edges = model.get_vectorized_indicator_parameters().n_elem; + results[c].reserve_indicators(n_edges, config.no_iter); + } + + if (has_sbm_alloc) { + results[c].reserve_allocations(model.get_num_variables(), config.no_iter); + } + + if (has_nuts_diag) { + results[c].reserve_nuts_diagnostics(config.no_iter); + } + } + + if (no_threads > 1) { + std::vector> models; + std::vector> edge_priors; + models.reserve(no_chains); + edge_priors.reserve(no_chains); + for (int c = 0; c < no_chains; ++c) { + models.push_back(model.clone()); + models[c]->set_seed(config.seed + c); + edge_priors.push_back(edge_prior.clone()); + } + + MCMCChainRunner runner(results, models, edge_priors, config, pm); + tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); + RcppParallel::parallelFor(0, static_cast(no_chains), runner); + + } else { + model.set_seed(config.seed); + for (int c = 0; c < no_chains; ++c) { + auto chain_model = model.clone(); + chain_model->set_seed(config.seed + c); + auto chain_edge_prior = edge_prior.clone(); + run_mcmc_chain(results[c], *chain_model, *chain_edge_prior, config, c, pm); + } + } + + return results; +} + + +Rcpp::List convert_results_to_list(const std::vector& results) { + Rcpp::List output(results.size()); + + for (size_t i = 0; i < results.size(); ++i) { + const ChainResultNew& chain = results[i]; + Rcpp::List chain_list; + + chain_list["chain_id"] = chain.chain_id; + + if (chain.error) { + chain_list["error"] = true; + chain_list["error_msg"] = chain.error_msg; + } else { + chain_list["error"] = false; + chain_list["samples"] = chain.samples; + chain_list["userInterrupt"] = chain.userInterrupt; + + if (chain.has_indicators) { + chain_list["indicator_samples"] = chain.indicator_samples; + } + + if (chain.has_allocations) { + chain_list["allocation_samples"] = chain.allocation_samples; + } + + if (chain.has_nuts_diagnostics) { + chain_list["treedepth"] = chain.treedepth_samples; + chain_list["divergent"] = chain.divergent_samples; + chain_list["energy"] = chain.energy_samples; + } + } + + output[i] = chain_list; + } + + return output; +} diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h new file mode 100644 index 0000000..df7e78f --- /dev/null +++ b/src/mcmc/mcmc_runner.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "models/base_model.h" +#include "mcmc/chain_result.h" +#include "priors/edge_prior.h" +#include "utils/progress_manager.h" +#include "mcmc/sampler_config.h" +#include "mcmc/base_sampler.h" +#include "mcmc/mcmc_utils.h" + + +/// Create a sampler matching config.sampler_type. +std::unique_ptr create_sampler(const SamplerConfig& config); + +/// Run a single MCMC chain (warmup + sampling) writing into chain_result. +void run_mcmc_chain( + ChainResultNew& chain_result, + BaseModel& model, + BaseEdgePrior& edge_prior, + const SamplerConfig& config, + int chain_id, + ProgressManager& pm +); + + +/// Worker struct for TBB parallel chain execution. +struct MCMCChainRunner : public RcppParallel::Worker { + std::vector& results_; + std::vector>& models_; + std::vector>& edge_priors_; + const SamplerConfig& config_; + ProgressManager& pm_; + + MCMCChainRunner( + std::vector& results, + std::vector>& models, + std::vector>& edge_priors, + const SamplerConfig& config, + ProgressManager& pm + ) : + results_(results), + models_(models), + edge_priors_(edge_priors), + config_(config), + pm_(pm) + {} + + void operator()(std::size_t begin, std::size_t end); +}; + + +/// Run multi-chain MCMC (parallel or sequential based on no_threads). +std::vector run_mcmc_sampler( + BaseModel& model, + BaseEdgePrior& edge_prior, + const SamplerConfig& config, + int no_chains, + int no_threads, + ProgressManager& pm +); + +/// Convert chain results to Rcpp::List for return to R. +Rcpp::List convert_results_to_list(const std::vector& results); diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 96ca6a0..d500c58 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -59,7 +59,7 @@ double heuristic_initial_step_size( double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum and evaluate arma::vec r = arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -131,7 +131,7 @@ double heuristic_initial_step_size( ) { double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum from N(0, M) where M = diag(1/inv_mass_diag) arma::vec r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); diff --git a/src/mcmc/mh_sampler.h b/src/mcmc/mh_sampler.h new file mode 100644 index 0000000..3902f73 --- /dev/null +++ b/src/mcmc/mh_sampler.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include "mcmc/base_sampler.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/sampler_config.h" +#include "models/base_model.h" + +/** + * MHSampler - Metropolis-Hastings sampler + * + * Delegates to the model's component-wise MH updates. The model + * handles proposal adaptation internally during warmup. + * + * This is a thin wrapper that provides a uniform interface consistent + * with other samplers (NUTS, HMC), but the actual sampling logic + * (component-wise updates, Gibbs sweeps, etc.) is model-specific. + */ +class MHSampler : public BaseSampler { +public: + /** + * Construct MH sampler with configuration + * @param config Sampler configuration + */ + explicit MHSampler(const SamplerConfig& config) + : no_warmup_(config.no_warmup), + warmup_iteration_(0) + {} + + /** + * Perform one MH step during warmup + * + * The model handles proposal adaptation internally. + */ + SamplerResult warmup_step(BaseModel& model) override { + warmup_iteration_++; + model.do_one_mh_step(); + + SamplerResult result; + result.accept_prob = 1.0; // Not tracked for component-wise MH + return result; + } + + /** + * Finalize warmup phase + * + * Nothing to do - model handles adaptation internally. + */ + void finalize_warmup() override { + // Model handles proposal finalization internally + } + + /** + * Perform one MH step during sampling + */ + SamplerResult sample_step(BaseModel& model) override { + model.do_one_mh_step(); + + SamplerResult result; + result.accept_prob = 1.0; + return result; + } + +private: + int no_warmup_; + int warmup_iteration_; +}; diff --git a/src/mcmc/nuts_sampler.h b/src/mcmc/nuts_sampler.h new file mode 100644 index 0000000..d27a895 --- /dev/null +++ b/src/mcmc/nuts_sampler.h @@ -0,0 +1,44 @@ +#pragma once + +#include "mcmc/adaptive_gradient_sampler.h" +#include "mcmc/mcmc_nuts.h" + +/** + * NUTSSampler - No-U-Turn Sampler + * + * Adaptive tree-depth leapfrog integration. Inherits warmup adaptation + * (step size + diagonal mass matrix) from AdaptiveGradientSampler. + */ +class NUTSSampler : public AdaptiveGradientSampler { +public: + explicit NUTSSampler(const SamplerConfig& config, int n_warmup = 1000) + : AdaptiveGradientSampler(config.initial_step_size, config.target_acceptance, n_warmup), + max_tree_depth_(config.max_tree_depth) + {} + + bool has_nuts_diagnostics() const override { return true; } + +protected: + SamplerResult do_gradient_step(BaseModel& model) override { + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + auto joint_fn = [&model](const arma::vec& params) + -> std::pair { + return model.logp_and_gradient(params); + }; + + arma::vec active_inv_mass = model.get_active_inv_mass(); + + SamplerResult result = nuts_sampler_joint( + theta, step_size_, joint_fn, + active_inv_mass, rng, max_tree_depth_ + ); + + model.set_vectorized_parameters(result.state); + return result; + } + +private: + int max_tree_depth_; +}; diff --git a/src/mcmc/sampler_config.h b/src/mcmc/sampler_config.h new file mode 100644 index 0000000..55a2e62 --- /dev/null +++ b/src/mcmc/sampler_config.h @@ -0,0 +1,48 @@ +#pragma once + +#include + +/** + * SamplerConfig - Configuration for MCMC sampling + * + * Holds all settings for the generic MCMC runner, including: + * - Sampler type selection (MH, NUTS, HMC) + * - Iteration counts + * - NUTS/HMC specific parameters + * - Edge selection settings + */ +struct SamplerConfig { + // Sampler type: "adaptive_metropolis", "nuts", "hmc" + std::string sampler_type = "adaptive_metropolis"; + + // Iteration counts + int no_iter = 1000; + int no_warmup = 500; + + // NUTS/HMC parameters + int max_tree_depth = 10; + int num_leapfrogs = 10; // For HMC only + double initial_step_size = 0.1; + double target_acceptance = 0.8; + + // Edge selection settings + bool edge_selection = false; + int edge_selection_start = -1; // -1 = no_warmup (default, start at sampling) + + // Missing data imputation + bool na_impute = false; + + // Random seed + int seed = 42; + + // Constructor with defaults + SamplerConfig() = default; + + // Get actual edge selection start iteration + int get_edge_selection_start() const { + if (edge_selection_start < 0) { + return no_warmup; // Default: start at beginning of sampling + } + return edge_selection_start; + } +}; diff --git a/src/models/adaptive_metropolis.h b/src/models/adaptive_metropolis.h new file mode 100644 index 0000000..578f5d1 --- /dev/null +++ b/src/models/adaptive_metropolis.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include + +class AdaptiveProposal { + +public: + + AdaptiveProposal(size_t num_params, size_t adaptation_window = 50, double target_accept = 0.44) { + proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD, need to tweak this somehow? + acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros); + adaptation_window_ = adaptation_window; + target_accept_ = target_accept; + } + + double get_proposal_sd(size_t param_index) const { + validate_index(param_index); + return proposal_sds_[param_index]; + } + + void update_proposal_sd(size_t param_index) { + + if (!adapting_) { + return; + } + + double current_sd = get_proposal_sd(param_index); + double observed_acceptance_probability = acceptance_counts_[param_index] / static_cast(iterations_ + 1); + double rm_weight = std::pow(iterations_, -decay_rate_); + + // Robbins-Monro update step + double updated_sd = current_sd + (observed_acceptance_probability - target_accept_) * rm_weight; + updated_sd = std::clamp(updated_sd, rm_lower_bound, rm_upper_bound); + + proposal_sds_(param_index) = updated_sd; + } + + void increment_accepts(size_t param_index) { + validate_index(param_index); + acceptance_counts_[param_index]++; + } + + void increment_iteration() { + iterations_++; + if (iterations_ >= adaptation_window_) { + adapting_ = false; + } + } + +private: + arma::vec proposal_sds_; + arma::ivec acceptance_counts_; + int iterations_ = 0, + adaptation_window_; + double target_accept_ = 0.44, + decay_rate_ = 0.75, + rm_lower_bound = 0.001, + rm_upper_bound = 2.0; + bool adapting_ = true; + + void validate_index(size_t index) const { + if (index >= proposal_sds_.n_elem) { + throw std::out_of_range("Parameter index out of range"); + } + } + +}; diff --git a/src/models/base_model.h b/src/models/base_model.h new file mode 100644 index 0000000..53ff6e3 --- /dev/null +++ b/src/models/base_model.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include +#include + +// Forward declarations +struct SamplerResult; +struct SafeRNG; + +class BaseModel { +public: + virtual ~BaseModel() = default; + + // Capability queries + virtual bool has_gradient() const { return false; } + virtual bool has_adaptive_mh() const { return false; } + virtual bool has_nuts() const { return has_gradient(); } + virtual bool has_edge_selection() const { return false; } + + // Core methods (to be overridden by derived classes) + virtual double logp(const arma::vec& parameters) = 0; + + virtual arma::vec gradient(const arma::vec& parameters) { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + throw std::runtime_error("Gradient method must be implemented in derived class"); + } + + virtual std::pair logp_and_gradient( + const arma::vec& parameters) { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + return {logp(parameters), gradient(parameters)}; + } + + // For Metropolis-Hastings (model handles parameter groups internally) + virtual void do_one_mh_step() { + throw std::runtime_error("do_one_mh_step method must be implemented in derived class"); + } + + // Edge selection (for models with spike-and-slab priors) + virtual void update_edge_indicators() { + throw std::runtime_error("update_edge_indicators not implemented for this model"); + } + + virtual arma::vec get_vectorized_parameters() const { + throw std::runtime_error("get_vectorized_parameters method must be implemented in derived class"); + } + + virtual void set_vectorized_parameters(const arma::vec& parameters) { + throw std::runtime_error("set_vectorized_parameters method must be implemented in derived class"); + } + + virtual arma::ivec get_vectorized_indicator_parameters() { + throw std::runtime_error("get_vectorized_indicator_parameters method must be implemented in derived class"); + } + + // Full parameter dimension (for fixed-size output, includes all possible params) + virtual size_t full_parameter_dimension() const { + return parameter_dimension(); // Default: same as active dimension + } + + // Get full vectorized parameters (zeros for inactive, for consistent output) + virtual arma::vec get_full_vectorized_parameters() const { + throw std::runtime_error("get_full_vectorized_parameters must be implemented in derived class"); + } + + // Return dimensionality of the active parameter space + virtual size_t parameter_dimension() const = 0; + + virtual void set_seed(int seed) { + throw std::runtime_error("set_seed method must be implemented in derived class"); + } + + virtual std::unique_ptr clone() const { + throw std::runtime_error("clone method must be implemented in derived class"); + } + + // RNG access for samplers + virtual SafeRNG& get_rng() { + throw std::runtime_error("get_rng method must be implemented in derived class"); + } + + // Step size for gradient-based samplers + virtual void set_step_size(double step_size) { step_size_ = step_size; } + virtual double get_step_size() const { return step_size_; } + + // Inverse mass matrix for HMC/NUTS + virtual void set_inv_mass(const arma::vec& inv_mass) { inv_mass_ = inv_mass; } + virtual const arma::vec& get_inv_mass() const { return inv_mass_; } + + // Get active inverse mass (for models with edge selection, may be subset) + virtual arma::vec get_active_inv_mass() const { return inv_mass_; } + + // Edge selection activation + virtual void set_edge_selection_active(bool active) { + (void)active; // Default: no-op + } + + // Initialize graph structure for edge selection + virtual void initialize_graph() { + // Default: no-op + } + + // Missing data imputation + virtual bool has_missing_data() const { return false; } + virtual void impute_missing() { + // Default: no-op + } + + // Edge prior support: models expose these so the external edge prior + // class can read current indicators and update inclusion probabilities. + virtual const arma::imat& get_edge_indicators() const { + throw std::runtime_error("get_edge_indicators not implemented for this model"); + } + virtual arma::mat& get_inclusion_probability() { + throw std::runtime_error("get_inclusion_probability not implemented for this model"); + } + virtual int get_num_variables() const { + throw std::runtime_error("get_num_variables not implemented for this model"); + } + virtual int get_num_pairwise() const { + throw std::runtime_error("get_num_pairwise not implemented for this model"); + } + +protected: + BaseModel() = default; + double step_size_ = 0.1; + arma::vec inv_mass_; +}; diff --git a/src/models/ggm/cholupdate.cpp b/src/models/ggm/cholupdate.cpp new file mode 100644 index 0000000..81331be --- /dev/null +++ b/src/models/ggm/cholupdate.cpp @@ -0,0 +1,125 @@ +#include "models/ggm/cholupdate.h" + +extern "C" { + +// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1876-L1883 +static inline double hypote(double x, double y) { +/* stable computation of sqrt(x^2 + y^2) */ + double t; + x = fabs(x);y=fabs(y); + if (y>x) { t = x;x = y; y = t;} + if (x==0) return(y); else t = y/x; + return(x*sqrt(1+t*t)); +} /* hypote */ + +// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1956 +void chol_up(double *R,double *u, int *n,int *up,double *eps) { +/* Rank 1 update of a cholesky factor. Works as follows: + + [up=1] R'R + uu' = [u,R'][u,R']' = [u,R']Q'Q[u,R']', and then uses Givens rotations to + construct Q such that Q[u,R']' = [0,R1']'. Hence R1'R1 = R'R + uu'. The construction + operates from first column to last. + + [up=0] uses an almost identical sequence, but employs hyperbolic rotations + in place of Givens. See Golub and van Loan (2013, 4e 6.5.4) + + Givens rotations are of form [c,-s] where c = cos(theta), s = sin(theta). + [s,c] + + Assumes R upper triangular, and that it is OK to use first two columns + below diagonal as temporary strorage for Givens rotations (the storage is + needed to ensure algorithm is column oriented). + + For downdate returns a negative value in R[1] (R[1,0]) if not +ve definite. +*/ + double c0,s0,*c,*s,z,*x,z0,*c1; + int j,j1,n1; + n1 = *n - 1; + if (*up) for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R */ + z = *u; /* initial element of u */ + x = R + *n * j; /* current column */ + c = R + 2;s = R + *n + 2; /* Storage for first n-2 Givens rotations */ + for (c1=c+j1;c R[j,j] */ + z0 = hypote(z,*x); /* sqrt(z^2+R[j,j]^2) */ + c0 = *x/z0; s0 = z/z0; /* need to zero z */ + /* now apply this rotation and this column is finished (so no need to update z) */ + *x = s0 * z + c0 * *x; + } else for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R for down-dating */ + z = *u; /* initial element of u */ + x = R + *n * j; /* current column */ + c = R + 2;s = R + *n + 2; /* Storage for first n-2 hyperbolic rotations */ + for (c1=c+j1;c R[j,j] */ + z0 = z / *x; /* sqrt(z^2+R[j,j]^2) */ + if (fabs(z0)>=1) { /* downdate not +ve def */ + //Rprintf("j = %d d = %g ",j,z0); + if (*n>1) R[1] = -2.0; + return; /* signals error */ + } + if (z0 > 1 - *eps) z0 = 1 - *eps; + c0 = 1/sqrt(1-z0*z0);s0 = c0 * z0; + /* now apply this rotation and this column is finished (so no need to update z) */ + *x = -s0 * z + c0 * *x; + } + + /* now zero c and s storage */ + c = R + 2;s = R + *n + 2; + for (x = c + *n - 2;c + +void cholesky_update( arma::mat& R, arma::vec& u, double eps = 1e-12); +void cholesky_downdate(arma::mat& R, arma::vec& u, double eps = 1e-12); diff --git a/src/models/ggm/ggm_model.cpp b/src/models/ggm/ggm_model.cpp new file mode 100644 index 0000000..3d7068c --- /dev/null +++ b/src/models/ggm/ggm_model.cpp @@ -0,0 +1,425 @@ +#include "models/ggm/ggm_model.h" +#include "models/adaptive_metropolis.h" +#include "rng/rng_utils.h" +#include "models/ggm/cholupdate.h" + +double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { + return(A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i)); +} + +void GGMModel::get_constants(size_t i, size_t j) { + + double logdet_omega = get_log_det(cholesky_of_precision_); + + double log_adj_omega_ii = logdet_omega + std::log(std::abs(covariance_matrix_(i, i))); + double log_adj_omega_ij = logdet_omega + std::log(std::abs(covariance_matrix_(i, j))); + double log_adj_omega_jj = logdet_omega + std::log(std::abs(covariance_matrix_(j, j))); + + double inv_omega_sub_j1j1 = compute_inv_submatrix_i(covariance_matrix_, i, j, j); + double log_abs_inv_omega_sub_jj = log_adj_omega_ii + std::log(std::abs(inv_omega_sub_j1j1)); + double Phi_q1q = (2 * std::signbit(covariance_matrix_(i, j)) - 1) * std::exp( + (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) + ); + double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); + + constants_[0] = Phi_q1q; + constants_[1] = Phi_q1q1; + constants_[2] = precision_matrix_(i, j) - Phi_q1q * Phi_q1q1; + constants_[3] = Phi_q1q1; + constants_[4] = precision_matrix_(j, j) - Phi_q1q * Phi_q1q; + constants_[5] = constants_[4] + constants_[2] * constants_[2] / (constants_[3] * constants_[3]); + +} + +double GGMModel::constrained_diagonal(const double x) const { + if (x == 0) { + return constants_[5]; + } else { + return constants_[4] + std::pow((x - constants_[2]) / constants_[3], 2); + } +} + +double GGMModel::get_log_det(arma::mat triangular_A) const { + // log-determinant of A'A where A is upper-triangular Cholesky factor + return 2 * arma::accu(arma::log(triangular_A.diag())); +} + +double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { + + double logdet_omega = get_log_det(phi); + // TODO: why not just dot(omega, suf_stat_)? + double trace_prod = arma::accu(omega % suf_stat_); + + double log_likelihood = n_ * (p_ * log(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; + + return log_likelihood; +} + +double GGMModel::log_density_impl_edge(size_t i, size_t j) const { + + // Log-likelihood ratio (not the full log-likelihood) + + double Ui2 = precision_matrix_(i, j) - precision_proposal_(i, j); + double Uj2 = (precision_matrix_(j, j) - precision_proposal_(j, j)) / 2; + + double cc11 = 0 + covariance_matrix_(j, j); + double cc12 = 1 - (covariance_matrix_(i, j) * Ui2 + covariance_matrix_(j, j) * Uj2); + double cc22 = 0 + Ui2 * Ui2 * covariance_matrix_(i, i) + 2 * Ui2 * Uj2 * covariance_matrix_(i, j) + Uj2 * Uj2 * covariance_matrix_(j, j); + + double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + // logdet - (logdet(aOmega_prop) - logdet(aOmega)) + + double trace_prod = -2 * (suf_stat_(j, j) * Uj2 + suf_stat_(i, j) * Ui2); + + double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; + return log_likelihood_ratio; + +} + +double GGMModel::log_density_impl_diag(size_t j) const { + // same as above but for i == j, so Ui2 = 0 + double Uj2 = (precision_matrix_(j, j) - precision_proposal_(j, j)) / 2; + + double cc11 = 0 + covariance_matrix_(j, j); + double cc12 = 1 - covariance_matrix_(j, j) * Uj2; + double cc22 = 0 + Uj2 * Uj2 * covariance_matrix_(j, j); + + double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + double trace_prod = -2 * suf_stat_(j, j) * Uj2; + + double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; + return log_likelihood_ratio; + +} + +void GGMModel::update_edge_parameter(size_t i, size_t j) { + + if (edge_indicators_(i, j) == 0) { + return; // Edge is not included; skip update + } + + get_constants(i, j); + double Phi_q1q = constants_[0]; + double Phi_q1q1 = constants_[1]; + + size_t e = j * (j + 1) / 2 + i; // parameter index in vectorized form (column-major upper triangle) + double proposal_sd = proposal_.get_proposal_sd(e); + + double phi_prop = rnorm(rng_, Phi_q1q, proposal_sd); + double omega_prop_q1q = constants_[2] + constants_[3] * phi_prop; + double omega_prop_qq = constrained_diagonal(omega_prop_q1q); + + // form full proposal matrix for Omega + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = omega_prop_q1q; + precision_proposal_(j, i) = omega_prop_q1q; + precision_proposal_(j, j) = omega_prop_qq; + + double ln_alpha = log_density_impl_edge(i, j); + + ln_alpha += R::dcauchy(precision_proposal_(i, j), 0.0, pairwise_scale_, true); + ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, pairwise_scale_, true); + + if (std::log(runif(rng_)) < ln_alpha) { + // accept proposal + proposal_.increment_accepts(e); + + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); + + + precision_matrix_(i, j) = omega_prop_q1q; + precision_matrix_(j, i) = omega_prop_q1q; + precision_matrix_(j, j) = omega_prop_qq; + + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + + } + + proposal_.update_proposal_sd(e); +} + +void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) +{ + + v2_[0] = omega_ij_old - precision_proposal_(i, j); + v2_[1] = (omega_jj_old - precision_proposal_(j, j)) / 2; + + vf1_[i] = v1_[0]; + vf1_[j] = v1_[1]; + vf2_[i] = v2_[0]; + vf2_[j] = v2_[1]; + + // we now have + // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) + + u1_ = (vf1_ + vf2_) / sqrt(2); + u2_ = (vf1_ - vf2_) / sqrt(2); + + // update phi (2x O(p^2)) + cholesky_update(cholesky_of_precision_, u1_); + cholesky_downdate(cholesky_of_precision_, u2_); + + // update inverse (2x O(p^2)) + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_matrix_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); + + // reset for next iteration + vf1_[i] = 0.0; + vf1_[j] = 0.0; + vf2_[i] = 0.0; + vf2_[j] = 0.0; + +} + +void GGMModel::update_diagonal_parameter(size_t i) { + // Implementation of diagonal parameter update + // 1-3) from before + double logdet_omega = get_log_det(cholesky_of_precision_); + double logdet_omega_sub_ii = logdet_omega + std::log(covariance_matrix_(i, i)); + + size_t e = i * (i + 3) / 2; // parameter index in vectorized form (column-major upper triangle, i==j) + double proposal_sd = proposal_.get_proposal_sd(e); + + double theta_curr = (logdet_omega - logdet_omega_sub_ii) / 2; + double theta_prop = rnorm(rng_, theta_curr, proposal_sd); + + //4) Replace and rebuild omega + precision_proposal_ = precision_matrix_; + precision_proposal_(i, i) = precision_matrix_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); + + double ln_alpha = log_density_impl_diag(i); + + ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); + ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); + ln_alpha += theta_prop - theta_curr; // Jacobian adjustment ? + + if (std::log(runif(rng_)) < ln_alpha) { + + proposal_.increment_accepts(e); + + double omega_ii = precision_matrix_(i, i); + precision_matrix_(i, i) = precision_proposal_(i, i); + + cholesky_update_after_diag(omega_ii, i); + + } + + proposal_.update_proposal_sd(e); +} + +void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i) +{ + + double delta = omega_ii_old - precision_proposal_(i, i); + + bool s = delta > 0; + vf1_(i) = std::sqrt(std::abs(delta)); + + if (s) + cholesky_downdate(cholesky_of_precision_, vf1_); + else + cholesky_update(cholesky_of_precision_, vf1_); + + // update inverse (2x O(p^2)) + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_matrix_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); + + // reset for next iteration + vf1_(i) = 0.0; +} + + +void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { + + size_t e = j * (j + 1) / 2 + i; // parameter index in vectorized form (column-major upper triangle) + double proposal_sd = proposal_.get_proposal_sd(e); + + if (edge_indicators_(i, j) == 1) { + // Propose to turn OFF the edge + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = 0.0; + precision_proposal_(j, i) = 0.0; + + // Update diagonal to preserve positive-definiteness + get_constants(i, j); + precision_proposal_(j, j) = constrained_diagonal(0.0); + + // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); + double ln_alpha = log_density_impl_edge(i, j); + // { + // double ln_alpha_ref = log_likelihood(precision_proposal_) - log_likelihood(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; + // precision_matrix_.print(Rcpp::Rcout, "Current omega:"); + // precision_proposal_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } + + + ln_alpha += std::log(1.0 - inclusion_probability_(i, j)) - std::log(inclusion_probability_(i, j)); + + ln_alpha += R::dnorm(precision_matrix_(i, j) / constants_[3], 0.0, proposal_sd, true) - std::log(constants_[3]); + ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, pairwise_scale_, true); + + if (std::log(runif(rng_)) < ln_alpha) { + + // Store old values for Cholesky update + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); + + // Update omega + precision_matrix_(i, j) = 0.0; + precision_matrix_(j, i) = 0.0; + precision_matrix_(j, j) = precision_proposal_(j, j); + + // Update edge indicator + edge_indicators_(i, j) = 0; + edge_indicators_(j, i) = 0; + + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + + } + + } else { + // Propose to turn ON the edge + double epsilon = rnorm(rng_, 0.0, proposal_sd); + + // Get constants for current state (with edge OFF) + get_constants(i, j); + double omega_prop_ij = constants_[3] * epsilon; + double omega_prop_jj = constrained_diagonal(omega_prop_ij); + + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = omega_prop_ij; + precision_proposal_(j, i) = omega_prop_ij; + precision_proposal_(j, j) = omega_prop_jj; + + // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); + double ln_alpha = log_density_impl_edge(i, j); + // { + // double ln_alpha_ref = log_likelihood(precision_proposal_) - log_likelihood(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; + // precision_matrix_.print(Rcpp::Rcout, "Current omega:"); + // precision_proposal_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } + ln_alpha += std::log(inclusion_probability_(i, j)) - std::log(1.0 - inclusion_probability_(i, j)); + + // Prior change: add slab (Cauchy prior) + ln_alpha += R::dcauchy(omega_prop_ij, 0.0, pairwise_scale_, true); + + // Proposal term: proposed edge value given it was generated from truncated normal + ln_alpha -= R::dnorm(omega_prop_ij / constants_[3], 0.0, proposal_sd, true) - std::log(constants_[3]); + + if (std::log(runif(rng_)) < ln_alpha) { + // Accept: turn ON the edge + proposal_.increment_accepts(e); + + // Store old values for Cholesky update + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); + + // Update omega + precision_matrix_(i, j) = omega_prop_ij; + precision_matrix_(j, i) = omega_prop_ij; + precision_matrix_(j, j) = omega_prop_jj; + + // Update edge indicator + edge_indicators_(i, j) = 1; + edge_indicators_(j, i) = 1; + + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + + } + } +} + +void GGMModel::do_one_mh_step() { + + // Update off-diagonals (upper triangle) + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + update_edge_parameter(i, j); + } + } + + // Update diagonals + for (size_t i = 0; i < p_; ++i) { + update_diagonal_parameter(i); + } + + if (edge_selection_active_) { + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + update_edge_indicator_parameter_pair(i, j); + } + } + } + + // could also be called in the main MCMC loop + proposal_.increment_iteration(); +} + +void GGMModel::initialize_graph() { + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + double p = inclusion_probability_(i, j); + int draw = (runif(rng_) < p) ? 1 : 0; + edge_indicators_(i, j) = draw; + edge_indicators_(j, i) = draw; + if (!draw) { + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = 0.0; + precision_proposal_(j, i) = 0.0; + get_constants(i, j); + precision_proposal_(j, j) = constrained_diagonal(0.0); + + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); + precision_matrix_(j, j) = precision_proposal_(j, j); + precision_matrix_(i, j) = 0.0; + precision_matrix_(j, i) = 0.0; + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + } + } + } +} + + +GGMModel createGGMModelFromR( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection, + const double pairwise_scale +) { + + if (inputFromR.containsElementNamed("n") && inputFromR.containsElementNamed("suf_stat")) { + int n = Rcpp::as(inputFromR["n"]); + arma::mat suf_stat = Rcpp::as(inputFromR["suf_stat"]); + return GGMModel( + n, + suf_stat, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection, + pairwise_scale + ); + } else if (inputFromR.containsElementNamed("X")) { + arma::mat X = Rcpp::as(inputFromR["X"]); + return GGMModel( + X, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection, + pairwise_scale + ); + } else { + throw std::invalid_argument("Input list must contain either 'X' or both 'n' and 'suf_stat'."); + } + +} diff --git a/src/models/ggm/ggm_model.h b/src/models/ggm/ggm_model.h new file mode 100644 index 0000000..7946e77 --- /dev/null +++ b/src/models/ggm/ggm_model.h @@ -0,0 +1,250 @@ +#pragma once + +#include +#include +#include "models/base_model.h" +#include "models/adaptive_metropolis.h" +#include "rng/rng_utils.h" + + +/** + * GGMModel - Gaussian Graphical Model + * + * Bayesian inference on the precision matrix (inverse covariance) of a + * multivariate Gaussian via element-wise Metropolis-Hastings. Edge + * selection uses a spike-and-slab prior with Cauchy slab. + * + * The Cholesky factor of the precision matrix is maintained incrementally + * through rank-1 updates/downdates after each element change. + */ +class GGMModel : public BaseModel { +public: + + // Construct from raw observations + GGMModel( + const arma::mat& observations, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true, + const double pairwise_scale = 2.5 + ) : n_(observations.n_rows), + p_(observations.n_cols), + // TODO: we need to adjust the algorithm to also sample the means! + dim_((p_ * (p_ + 1)) / 2), + suf_stat_(observations.t() * observations), + inclusion_probability_(inclusion_probability), + edge_selection_(edge_selection), + pairwise_scale_(pairwise_scale), + proposal_(AdaptiveProposal(dim_, 500)), + precision_matrix_(arma::eye(p_, p_)), + cholesky_of_precision_(arma::eye(p_, p_)), + inv_cholesky_of_precision_(arma::eye(p_, p_)), + covariance_matrix_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), + vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), + precision_proposal_(arma::mat(p_, p_, arma::fill::none)) + {} + + // Construct from sufficient statistics + GGMModel( + const int n, + const arma::mat& suf_stat, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true, + const double pairwise_scale = 2.5 + ) : n_(n), + p_(suf_stat.n_cols), + dim_((p_ * (p_ + 1)) / 2), + suf_stat_(suf_stat), + inclusion_probability_(inclusion_probability), + edge_selection_(edge_selection), + pairwise_scale_(pairwise_scale), + proposal_(AdaptiveProposal(dim_, 500)), + precision_matrix_(arma::eye(p_, p_)), + cholesky_of_precision_(arma::eye(p_, p_)), + inv_cholesky_of_precision_(arma::eye(p_, p_)), + covariance_matrix_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), + vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), + precision_proposal_(arma::mat(p_, p_, arma::fill::none)) + {} + + GGMModel(const GGMModel& other) + : BaseModel(other), + dim_(other.dim_), + suf_stat_(other.suf_stat_), + n_(other.n_), + p_(other.p_), + inclusion_probability_(other.inclusion_probability_), + edge_selection_(other.edge_selection_), + pairwise_scale_(other.pairwise_scale_), + precision_matrix_(other.precision_matrix_), + cholesky_of_precision_(other.cholesky_of_precision_), + inv_cholesky_of_precision_(other.inv_cholesky_of_precision_), + covariance_matrix_(other.covariance_matrix_), + edge_indicators_(other.edge_indicators_), + vectorized_parameters_(other.vectorized_parameters_), + vectorized_indicator_parameters_(other.vectorized_indicator_parameters_), + proposal_(other.proposal_), + rng_(other.rng_), + precision_proposal_(other.precision_proposal_) + {} + + + void set_adaptive_proposal(AdaptiveProposal proposal) { + proposal_ = proposal; + } + + bool has_gradient() const override { return false; } + bool has_adaptive_mh() const override { return true; } + bool has_edge_selection() const override { return edge_selection_; } + + void set_edge_selection_active(bool active) override { + edge_selection_active_ = active; + } + + void initialize_graph() override; + + // GGM handles edge indicator updates inside do_one_mh_step() + void update_edge_indicators() override {} + + // GGM uses component-wise MH; logp is unused. + double logp(const arma::vec& parameters) override { return 0.0; } + + double log_likelihood(const arma::mat& omega) const { return log_density_impl(omega, arma::chol(omega)); }; + double log_likelihood() const { return log_density_impl(precision_matrix_, cholesky_of_precision_); } + + void do_one_mh_step() override; + + size_t parameter_dimension() const override { return dim_; } + size_t full_parameter_dimension() const override { return dim_; } + + void set_seed(int seed) override { + rng_ = SafeRNG(seed); + } + + arma::vec get_vectorized_parameters() const override { + return extract_upper_triangle(); + } + + arma::vec get_full_vectorized_parameters() const override { + return extract_upper_triangle(); + } + + arma::ivec get_vectorized_indicator_parameters() override { + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + vectorized_indicator_parameters_(e) = edge_indicators_(i, j); + ++e; + } + } + return vectorized_indicator_parameters_; + } + + SafeRNG& get_rng() override { return rng_; } + + const arma::imat& get_edge_indicators() const override { + return edge_indicators_; + } + + arma::mat& get_inclusion_probability() override { + return inclusion_probability_; + } + + int get_num_variables() const override { + return static_cast(p_); + } + + int get_num_pairwise() const override { + return static_cast(p_ * (p_ - 1) / 2); + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + + arma::vec extract_upper_triangle() const { + arma::vec result(dim_); + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + result(e) = precision_matrix_(i, j); + ++e; + } + } + return result; + } + + // Data + size_t n_; + size_t p_; + size_t dim_; + arma::mat suf_stat_; + arma::mat inclusion_probability_; + bool edge_selection_; + bool edge_selection_active_ = false; + double pairwise_scale_; + + // Parameters + arma::mat precision_matrix_, cholesky_of_precision_, inv_cholesky_of_precision_, covariance_matrix_; + arma::imat edge_indicators_; + arma::vec vectorized_parameters_; + arma::ivec vectorized_indicator_parameters_; + + AdaptiveProposal proposal_; + SafeRNG rng_; + + // Scratch space + arma::mat precision_proposal_; + + // Workspace for conditional precision reparametrization. + // [0] Phi_q1q, [1] Phi_q1q1, [2] omega_ij - Phi_q1q*Phi_q1q1, + // [3] Phi_q1q1, [4] omega_jj - Phi_q1q^2, [5] constrained diagonal at x=0. + std::array constants_{}; + + // Work vectors for rank-2 Cholesky update. + // A symmetric rank-2 update A + vf1*vf2' + vf2*vf1' is decomposed into + // two rank-1 updates via u1 = (vf1+vf2)/sqrt(2), u2 = (vf1-vf2)/sqrt(2). + arma::vec v1_ = {0, -1}; + arma::vec v2_ = {0, 0}; + arma::vec vf1_ = arma::zeros(p_); + arma::vec vf2_ = arma::zeros(p_); + arma::vec u1_ = arma::zeros(p_); + arma::vec u2_ = arma::zeros(p_); + + // MH updates + void update_edge_parameter(size_t i, size_t j); + void update_diagonal_parameter(size_t i); + void update_edge_indicator_parameter_pair(size_t i, size_t j); + + // Helpers + void get_constants(size_t i, size_t j); + double compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const; + + // Conditional precision constraint: returns the required diagonal + // value omega_jj that keeps the precision matrix positive definite + // after changing the off-diagonal element to x. + double constrained_diagonal(const double x) const; + + double log_density_impl(const arma::mat& omega, const arma::mat& phi) const; + double log_density_impl_edge(size_t i, size_t j) const; + double log_density_impl_diag(size_t j) const; + double get_log_det(arma::mat triangular_A) const; + void cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j); + void cholesky_update_after_diag(double omega_ii_old, size_t i); +}; + + +GGMModel createGGMModelFromR( + const Rcpp::List& inputFromR, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true, + const double pairwise_scale = 2.5 +); diff --git a/src/models/omrf/omrf_model.cpp b/src/models/omrf/omrf_model.cpp new file mode 100644 index 0000000..7c724dc --- /dev/null +++ b/src/models/omrf/omrf_model.cpp @@ -0,0 +1,1312 @@ +#include +#include "models/omrf/omrf_model.h" +#include "models/adaptive_metropolis.h" +#include "rng/rng_utils.h" +#include "mcmc/mcmc_hmc.h" +#include "mcmc/mcmc_nuts.h" +#include "mcmc/mcmc_rwm.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/mcmc_adaptation.h" +#include "mcmc/mcmc_runner.h" +#include "math/explog_switch.h" +#include "utils/common_helpers.h" +#include "utils/variable_helpers.h" + + +// ============================================================================= +// Constructor +// ============================================================================= + +OMRFModel::OMRFModel( + const arma::imat& observations, + const arma::ivec& num_categories, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const arma::uvec& is_ordinal_variable, + const arma::ivec& baseline_category, + double main_alpha, + double main_beta, + double pairwise_scale, + bool edge_selection +) : + n_(observations.n_rows), + p_(observations.n_cols), + observations_(observations), + num_categories_(num_categories), + is_ordinal_variable_(is_ordinal_variable), + baseline_category_(baseline_category), + inclusion_probability_(inclusion_probability), + main_alpha_(main_alpha), + main_beta_(main_beta), + pairwise_scale_(pairwise_scale), + edge_selection_(edge_selection), + edge_selection_active_(false), + proposal_(AdaptiveProposal(1, 500)), // Will be resized later + step_size_(0.1), + has_missing_(false), + gradient_cache_valid_(false) +{ + // Initialize parameter dimensions + num_main_ = count_num_main_effects_internal(); + num_pairwise_ = (p_ * (p_ - 1)) / 2; + + // Initialize parameters + int max_cats = num_categories_.max(); + main_effects_ = arma::zeros(p_, max_cats); + pairwise_effects_ = arma::zeros(p_, p_); + edge_indicators_ = initial_edge_indicators; + + // Initialize proposal SDs + proposal_sd_main_ = arma::ones(p_, max_cats) * 0.5; + proposal_sd_pairwise_ = arma::ones(p_, p_) * 0.5; + + // Initialize adaptive proposal + proposal_ = AdaptiveProposal(num_main_ + num_pairwise_, 500); + + // Initialize mass matrix + inv_mass_ = arma::ones(num_main_ + num_pairwise_); + + // Pre-compute observations as double (for efficient matrix operations) + observations_double_ = arma::conv_to::from(observations_); + + // Compute sufficient statistics + compute_sufficient_statistics(); + + // Initialize residual matrix + update_residual_matrix(); + + // Build interaction index + build_interaction_index(); +} + + +// ============================================================================= +// Copy constructor +// ============================================================================= + +OMRFModel::OMRFModel(const OMRFModel& other) + : BaseModel(other), + n_(other.n_), + p_(other.p_), + observations_(other.observations_), + observations_double_(other.observations_double_), + num_categories_(other.num_categories_), + is_ordinal_variable_(other.is_ordinal_variable_), + baseline_category_(other.baseline_category_), + counts_per_category_(other.counts_per_category_), + blume_capel_stats_(other.blume_capel_stats_), + pairwise_stats_(other.pairwise_stats_), + residual_matrix_(other.residual_matrix_), + main_effects_(other.main_effects_), + pairwise_effects_(other.pairwise_effects_), + edge_indicators_(other.edge_indicators_), + inclusion_probability_(other.inclusion_probability_), + main_alpha_(other.main_alpha_), + main_beta_(other.main_beta_), + pairwise_scale_(other.pairwise_scale_), + edge_selection_(other.edge_selection_), + edge_selection_active_(other.edge_selection_active_), + num_main_(other.num_main_), + num_pairwise_(other.num_pairwise_), + proposal_(other.proposal_), + proposal_sd_main_(other.proposal_sd_main_), + proposal_sd_pairwise_(other.proposal_sd_pairwise_), + rng_(other.rng_), + step_size_(other.step_size_), + inv_mass_(other.inv_mass_), + has_missing_(other.has_missing_), + missing_index_(other.missing_index_), + grad_obs_cache_(other.grad_obs_cache_), + index_matrix_cache_(other.index_matrix_cache_), + gradient_cache_valid_(other.gradient_cache_valid_), + interaction_index_(other.interaction_index_) +{ +} + + +// ============================================================================= +// Sufficient statistics computation +// ============================================================================= + +void OMRFModel::compute_sufficient_statistics() { + int max_cats = num_categories_.max(); + + // Category counts for ordinal variables + counts_per_category_ = arma::zeros(max_cats + 1, p_); + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + for (size_t i = 0; i < n_; ++i) { + int cat = observations_(i, v); + if (cat >= 0 && cat <= num_categories_(v)) { + counts_per_category_(cat, v)++; + } + } + } + } + + // Blume-Capel statistics (linear and quadratic sums) + blume_capel_stats_ = arma::zeros(2, p_); + for (size_t v = 0; v < p_; ++v) { + if (!is_ordinal_variable_(v)) { + int baseline = baseline_category_(v); + for (size_t i = 0; i < n_; ++i) { + int s = observations_(i, v) - baseline; + blume_capel_stats_(0, v) += s; // linear + blume_capel_stats_(1, v) += s * s; // quadratic + } + } + } + + // Pairwise statistics (X^T X) - use pre-computed transformed observations + arma::mat ps = observations_double_.t() * observations_double_; + pairwise_stats_ = arma::conv_to::from(ps); +} + + +size_t OMRFModel::count_num_main_effects_internal() const { + size_t count = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + count += num_categories_(v); + } else { + count += 2; // linear and quadratic for Blume-Capel + } + } + return count; +} + + +void OMRFModel::build_interaction_index() { + interaction_index_ = arma::zeros(num_pairwise_, 3); + int idx = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + interaction_index_(idx, 0) = idx; + interaction_index_(idx, 1) = v1; + interaction_index_(idx, 2) = v2; + idx++; + } + } +} + + +void OMRFModel::update_residual_matrix() { + residual_matrix_ = observations_double_ * pairwise_effects_; +} + + +void OMRFModel::update_residual_columns(int var1, int var2, double delta) { + residual_matrix_.col(var1) += delta * observations_double_.col(var2); + residual_matrix_.col(var2) += delta * observations_double_.col(var1); +} + + +void OMRFModel::set_pairwise_effects(const arma::mat& pairwise_effects) { + pairwise_effects_ = pairwise_effects; + update_residual_matrix(); + invalidate_gradient_cache(); +} + + +// ============================================================================= +// BaseModel interface implementation +// ============================================================================= + +size_t OMRFModel::parameter_dimension() const { + // Count active parameters: main effects + included pairwise effects + size_t active = num_main_; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + active++; + } + } + } + return active; +} + + +void OMRFModel::set_seed(int seed) { + rng_ = SafeRNG(seed); +} + + +std::unique_ptr OMRFModel::clone() const { + return std::make_unique(*this); +} + + +void OMRFModel::set_adaptive_proposal(AdaptiveProposal proposal) { + proposal_ = proposal; +} + + +// ============================================================================= +// Parameter vectorization +// ============================================================================= + +arma::vec OMRFModel::vectorize_parameters() const { + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + arma::vec param_vec(num_main_ + num_active); + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + param_vec(offset++) = main_effects_(v, c); + } + } else { + param_vec(offset++) = main_effects_(v, 0); // linear + param_vec(offset++) = main_effects_(v, 1); // quadratic + } + } + + // Active pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + param_vec(offset++) = pairwise_effects_(v1, v2); + } + } + } + + return param_vec; +} + + +void OMRFModel::unvectorize_parameters(const arma::vec& param_vec) { + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + main_effects_(v, c) = param_vec(offset++); + } + } else { + main_effects_(v, 0) = param_vec(offset++); // linear + main_effects_(v, 1) = param_vec(offset++); // quadratic + } + } + + // Active pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + double val = param_vec(offset++); + pairwise_effects_(v1, v2) = val; + pairwise_effects_(v2, v1) = val; + } + } + } + + update_residual_matrix(); + invalidate_gradient_cache(); +} + + +void OMRFModel::unvectorize_to_temps( + const arma::vec& parameters, + arma::mat& temp_main, + arma::mat& temp_pairwise, + arma::mat& temp_residual +) const { + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + temp_main(v, c) = parameters(offset++); + } + } else { + temp_main(v, 0) = parameters(offset++); + temp_main(v, 1) = parameters(offset++); + } + } + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + temp_pairwise(v1, v2) = parameters(offset++); + temp_pairwise(v2, v1) = temp_pairwise(v1, v2); + } + } + } + + temp_residual = observations_double_ * temp_pairwise; +} + + +arma::vec OMRFModel::get_vectorized_parameters() const { + return vectorize_parameters(); +} + + +void OMRFModel::set_vectorized_parameters(const arma::vec& parameters) { + unvectorize_parameters(parameters); +} + + +arma::vec OMRFModel::get_full_vectorized_parameters() const { + // Fixed-size vector: all main effects + ALL pairwise effects + arma::vec param_vec(num_main_ + num_pairwise_); + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + param_vec(offset++) = main_effects_(v, c); + } + } else { + param_vec(offset++) = main_effects_(v, 0); // linear + param_vec(offset++) = main_effects_(v, 1); // quadratic + } + } + + // ALL pairwise effects (zeros for inactive edges) + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + param_vec(offset++) = pairwise_effects_(v1, v2); + } + } + + return param_vec; +} + + +arma::ivec OMRFModel::get_vectorized_indicator_parameters() { + arma::ivec indicators(num_pairwise_); + int idx = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + indicators(idx++) = edge_indicators_(v1, v2); + } + } + return indicators; +} + + +arma::vec OMRFModel::get_active_inv_mass() const { + if (!edge_selection_active_) { + return inv_mass_; + } + + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + arma::vec active_inv_mass(num_main_ + num_active); + active_inv_mass.head(num_main_) = inv_mass_.head(num_main_); + + int offset_full = num_main_; + int offset_active = num_main_; + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + active_inv_mass(offset_active) = inv_mass_(offset_full); + offset_active++; + } + offset_full++; + } + } + + return active_inv_mass; +} + + +void OMRFModel::vectorize_parameters_into(arma::vec& param_vec) const { + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + // Resize if needed (should rarely happen after first call) + size_t needed_size = num_main_ + num_active; + if (param_vec.n_elem != needed_size) { + param_vec.set_size(needed_size); + } + + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + param_vec(offset++) = main_effects_(v, c); + } + } else { + param_vec(offset++) = main_effects_(v, 0); // linear + param_vec(offset++) = main_effects_(v, 1); // quadratic + } + } + + // Active pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + param_vec(offset++) = pairwise_effects_(v1, v2); + } + } + } +} + + +void OMRFModel::get_active_inv_mass_into(arma::vec& active_inv_mass) const { + if (!edge_selection_active_) { + // No edge selection - just use full inv_mass + if (active_inv_mass.n_elem != inv_mass_.n_elem) { + active_inv_mass.set_size(inv_mass_.n_elem); + } + active_inv_mass = inv_mass_; + return; + } + + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + size_t needed_size = num_main_ + num_active; + if (active_inv_mass.n_elem != needed_size) { + active_inv_mass.set_size(needed_size); + } + + active_inv_mass.head(num_main_) = inv_mass_.head(num_main_); + + int offset_full = num_main_; + int offset_active = num_main_; + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + active_inv_mass(offset_active) = inv_mass_(offset_full); + offset_active++; + } + offset_full++; + } + } +} + + +// ============================================================================= +// Log-pseudoposterior computation +// ============================================================================= + +double OMRFModel::logp(const arma::vec& parameters) { + arma::mat temp_main = main_effects_; + arma::mat temp_pairwise = pairwise_effects_; + arma::mat temp_residual; + unvectorize_to_temps(parameters, temp_main, temp_pairwise, temp_residual); + return log_pseudoposterior_with_state(temp_main, temp_pairwise, temp_residual); +} + + +double OMRFModel::log_pseudoposterior_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat +) const { + double log_post = 0.0; + + auto log_beta_prior = [this](double x) { + return x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + }; + + // Main effect contributions (priors and sufficient statistics) + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + + if (is_ordinal_variable_(v)) { + for (int c = 0; c < num_cats; ++c) { + log_post += log_beta_prior(main_eff(v, c)); + log_post += main_eff(v, c) * counts_per_category_(c + 1, v); + } + } else { + log_post += log_beta_prior(main_eff(v, 0)); + log_post += log_beta_prior(main_eff(v, 1)); + log_post += main_eff(v, 0) * blume_capel_stats_(0, v); + log_post += main_eff(v, 1) * blume_capel_stats_(1, v); + } + } + + // Log-denominator contributions using vectorized helpers + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + arma::vec residual_score = residual_mat.col(v); + arma::vec bound = num_cats * residual_score; + + arma::vec denom(n_, arma::fill::zeros); + if (is_ordinal_variable_(v)) { + // Extract main effect parameters for this variable + arma::vec main_effect_param = main_eff.row(v).cols(0, num_cats - 1).t(); + denom = compute_denom_ordinal(residual_score, main_effect_param, bound); + } else { + int ref = baseline_category_(v); + double lin_effect = main_eff(v, 0); + double quad_effect = main_eff(v, 1); + // This updates bound in-place + denom = compute_denom_blume_capel(residual_score, lin_effect, quad_effect, ref, num_cats, bound); + } + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } + + // Pairwise effect contributions: sufficient statistics + Cauchy prior + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + double effect = pairwise_eff(v1, v2); + // Sufficient statistics term (data likelihood contribution) + log_post += 2.0 * pairwise_stats_(v1, v2) * effect; + // Cauchy prior using R's dcauchy for consistency + log_post += R::dcauchy(effect, 0.0, pairwise_scale_, true); + } + } + } + + return log_post; +} + + +double OMRFModel::log_pseudoposterior_internal() const { + return log_pseudoposterior_with_state(main_effects_, pairwise_effects_, residual_matrix_); +} + + +double OMRFModel::log_pseudoposterior_main_component(int variable, int category, int parameter) const { + double log_post = 0.0; + + auto log_beta_prior = [this](double x) { + return x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + }; + + int num_cats = num_categories_(variable); + + if (is_ordinal_variable_(variable)) { + log_post += log_beta_prior(main_effects_(variable, category)); + log_post += main_effects_(variable, category) * counts_per_category_(category + 1, variable); + + arma::vec residual_score = residual_matrix_.col(variable); + arma::vec main_param = main_effects_.row(variable).cols(0, num_cats - 1).t(); + arma::vec bound = num_cats * residual_score; + + arma::vec denom = compute_denom_ordinal(residual_score, main_param, bound); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } else { + log_post += log_beta_prior(main_effects_(variable, parameter)); + log_post += main_effects_(variable, parameter) * blume_capel_stats_(parameter, variable); + + arma::vec residual_score = residual_matrix_.col(variable); + arma::vec bound(n_); + + arma::vec denom = compute_denom_blume_capel( + residual_score, main_effects_(variable, 0), main_effects_(variable, 1), + baseline_category_(variable), num_cats, bound + ); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } + + return log_post; +} + + +double OMRFModel::log_pseudoposterior_pairwise_component(int var1, int var2) const { + double log_post = 2.0 * pairwise_effects_(var1, var2) * pairwise_stats_(var1, var2); + + for (int var : {var1, var2}) { + int num_cats = num_categories_(var); + arma::vec residual_score = residual_matrix_.col(var); + + if (is_ordinal_variable_(var)) { + arma::vec main_param = main_effects_.row(var).cols(0, num_cats - 1).t(); + arma::vec bound = num_cats * residual_score; + arma::vec denom = compute_denom_ordinal(residual_score, main_param, bound); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } else { + arma::vec bound(n_); + arma::vec denom = compute_denom_blume_capel( + residual_score, main_effects_(var, 0), main_effects_(var, 1), + baseline_category_(var), num_cats, bound + ); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } + } + + if (edge_indicators_(var1, var2) == 1) { + log_post += R::dcauchy(pairwise_effects_(var1, var2), 0.0, pairwise_scale_, true); + } + + return log_post; +} + + +double OMRFModel::compute_log_likelihood_ratio_for_variable( + int variable, + const arma::vec& interacting_score, + double proposed_state, + double current_state +) const { + int num_cats = num_categories_(variable); + + // Residual without the current interaction contribution + arma::vec rest_base = residual_matrix_.col(variable) - current_state * interacting_score; + + if (is_ordinal_variable_(variable)) { + arma::vec main_param = main_effects_.row(variable).cols(0, num_cats - 1).t(); + + arma::vec rest_curr = rest_base + current_state * interacting_score; + arma::vec bound_curr = num_cats * rest_curr; + arma::vec denom_curr = compute_denom_ordinal(rest_curr, main_param, bound_curr); + + arma::vec rest_prop = rest_base + proposed_state * interacting_score; + arma::vec bound_prop = num_cats * rest_prop; + arma::vec denom_prop = compute_denom_ordinal(rest_prop, main_param, bound_prop); + + return arma::accu(bound_curr + ARMA_MY_LOG(denom_curr)) + - arma::accu(bound_prop + ARMA_MY_LOG(denom_prop)); + } else { + arma::vec rest_curr = rest_base + current_state * interacting_score; + arma::vec bound_curr(n_); + arma::vec denom_curr = compute_denom_blume_capel( + rest_curr, main_effects_(variable, 0), main_effects_(variable, 1), + baseline_category_(variable), num_cats, bound_curr + ); + double log_ratio = arma::accu(ARMA_MY_LOG(denom_curr) + bound_curr); + + arma::vec rest_prop = rest_base + proposed_state * interacting_score; + arma::vec bound_prop(n_); + arma::vec denom_prop = compute_denom_blume_capel( + rest_prop, main_effects_(variable, 0), main_effects_(variable, 1), + baseline_category_(variable), num_cats, bound_prop + ); + log_ratio -= arma::accu(ARMA_MY_LOG(denom_prop) + bound_prop); + + return log_ratio; + } +} + + +double OMRFModel::log_pseudolikelihood_ratio_interaction( + int variable1, + int variable2, + double proposed_state, + double current_state +) const { + double delta = proposed_state - current_state; + double log_ratio = 2.0 * delta * pairwise_stats_(variable1, variable2); + + // Contribution from variable1 (interacting with variable2's observations) + log_ratio += compute_log_likelihood_ratio_for_variable( + variable1, observations_double_.col(variable2), + proposed_state, current_state); + + // Contribution from variable2 (interacting with variable1's observations) + log_ratio += compute_log_likelihood_ratio_for_variable( + variable2, observations_double_.col(variable1), + proposed_state, current_state); + + return log_ratio; +} + + +// ============================================================================= +// Gradient computation +// ============================================================================= + +void OMRFModel::ensure_gradient_cache() { + if (gradient_cache_valid_) return; + + // Compute observed gradient and index matrix (constant during MCMC) + int num_active = 0; + index_matrix_cache_ = arma::zeros(p_, p_); + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + index_matrix_cache_(v1, v2) = num_main_ + num_active; + num_active++; + } + } + } + + grad_obs_cache_ = arma::zeros(num_main_ + num_active); + + // Observed statistics for main effects + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + grad_obs_cache_(offset++) = counts_per_category_(c + 1, v); + } + } else { + grad_obs_cache_(offset++) = blume_capel_stats_(0, v); + grad_obs_cache_(offset++) = blume_capel_stats_(1, v); + } + } + + // Observed statistics for pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + grad_obs_cache_(offset++) = 2.0 * pairwise_stats_(v1, v2); + } + } + } + + gradient_cache_valid_ = true; +} + + +arma::vec OMRFModel::gradient(const arma::vec& parameters) { + arma::mat temp_main = main_effects_; + arma::mat temp_pairwise = pairwise_effects_; + arma::mat temp_residual; + unvectorize_to_temps(parameters, temp_main, temp_pairwise, temp_residual); + return gradient_with_state(temp_main, temp_pairwise, temp_residual); +} + + +arma::vec OMRFModel::gradient_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat +) const { + // Start with cached observed gradient + arma::vec gradient = grad_obs_cache_; + + // Expected statistics for main and pairwise effects + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + arma::vec residual_score = residual_mat.col(v); + arma::vec bound = num_cats * residual_score; + + if (is_ordinal_variable_(v)) { + // Extract main effect parameters for this variable + arma::vec main_param = main_eff.row(v).cols(0, num_cats - 1).t(); + + // Use optimized helper function + arma::mat probs = compute_probs_ordinal(main_param, residual_score, bound, num_cats); + + // Main effects gradient + for (int c = 0; c < num_cats; ++c) { + gradient(offset + c) -= arma::accu(probs.col(c + 1)); + } + + // Pairwise effects gradient + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + + arma::vec expected_value = arma::zeros(n_); + for (int c = 1; c <= num_cats; ++c) { + expected_value += c * probs.col(c) % observations_double_.col(j); + } + + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += num_cats; + } else { + int ref = baseline_category_(v); + double lin_eff = main_eff(v, 0); + double quad_eff = main_eff(v, 1); + + // Use optimized helper function (updates bound in-place) + arma::mat probs = compute_probs_blume_capel(residual_score, lin_eff, quad_eff, ref, num_cats, bound); + + arma::vec score = arma::regspace(0, num_cats) - static_cast(ref); + arma::vec sq_score = arma::square(score); + + // Main effects gradient + gradient(offset) -= arma::accu(probs * score); + gradient(offset + 1) -= arma::accu(probs * sq_score); + + // Pairwise effects gradient + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + + arma::vec expected_value = arma::zeros(n_); + for (int c = 0; c <= num_cats; ++c) { + int s = c - ref; + expected_value += s * probs.col(c) % observations_double_.col(j); + } + + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += 2; + } + } + + // Prior gradients for main effects (Beta-prime prior) + offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_pars = is_ordinal_variable_(v) ? num_categories_(v) : 2; + for (int c = 0; c < num_pars; ++c) { + double x = main_eff(v, c); + double prob = 1.0 / (1.0 + std::exp(-x)); + gradient(offset + c) += main_alpha_ - (main_alpha_ + main_beta_) * prob; + } + offset += num_pars; + } + + // Cauchy prior gradient for pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + int idx = index_matrix_cache_(v1, v2); + double x = pairwise_eff(v1, v2); + gradient(idx) -= 2.0 * x / (pairwise_scale_ * pairwise_scale_ + x * x); + } + } + } + + return gradient; +} + + +arma::vec OMRFModel::gradient_internal() const { + return gradient_with_state(main_effects_, pairwise_effects_, residual_matrix_); +} + + +std::pair OMRFModel::logp_and_gradient(const arma::vec& parameters) { + ensure_gradient_cache(); + + arma::mat temp_main(main_effects_.n_rows, main_effects_.n_cols, arma::fill::none); + arma::mat temp_pairwise(p_, p_, arma::fill::zeros); + arma::mat temp_residual; + unvectorize_to_temps(parameters, temp_main, temp_pairwise, temp_residual); + + // Initialize gradient from cached observed statistics + arma::vec gradient = grad_obs_cache_; + double log_post = 0.0; + + // Merged per-variable loop: compute probability table ONCE per variable + // and derive both logp and gradient contributions from it. + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + arma::vec residual_score = temp_residual.col(v); + arma::vec bound = num_cats * residual_score; + + if (is_ordinal_variable_(v)) { + arma::vec main_param = temp_main.row(v).cols(0, num_cats - 1).t(); + + // Prior + sufficient statistics for logp + for (int c = 0; c < num_cats; ++c) { + double x = temp_main(v, c); + log_post += x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + log_post += x * counts_per_category_(c + 1, v); + } + + // Compute probability table ONCE (replaces separate denom + probs) + arma::mat probs = compute_probs_ordinal(main_param, residual_score, bound, num_cats); + + // Log-partition: bound + log(denom) = -log(probs(:, 0)) + log_post += arma::accu(ARMA_MY_LOG(probs.col(0))); + + // Gradient: main effects + for (int c = 0; c < num_cats; ++c) { + gradient(offset + c) -= arma::accu(probs.col(c + 1)); + } + + // Gradient: pairwise effects + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + arma::vec expected_value = arma::zeros(n_); + for (int c = 1; c <= num_cats; ++c) { + expected_value += c * probs.col(c) % observations_double_.col(j); + } + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += num_cats; + } else { + int ref = baseline_category_(v); + double lin_eff = temp_main(v, 0); + double quad_eff = temp_main(v, 1); + + // Prior + sufficient statistics for logp + log_post += lin_eff * main_alpha_ - std::log1p(std::exp(lin_eff)) * (main_alpha_ + main_beta_); + log_post += quad_eff * main_alpha_ - std::log1p(std::exp(quad_eff)) * (main_alpha_ + main_beta_); + log_post += lin_eff * blume_capel_stats_(0, v); + log_post += quad_eff * blume_capel_stats_(1, v); + + // Compute probability table ONCE (bound is updated in-place) + arma::mat probs = compute_probs_blume_capel(residual_score, lin_eff, quad_eff, ref, num_cats, bound); + + // Log-partition: bound + log(denom) = ref * r - log(probs(:, ref)) + log_post -= static_cast(ref) * arma::accu(residual_score) + - arma::accu(ARMA_MY_LOG(probs.col(ref))); + + // Gradient: main effects + arma::vec score = arma::regspace(0, num_cats) - static_cast(ref); + arma::vec sq_score = arma::square(score); + gradient(offset) -= arma::accu(probs * score); + gradient(offset + 1) -= arma::accu(probs * sq_score); + + // Gradient: pairwise effects + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + arma::vec expected_value = arma::zeros(n_); + for (int c = 0; c <= num_cats; ++c) { + int s = c - ref; + expected_value += s * probs.col(c) % observations_double_.col(j); + } + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += 2; + } + } + + // Gradient: main effect prior (Beta-prime) + offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_pars = is_ordinal_variable_(v) ? num_categories_(v) : 2; + for (int c = 0; c < num_pars; ++c) { + double x = temp_main(v, c); + double prob = 1.0 / (1.0 + std::exp(-x)); + gradient(offset + c) += main_alpha_ - (main_alpha_ + main_beta_) * prob; + } + offset += num_pars; + } + + // Pairwise: sufficient statistics + Cauchy prior (logp) and gradient + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + double effect = temp_pairwise(v1, v2); + log_post += 2.0 * pairwise_stats_(v1, v2) * effect; + log_post += R::dcauchy(effect, 0.0, pairwise_scale_, true); + + int idx = index_matrix_cache_(v1, v2); + gradient(idx) -= 2.0 * effect / (pairwise_scale_ * pairwise_scale_ + effect * effect); + } + } + } + + return {log_post, std::move(gradient)}; +} + + +// ============================================================================= +// Metropolis-Hastings updates +// ============================================================================= + +void OMRFModel::update_main_effect_parameter(int variable, int category, int parameter) { + double& current = is_ordinal_variable_(variable) + ? main_effects_(variable, category) + : main_effects_(variable, parameter); + + double proposal_sd = is_ordinal_variable_(variable) + ? proposal_sd_main_(variable, category) + : proposal_sd_main_(variable, parameter); + + int cat_for_log = is_ordinal_variable_(variable) ? category : -1; + int par_for_log = is_ordinal_variable_(variable) ? -1 : parameter; + + double current_value = current; + double proposed_value = rnorm(rng_, current_value, proposal_sd); + + // Evaluate log-posterior at proposed + current = proposed_value; + double lp_proposed = log_pseudoposterior_main_component(variable, cat_for_log, par_for_log); + + // Evaluate log-posterior at current + current = current_value; + double lp_current = log_pseudoposterior_main_component(variable, cat_for_log, par_for_log); + + double log_accept = lp_proposed - lp_current; + if (MY_LOG(runif(rng_)) < log_accept) { + current = proposed_value; + } +} + + +void OMRFModel::update_pairwise_effect(int var1, int var2) { + if (edge_indicators_(var1, var2) == 0) return; + + double current_value = pairwise_effects_(var1, var2); + double proposal_sd = proposal_sd_pairwise_(var1, var2); + double proposed_value = rnorm(rng_, current_value, proposal_sd); + + double log_accept = log_pseudolikelihood_ratio_interaction( + var1, var2, proposed_value, current_value); + + // Cauchy prior ratio + log_accept += R::dcauchy(proposed_value, 0.0, pairwise_scale_, true) + - R::dcauchy(current_value, 0.0, pairwise_scale_, true); + + double accept_prob = std::min(1.0, MY_EXP(log_accept)); + + if (runif(rng_) < accept_prob) { + double delta = proposed_value - current_value; + pairwise_effects_(var1, var2) = proposed_value; + pairwise_effects_(var2, var1) = proposed_value; + update_residual_columns(var1, var2, delta); + } +} + + +void OMRFModel::update_edge_indicator(int var1, int var2) { + double current_state = pairwise_effects_(var1, var2); + double proposal_sd = proposal_sd_pairwise_(var1, var2); + + bool proposing_addition = (edge_indicators_(var1, var2) == 0); + double proposed_state = proposing_addition ? rnorm(rng_, current_state, proposal_sd) : 0.0; + + double log_accept = log_pseudolikelihood_ratio_interaction(var1, var2, proposed_state, current_state); + + double incl_prob = inclusion_probability_(var1, var2); + + if (proposing_addition) { + log_accept += R::dcauchy(proposed_state, 0.0, pairwise_scale_, true); + log_accept -= R::dnorm(proposed_state, current_state, proposal_sd, true); + log_accept += MY_LOG(incl_prob) - MY_LOG(1.0 - incl_prob); + } else { + log_accept -= R::dcauchy(current_state, 0.0, pairwise_scale_, true); + log_accept += R::dnorm(current_state, proposed_state, proposal_sd, true); + log_accept -= MY_LOG(incl_prob) - MY_LOG(1.0 - incl_prob); + } + + if (MY_LOG(runif(rng_)) < log_accept) { + int updated = 1 - edge_indicators_(var1, var2); + edge_indicators_(var1, var2) = updated; + edge_indicators_(var2, var1) = updated; + + double delta = proposed_state - current_state; + pairwise_effects_(var1, var2) = proposed_state; + pairwise_effects_(var2, var1) = proposed_state; + + update_residual_columns(var1, var2, delta); + } +} + + +// ============================================================================= +// Main update methods +// ============================================================================= + +void OMRFModel::do_one_mh_step() { + // Update main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + update_main_effect_parameter(v, c, -1); + } + } else { + for (int p = 0; p < 2; ++p) { + update_main_effect_parameter(v, -1, p); + } + } + } + + // Update pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + update_pairwise_effect(v1, v2); + } + } + + invalidate_gradient_cache(); + proposal_.increment_iteration(); +} + + +void OMRFModel::update_edge_indicators() { + for (size_t idx = 0; idx < num_pairwise_; ++idx) { + int var1 = interaction_index_(idx, 1); + int var2 = interaction_index_(idx, 2); + update_edge_indicator(var1, var2); + } +} + + +void OMRFModel::initialize_graph() { + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + double p = inclusion_probability_(v1, v2); + int draw = (runif(rng_) < p) ? 1 : 0; + edge_indicators_(v1, v2) = draw; + edge_indicators_(v2, v1) = draw; + if (!draw) { + pairwise_effects_(v1, v2) = 0.0; + pairwise_effects_(v2, v1) = 0.0; + } + } + } + update_residual_matrix(); + invalidate_gradient_cache(); +} + + + +void OMRFModel::impute_missing() { + if (!has_missing_) return; + + const int num_variables = p_; + const int num_missings = missing_index_.n_rows; + const int max_num_categories = num_categories_.max(); + + arma::vec category_probabilities(max_num_categories + 1); + + for (int miss = 0; miss < num_missings; miss++) { + const int person = missing_index_(miss, 0); + const int variable = missing_index_(miss, 1); + + const double residual_score = residual_matrix_(person, variable); + const int num_cats = num_categories_(variable); + const bool is_ordinal = is_ordinal_variable_(variable); + + double cumsum = 0.0; + + if (is_ordinal) { + cumsum = 1.0; + category_probabilities[0] = cumsum; + for (int cat = 0; cat < num_cats; cat++) { + const int score = cat + 1; + const double exponent = main_effects_(variable, cat) + score * residual_score; + cumsum += MY_EXP(exponent); + category_probabilities[score] = cumsum; + } + } else { + const int ref = baseline_category_(variable); + + cumsum = MY_EXP( + main_effects_(variable, 0) * ref + main_effects_(variable, 1) * ref * ref + ); + category_probabilities[0] = cumsum; + + for (int cat = 0; cat <= num_cats; cat++) { + const int score = cat - ref; + const double exponent = + main_effects_(variable, 0) * score + + main_effects_(variable, 1) * score * score + + score * residual_score; + cumsum += MY_EXP(exponent); + category_probabilities[cat] = cumsum; + } + } + + // Sample from categorical distribution via inverse transform + const double u = runif(rng_) * cumsum; + int sampled_score = 0; + while (u > category_probabilities[sampled_score]) { + sampled_score++; + } + + int new_value = sampled_score; + if (!is_ordinal) + new_value -= baseline_category_(variable); + const int old_value = observations_(person, variable); + + if (new_value != old_value) { + observations_(person, variable) = new_value; + observations_double_(person, variable) = static_cast(new_value); + + if (is_ordinal) { + counts_per_category_(old_value, variable)--; + counts_per_category_(new_value, variable)++; + } else { + const int delta = new_value - old_value; + const int delta_sq = new_value * new_value - old_value * old_value; + blume_capel_stats_(0, variable) += delta; + blume_capel_stats_(1, variable) += delta_sq; + } + + // Incrementally update residuals across all variables + for (int var = 0; var < num_variables; var++) { + const double delta_score = (new_value - old_value) * pairwise_effects_(var, variable); + residual_matrix_(person, var) += delta_score; + } + } + } + + // Recompute pairwise sufficient statistics + arma::mat ps = observations_double_.t() * observations_double_; + pairwise_stats_ = arma::conv_to::from(ps); +} + + +void OMRFModel::set_missing_data(const arma::imat& missing_index) { + missing_index_ = missing_index; + has_missing_ = (missing_index.n_rows > 0 && missing_index.n_cols == 2); +} + + +// ============================================================================= +// Factory function +// ============================================================================= + +OMRFModel createOMRFModelFromR( + const Rcpp::List& inputFromR, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + bool edge_selection +) { + arma::imat observations = Rcpp::as(inputFromR["observations"]); + arma::ivec num_categories = Rcpp::as(inputFromR["num_categories"]); + arma::uvec is_ordinal_variable = Rcpp::as(inputFromR["is_ordinal_variable"]); + arma::ivec baseline_category = Rcpp::as(inputFromR["baseline_category"]); + + double main_alpha = inputFromR.containsElementNamed("main_alpha") + ? Rcpp::as(inputFromR["main_alpha"]) : 1.0; + double main_beta = inputFromR.containsElementNamed("main_beta") + ? Rcpp::as(inputFromR["main_beta"]) : 1.0; + double pairwise_scale = inputFromR.containsElementNamed("pairwise_scale") + ? Rcpp::as(inputFromR["pairwise_scale"]) : 2.5; + + return OMRFModel( + observations, + num_categories, + inclusion_probability, + initial_edge_indicators, + is_ordinal_variable, + baseline_category, + main_alpha, + main_beta, + pairwise_scale, + edge_selection + ); +} + + diff --git a/src/models/omrf/omrf_model.h b/src/models/omrf/omrf_model.h new file mode 100644 index 0000000..d8a0ab2 --- /dev/null +++ b/src/models/omrf/omrf_model.h @@ -0,0 +1,441 @@ +#pragma once + +#include +#include +#include "models/base_model.h" +#include "models/adaptive_metropolis.h" +#include "rng/rng_utils.h" +#include "mcmc/mcmc_utils.h" +#include "utils/common_helpers.h" + +/** + * OMRFModel - Ordinal Markov Random Field Model + * + * A class-based implementation of the OMRF model for Bayesian inference on + * ordinal and Blume-Capel variables. This class encapsulates: + * - Parameter storage (main effects, pairwise effects, edge indicators) + * - Sufficient statistics computation + * - Log-pseudoposterior and gradient evaluations + * - Adaptive Metropolis-Hastings updates for individual parameters + * - NUTS/HMC updates for joint parameter sampling + * - Edge selection (spike-and-slab) with asymmetric proposals + * + * Inherits from BaseModel for compatibility with the generic MCMC framework. + */ +class OMRFModel : public BaseModel { +public: + + /** + * Constructor from raw observations + * + * @param observations Integer matrix of categorical observations (persons × variables) + * @param num_categories Number of categories per variable + * @param inclusion_probability Prior inclusion probabilities for edges + * @param initial_edge_indicators Initial edge inclusion matrix + * @param is_ordinal_variable Indicator (1 = ordinal, 0 = Blume-Capel) + * @param baseline_category Reference categories for Blume-Capel variables + * @param main_alpha Beta prior hyperparameter α for main effects + * @param main_beta Beta prior hyperparameter β for main effects + * @param pairwise_scale Scale parameter of Cauchy prior on interactions + * @param edge_selection Enable edge selection (spike-and-slab) + */ + OMRFModel( + const arma::imat& observations, + const arma::ivec& num_categories, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const arma::uvec& is_ordinal_variable, + const arma::ivec& baseline_category, + double main_alpha = 1.0, + double main_beta = 1.0, + double pairwise_scale = 2.5, + bool edge_selection = true + ); + + /** + * Copy constructor for cloning (required for parallel chains) + */ + OMRFModel(const OMRFModel& other); + + // ========================================================================= + // BaseModel interface implementation + // ========================================================================= + + bool has_gradient() const override { return true; } + bool has_adaptive_mh() const override { return true; } + bool has_edge_selection() const override { return edge_selection_; } + bool has_missing_data() const override { return has_missing_; } + + /** + * Compute log-pseudoposterior for given parameter vector + */ + double logp(const arma::vec& parameters) override; + + /** + * Compute gradient of log-pseudoposterior + */ + arma::vec gradient(const arma::vec& parameters) override; + + /** + * Combined log-posterior and gradient evaluation (more efficient) + */ + std::pair logp_and_gradient(const arma::vec& parameters) override; + + /** + * Perform one adaptive MH step (updates all parameters) + */ + void do_one_mh_step() override; + + /** + * Return dimensionality of active parameter space + */ + size_t parameter_dimension() const override; + + /** + * Set random seed for reproducibility + */ + void set_seed(int seed) override; + + /** + * Get vectorized parameters (main effects + active pairwise effects) + */ + arma::vec get_vectorized_parameters() const override; + + /** + * Set parameters from vectorized form + */ + void set_vectorized_parameters(const arma::vec& parameters) override; + + /** + * Get vectorized edge indicators + */ + arma::ivec get_vectorized_indicator_parameters() override; + + /** + * Clone the model for parallel execution + */ + std::unique_ptr clone() const override; + + /** + * Get RNG for samplers + */ + SafeRNG& get_rng() override { return rng_; } + + // ========================================================================= + // OMRF-specific methods + // ========================================================================= + + /** + * Set the adaptive proposal mechanism + */ + void set_adaptive_proposal(AdaptiveProposal proposal); + + /** + * Update edge indicators via Metropolis-Hastings + */ + void update_edge_indicators() override; + + /** + * Initialize random graph structure (for starting edge selection) + */ + void initialize_graph() override; + + /** + * Impute missing values (if any) + */ + void impute_missing() override; + + /** + * Set missing data information + */ + void set_missing_data(const arma::imat& missing_index); + + // ========================================================================= + // Accessors + // ========================================================================= + + const arma::mat& get_main_effects() const { return main_effects_; } + const arma::mat& get_pairwise_effects() const { return pairwise_effects_; } + const arma::imat& get_edge_indicators() const override { return edge_indicators_; } + arma::mat& get_inclusion_probability() override { return inclusion_probability_; } + const arma::mat& get_residual_matrix() const { return residual_matrix_; } + + void set_main_effects(const arma::mat& main_effects) { main_effects_ = main_effects; } + void set_pairwise_effects(const arma::mat& pairwise_effects); + void set_edge_indicators(const arma::imat& edge_indicators) { edge_indicators_ = edge_indicators; } + + int get_num_variables() const override { return static_cast(p_); } + int get_num_pairwise() const override { return static_cast(num_pairwise_); } + size_t num_variables() const { return p_; } + size_t num_observations() const { return n_; } + size_t num_main_effects() const { return num_main_; } + size_t num_pairwise_effects() const { return num_pairwise_; } + + // Shorthand accessors (for interface compatibility) + size_t get_p() const { return p_; } + size_t get_n() const { return n_; } + + // Adaptation control + void set_step_size(double step_size) { step_size_ = step_size; } + double get_step_size() const { return step_size_; } + void set_inv_mass(const arma::vec& inv_mass) { inv_mass_ = inv_mass; } + const arma::vec& get_inv_mass() const { return inv_mass_; } + + /** + * Get full dimension (main + ALL pairwise, regardless of edge indicators) + * Used for fixed-size sample storage + */ + size_t full_parameter_dimension() const override { return num_main_ + num_pairwise_; } + + /** + * Get all parameters in a fixed-size vector (inactive edges are 0) + * Used for sample storage to avoid dimension changes + */ + arma::vec get_full_vectorized_parameters() const override; + + // Proposal SD access (for external adaptation) + arma::mat& get_proposal_sd_main() { return proposal_sd_main_; } + arma::mat& get_proposal_sd_pairwise() { return proposal_sd_pairwise_; } + + // Control edge selection phase + void set_edge_selection_active(bool active) override { edge_selection_active_ = active; } + bool is_edge_selection_active() const { return edge_selection_active_; } + +private: + // ========================================================================= + // Data members + // ========================================================================= + + // Data + size_t n_; // Number of observations + size_t p_; // Number of variables + arma::imat observations_; // Categorical observations (n × p) + arma::mat observations_double_; // Observations as double (for efficient matrix ops) + arma::ivec num_categories_; // Categories per variable + arma::uvec is_ordinal_variable_; // 1 = ordinal, 0 = Blume-Capel + arma::ivec baseline_category_; // Reference category for Blume-Capel + + // Sufficient statistics + arma::imat counts_per_category_; // Category counts (max_cats+1 × p) + arma::imat blume_capel_stats_; // [linear_sum, quadratic_sum] for BC vars (2 × p) + arma::imat pairwise_stats_; // X^T X + arma::mat residual_matrix_; // X * pairwise_effects (n × p) + + // Parameters + arma::mat main_effects_; // Main effect parameters (p × max_cats) + arma::mat pairwise_effects_; // Pairwise interaction strengths (p × p, symmetric) + arma::imat edge_indicators_; // Edge inclusion indicators (p × p, symmetric binary) + + // Priors + arma::mat inclusion_probability_; // Prior inclusion probabilities + double main_alpha_; // Beta prior α + double main_beta_; // Beta prior β + double pairwise_scale_; // Cauchy scale for pairwise effects + + // Model configuration + bool edge_selection_; // Enable edge selection + bool edge_selection_active_; // Currently in edge selection phase + + // Dimension tracking + size_t num_main_; // Total number of main effect parameters + size_t num_pairwise_; // Number of possible pairwise effects + + // Adaptive proposals + AdaptiveProposal proposal_; + arma::mat proposal_sd_main_; + arma::mat proposal_sd_pairwise_; + + // RNG + SafeRNG rng_; + + // NUTS/HMC settings + double step_size_; + arma::vec inv_mass_; + + // Missing data handling + bool has_missing_; + arma::imat missing_index_; + + // Cached gradient components + arma::vec grad_obs_cache_; + arma::imat index_matrix_cache_; + bool gradient_cache_valid_; + + // Interaction indexing (for edge updates) + arma::imat interaction_index_; + + // ========================================================================= + // Private helper methods + // ========================================================================= + + /** + * Compute sufficient statistics from observations + */ + void compute_sufficient_statistics(); + + /** + * Count total number of main effect parameters + */ + size_t count_num_main_effects_internal() const; + + /** + * Build interaction index matrix + */ + void build_interaction_index(); + + /** + * Update residual matrix after pairwise effects change + */ + void update_residual_matrix(); + + /** + * Incrementally update two residual columns after a single pairwise effect change + */ + void update_residual_columns(int var1, int var2, double delta); + + /** + * Invalidate gradient cache (call after parameter changes) + */ + void invalidate_gradient_cache() { gradient_cache_valid_ = false; } + + /** + * Ensure gradient cache is valid + */ + void ensure_gradient_cache(); + + // ------------------------------------------------------------------------- + // Log-posterior components + // ------------------------------------------------------------------------- + + /** + * Full log-pseudoposterior (internal, uses current state) + */ + double log_pseudoposterior_internal() const; + + /** + * Full log-pseudoposterior with external state (avoids modifying model) + */ + double log_pseudoposterior_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat + ) const; + + /** + * Log-posterior for single main effect component + */ + double log_pseudoposterior_main_component(int variable, int category, int parameter) const; + + /** + * Log-posterior for single pairwise interaction + */ + double log_pseudoposterior_pairwise_component(int var1, int var2) const; + + /** + * Log-likelihood ratio for variable update + */ + double compute_log_likelihood_ratio_for_variable( + int variable, + const arma::vec& interacting_score, + double proposed_state, + double current_state + ) const; + + /** + * Log-pseudolikelihood ratio for interaction update + */ + double log_pseudolikelihood_ratio_interaction( + int variable1, + int variable2, + double proposed_state, + double current_state + ) const; + + // ------------------------------------------------------------------------- + // Gradient components + // ------------------------------------------------------------------------- + + /** + * Compute gradient with current state + */ + arma::vec gradient_internal() const; + + /** + * Compute gradient with external state (avoids modifying model) + */ + arma::vec gradient_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat + ) const; + + // ------------------------------------------------------------------------- + // Parameter vectorization + // ------------------------------------------------------------------------- + + /** + * Flatten parameters to vector + */ + arma::vec vectorize_parameters() const; + + /** + * Flatten parameters into pre-allocated vector (avoids allocation) + */ + void vectorize_parameters_into(arma::vec& param_vec) const; + + /** + * Unflatten vector to parameter matrices + */ + void unvectorize_parameters(const arma::vec& param_vec); + + /** + * Unvectorize a parameter vector into temporary main/pairwise matrices, + * then compute the corresponding residual matrix. + */ + void unvectorize_to_temps( + const arma::vec& parameters, + arma::mat& temp_main, + arma::mat& temp_pairwise, + arma::mat& temp_residual + ) const; + + /** + * Extract active inverse mass (only for included edges) + */ + arma::vec get_active_inv_mass() const; + + /** + * Extract active inverse mass into pre-allocated vector (avoids allocation) + */ + void get_active_inv_mass_into(arma::vec& active_inv_mass) const; + + // ------------------------------------------------------------------------- + // Metropolis updates + // ------------------------------------------------------------------------- + + /** + * Update single main effect parameter via RWM + */ + void update_main_effect_parameter(int variable, int category, int parameter); + + /** + * Update single pairwise effect via RWM + */ + void update_pairwise_effect(int var1, int var2); + + /** + * Update single edge indicator (spike-and-slab) + */ + void update_edge_indicator(int var1, int var2); +}; + + +/** + * Factory function to create OMRFModel from R inputs + */ +OMRFModel createOMRFModelFromR( + const Rcpp::List& inputFromR, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + bool edge_selection = true +); diff --git a/src/models/skeleton_model.cpp b/src/models/skeleton_model.cpp new file mode 100644 index 0000000..0720745 --- /dev/null +++ b/src/models/skeleton_model.cpp @@ -0,0 +1,230 @@ +// /** +// * ================================================== +// * Header dependencies for SkeletonVariables +// * ================================================== +// * +// * This file defines a template ("skeleton") for variable-type classes +// * in the bgms codebase. The included headers constitute the minimal +// * set of dependencies required by any variable model, independent of +// * its specific statistical formulation. +// * +// * Each include serves a specific purpose: +// * - : ownership and cloning via std::unique_ptr +// * - base_model.h : abstract interface for variable models +// * - adaptiveMetropolis.h : Metropolis–Hastings proposal mechanism +// * - rng_utils.h : reproducible random number generation +// */ + +// #include + +// #include "models/base_model.h" +// #include "models/adaptive_metropolis.h" +// #include "rng/rng_utils.h" + + +// /** +// * ================================================== +// * SkeletonVariables class +// * ================================================== +// * +// * A class in C++ represents a concrete type that bundles together: +// * - internal state, including observed data, model parameters, +// * and auxiliary objects that maintain sampling state, +// * - and functions (methods) that act on this state to evaluate +// * the log posterior, its gradient, and to perform sampling updates. +// * +// * In the bgms codebase, each statistical variable model is implemented +// * as a C++ class that stores the model state and provides the methods +// * required for inference. +// * +// * SkeletonVariables defines a template for such implementation classes. +// * It specifies the structure and interface that concrete implementations +// * of variable models must follow, without imposing a particular +// * statistical formulation. +// * +// * SkeletonVariables inherits from BaseModel, which defines the common +// * interface for all variable-model implementations in bgms: +// * +// * class SkeletonVariables : public BaseModel +// * +// * Inheriting from BaseModel means that SkeletonVariables must provide +// * a fixed set of functions (such as log_posterior and do_one_mh_step) +// * that the rest of the codebase relies on. As a result, code elsewhere +// * in bgms can interact with SkeletonVariables through the BaseModel +// * interface, without needing to know which specific variable model +// * implementation is being used. +// */ +// class SkeletonVariables : public BaseModel { + +// /* +// * The 'public:' label below marks the beginning of the part of the class +// * that is accessible from outside the class. +// * +// * Functions and constructors declared under 'public' are intended to be +// * called by other components of the bgms codebase, such as samplers, +// * model-selection routines, and result containers. +// * +// * Together, these public members define how a variable-model +// * implementation can be created, queried, and updated by external code. +// */ +// public: + +// /* +// * Constructors are responsible for establishing the complete internal +// * state of the object. +// * +// * A SkeletonVariables object represents a fully specified variable-model +// * implementation at a given point in an inference procedure. Therefore, +// * all information required to evaluate the log posterior and to perform +// * sampling updates must be stored within the object itself. +// * +// * After construction, the object is expected to be immediately usable: +// * no additional initialization steps are required before it can be +// * queried, updated, or copied. +// * +// * The constructor below uses a constructor initializer list to define +// * how base classes and data members are constructed. The initializer +// * list is evaluated before the constructor body runs and is used to: +// * - construct the BaseModel subobject, +// * - initialize data members directly from constructor arguments. +// * +// * This ensures that all base-class and member invariants are established +// * before any additional derived quantities are computed in the +// * constructor body. +// */ +// SkeletonVariables( +// const arma::mat& observations, +// const arma::mat& inclusion_probability, +// const arma::imat& initial_edge_indicators, +// const bool edge_selection = true +// ) +// : BaseModel(edge_selection), +// observations_(observations), +// inclusion_probability_(inclusion_probability), +// edge_indicators_(initial_edge_indicators) +// { +// /* +// * The constructor body initializes derived state that depends on +// * already-constructed members, such as dimensions, parameter vectors, +// * and sampling-related objects. +// */ + +// n_ = observations_.n_rows; +// p_ = observations_.n_cols; + +// // Dimension of the parameter vector (model-specific). +// // For the skeleton, we assume one parameter per variable. +// dim_ = p_; + +// // Initialize parameter vector +// parameters_.zeros(dim_); + +// // Initialize adaptive Metropolis–Hastings sampler +// adaptive_mh_ = AdaptiveMetropolis(dim_); +// } + +// /* +// * Copy constructor. +// * +// * The copy constructor creates a new SkeletonVariables object that is an +// * exact copy of an existing one. This includes not only the observed data +// * and model parameters, but also all internal state required for inference, +// * such as sampler state and random number generator state. +// * +// * Copying is required in bgms because variable-model objects are duplicated +// * during inference, for example when running multiple +// * chains, or storing and restoring model states. +// * +// * The copy is performed using a constructor initializer list to ensure +// * that the BaseModel subobject and all data members are constructed +// * directly from their counterparts in the source object. +// * +// * After construction, the new object is independent from the original +// * but represents the same model state. +// */ +// SkeletonVariables(const SkeletonVariables& other) +// : BaseModel(other), +// observations_(other.observations_), +// inclusion_probability_(other.inclusion_probability_), +// edge_indicators_(other.edge_indicators_), +// n_(other.n_), +// p_(other.p_), +// dim_(other.dim_), +// parameters_(other.parameters_), +// adaptive_mh_(other.adaptive_mh_), +// rng_(other.rng_) +// {} + +// // -------------------------------------------------- +// // Polymorphic copy +// // -------------------------------------------------- +// std::unique_ptr clone() const override { +// return std::make_unique(*this); +// } + +// // -------------------------------------------------- +// // Capabilities +// // -------------------------------------------------- +// bool has_log_posterior() const override { return true; } +// bool has_gradient() const override { return true; } +// bool has_adaptive_mh() const override { return true; } + +// // -------------------------------------------------- +// // Log posterior (LIKELIHOOD + PRIOR) +// // -------------------------------------------------- +// double log_posterior(const arma::vec& parameters) override { +// // Skeleton: flat log-density +// // Real models will: +// // - unpack parameters +// // - compute likelihood +// // - add priors +// return 0.0; +// } + +// // -------------------------------------------------- +// // Gradient of LOG (LIKELIHOOD * PRIOR) +// // -------------------------------------------------- +// void gradient(const arma::vec& parameters) override { +// // Skeleton: +// } + +// // -------------------------------------------------- +// // One Metropolis–Hastings step +// // -------------------------------------------------- +// void do_one_mh_step(arma::vec& parameters) override { +// // Skeleton: +// } + +// // -------------------------------------------------- +// // Required interface +// // -------------------------------------------------- +// size_t parameter_dimension() const override { +// return dim_; +// } + +// void set_seed(int seed) override { +// rng_ = SafeRNG(seed); +// } + +// protected: +// // -------------------------------------------------- +// // Data +// // -------------------------------------------------- +// arma::mat observations_; +// arma::mat inclusion_probability_; +// arma::imat edge_indicators_; + +// // -------------------------------------------------- +// // Dimensions +// // -------------------------------------------------- +// size_t n_ = 0; // number of observations +// size_t p_ = 0; // number of variables +// size_t dim_ = 0; // dimension of parameter vector + +// // -------------------------------------------------- +// // Parameters & MCMC machinery +// // -------------------------------------------------- +// arma::vec parameters_; +// AdaptiveMetropolis adaptive_mh_; +// SafeRNG rng_; +// }; diff --git a/src/priors/edge_prior.h b/src/priors/edge_prior.h new file mode 100644 index 0000000..c17c838 --- /dev/null +++ b/src/priors/edge_prior.h @@ -0,0 +1,258 @@ +#pragma once + +#include +#include +#include "../rng/rng_utils.h" +#include "../utils/common_helpers.h" +#include "sbm_edge_prior.h" +#include "../sbm_edge_prior_interface.h" + + +/** + * Abstract base class for edge inclusion priors. + * + * The edge prior updates the inclusion probability matrix based on the + * current edge indicators. This is independent of the model type (GGM, OMRF, + * etc.), so it is implemented as a separate class hierarchy. + * + * The MCMC runner calls update() after each edge indicator update, passing + * the current edge indicators and inclusion probability matrix. The edge + * prior modifies inclusion_probability in place. + */ +class BaseEdgePrior { +public: + virtual ~BaseEdgePrior() = default; + + virtual void update( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + int num_pairwise, + SafeRNG& rng + ) = 0; + + virtual std::unique_ptr clone() const = 0; + + virtual bool has_allocations() const { return false; } + virtual arma::ivec get_allocations() const { return arma::ivec(); } +}; + + +/** + * Bernoulli edge prior (fixed inclusion probabilities, no update needed). + */ +class BernoulliEdgePrior : public BaseEdgePrior { +public: + void update( + const arma::imat& /*edge_indicators*/, + arma::mat& /*inclusion_probability*/, + int /*num_variables*/, + int /*num_pairwise*/, + SafeRNG& /*rng*/ + ) override { + // No-op: inclusion probabilities are fixed + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } +}; + + +/** + * Beta-Bernoulli edge prior. + * + * Draws a shared inclusion probability from Beta(alpha + #included, + * beta + #excluded) and assigns it to all edges. + */ +class BetaBernoulliEdgePrior : public BaseEdgePrior { +public: + BetaBernoulliEdgePrior(double alpha = 1.0, double beta = 1.0) + : alpha_(alpha), beta_(beta) {} + + void update( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + int num_pairwise, + SafeRNG& rng + ) override { + int num_edges_included = 0; + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + num_edges_included += edge_indicators(i, j); + } + } + + double prob = rbeta(rng, + alpha_ + num_edges_included, + beta_ + num_pairwise - num_edges_included + ); + + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + inclusion_probability(i, j) = prob; + inclusion_probability(j, i) = prob; + } + } + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + double alpha_; + double beta_; +}; + + +/** + * Stochastic Block Model (MFM-SBM) edge prior. + * + * Maintains cluster allocations and block-level inclusion probabilities. + * Each edge's inclusion probability depends on its endpoints' cluster + * assignments. + */ +class StochasticBlockEdgePrior : public BaseEdgePrior { +public: + StochasticBlockEdgePrior( + double beta_bernoulli_alpha, + double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, + double dirichlet_alpha, + double lambda + ) : beta_bernoulli_alpha_(beta_bernoulli_alpha), + beta_bernoulli_beta_(beta_bernoulli_beta), + beta_bernoulli_alpha_between_(beta_bernoulli_alpha_between), + beta_bernoulli_beta_between_(beta_bernoulli_beta_between), + dirichlet_alpha_(dirichlet_alpha), + lambda_(lambda), + initialized_(false) + {} + + /** + * Initialize SBM state from the current edge indicators. Called + * automatically on first update(). + */ + void initialize( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + SafeRNG& rng + ) { + cluster_allocations_.set_size(num_variables); + cluster_allocations_[0] = 0; + cluster_allocations_[1] = 1; + for (int i = 2; i < num_variables; i++) { + cluster_allocations_[i] = (runif(rng) > 0.5) ? 1 : 0; + } + + cluster_prob_ = block_probs_mfm_sbm( + cluster_allocations_, + arma::conv_to::from(edge_indicators), + num_variables, + beta_bernoulli_alpha_, beta_bernoulli_beta_, + beta_bernoulli_alpha_between_, beta_bernoulli_beta_between_, + rng + ); + + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + inclusion_probability(i, j) = cluster_prob_(cluster_allocations_[i], cluster_allocations_[j]); + inclusion_probability(j, i) = inclusion_probability(i, j); + } + } + + log_Vn_ = compute_Vn_mfm_sbm( + num_variables, dirichlet_alpha_, num_variables + 10, lambda_); + + initialized_ = true; + } + + void update( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + int /*num_pairwise*/, + SafeRNG& rng + ) override { + if (!initialized_) { + initialize(edge_indicators, inclusion_probability, num_variables, rng); + } + + cluster_allocations_ = block_allocations_mfm_sbm( + cluster_allocations_, num_variables, log_Vn_, cluster_prob_, + arma::conv_to::from(edge_indicators), dirichlet_alpha_, + beta_bernoulli_alpha_, beta_bernoulli_beta_, + beta_bernoulli_alpha_between_, beta_bernoulli_beta_between_, rng + ); + + cluster_prob_ = block_probs_mfm_sbm( + cluster_allocations_, + arma::conv_to::from(edge_indicators), num_variables, + beta_bernoulli_alpha_, beta_bernoulli_beta_, + beta_bernoulli_alpha_between_, beta_bernoulli_beta_between_, rng + ); + + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + inclusion_probability(i, j) = cluster_prob_(cluster_allocations_[i], cluster_allocations_[j]); + inclusion_probability(j, i) = inclusion_probability(i, j); + } + } + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + bool has_allocations() const override { return initialized_; } + + arma::ivec get_allocations() const override { + return arma::conv_to::from(cluster_allocations_) + 1; // 1-based + } + +private: + double beta_bernoulli_alpha_; + double beta_bernoulli_beta_; + double beta_bernoulli_alpha_between_; + double beta_bernoulli_beta_between_; + double dirichlet_alpha_; + double lambda_; + + bool initialized_; + arma::uvec cluster_allocations_; + arma::mat cluster_prob_; + arma::vec log_Vn_; +}; + + +/** + * Factory: create an edge prior from an EdgePrior enum and hyperparameters. + */ +inline std::unique_ptr create_edge_prior( + EdgePrior type, + double beta_bernoulli_alpha = 1.0, + double beta_bernoulli_beta = 1.0, + double beta_bernoulli_alpha_between = 1.0, + double beta_bernoulli_beta_between = 1.0, + double dirichlet_alpha = 1.0, + double lambda = 1.0 +) { + switch (type) { + case Beta_Bernoulli: + return std::make_unique( + beta_bernoulli_alpha, beta_bernoulli_beta); + case Stochastic_Block: + return std::make_unique( + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda); + case Bernoulli: + case Not_Applicable: + default: + return std::make_unique(); + } +} diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp new file mode 100644 index 0000000..481e92b --- /dev/null +++ b/src/sample_ggm.cpp @@ -0,0 +1,70 @@ +#include +#include +#include +#include +#include + +#include "models/ggm/ggm_model.h" +#include "utils/progress_manager.h" +#include "utils/common_helpers.h" +#include "priors/edge_prior.h" +#include "mcmc/chain_result.h" +#include "mcmc/mcmc_runner.h" +#include "mcmc/sampler_config.h" + +// [[Rcpp::export]] +Rcpp::List sample_ggm( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const int no_chains, + const bool edge_selection, + const int seed, + const int no_threads, + const int progress_type, + const std::string& edge_prior = "Bernoulli", + const double beta_bernoulli_alpha = 1.0, + const double beta_bernoulli_beta = 1.0, + const double beta_bernoulli_alpha_between = 1.0, + const double beta_bernoulli_beta_between = 1.0, + const double dirichlet_alpha = 1.0, + const double lambda = 1.0 +) { + + // Create model from R input + GGMModel model = createGGMModelFromR( + inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + + // Configure sampler - GGM only supports MH + SamplerConfig config; + config.sampler_type = "mh"; + config.no_iter = no_iter; + config.no_warmup = no_warmup; + config.edge_selection = edge_selection; + config.seed = seed; + + // Set up progress manager + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + + // Create edge prior + EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); + auto edge_prior_obj = create_edge_prior( + edge_prior_enum, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda + ); + + // Run MCMC using unified infrastructure + std::vector results = run_mcmc_sampler( + model, *edge_prior_obj, config, no_chains, no_threads, pm); + + // Convert to R list format + Rcpp::List output = convert_results_to_list(results); + + pm.finish(); + + return output; +} \ No newline at end of file diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp new file mode 100644 index 0000000..9525f99 --- /dev/null +++ b/src/sample_omrf.cpp @@ -0,0 +1,122 @@ +/** + * sample_omrf.cpp - R interface for OMRF model sampling + * + * Uses the unified MCMC runner infrastructure to sample from OMRF models. + * Supports MH, NUTS, and HMC samplers with optional edge selection. + */ +#include +#include +#include + +#include "models/omrf/omrf_model.h" +#include "utils/progress_manager.h" +#include "utils/common_helpers.h" +#include "priors/edge_prior.h" +#include "mcmc/chain_result.h" +#include "mcmc/mcmc_runner.h" +#include "mcmc/sampler_config.h" + +/** + * R-exported function to sample from an OMRF model + * + * @param inputFromR List with model specification + * @param prior_inclusion_prob Prior inclusion probabilities (p x p matrix) + * @param initial_edge_indicators Initial edge indicators (p x p integer matrix) + * @param no_iter Number of post-warmup iterations + * @param no_warmup Number of warmup iterations + * @param no_chains Number of parallel chains + * @param edge_selection Whether to do edge selection (spike-and-slab) + * @param sampler_type "mh", "nuts", or "hmc" + * @param seed Random seed + * @param no_threads Number of threads for parallel execution + * @param progress_type Progress bar type + * @param edge_prior Edge prior type: "Bernoulli", "Beta-Bernoulli", "Stochastic-Block" + * @param na_impute Whether to impute missing data + * @param missing_index Matrix of missing data indices (n_missing x 2, 0-based) + * @param beta_bernoulli_alpha Beta-Bernoulli alpha hyperparameter + * @param beta_bernoulli_beta Beta-Bernoulli beta hyperparameter + * @param beta_bernoulli_alpha_between SBM between-cluster alpha + * @param beta_bernoulli_beta_between SBM between-cluster beta + * @param dirichlet_alpha Dirichlet alpha for SBM + * @param lambda Lambda for SBM + * @param target_acceptance Target acceptance rate for NUTS/HMC (default: 0.8) + * @param max_tree_depth Maximum tree depth for NUTS (default: 10) + * @param num_leapfrogs Number of leapfrog steps for HMC (default: 10) + * @param edge_selection_start Iteration to start edge selection (-1 = no_warmup/2) + * + * @return List with per-chain results including samples and diagnostics + */ +// [[Rcpp::export]] +Rcpp::List sample_omrf( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const int no_chains, + const bool edge_selection, + const std::string& sampler_type, + const int seed, + const int no_threads, + const int progress_type, + const std::string& edge_prior = "Bernoulli", + const bool na_impute = false, + const Rcpp::Nullable missing_index_nullable = R_NilValue, + const double beta_bernoulli_alpha = 1.0, + const double beta_bernoulli_beta = 1.0, + const double beta_bernoulli_alpha_between = 1.0, + const double beta_bernoulli_beta_between = 1.0, + const double dirichlet_alpha = 1.0, + const double lambda = 1.0, + const double target_acceptance = 0.8, + const int max_tree_depth = 10, + const int num_leapfrogs = 10, + const int edge_selection_start = -1 +) { + // Create model from R input + OMRFModel model = createOMRFModelFromR( + inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + + // Set up missing data imputation + if (na_impute && missing_index_nullable.isNotNull()) { + arma::imat missing_index = Rcpp::as( + Rcpp::IntegerMatrix(missing_index_nullable.get())); + model.set_missing_data(missing_index); + } + + // Create edge prior + EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); + auto edge_prior_obj = create_edge_prior( + edge_prior_enum, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda + ); + + // Configure sampler + SamplerConfig config; + config.sampler_type = sampler_type; + config.no_iter = no_iter; + config.no_warmup = no_warmup; + config.edge_selection = edge_selection; + config.edge_selection_start = edge_selection_start; + config.seed = seed; + config.target_acceptance = target_acceptance; + config.max_tree_depth = max_tree_depth; + config.num_leapfrogs = num_leapfrogs; + config.na_impute = na_impute; + + // Set up progress manager + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + + // Run MCMC using unified infrastructure + std::vector results = run_mcmc_sampler( + model, *edge_prior_obj, config, no_chains, no_threads, pm); + + // Convert to R list format + Rcpp::List output = convert_results_to_list(results); + + pm.finish(); + + return output; +} diff --git a/tests/testthat/test-ggm.R b/tests/testthat/test-ggm.R new file mode 100644 index 0000000..a497342 --- /dev/null +++ b/tests/testthat/test-ggm.R @@ -0,0 +1,229 @@ +test_that("GGM runs and returns bgms object", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + colnames(x) <- paste0("V", 1:4) + + fit <- bgm( + x = x, + variable_type = "continuous", + iter = 200, + warmup = 1000, + chains = 1, + display_progress = "none", + seed = 123 + ) + + expect_s3_class(fit, "bgms") + expect_true(fit$arguments$is_continuous) + expect_equal(fit$arguments$num_variables, 4) + expect_equal(fit$arguments$iter, 200) + expect_equal(fit$raw_samples$niter, 200) + expect_equal(fit$raw_samples$nchains, 1) +}) + +test_that("GGM output has correct dimensions", { + testthat::skip_on_cran() + + p <- 5 + set.seed(42) + x <- matrix(rnorm(100 * p), nrow = 100, ncol = p) + colnames(x) <- paste0("V", 1:p) + + fit <- bgm( + x = x, + variable_type = "continuous", + edge_selection = TRUE, + iter = 100, + warmup = 1000, + chains = 1, + display_progress = "none", + seed = 42 + ) + + # main: p diagonal precision elements + expect_equal(nrow(fit$posterior_summary_main), p) + expect_equal(nrow(fit$posterior_mean_main), p) + + # pairwise: p*(p-1)/2 off-diagonal elements + n_edges <- p * (p - 1) / 2 + expect_equal(nrow(fit$posterior_summary_pairwise), n_edges) + expect_equal(nrow(fit$posterior_mean_pairwise), p) + expect_equal(ncol(fit$posterior_mean_pairwise), p) + + # indicators + expect_equal(nrow(fit$posterior_summary_indicator), n_edges) + expect_equal(nrow(fit$posterior_mean_indicator), p) + expect_equal(ncol(fit$posterior_mean_indicator), p) + + # raw samples + expect_equal(ncol(fit$raw_samples$main[[1]]), p) + expect_equal(ncol(fit$raw_samples$pairwise[[1]]), n_edges) + expect_equal(nrow(fit$raw_samples$main[[1]]), 100) +}) + +test_that("GGM without edge selection works", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + colnames(x) <- paste0("V", 1:4) + + fit <- bgm( + x = x, + variable_type = "continuous", + edge_selection = FALSE, + iter = 100, + warmup = 1000, + chains = 1, + display_progress = "none", + seed = 99 + ) + + expect_s3_class(fit, "bgms") + expect_null(fit$posterior_summary_indicator) + expect_null(fit$posterior_mean_indicator) +}) + +test_that("GGM rejects NUTS and HMC", { + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + + expect_error( + bgm(x = x, variable_type = "continuous", update_method = "nuts"), + "only supports.*adaptive-metropolis" + ) + expect_error( + bgm(x = x, variable_type = "continuous", update_method = "hamiltonian-mc"), + "only supports.*adaptive-metropolis" + ) +}) + +test_that("Mixed continuous and ordinal is rejected", { + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + + expect_error( + bgm(x = x, variable_type = c("continuous", "ordinal", "ordinal", "ordinal")), + "all variables must be of type" + ) +}) + +test_that("GGM is reproducible", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + colnames(x) <- paste0("V", 1:4) + + fit1 <- bgm( + x = x, variable_type = "continuous", + iter = 100, warmup = 1000, chains = 1, + display_progress = "none", seed = 321 + ) + fit2 <- bgm( + x = x, variable_type = "continuous", + iter = 100, warmup = 1000, chains = 1, + display_progress = "none", seed = 321 + ) + + expect_equal(fit1$raw_samples$main, fit2$raw_samples$main) + expect_equal(fit1$raw_samples$pairwise, fit2$raw_samples$pairwise) +}) + +test_that("GGM posterior precision diagonal means are positive", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(500), nrow = 100, ncol = 5) + colnames(x) <- paste0("V", 1:5) + + fit <- bgm( + x = x, variable_type = "continuous", + edge_selection = FALSE, + iter = 500, warmup = 1000, chains = 1, + display_progress = "none", seed = 42 + ) + + # Diagonal precision elements should be positive + expect_true(all(fit$posterior_summary_main$mean > 0)) +}) + + +test_that("GGM posterior compare against simulated data", { + testthat::skip_on_cran() + + n <- 1000 + p <- 10 + ne <- p * (p - 1) / 2 + # avoid a test dependency on BDgraph and a random graph structure by using a fixed precision matrix + # set.seed(42) + # adj <- matrix(0, p, p) + # adj[lower.tri(adj)] <- runif(ne) < .5 + # adj <- adj + t(adj) + # omega <- zapsmall(BDgraph::rgwish(1, adj)) + omega <- structure(c(6.240119, 0, 0, -0.370239, 0, 0, 0, 0, -1.622902, + 0, 0, 1.905013, 0, -0.194995, 0, 0, -2.468628, -0.557277, 0, + 0, 0, 0, 5.509142, -7.942389, 1.40081, 0, 0, -0.76775, 0, 0, + -0.370239, -0.194995, -7.942389, 15.521405, -3.537489, 0, 4.60785, + 0, 3.278511, 0, 0, 0, 1.40081, -3.537489, 2.78257, 0, 0, 1.374641, + 0, -1.198092, 0, 0, 0, 0, 0, 1.350879, 0, 0.230677, -1.357952, + 0, 0, -2.468628, 0, 4.60785, 0, 0, 15.88698, 0, 1.20017, -1.973919, + 0, -0.557277, -0.76775, 0, 1.374641, 0.230677, 0, 7.007312, 1.597035, + 0, -1.622902, 0, 0, 3.278511, 0, -1.357952, 1.20017, 1.597035, + 13.378039, -4.769958, 0, 0, 0, 0, -1.198092, 0, -1.973919, 0, + -4.769958, 5.536877), dim = c(10L, 10L)) + adj <- omega != 0 + diag(adj) <- 0 + covmat <- solve(omega) + chol <- chol(covmat) + + set.seed(43) + x <- matrix(rnorm(n * p), nrow = n, ncol = p) %*% chol + # cov(x) - covmat + + fit_no_vs <- bgm( + x = x, variable_type = "continuous", + edge_selection = FALSE, + iter = 5000, warmup = 1000, chains = 2, + display_progress = "none", seed = 42 + ) + + expect_true(cor(fit_no_vs$posterior_summary_main$mean, diag(omega)) > 0.9) + expect_true(cor(fit_no_vs$posterior_summary_pairwise$mean, omega[upper.tri(omega)]) > 0.9) + + fit_vs <- bgm( + x = x, variable_type = "continuous", + edge_selection = TRUE, + iter = 10000, warmup = 2000, chains = 2, + display_progress = "none", seed = 42 + ) + + expect_true(cor(fit_vs$posterior_summary_main$mean, diag(omega)) > 0.9) + expect_true(cor(fit_vs$posterior_summary_pairwise$mean, omega[upper.tri(omega)]) > 0.9) + expect_true(cor(fit_vs$posterior_summary_indicator$mean, adj[upper.tri(adj)]) > 0.85) + + # This test somehow fails? Can you fix this? + fit_vs_sbm <- bgm( + x = x, variable_type = "continuous", + edge_selection = TRUE, + edge_prior = "Stochastic-Block", + iter = 5000, warmup = 1000, chains = 2, + display_progress = "none", seed = 42 + ) + + expect_true(cor(fit_vs_sbm$posterior_summary_main$mean, diag(omega)) > 0.9) + expect_true(cor(fit_vs_sbm$posterior_summary_pairwise$mean, omega[upper.tri(omega)]) > 0.9) + expect_true(cor(fit_vs_sbm$posterior_summary_indicator$mean, adj[upper.tri(adj)]) > 0.85) + + # SBM-specific output + expect_false(is.null(fit_vs_sbm$posterior_mean_coclustering_matrix)) + expect_equal(nrow(fit_vs_sbm$posterior_mean_coclustering_matrix), p) + expect_equal(ncol(fit_vs_sbm$posterior_mean_coclustering_matrix), p) + expect_false(is.null(fit_vs_sbm$posterior_num_blocks)) + expect_false(is.null(fit_vs_sbm$posterior_mode_allocations)) + expect_false(is.null(fit_vs_sbm$raw_samples$allocations)) + + +})