From a2b8af584f0a49ac2092af6e5aca9b32db4d3fda Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 11 Nov 2025 21:52:01 +0100 Subject: [PATCH 01/23] ggm compiles but runtime has a weird error --- R/RcppExports.R | 8 ++ src/RcppExports.cpp | 36 +++++++ src/adaptiveMetropolis.h | 55 ++++++++++ src/base_model.cpp | 0 src/base_model.h | 56 ++++++++++ src/chainResultNew.h | 32 ++++++ src/cholupdate.cpp | 129 +++++++++++++++++++++++ src/cholupdate.h | 6 ++ src/ggm_model.cpp | 215 +++++++++++++++++++++++++++++++++++++++ src/ggm_model.h | 134 ++++++++++++++++++++++++ src/sample_ggm.cpp | 185 +++++++++++++++++++++++++++++++++ test_ggm.R | 72 +++++++++++++ 12 files changed, 928 insertions(+) create mode 100644 src/adaptiveMetropolis.h create mode 100644 src/base_model.cpp create mode 100644 src/base_model.h create mode 100644 src/chainResultNew.h create mode 100644 src/cholupdate.cpp create mode 100644 src/cholupdate.h create mode 100644 src/ggm_model.cpp create mode 100644 src/ggm_model.h create mode 100644 src/sample_ggm.cpp create mode 100644 test_ggm.R diff --git a/R/RcppExports.R b/R/RcppExports.R index 25253d09..808269a3 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -9,6 +9,10 @@ run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_ .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type, pairwise_scaling_factors) } +chol_update_arma <- function(R, u, downdate = FALSE, eps = 1e-12) { + .Call(`_bgms_chol_update_arma`, R, u, downdate, eps) +} + compute_conditional_probs <- function(observations, predict_vars, interactions, thresholds, no_categories, variable_type, baseline_category) { .Call(`_bgms_compute_conditional_probs`, observations, predict_vars, interactions, thresholds, no_categories, variable_type, baseline_category) } @@ -25,6 +29,10 @@ run_simulation_parallel <- function(pairwise_samples, main_samples, draw_indices .Call(`_bgms_run_simulation_parallel`, pairwise_samples, main_samples, draw_indices, no_states, no_variables, no_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } +sample_ggm <- function(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) { + .Call(`_bgms_sample_ggm`, X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) +} + compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index ff1811bf..91db45cc 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -104,6 +104,20 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// chol_update_arma +arma::mat chol_update_arma(arma::mat& R, arma::vec& u, bool downdate, double eps); +RcppExport SEXP _bgms_chol_update_arma(SEXP RSEXP, SEXP uSEXP, SEXP downdateSEXP, SEXP epsSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< arma::mat& >::type R(RSEXP); + Rcpp::traits::input_parameter< arma::vec& >::type u(uSEXP); + Rcpp::traits::input_parameter< bool >::type downdate(downdateSEXP); + Rcpp::traits::input_parameter< double >::type eps(epsSEXP); + rcpp_result_gen = Rcpp::wrap(chol_update_arma(R, u, downdate, eps)); + return rcpp_result_gen; +END_RCPP +} // compute_conditional_probs Rcpp::List compute_conditional_probs(arma::imat observations, arma::ivec predict_vars, arma::mat interactions, arma::mat thresholds, arma::ivec no_categories, Rcpp::StringVector variable_type, arma::ivec baseline_category); RcppExport SEXP _bgms_compute_conditional_probs(SEXP observationsSEXP, SEXP predict_varsSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP no_categoriesSEXP, SEXP variable_typeSEXP, SEXP baseline_categorySEXP) { @@ -179,6 +193,26 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// sample_ggm +Rcpp::List sample_ggm(const arma::mat& X, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, int progress_type); +RcppExport SEXP _bgms_sample_ggm(SEXP XSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); + Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + rcpp_result_gen = Rcpp::wrap(sample_ggm(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)); + return rcpp_result_gen; +END_RCPP +} // compute_Vn_mfm_sbm arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda); RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) { @@ -197,10 +231,12 @@ END_RCPP static const R_CallMethodDef CallEntries[] = { {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 38}, {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 35}, + {"_bgms_chol_update_arma", (DL_FUNC) &_bgms_chol_update_arma, 4}, {"_bgms_compute_conditional_probs", (DL_FUNC) &_bgms_compute_conditional_probs, 7}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 7}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 10}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/adaptiveMetropolis.h b/src/adaptiveMetropolis.h new file mode 100644 index 00000000..f8df5c63 --- /dev/null +++ b/src/adaptiveMetropolis.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +class AdaptiveProposal { + +public: + + AdaptiveProposal(size_t num_params, size_t adaption_window = 50, double target_accept = 0.44) { + proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD + // acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros); + // total_proposals_ = arma::ivec(num_params, arma::fill::zeros); + adaptation_window_ = adaption_window; + target_accept_ = target_accept; + } + + double get_proposal_sd(size_t param_index) const { + validate_index(param_index); + return proposal_sds_[param_index]; + } + + void update_proposal_sd(size_t param_index, double alpha) { + + if (!adapting_) { + return; + } + + double current_sd = get_proposal_sd(param_index); + double updated_sd = current_sd + std::pow(1.0 / iterations_, 0.6) * (alpha - target_accept_); + // proposal_sds_[param_index] = std::min(20.0, std::max(1.0 / std::sqrt(n), updated_sd)); + proposal_sds_(param_index) = std::min(20.0, updated_sd); + } + + void increment_iteration() { + iterations_++; + if (iterations_ >= adaptation_window_) { + adapting_ = false; + } + } + +private: + arma::vec proposal_sds_; + int iterations_ = 0, + adaptation_window_; + double target_accept_ = 0.44; + bool adapting_ = true; + + void validate_index(size_t index) const { + if (index >= proposal_sds_.n_elem) { + throw std::out_of_range("Parameter index out of range"); + } + } + +}; diff --git a/src/base_model.cpp b/src/base_model.cpp new file mode 100644 index 00000000..e69de29b diff --git a/src/base_model.h b/src/base_model.h new file mode 100644 index 00000000..20cd4930 --- /dev/null +++ b/src/base_model.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +class BaseModel { +public: + virtual ~BaseModel() = default; + + // Capability queries + virtual bool has_gradient() const { return false; } + virtual bool has_adaptive_mh() const { return false; } + + // Core methods (to be overridden by derived classes) + virtual double logp(const std::vector& parameters) = 0; + + virtual arma::vec gradient(const std::vector& parameters) { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + throw std::runtime_error("Gradient method must be implemented in derived class"); + } + + virtual std::pair logp_and_gradient( + const std::vector& parameters) { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + return {logp(parameters), gradient(parameters)}; + } + + // For Metropolis-Hastings (model handles parameter groups internally) + virtual void do_one_mh_step() { + throw std::runtime_error("do_one_mh_step method must be implemented in derived class"); + } + + virtual arma::vec get_vectorized_parameters() { + throw std::runtime_error("get_vectorized_parameters method must be implemented in derived class"); + } + + // Return dimensionality of the parameter space + virtual size_t parameter_dimension() const = 0; + + virtual void set_seed(int seed) { + throw std::runtime_error("set_seed method must be implemented in derived class"); + } + + virtual std::unique_ptr clone() const { + throw std::runtime_error("clone method must be implemented in derived class"); + } + + +protected: + BaseModel() = default; +}; diff --git a/src/chainResultNew.h b/src/chainResultNew.h new file mode 100644 index 00000000..5a6dd855 --- /dev/null +++ b/src/chainResultNew.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +class ChainResultNew { + +public: + ChainResultNew() {} + + bool error; + std::string error_msg; + int chain_id; + bool userInterrupt; + + arma::mat samples; + + void reserve(size_t param_dim, size_t n_iter) { + samples.set_size(param_dim, n_iter); + } + void store_sample(size_t iter, const arma::vec& sample) { + samples.col(iter) = sample; + } + + // arma::imat indicator_samples; + + // other samples + // arma::ivec treedepth_samples; + // arma::ivec divergent_samples; + // arma::vec energy_samples; + // arma::imat allocation_samples; +}; diff --git a/src/cholupdate.cpp b/src/cholupdate.cpp new file mode 100644 index 00000000..5ab8f6eb --- /dev/null +++ b/src/cholupdate.cpp @@ -0,0 +1,129 @@ +#include "cholupdate.h" + +extern "C" { + +// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1876-L1883 +static inline double hypote(double x, double y) { +/* stable computation of sqrt(x^2 + y^2) */ + double t; + x = fabs(x);y=fabs(y); + if (y>x) { t = x;x = y; y = t;} + if (x==0) return(y); else t = y/x; + return(x*sqrt(1+t*t)); +} /* hypote */ + +// from mgcv: https://github.com/cran/mgcv/blob/1b6a4c8374612da27e36420b4459e93acb183f2d/src/mat.c#L1956 +void chol_up(double *R,double *u, int *n,int *up,double *eps) { +/* Rank 1 update of a cholesky factor. Works as follows: + + [up=1] R'R + uu' = [u,R'][u,R']' = [u,R']Q'Q[u,R']', and then uses Givens rotations to + construct Q such that Q[u,R']' = [0,R1']'. Hence R1'R1 = R'R + uu'. The construction + operates from first column to last. + + [up=0] uses an almost identical sequence, but employs hyperbolic rotations + in place of Givens. See Golub and van Loan (2013, 4e 6.5.4) + + Givens rotations are of form [c,-s] where c = cos(theta), s = sin(theta). + [s,c] + + Assumes R upper triangular, and that it is OK to use first two columns + below diagonal as temporary strorage for Givens rotations (the storage is + needed to ensure algorithm is column oriented). + + For downdate returns a negative value in R[1] (R[1,0]) if not +ve definite. +*/ + double c0,s0,*c,*s,z,*x,z0,*c1; + int j,j1,n1; + n1 = *n - 1; + if (*up) for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R */ + z = *u; /* initial element of u */ + x = R + *n * j; /* current column */ + c = R + 2;s = R + *n + 2; /* Storage for first n-2 Givens rotations */ + for (c1=c+j1;c R[j,j] */ + z0 = hypote(z,*x); /* sqrt(z^2+R[j,j]^2) */ + c0 = *x/z0; s0 = z/z0; /* need to zero z */ + /* now apply this rotation and this column is finished (so no need to update z) */ + *x = s0 * z + c0 * *x; + } else for (j1=-1,j=0;j<*n;j++,u++,j1++) { /* loop over columns of R for down-dating */ + z = *u; /* initial element of u */ + x = R + *n * j; /* current column */ + c = R + 2;s = R + *n + 2; /* Storage for first n-2 hyperbolic rotations */ + for (c1=c+j1;c R[j,j] */ + z0 = z / *x; /* sqrt(z^2+R[j,j]^2) */ + if (fabs(z0)>=1) { /* downdate not +ve def */ + //Rprintf("j = %d d = %g ",j,z0); + if (*n>1) R[1] = -2.0; + return; /* signals error */ + } + if (z0 > 1 - *eps) z0 = 1 - *eps; + c0 = 1/sqrt(1-z0*z0);s0 = c0 * z0; + /* now apply this rotation and this column is finished (so no need to update z) */ + *x = -s0 * z + c0 * *x; + } + + /* now zero c and s storage */ + c = R + 2;s = R + *n + 2; + for (x = c + *n - 2;c + +void cholesky_update( arma::mat& R, arma::vec& u, double eps = 1e-12); +void cholesky_downdate(arma::mat& R, arma::vec& u, double eps = 1e-12); diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp new file mode 100644 index 00000000..eee45bf7 --- /dev/null +++ b/src/ggm_model.cpp @@ -0,0 +1,215 @@ +#include "ggm_model.h" +#include "adaptiveMetropolis.h" +#include "rng_utils.h" +#include "cholupdate.h" + +double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { + return(A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i)); +} + +void GGMModel::get_constants(size_t i, size_t j) { + + // TODO: helper function? + double logdet_omega = 0.0; + for (size_t i = 0; i < p_; i++) { + logdet_omega += std::log(phi_(i, i)); + } + + double log_adj_omega_ii = logdet_omega + log(abs(inv_omega_(i, i))); + double log_adj_omega_ij = logdet_omega + log(abs(inv_omega_(i, j))); + double log_adj_omega_jj = logdet_omega + log(abs(inv_omega_(j, j))); + + double inv_omega_sub_j1j1 = compute_inv_submatrix_i(inv_omega_, i, j, j); + double log_abs_inv_omega_sub_jj = log_adj_omega_ii + log(abs(inv_omega_sub_j1j1)); + + double Phi_q1q = (-1 * (2 * std::signbit(inv_omega_(i, j)) - 1)) * std::exp( + (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) + ); + double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); + + constants_[1] = Phi_q1q; + constants_[2] = Phi_q1q1; + constants_[3] = omega_(i, j) - Phi_q1q * Phi_q1q1; + constants_[4] = Phi_q1q1; + constants_[5] = omega_(j, j) - Phi_q1q * Phi_q1q; + constants_[6] = constants_[5] + constants_[3] * constants_[3] / (constants_[4] * constants_[4]); + +} + +double GGMModel::R(const double x) const { + if (x == 0) { + return constants_[6]; + } else { + return constants_[3] + std::pow((x - constants_[3]) / constants_[4], 2); + } +} + +void GGMModel::update_edge_parameter(size_t i, size_t j) { + + if (edge_indicators_(i, j) == 0) { + return; // Edge is not included; skip update + } + + get_constants(i, j); + double Phi_q1q = constants_[1]; + double Phi_q1q1 = constants_[2]; + + size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form + double proposal_sd = proposal_.get_proposal_sd(e); + + double phi_prop = rnorm(rng_, Phi_q1q, proposal_sd); + double omega_prop_q1q = constants_[3] + constants_[4] * phi_prop; + double omega_prop_qq = R(omega_prop_q1q); + + // form full proposal matrix for Omega + omega_prop_ = omega_; // TODO: needs to be a copy! + omega_prop_(i, j) = omega_prop_q1q; + omega_prop_(j, i) = omega_prop_q1q; + omega_prop_(j, j) = omega_prop_qq; + + double ln_alpha = log_density(omega_prop_) - log_density(); + ln_alpha += R::dcauchy(omega_prop_(i, j), 0.0, 2.5, true); + ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); + + double u = runif(rng_); + if (ln_alpha > log(u)) { + // accept proposal + + double omega_ij = omega_(i, j); + double omega_jj = omega_(j, j); + + omega_(i, j) = omega_prop_q1q; + omega_(j, i) = omega_prop_q1q; + omega_(j, j) = omega_prop_qq; + + // TODO: preallocate? + // find v for low rank update + arma::vec v1 = {0, -1}; + arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2}; + + arma::vec vf1 = arma::zeros(p_); + arma::vec vf2 = arma::zeros(p_); + vf1[i] = v1[1]; + vf1[j] = v1[2]; + vf2[i] = v2[1]; + vf2[j] = v2[2]; + + // we now have + // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) + + arma::vec u1 = (vf1 + vf2) / sqrt(2); + arma::vec u2 = (vf1 - vf2) / sqrt(2); + + // we now have + // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) + // and also + // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) + + // update phi + cholesky_update(phi_, u1); + cholesky_downdate(phi_, u2); + + // update inverse + inv_omega_ = phi_.t() * phi_; + + } + + double alpha = std::min(1.0, std::exp(ln_alpha)); + proposal_.update_proposal_sd(e, alpha); +} + +void GGMModel::do_one_mh_step() { + + // Update off-diagonals (upper triangle) + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + Rcpp::Rcout << "Updating edge parameter (" << i << ", " << j << ")" << std::endl; + update_edge_parameter(i, j); + } + } + + // Update diagonals + for (size_t i = 0; i < p_; ++i) { + Rcpp::Rcout << "Updating diagonal parameter " << i << std::endl; + update_diagonal_parameter(i); + } + + // if (edge_selection_) { + // for (size_t i = 0; i < p_ - 1; ++i) { + // for (size_t j = i + 1; j < p_; ++j) { + // update_edge_indicator_parameter_pair(i, j); + // } + // } + // } + proposal_.increment_iteration(); +} + +double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { + double logdet_omega = 0.0; + for (size_t i = 0; i < p_; i++) { + logdet_omega += std::log(phi(i, i)); + } + + // TODO: does this allocate? + double trace_prod = arma::accu(omega % suf_stat_); + + double log_likelihood = n_ * (p_ * log(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; + + return log_likelihood; +} + + +void GGMModel::update_diagonal_parameter(size_t i) { + // Implementation of diagonal parameter update + // 1-3) from before + double logdet_omega = 0.0; + for (size_t i = 0; i < p_; i++) { + logdet_omega += std::log(phi_(i, i)); + } + + double logdet_omega_sub_ii = logdet_omega + std::log(inv_omega_(i, i)); + + size_t e = i * (i + 1) / 2 + i; // parameter index in vectorized form + double proposal_sd = proposal_.get_proposal_sd(e); + + double theta_curr = (logdet_omega - logdet_omega_sub_ii) / 2; + double theta_prop = rnorm(rng_, theta_curr, proposal_sd); + + //4) Replace and rebuild omega + omega_prop_ = omega_; + omega_prop_(i, i) = omega_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); + + // 5) Acceptance ratio + double ln_alpha = log_density(omega_prop_) - log_density(); + ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); + ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); + ln_alpha += theta_prop - theta_curr; // Jacobian adjustment ? + + + if (log(runif(rng_)) < ln_alpha) { + + double omega_ii = omega_(i, i); + + arma::vec u(p_, arma::fill::zeros); + double delta = omega_ii - omega_prop_(i, i); + bool s = delta > 0; + u(i) = sqrt(abs(delta)); + + omega_(i, i) = omega_prop_(i, i); + + if (!s) + cholesky_update(phi_, u); + else + cholesky_downdate(phi_, u); + + inv_omega_ = phi_.t() * phi_; + + } + + double alpha = std::min(1.0, std::exp(ln_alpha)); + proposal_.update_proposal_sd(e, alpha); +} + +void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { + // Implementation of edge indicator parameter pair update +} diff --git a/src/ggm_model.h b/src/ggm_model.h new file mode 100644 index 00000000..05081db0 --- /dev/null +++ b/src/ggm_model.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include "base_model.h" +#include "adaptiveMetropolis.h" +#include "rng_utils.h" + +class GGMModel : public BaseModel { +public: + + GGMModel( + const arma::mat& X, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) : n_(X.n_rows), + p_(X.n_cols), + dim_((p_ * (p_ + 1)) / 2), + suf_stat_(X.t() * X), + prior_inclusion_prob_(prior_inclusion_prob), + edge_selection_(edge_selection), + proposal_(AdaptiveProposal(dim_, 500)), + omega_(arma::eye(p_, p_)), + phi_(arma::eye(p_, p_)), + inv_omega_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), + constants_(6) + {} + + GGMModel(const GGMModel& other) + : BaseModel(other), + dim_(other.dim_), + suf_stat_(other.suf_stat_), + n_(other.n_), + p_(other.p_), + prior_inclusion_prob_(other.prior_inclusion_prob_), + edge_selection_(other.edge_selection_), + omega_(other.omega_), + phi_(other.phi_), + inv_omega_(other.inv_omega_), + edge_indicators_(other.edge_indicators_), + vectorized_parameters_(other.vectorized_parameters_), + proposal_(other.proposal_), + rng_(other.rng_), + omega_prop_(other.omega_prop_), + constants_(other.constants_) + {} + + // // rng_ = SafeRNG(123); + + // } + + void set_adaptive_proposal(AdaptiveProposal proposal) { + proposal_ = proposal; + } + + virtual bool has_gradient() const { return false; } + virtual bool has_adaptive_mh() const override { return true; } + + double logp(const std::vector& parameters) override { + // Implement log probability computation + return 0.0; + } + + // TODO: this can be done more efficiently, no need for the Cholesky! + double log_density(const arma::mat& omega) const { return log_density_impl(omega, arma::chol(omega)); }; + double log_density() const { return log_density_impl(omega_, phi_); } + + void do_one_mh_step() override; + + size_t parameter_dimension() const override { + Rcpp::Rcout << "GGMModel::parameter_dimension() returning: " << dim_ << std::endl; + return dim_; + } + + void set_seed(int seed) override { + rng_ = SafeRNG(seed); + } + + arma::vec get_vectorized_parameters() override { + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + vectorized_parameters_(e) = omega_(i, j); + e++; + } + } + return vectorized_parameters_; + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); // uses copy constructor + } + +private: + // data + size_t dim_; + arma::mat suf_stat_; + size_t n_; + size_t p_; + arma::mat prior_inclusion_prob_; + bool edge_selection_ = true; + + // parameters + arma::mat omega_, phi_, inv_omega_; + arma::imat edge_indicators_; + arma::vec vectorized_parameters_; + + + AdaptiveProposal proposal_; + SafeRNG rng_; + + // internal helper variables + arma::mat omega_prop_; + arma::vec constants_; // Phi_q1q, Phi_q1q1, c[1], c[2], c[3], c[4] + + // Parameter group updates with optimized likelihood evaluations + void update_edge_parameter(size_t i, size_t j); + void update_diagonal_parameter(size_t i); + void update_edge_indicator_parameter_pair(size_t i, size_t j); + + // Helper methods + void get_constants(size_t i, size_t j); + double compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const; + double R(const double x) const; + + double log_density_impl(const arma::mat& omega, const arma::mat& phi) const; + + // double find_reasonable_step_size_edge(const arma::mat& omega, size_t i, size_t j); + // double find_reasonable_step_size_diag(const arma::mat& omega, size_t i); + // double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal); + // double diag_log_ratio(const arma::mat& omega, size_t i, double proposal); +}; diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp new file mode 100644 index 00000000..0d9de749 --- /dev/null +++ b/src/sample_ggm.cpp @@ -0,0 +1,185 @@ +#include +#include +#include +#include +#include + +#include "ggm_model.h" +#include "progress_manager.h" +#include "chainResultNew.h" + +void run_mcmc_sampler_single_thread( + ChainResultNew& chain_result, + BaseModel& model, + const int no_iter, + const int no_warmup, + const int chain_id, + ProgressManager& pm +) { + + size_t i = 0; + for (size_t iter = 0; iter < no_iter + no_warmup; ++iter) { + + model.do_one_mh_step(); + + if (iter >= no_warmup) { + + chain_result.store_sample(i, model.get_vectorized_parameters()); + i++; + } + + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + break; + } + } +} + +struct GGMChainRunner : public RcppParallel::Worker { + std::vector& results_; + std::vector>& models_; + size_t no_iter_; + size_t no_warmup_; + int seed_; + ProgressManager& pm_; + + GGMChainRunner( + std::vector& results, + std::vector>& models, + const size_t no_iter, + const size_t no_warmup, + const int seed, + ProgressManager& pm + ) : + results_(results), + models_(models), + no_iter_(no_iter), + no_warmup_(no_warmup), + seed_(seed), + pm_(pm) + {} + + void operator()(std::size_t begin, std::size_t end) { + for (std::size_t i = begin; i < end; ++i) { + + ChainResultNew& chain_result = results_[i]; + chain_result.chain_id = static_cast(i + 1); + chain_result.error = false; + BaseModel& model = *models_[i]; + model.set_seed(seed_ + i); + try { + + run_mcmc_sampler_single_thread(chain_result, model, no_iter_, no_warmup_, i, pm_); + + } catch (std::exception& e) { + chain_result.error = true; + chain_result.error_msg = e.what(); + } catch (...) { + chain_result.error = true; + chain_result.error_msg = "Unknown error"; + } + } + } +}; + +void run_mcmc_sampler_threaded( + std::vector& results, + std::vector>& models, + const int no_iter, + const int no_warmup, + const int seed, + const int no_threads, + ProgressManager& pm +) { + + GGMChainRunner runner(results, models, no_iter, no_warmup, seed, pm); + tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); + RcppParallel::parallelFor(0, results.size(), runner); +} + + +std::vector run_mcmc_sampler( + BaseModel& model, + const int no_iter, + const int no_warmup, + const int no_chains, + const int seed, + const int no_threads, + ProgressManager& pm +) { + + Rcpp::Rcout << "Allocating results objects..." << std::endl; + std::vector results(no_chains); + for (size_t c = 0; c < no_chains; ++c) { + results[c].reserve(model.parameter_dimension(), no_iter); + } + + if (no_threads > 1) { + + Rcpp::Rcout << "Running multi-threaded MCMC sampling..." << std::endl; + std::vector> models; + models.reserve(no_chains); + for (size_t c = 0; c < no_chains; ++c) { + models.push_back(model.clone()); // deep copy via virtual clone + } + run_mcmc_sampler_threaded(results, models, no_iter, no_warmup, seed, no_threads, pm); + } else { + model.set_seed(seed); + Rcpp::Rcout << "Running single-threaded MCMC sampling..." << std::endl; + for (size_t c = 0; c < no_chains; ++c) { + run_mcmc_sampler_single_thread(results[c], model, no_iter, no_warmup, c, pm); + } + } + return results; +} + +Rcpp::List convert_sampler_output_to_ggm_result(const std::vector& results) { + + Rcpp::List output(results.size()); + for (size_t i = 0; i < results.size(); ++i) { + + Rcpp::List chain_i; + chain_i["chain_id"] = results[i].chain_id; + if (results[i].error) { + chain_i["error"] = results[i].error_msg; + } else { + chain_i["samples"] = results[i].samples; + chain_i["userInterrupt"] = results[i].userInterrupt; + + } + output[i] = chain_i; + } + return output; +} + +// [[Rcpp::export]] +Rcpp::List sample_ggm( + const arma::mat& X, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const int no_chains, + const bool edge_selection, + const int seed, + const int no_threads, + int progress_type +) { + + // should be done dynamically + // also adaptation method should be specified differently + GGMModel model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection); + + Rcpp::Rcout << "GGMModel::parameter_dimension() returning: " << model.parameter_dimension() << std::endl; + + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + + std::vector output = run_mcmc_sampler(model, no_iter, no_warmup, no_chains, seed, no_threads, pm); + + Rcpp::List ggm_result = convert_sampler_output_to_ggm_result(output); + + pm.finish(); + + return ggm_result; +} \ No newline at end of file diff --git a/test_ggm.R b/test_ggm.R new file mode 100644 index 00000000..60c5ef0f --- /dev/null +++ b/test_ggm.R @@ -0,0 +1,72 @@ +library(bgms) + +# Dimension and true precision +p <- 10 + +adj <- matrix(0, nrow = p, ncol = p) +adj[lower.tri(adj)] <- rbinom(p * (p - 1) / 2, size = 1, prob = 0.3) +adj <- adj + t(adj) +# qgraph::qgraph(adj) +Omega <- BDgraph::rgwish(1, adj = adj, b = p + sample(0:p, 1), D = diag(p)) +Sigma <- solve(Omega) +zapsmall(Omega) + +# Data +n <- 1e2 +x <- mvtnorm::rmvnorm(n = n, mean = rep(0, p), sigma = Sigma) + + +# ---- Run MCMC with warmup and sampling ------------------------------------ + +# debugonce(mbgms:::bgm_gaussian) +sampling_results <- bgms:::sample_ggm( + X = x, + prior_inclusion_prob = matrix(.5, p, p), + initial_edge_indicators = adj, + no_iter = 4000, + no_warmup = 4000, + no_chains = 3, + edge_selection = FALSE, + no_threads = 1, + seed = 123, + progress_type = 2 +) + +profvis::profvis({ + sampling_results <- bgm_gaussian( + x = x, + n = n, + n_iter = 400, + n_warmup = 400, + n_phases = 10 + ) +}) + +# Extract results +aveOmega <- sampling_results$aveOmega +aveGamma <- sampling_results$aveGamma +aOmega <- sampling_results$aOmega +aGamma <- sampling_results$aGamma +prob <- sampling_results$prob +proposal_sd <- sampling_results$proposal_sd + +library(patchwork) +library(ggplot2) +df <- data.frame( + true = aveOmega[lower.tri(aveOmega)], + Omega[lower.tri(Omega)], + estimated = aveOmega[lower.tri(aveOmega)], + p_inclusion = aveGamma[lower.tri(aveGamma)] +) +p1 <- ggplot(df, aes(x = true, y = estimated)) + + geom_point(size = 5, alpha = 0.8, shape = 21, fill = "grey") + + geom_abline(slope = 1, intercept = 0, color = "grey") + + labs(x = "True Values Omega", y = "Estimated Values Omega (Posterior Mean)") +p2 <- ggplot(df, aes(x = estimated, y = p_inclusion)) + + geom_point(size = 5, alpha = 0.8, shape = 21, fill = "grey") + + labs( + x = "Estimated Values Omega (Posterior Mean)", + y = "Estimated Inclusion Probabilities" + ) +(p1 + p2) + plot_layout(ncol = 1) & theme_bw(base_size = 20) + From ddeff6e108a81a3c69e06500852a767c1018d2ee Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 12 Nov 2025 16:49:42 +0100 Subject: [PATCH 02/23] almost functional and O(P^2) per update --- src/RcppExports.cpp | 4 +- src/adaptiveMetropolis.h | 37 +++- src/base_model.h | 10 +- src/chainResultNew.h | 12 +- src/ggm_model.cpp | 460 ++++++++++++++++++++++++++++++++------- src/ggm_model.h | 40 +++- src/sample_ggm.cpp | 14 +- test_ggm.R | 50 ++++- 8 files changed, 510 insertions(+), 117 deletions(-) diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 91db45cc..1aebcbfe 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -194,7 +194,7 @@ BEGIN_RCPP END_RCPP } // sample_ggm -Rcpp::List sample_ggm(const arma::mat& X, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, int progress_type); +Rcpp::List sample_ggm(const arma::mat& X, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type); RcppExport SEXP _bgms_sample_ggm(SEXP XSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; @@ -208,7 +208,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); - Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); rcpp_result_gen = Rcpp::wrap(sample_ggm(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)); return rcpp_result_gen; END_RCPP diff --git a/src/adaptiveMetropolis.h b/src/adaptiveMetropolis.h index f8df5c63..d8d9cb6c 100644 --- a/src/adaptiveMetropolis.h +++ b/src/adaptiveMetropolis.h @@ -8,9 +8,8 @@ class AdaptiveProposal { public: AdaptiveProposal(size_t num_params, size_t adaption_window = 50, double target_accept = 0.44) { - proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD - // acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros); - // total_proposals_ = arma::ivec(num_params, arma::fill::zeros); + proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD, need to tweak this somehow? + acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros); adaptation_window_ = adaption_window; target_accept_ = target_accept; } @@ -20,16 +19,26 @@ class AdaptiveProposal { return proposal_sds_[param_index]; } - void update_proposal_sd(size_t param_index, double alpha) { + void update_proposal_sd(size_t param_index) { if (!adapting_) { return; } double current_sd = get_proposal_sd(param_index); - double updated_sd = current_sd + std::pow(1.0 / iterations_, 0.6) * (alpha - target_accept_); - // proposal_sds_[param_index] = std::min(20.0, std::max(1.0 / std::sqrt(n), updated_sd)); - proposal_sds_(param_index) = std::min(20.0, updated_sd); + double observed_acceptance_probability = acceptance_counts_[param_index] / static_cast(iterations_ + 1); + double rm_weight = std::pow(iterations_, -decay_rate_); + + // Robbins-Monro update step + double updated_sd = current_sd + (observed_acceptance_probability - target_accept_) * rm_weight; + updated_sd = std::clamp(updated_sd, rm_lower_bound, rm_upper_bound); + + proposal_sds_(param_index) = updated_sd; + } + + void increment_accepts(size_t param_index) { + validate_index(param_index); + acceptance_counts_[param_index]++; } void increment_iteration() { @@ -40,11 +49,15 @@ class AdaptiveProposal { } private: - arma::vec proposal_sds_; - int iterations_ = 0, - adaptation_window_; - double target_accept_ = 0.44; - bool adapting_ = true; + arma::vec proposal_sds_; + arma::ivec acceptance_counts_; + int iterations_ = 0, + adaptation_window_; + double target_accept_ = 0.44, + decay_rate_ = 0.75, + rm_lower_bound = 0.001, + rm_upper_bound = 2.0; + bool adapting_ = true; void validate_index(size_t index) const { if (index >= proposal_sds_.n_elem) { diff --git a/src/base_model.h b/src/base_model.h index 20cd4930..10f443dc 100644 --- a/src/base_model.h +++ b/src/base_model.h @@ -13,9 +13,9 @@ class BaseModel { virtual bool has_adaptive_mh() const { return false; } // Core methods (to be overridden by derived classes) - virtual double logp(const std::vector& parameters) = 0; + virtual double logp(const arma::vec& parameters) = 0; - virtual arma::vec gradient(const std::vector& parameters) { + virtual arma::vec gradient(const arma::vec& parameters) { if (!has_gradient()) { throw std::runtime_error("Gradient not implemented for this model"); } @@ -23,7 +23,7 @@ class BaseModel { } virtual std::pair logp_and_gradient( - const std::vector& parameters) { + const arma::vec& parameters) { if (!has_gradient()) { throw std::runtime_error("Gradient not implemented for this model"); } @@ -39,6 +39,10 @@ class BaseModel { throw std::runtime_error("get_vectorized_parameters method must be implemented in derived class"); } + virtual arma::ivec get_vectorized_indicator_parameters() { + throw std::runtime_error("get_vectorized_indicator_parameters method must be implemented in derived class"); + } + // Return dimensionality of the parameter space virtual size_t parameter_dimension() const = 0; diff --git a/src/chainResultNew.h b/src/chainResultNew.h index 5a6dd855..fa269a54 100644 --- a/src/chainResultNew.h +++ b/src/chainResultNew.h @@ -8,17 +8,17 @@ class ChainResultNew { public: ChainResultNew() {} - bool error; + bool error = false, + userInterrupt = false; std::string error_msg; - int chain_id; - bool userInterrupt; + int chain_id; - arma::mat samples; + arma::mat samples; - void reserve(size_t param_dim, size_t n_iter) { + void reserve(const size_t param_dim, const size_t n_iter) { samples.set_size(param_dim, n_iter); } - void store_sample(size_t iter, const arma::vec& sample) { + void store_sample(const size_t iter, const arma::vec& sample) { samples.col(iter) = sample; } diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index eee45bf7..0ffc19b2 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -10,10 +10,7 @@ double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, con void GGMModel::get_constants(size_t i, size_t j) { // TODO: helper function? - double logdet_omega = 0.0; - for (size_t i = 0; i < p_; i++) { - logdet_omega += std::log(phi_(i, i)); - } + double logdet_omega = get_log_det(phi_); double log_adj_omega_ii = logdet_omega + log(abs(inv_omega_(i, i))); double log_adj_omega_ij = logdet_omega + log(abs(inv_omega_(i, j))); @@ -22,7 +19,7 @@ void GGMModel::get_constants(size_t i, size_t j) { double inv_omega_sub_j1j1 = compute_inv_submatrix_i(inv_omega_, i, j, j); double log_abs_inv_omega_sub_jj = log_adj_omega_ii + log(abs(inv_omega_sub_j1j1)); - double Phi_q1q = (-1 * (2 * std::signbit(inv_omega_(i, j)) - 1)) * std::exp( + double Phi_q1q = (2 * std::signbit(inv_omega_(i, j)) - 1) * std::exp( (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) ); double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); @@ -40,10 +37,100 @@ double GGMModel::R(const double x) const { if (x == 0) { return constants_[6]; } else { - return constants_[3] + std::pow((x - constants_[3]) / constants_[4], 2); + return constants_[5] + std::pow((x - constants_[3]) / constants_[4], 2); } } +double GGMModel::get_log_det(arma::mat triangular_A) const { + // assume A is an (upper) triangular cholesky factor + // returns the log determinant of A'A + + // TODO: should we just do + // log_det(val, sign, trimatu(A))? + return 2 * arma::accu(arma::log(triangular_A.diag())); +} + +double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { + + double logdet_omega = get_log_det(phi); + // TODO: why not just dot(omega, suf_stat_)? + double trace_prod = arma::accu(omega % suf_stat_); + + double log_likelihood = n_ * (p_ * log(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; + + return log_likelihood; +} + +double GGMModel::log_density_impl_edge(size_t i, size_t j) const { + + // this is the log likelihood ratio, not the full log likelihood like GGMModel::log_density_impl + + double Ui2 = omega_(i, j) - omega_prop_(i, j); + // only reached from R + // if (omega_(j, j) == omega_prop_(j, j)) { + // k = i; + // i = j; + // j = k; + // } + double Uj2 = (omega_(j, j) - omega_prop_(j, j)) / 2; + + + // W <- matrix(c(0, 1, 1, 0), 2, 2) + // U0 <- matrix(c(0, -1, Ui2, Uj2)) + // U <- matrix(0, nrow(aOmega), 2) + // U[c(i, j), 1] <- c(0, -1) + // U[c(i, j), 2] <- c(Ui2, Uj2) + // aOmega_prop - (aOmega + U %*% W %*% t(U)) + // det(aOmega_prop) - det(aOmega + U %*% W %*% t(U)) + // det(aOmega_prop) - det(W + t(U) %*% inv_aOmega %*% U) * det(W) * det(aOmega) + // below computes logdet(W + t(U) %*% inv_aOmega %*% U) directly (this is a 2x2 matrix) + + double cc11 = 0 + inv_omega_(j, j); + double cc12 = 1 - (inv_omega_(i, j) * Ui2 + inv_omega_(j, j) * Uj2); + double cc22 = 0 + Ui2 * Ui2 * inv_omega_(i, i) + 2 * Ui2 * Uj2 * inv_omega_(i, j) + Uj2 * Uj2 * inv_omega_(j, j); + + double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); + // logdet - (logdet(aOmega_prop) - logdet(aOmega)) + + double trace_prod = -2 * (suf_stat_(j, j) * Uj2 + suf_stat_(i, j) * Ui2); + + // This function uses the fact that the determinant doesn't change during edge updates. + // double trace_prod = 0.0; + // // TODO: we only need one of the two lines below, but it's not entirely clear which one + // trace_prod += suf_stat_(j, j) * (omega_prop(j, j) - omega(j, j)); + // trace_prod += suf_stat_(i, i) * (omega_prop(i, i) - omega(i, i)); + // trace_prod += 2 * suf_stat_(i, j) * (omega_prop(i, j) - omega(i, j)); + // trace_prod - sum((aOmega_prop - aOmega) * SufStat) + + double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; + return log_likelihood_ratio; + +} + +double GGMModel::log_density_impl_diag(size_t j) const { + // same as above but for i == j, so Ui2 = 0 + double Uj2 = (omega_(j, j) - omega_prop_(j, j)) / 2; + + double cc11 = 0 + inv_omega_(j, j); + double cc12 = 1 - inv_omega_(j, j) * Uj2; + double cc22 = 0 + Uj2 * Uj2 * inv_omega_(j, j); + + double logdet = log(abs(cc11 * cc22 - cc12 * cc12)); + double trace_prod = -2 * suf_stat_(j, j) * Uj2; + + // This function uses the fact that the determinant doesn't change during edge updates. + // double trace_prod = 0.0; + // // TODO: we only need one of the two lines below, but it's not entirely clear which one + // trace_prod += suf_stat_(j, j) * (omega_prop(j, j) - omega(j, j)); + // trace_prod += suf_stat_(i, i) * (omega_prop(i, i) - omega(i, i)); + // trace_prod += 2 * suf_stat_(i, j) * (omega_prop(i, j) - omega(i, j)); + // trace_prod - sum((aOmega_prop - aOmega) * SufStat) + + double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; + return log_likelihood_ratio; + +} + void GGMModel::update_edge_parameter(size_t i, size_t j) { if (edge_indicators_(i, j) == 0) { @@ -67,13 +154,44 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) { omega_prop_(j, i) = omega_prop_q1q; omega_prop_(j, j) = omega_prop_qq; - double ln_alpha = log_density(omega_prop_) - log_density(); + // Rcpp::Rcout << "i: " << i << ", j: " << j << + // ", proposed phi: " << phi_prop << + // ", proposal_sd omega_ij: " << proposal_sd << + // ", proposed omega_ij: " << omega_prop_q1q << + // ", proposed omega_jj: " << omega_prop_qq << std::endl; + // constants_.print(Rcpp::Rcout, "Constants:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + + // arma::vec eigval = eig_sym(omega_prop_); + // if (arma::any(eigval <= 0)) { + // Rcpp::Rcout << "Warning: omega_prop_ is not positive definite for edge (" << i << ", " << j << ")" << std::endl; + + // Rcpp::Rcout << + // ", proposed phi: " << phi_prop << + // ", proposal_sd omega_ij: " << proposal_sd << + // ", proposed omega_ij: " << omega_prop_q1q << + // ", proposed omega_jj: " << omega_prop_qq << std::endl; + // constants_.print(Rcpp::Rcout, "Constants:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // omega_.print(Rcpp::Rcout, "Current omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + + // } + + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_edge(i, j); + + if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { + Rcpp::Rcout << "Warning: log density implementations do not match for edge (" << i << ", " << j << ")" << std::endl; + } + ln_alpha += R::dcauchy(omega_prop_(i, j), 0.0, 2.5, true); ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); - double u = runif(rng_); - if (ln_alpha > log(u)) { + if (std::log(runif(rng_)) < ln_alpha) { // accept proposal + proposal_.increment_accepts(e); double omega_ij = omega_(i, j); double omega_jj = omega_(j, j); @@ -89,10 +207,10 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) { arma::vec vf1 = arma::zeros(p_); arma::vec vf2 = arma::zeros(p_); - vf1[i] = v1[1]; - vf1[j] = v1[2]; - vf2[i] = v2[1]; - vf2[j] = v2[2]; + vf1[i] = v1[0]; + vf1[j] = v1[1]; + vf2[i] = v2[0]; + vf2[j] = v2[1]; // we now have // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) @@ -105,68 +223,24 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) { // and also // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) - // update phi + // update phi (2x O(p^2)) cholesky_update(phi_, u1); cholesky_downdate(phi_, u2); - // update inverse - inv_omega_ = phi_.t() * phi_; + // update inverse (2x O(p^2)) + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); } - double alpha = std::min(1.0, std::exp(ln_alpha)); - proposal_.update_proposal_sd(e, alpha); -} - -void GGMModel::do_one_mh_step() { - - // Update off-diagonals (upper triangle) - for (size_t i = 0; i < p_ - 1; ++i) { - for (size_t j = i + 1; j < p_; ++j) { - Rcpp::Rcout << "Updating edge parameter (" << i << ", " << j << ")" << std::endl; - update_edge_parameter(i, j); - } - } - - // Update diagonals - for (size_t i = 0; i < p_; ++i) { - Rcpp::Rcout << "Updating diagonal parameter " << i << std::endl; - update_diagonal_parameter(i); - } - - // if (edge_selection_) { - // for (size_t i = 0; i < p_ - 1; ++i) { - // for (size_t j = i + 1; j < p_; ++j) { - // update_edge_indicator_parameter_pair(i, j); - // } - // } - // } - proposal_.increment_iteration(); -} - -double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { - double logdet_omega = 0.0; - for (size_t i = 0; i < p_; i++) { - logdet_omega += std::log(phi(i, i)); - } - - // TODO: does this allocate? - double trace_prod = arma::accu(omega % suf_stat_); - - double log_likelihood = n_ * (p_ * log(2 * arma::datum::pi) / 2 + logdet_omega / 2) - trace_prod / 2; - - return log_likelihood; + proposal_.update_proposal_sd(e); } void GGMModel::update_diagonal_parameter(size_t i) { // Implementation of diagonal parameter update // 1-3) from before - double logdet_omega = 0.0; - for (size_t i = 0; i < p_; i++) { - logdet_omega += std::log(phi_(i, i)); - } - + double logdet_omega = get_log_det(phi_); double logdet_omega_sub_ii = logdet_omega + std::log(inv_omega_(i, i)); size_t e = i * (i + 1) / 2 + i; // parameter index in vectorized form @@ -179,37 +253,277 @@ void GGMModel::update_diagonal_parameter(size_t i) { omega_prop_ = omega_; omega_prop_(i, i) = omega_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); + // Rcpp::Rcout << "i: " << i << + // ", current theta: " << theta_curr << + // ", proposed theta: " << theta_prop << + // ", proposal_sd: " << proposal_sd << std::endl; + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // 5) Acceptance ratio - double ln_alpha = log_density(omega_prop_) - log_density(); + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_diag(i); + if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { + Rcpp::Rcout << "Warning: log density implementations do not match for diag (" << i << ", " << i << ")" << std::endl; + } + ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); ln_alpha += theta_prop - theta_curr; // Jacobian adjustment ? + if (std::log(runif(rng_)) < ln_alpha) { - if (log(runif(rng_)) < ln_alpha) { + proposal_.increment_accepts(e); double omega_ii = omega_(i, i); arma::vec u(p_, arma::fill::zeros); double delta = omega_ii - omega_prop_(i, i); bool s = delta > 0; - u(i) = sqrt(abs(delta)); + u(i) = std::sqrt(std::abs(delta)); omega_(i, i) = omega_prop_(i, i); - if (!s) - cholesky_update(phi_, u); - else + if (s) cholesky_downdate(phi_, u); + else + cholesky_update(phi_, u); + + // update inverse (2x O(p^2)) + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); - inv_omega_ = phi_.t() * phi_; } - double alpha = std::min(1.0, std::exp(ln_alpha)); - proposal_.update_proposal_sd(e, alpha); + proposal_.update_proposal_sd(e); } void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { - // Implementation of edge indicator parameter pair update + + size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form + double proposal_sd = proposal_.get_proposal_sd(e); + + if (edge_indicators_(i, j) == 1) { + // Propose to turn OFF the edge + omega_prop_ = omega_; + omega_prop_(i, j) = 0.0; + omega_prop_(j, i) = 0.0; + + // Update diagonal using R function with omega_ij = 0 + get_constants(i, j); + omega_prop_(j, j) = R(0.0); + + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_edge(i, j); + if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { + Rcpp::Rcout << "Warning: log density indicator implementations do not match for edge (" << i << ", " << j << ")" << std::endl; + } + + ln_alpha += std::log(1.0 - prior_inclusion_prob_(i, j)) - std::log(prior_inclusion_prob_(i, j)); + + ln_alpha += R::dnorm(omega_(i, j) / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); + ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); + + if (std::log(runif(rng_)) < ln_alpha) { + + // Store old values for Cholesky update + double omega_ij_old = omega_(i, j); + double omega_jj_old = omega_(j, j); + + // Update omega + omega_(i, j) = 0.0; + omega_(j, i) = 0.0; + omega_(j, j) = omega_prop_(j, j); + + // Update edge indicator + edge_indicators_(i, j) = 0; + edge_indicators_(j, i) = 0; + + // Cholesky update vectors + arma::vec v1 = {0, -1}; + arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2}; + + arma::vec vf1 = arma::zeros(p_); + arma::vec vf2 = arma::zeros(p_); + vf1[i] = v1[0]; + vf1[j] = v1[1]; + vf2[i] = v2[0]; + vf2[j] = v2[1]; + + arma::vec u1 = (vf1 + vf2) / sqrt(2); + arma::vec u2 = (vf1 - vf2) / sqrt(2); + + // Update Cholesky factor + cholesky_update(phi_, u1); + cholesky_downdate(phi_, u2); + + // Update inverse + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); + } + + } else { + // Propose to turn ON the edge + double epsilon = rnorm(rng_, 0.0, proposal_sd); + + // Get constants for current state (with edge OFF) + get_constants(i, j); + double omega_prop_ij = constants_[4] * epsilon; + double omega_prop_jj = R(omega_prop_ij); + + omega_prop_ = omega_; + omega_prop_(i, j) = omega_prop_ij; + omega_prop_(j, i) = omega_prop_ij; + omega_prop_(j, j) = omega_prop_jj; + + // double ln_alpha = log_density(omega_prop_) - log_density(); + double ln_alpha = log_density_impl_edge(i, j); + if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { + Rcpp::Rcout << "Warning: log density indicator implementations do not match for edge (" << i << ", " << j << ")" << std::endl; + } + ln_alpha += std::log(prior_inclusion_prob_(i, j)) - std::log(1.0 - prior_inclusion_prob_(i, j)); + + // Prior change: add slab (Cauchy prior) + ln_alpha += R::dcauchy(omega_prop_ij, 0.0, 2.5, true); + + // Proposal term: proposed edge value given it was generated from truncated normal + ln_alpha -= R::dnorm(omega_prop_ij / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); + + // TODO: this can be factored out? + if (std::log(runif(rng_)) < ln_alpha) { + // Accept: turn ON the edge + proposal_.increment_accepts(e); + + // Store old values for Cholesky update + double omega_ij_old = omega_(i, j); + double omega_jj_old = omega_(j, j); + + // Update omega + omega_(i, j) = omega_prop_ij; + omega_(j, i) = omega_prop_ij; + omega_(j, j) = omega_prop_jj; + + // Update edge indicator + edge_indicators_(i, j) = 1; + edge_indicators_(j, i) = 1; + + // Cholesky update vectors + arma::vec v1 = {0, -1}; + arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2}; + + arma::vec vf1 = arma::zeros(p_); + arma::vec vf2 = arma::zeros(p_); + vf1[i] = v1[0]; + vf1[j] = v1[1]; + vf2[i] = v2[0]; + vf2[j] = v2[1]; + + arma::vec u1 = (vf1 + vf2) / sqrt(2); + arma::vec u2 = (vf1 - vf2) / sqrt(2); + + // Update Cholesky factor + cholesky_update(phi_, u1); + cholesky_downdate(phi_, u2); + + // Update inverse + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); + } + } +} + +void GGMModel::do_one_mh_step() { + + // Update off-diagonals (upper triangle) + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + // Rcpp::Rcout << "Updating edge parameter (" << i << ", " << j << ")" << std::endl; + update_edge_parameter(i, j); + // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // } + // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); + // } + } + } + + // Update diagonals + for (size_t i = 0; i < p_; ++i) { + // Rcpp::Rcout << "Updating diagonal parameter " << i << std::endl; + update_diagonal_parameter(i); + + // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating diagonal " << i << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating diagonal " << i << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // } + // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating diagonal " << i << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating diagonal " << i << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); + // } + } + + if (edge_selection_) { + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + // Rcpp::Rcout << "Between model move for edge (" << i << ", " << j << ")" << std::endl; + update_edge_indicator_parameter_pair(i, j); + // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // } + // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; + // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); + // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); + // } + // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { + // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Omega:"); + // phi_.print(Rcpp::Rcout, "Phi:"); + // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); + // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); + // } + } + } + } + + // could also be called in the main MCMC loop + proposal_.increment_iteration(); } diff --git a/src/ggm_model.h b/src/ggm_model.h index 05081db0..48158588 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -22,9 +22,12 @@ class GGMModel : public BaseModel { proposal_(AdaptiveProposal(dim_, 500)), omega_(arma::eye(p_, p_)), phi_(arma::eye(p_, p_)), + inv_phi_(arma::eye(p_, p_)), inv_omega_(arma::eye(p_, p_)), edge_indicators_(initial_edge_indicators), vectorized_parameters_(dim_), + vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), + omega_prop_(arma::mat(p_, p_, arma::fill::none)), constants_(6) {} @@ -38,9 +41,11 @@ class GGMModel : public BaseModel { edge_selection_(other.edge_selection_), omega_(other.omega_), phi_(other.phi_), + inv_phi_(other.inv_phi_), inv_omega_(other.inv_omega_), edge_indicators_(other.edge_indicators_), vectorized_parameters_(other.vectorized_parameters_), + vectorized_indicator_parameters_(other.vectorized_indicator_parameters_), proposal_(other.proposal_), rng_(other.rng_), omega_prop_(other.omega_prop_), @@ -55,10 +60,10 @@ class GGMModel : public BaseModel { proposal_ = proposal; } - virtual bool has_gradient() const { return false; } - virtual bool has_adaptive_mh() const override { return true; } + bool has_gradient() const { return false; } + bool has_adaptive_mh() const override { return true; } - double logp(const std::vector& parameters) override { + double logp(const arma::vec& parameters) override { // Implement log probability computation return 0.0; } @@ -70,7 +75,6 @@ class GGMModel : public BaseModel { void do_one_mh_step() override; size_t parameter_dimension() const override { - Rcpp::Rcout << "GGMModel::parameter_dimension() returning: " << dim_ << std::endl; return dim_; } @@ -79,33 +83,47 @@ class GGMModel : public BaseModel { } arma::vec get_vectorized_parameters() override { + // upper triangle of omega_ size_t e = 0; for (size_t j = 0; j < p_; ++j) { for (size_t i = 0; i <= j; ++i) { vectorized_parameters_(e) = omega_(i, j); - e++; + ++e; } } return vectorized_parameters_; } + arma::ivec get_vectorized_indicator_parameters() override { + // upper triangle of omega_ + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + vectorized_indicator_parameters_(e) = edge_indicators_(i, j); + ++e; + } + } + return vectorized_indicator_parameters_; + } + std::unique_ptr clone() const override { return std::make_unique(*this); // uses copy constructor } private: // data - size_t dim_; - arma::mat suf_stat_; size_t n_; size_t p_; + size_t dim_; + arma::mat suf_stat_; arma::mat prior_inclusion_prob_; - bool edge_selection_ = true; + bool edge_selection_; // parameters - arma::mat omega_, phi_, inv_omega_; + arma::mat omega_, phi_, inv_phi_, inv_omega_; arma::imat edge_indicators_; arma::vec vectorized_parameters_; + arma::ivec vectorized_indicator_parameters_; AdaptiveProposal proposal_; @@ -126,7 +144,9 @@ class GGMModel : public BaseModel { double R(const double x) const; double log_density_impl(const arma::mat& omega, const arma::mat& phi) const; - + double log_density_impl_edge(size_t i, size_t j) const; + double log_density_impl_diag(size_t j) const; + double get_log_det(arma::mat triangular_A) const; // double find_reasonable_step_size_edge(const arma::mat& omega, size_t i, size_t j); // double find_reasonable_step_size_diag(const arma::mat& omega, size_t i); // double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal); diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 0d9de749..8ca13568 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -17,6 +17,7 @@ void run_mcmc_sampler_single_thread( ProgressManager& pm ) { + chain_result.chain_id = chain_id + 1; size_t i = 0; for (size_t iter = 0; iter < no_iter + no_warmup; ++iter) { @@ -25,7 +26,7 @@ void run_mcmc_sampler_single_thread( if (iter >= no_warmup) { chain_result.store_sample(i, model.get_vectorized_parameters()); - i++; + ++i; } pm.update(chain_id); @@ -64,8 +65,6 @@ struct GGMChainRunner : public RcppParallel::Worker { for (std::size_t i = begin; i < end; ++i) { ChainResultNew& chain_result = results_[i]; - chain_result.chain_id = static_cast(i + 1); - chain_result.error = false; BaseModel& model = *models_[i]; model.set_seed(seed_ + i); try { @@ -124,12 +123,17 @@ std::vector run_mcmc_sampler( models.push_back(model.clone()); // deep copy via virtual clone } run_mcmc_sampler_threaded(results, models, no_iter, no_warmup, seed, no_threads, pm); + } else { + model.set_seed(seed); Rcpp::Rcout << "Running single-threaded MCMC sampling..." << std::endl; + // TODO: this is actually not correct, each chain should have its own model object + // now chain 2 continues from chain 1 state for (size_t c = 0; c < no_chains; ++c) { run_mcmc_sampler_single_thread(results[c], model, no_iter, no_warmup, c, pm); } + } return results; } @@ -164,15 +168,13 @@ Rcpp::List sample_ggm( const bool edge_selection, const int seed, const int no_threads, - int progress_type + const int progress_type ) { // should be done dynamically // also adaptation method should be specified differently GGMModel model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection); - Rcpp::Rcout << "GGMModel::parameter_dimension() returning: " << model.parameter_dimension() << std::endl; - ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); std::vector output = run_mcmc_sampler(model, no_iter, no_warmup, no_chains, seed, no_threads, pm); diff --git a/test_ggm.R b/test_ggm.R index 60c5ef0f..5d0ca2c4 100644 --- a/test_ggm.R +++ b/test_ggm.R @@ -1,7 +1,7 @@ library(bgms) # Dimension and true precision -p <- 10 +p <- 50 adj <- matrix(0, nrow = p, ncol = p) adj[lower.tri(adj)] <- rbinom(p * (p - 1) / 2, size = 1, prob = 0.3) @@ -12,7 +12,7 @@ Sigma <- solve(Omega) zapsmall(Omega) # Data -n <- 1e2 +n <- 1e3 x <- mvtnorm::rmvnorm(n = n, mean = rep(0, p), sigma = Sigma) @@ -23,15 +23,55 @@ sampling_results <- bgms:::sample_ggm( X = x, prior_inclusion_prob = matrix(.5, p, p), initial_edge_indicators = adj, - no_iter = 4000, - no_warmup = 4000, + no_iter = 500, + no_warmup = 500, no_chains = 3, edge_selection = FALSE, no_threads = 1, seed = 123, - progress_type = 2 + progress_type = 1 ) +true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) +posterior_means <- rowMeans(sampling_results[[2]]$samples) + +plot(true_values, posterior_means) +abline(0, 1) + +sampling_results2 <- bgms:::sample_ggm( + X = x, + prior_inclusion_prob = matrix(.5, p, p), + initial_edge_indicators = adj, + no_iter = 500, + no_warmup = 500, + no_chains = 3, + edge_selection = TRUE, + no_threads = 1, + seed = 123, + progress_type = 1 +) + +true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) +posterior_means <- rowMeans(sampling_results2[[2]]$samples) + +plot(true_values, posterior_means) +abline(0, 1) + +plot(posterior_means, rowMeans(sampling_results2[[2]]$samples != 0)) + + +mmm <- matrix(c( + 1.6735, 0, 0, 0, 0, + 0, 1.0000, 0, 0, -3.4524, + 0, 0, 1.0000, 0, 0, + 0, 0, 0, 1.0000, 0, + 0, -3.4524, 0, 0, 9.6674 +), p, p) +mmm +chol(mmm) +base::isSymmetric(mmm) +eigen(mmm) + profvis::profvis({ sampling_results <- bgm_gaussian( x = x, From 481975c81ffaf5a7b4182a86a422cece3e1cb27d Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Thu, 13 Nov 2025 09:16:24 +0100 Subject: [PATCH 03/23] can be tested --- src/ggm_model.cpp | 60 ++++++++++++++++++++++++++++++----------------- test_ggm.R | 3 ++- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index 0ffc19b2..db148fba 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -94,14 +94,6 @@ double GGMModel::log_density_impl_edge(size_t i, size_t j) const { double trace_prod = -2 * (suf_stat_(j, j) * Uj2 + suf_stat_(i, j) * Ui2); - // This function uses the fact that the determinant doesn't change during edge updates. - // double trace_prod = 0.0; - // // TODO: we only need one of the two lines below, but it's not entirely clear which one - // trace_prod += suf_stat_(j, j) * (omega_prop(j, j) - omega(j, j)); - // trace_prod += suf_stat_(i, i) * (omega_prop(i, i) - omega(i, i)); - // trace_prod += 2 * suf_stat_(i, j) * (omega_prop(i, j) - omega(i, j)); - // trace_prod - sum((aOmega_prop - aOmega) * SufStat) - double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; return log_likelihood_ratio; @@ -115,7 +107,7 @@ double GGMModel::log_density_impl_diag(size_t j) const { double cc12 = 1 - inv_omega_(j, j) * Uj2; double cc22 = 0 + Uj2 * Uj2 * inv_omega_(j, j); - double logdet = log(abs(cc11 * cc22 - cc12 * cc12)); + double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); double trace_prod = -2 * suf_stat_(j, j) * Uj2; // This function uses the fact that the determinant doesn't change during edge updates. @@ -182,9 +174,15 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) { // double ln_alpha = log_density(omega_prop_) - log_density(); double ln_alpha = log_density_impl_edge(i, j); - if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { - Rcpp::Rcout << "Warning: log density implementations do not match for edge (" << i << ", " << j << ")" << std::endl; - } + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Current omega:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } ln_alpha += R::dcauchy(omega_prop_(i, j), 0.0, 2.5, true); ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); @@ -262,9 +260,16 @@ void GGMModel::update_diagonal_parameter(size_t i) { // 5) Acceptance ratio // double ln_alpha = log_density(omega_prop_) - log_density(); double ln_alpha = log_density_impl_diag(i); - if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { - Rcpp::Rcout << "Warning: log density implementations do not match for diag (" << i << ", " << i << ")" << std::endl; - } + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for diag (" << i << ", " << i << ")" << std::endl; + // // omega_.print(Rcpp::Rcout, "Current omega:"); + // // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // Rcpp::Rcout << "1e4 * diff: " << 10000 * (ln_alpha - ln_alpha_ref) << std::endl; + // } + // } ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); @@ -315,9 +320,16 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { // double ln_alpha = log_density(omega_prop_) - log_density(); double ln_alpha = log_density_impl_edge(i, j); - if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { - Rcpp::Rcout << "Warning: log density indicator implementations do not match for edge (" << i << ", " << j << ")" << std::endl; - } + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Current omega:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } + ln_alpha += std::log(1.0 - prior_inclusion_prob_(i, j)) - std::log(prior_inclusion_prob_(i, j)); @@ -378,9 +390,15 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { // double ln_alpha = log_density(omega_prop_) - log_density(); double ln_alpha = log_density_impl_edge(i, j); - if (std::abs(ln_alpha - (log_density(omega_prop_) - log_density())) > 1e-6) { - Rcpp::Rcout << "Warning: log density indicator implementations do not match for edge (" << i << ", " << j << ")" << std::endl; - } + // { + // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { + // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; + // omega_.print(Rcpp::Rcout, "Current omega:"); + // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; + // } + // } ln_alpha += std::log(prior_inclusion_prob_(i, j)) - std::log(1.0 - prior_inclusion_prob_(i, j)); // Prior change: add slab (Cauchy prior) diff --git a/test_ggm.R b/test_ggm.R index 5d0ca2c4..5f1995ca 100644 --- a/test_ggm.R +++ b/test_ggm.R @@ -1,7 +1,7 @@ library(bgms) # Dimension and true precision -p <- 50 +p <- 10 adj <- matrix(0, nrow = p, ncol = p) adj[lower.tri(adj)] <- rbinom(p * (p - 1) / 2, size = 1, prob = 0.3) @@ -34,6 +34,7 @@ sampling_results <- bgms:::sample_ggm( true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) posterior_means <- rowMeans(sampling_results[[2]]$samples) +cbind(true_values, posterior_means) plot(true_values, posterior_means) abline(0, 1) From c6da79c667052eb029d4eaf5f09b1d1211d35d4b Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 19 Nov 2025 12:55:17 +0100 Subject: [PATCH 04/23] reduce allocations and factor out common cholesky stuff --- src/ggm_model.cpp | 219 ++++++++++++++++++++++++++++++---------------- src/ggm_model.h | 9 ++ 2 files changed, 151 insertions(+), 77 deletions(-) diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index db148fba..9d475176 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -12,13 +12,12 @@ void GGMModel::get_constants(size_t i, size_t j) { // TODO: helper function? double logdet_omega = get_log_det(phi_); - double log_adj_omega_ii = logdet_omega + log(abs(inv_omega_(i, i))); - double log_adj_omega_ij = logdet_omega + log(abs(inv_omega_(i, j))); - double log_adj_omega_jj = logdet_omega + log(abs(inv_omega_(j, j))); + double log_adj_omega_ii = logdet_omega + std::log(std::abs(inv_omega_(i, i))); + double log_adj_omega_ij = logdet_omega + std::log(std::abs(inv_omega_(i, j))); + double log_adj_omega_jj = logdet_omega + std::log(std::abs(inv_omega_(j, j))); double inv_omega_sub_j1j1 = compute_inv_submatrix_i(inv_omega_, i, j, j); - double log_abs_inv_omega_sub_jj = log_adj_omega_ii + log(abs(inv_omega_sub_j1j1)); - + double log_abs_inv_omega_sub_jj = log_adj_omega_ii + std::log(std::abs(inv_omega_sub_j1j1)); double Phi_q1q = (2 * std::signbit(inv_omega_(i, j)) - 1) * std::exp( (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) ); @@ -191,49 +190,89 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) { // accept proposal proposal_.increment_accepts(e); - double omega_ij = omega_(i, j); - double omega_jj = omega_(j, j); + double omega_ij_old = omega_(i, j); + double omega_jj_old = omega_(j, j); + omega_(i, j) = omega_prop_q1q; omega_(j, i) = omega_prop_q1q; omega_(j, j) = omega_prop_qq; - // TODO: preallocate? - // find v for low rank update - arma::vec v1 = {0, -1}; - arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2}; + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + + // // TODO: preallocate? + // // find v for low rank update + // arma::vec v1 = {0, -1}; + // arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2}; - arma::vec vf1 = arma::zeros(p_); - arma::vec vf2 = arma::zeros(p_); - vf1[i] = v1[0]; - vf1[j] = v1[1]; - vf2[i] = v2[0]; - vf2[j] = v2[1]; + // arma::vec vf1 = arma::zeros(p_); + // arma::vec vf2 = arma::zeros(p_); + // vf1[i] = v1[0]; + // vf1[j] = v1[1]; + // vf2[i] = v2[0]; + // vf2[j] = v2[1]; - // we now have - // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) + // // we now have + // // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) - arma::vec u1 = (vf1 + vf2) / sqrt(2); - arma::vec u2 = (vf1 - vf2) / sqrt(2); + // arma::vec u1 = (vf1 + vf2) / sqrt(2); + // arma::vec u2 = (vf1 - vf2) / sqrt(2); - // we now have - // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) - // and also - // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) + // // we now have + // // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) + // // and also + // // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) - // update phi (2x O(p^2)) - cholesky_update(phi_, u1); - cholesky_downdate(phi_, u2); + // // update phi (2x O(p^2)) + // cholesky_update(phi_, u1); + // cholesky_downdate(phi_, u2); - // update inverse (2x O(p^2)) - arma::inv(inv_phi_, arma::trimatu(phi_)); - inv_omega_ = inv_phi_ * inv_phi_.t(); + // // update inverse (2x O(p^2)) + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); } proposal_.update_proposal_sd(e); } +void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) +{ + + v2_[0] = omega_ij_old - omega_prop_(i, j); + v2_[1] = (omega_jj_old - omega_prop_(j, j)) / 2; + + vf1_[i] = v1_[0]; + vf1_[j] = v1_[1]; + vf2_[i] = v2_[0]; + vf2_[j] = v2_[1]; + + // we now have + // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) + + u1_ = (vf1_ + vf2_) / sqrt(2); + u2_ = (vf1_ - vf2_) / sqrt(2); + + // we now have + // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) + // and also + // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) + + // update phi (2x O(p^2)) + cholesky_update(phi_, u1_); + cholesky_downdate(phi_, u2_); + + // update inverse (2x O(p^2)) + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); + + // reset for next iteration + vf1_[i] = 0.0; + vf1_[j] = 0.0; + vf2_[i] = 0.0; + vf2_[j] = 0.0; + +} void GGMModel::update_diagonal_parameter(size_t i) { // Implementation of diagonal parameter update @@ -280,22 +319,24 @@ void GGMModel::update_diagonal_parameter(size_t i) { proposal_.increment_accepts(e); double omega_ii = omega_(i, i); + omega_(i, i) = omega_prop_(i, i); - arma::vec u(p_, arma::fill::zeros); - double delta = omega_ii - omega_prop_(i, i); - bool s = delta > 0; - u(i) = std::sqrt(std::abs(delta)); + cholesky_update_after_diag(omega_ii, i); - omega_(i, i) = omega_prop_(i, i); + // arma::vec u(p_, arma::fill::zeros); + // double delta = omega_ii - omega_prop_(i, i); + // bool s = delta > 0; + // u(i) = std::sqrt(std::abs(delta)); - if (s) - cholesky_downdate(phi_, u); - else - cholesky_update(phi_, u); - // update inverse (2x O(p^2)) - arma::inv(inv_phi_, arma::trimatu(phi_)); - inv_omega_ = inv_phi_ * inv_phi_.t(); + // if (s) + // cholesky_downdate(phi_, u); + // else + // cholesky_update(phi_, u); + + // // update inverse (2x O(p^2)) + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); } @@ -303,6 +344,28 @@ void GGMModel::update_diagonal_parameter(size_t i) { proposal_.update_proposal_sd(e); } +void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i) +{ + + double delta = omega_ii_old - omega_prop_(i, i); + + bool s = delta > 0; + vf1_(i) = std::sqrt(std::abs(delta)); + + if (s) + cholesky_downdate(phi_, vf1_); + else + cholesky_update(phi_, vf1_); + + // update inverse (2x O(p^2)) + arma::inv(inv_phi_, arma::trimatu(phi_)); + inv_omega_ = inv_phi_ * inv_phi_.t(); + + // reset for next iteration + vf1_(i) = 0.0; +} + + void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form @@ -351,27 +414,28 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { edge_indicators_(i, j) = 0; edge_indicators_(j, i) = 0; - // Cholesky update vectors - arma::vec v1 = {0, -1}; - arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2}; + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + // // Cholesky update vectors + // arma::vec v1 = {0, -1}; + // arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2}; - arma::vec vf1 = arma::zeros(p_); - arma::vec vf2 = arma::zeros(p_); - vf1[i] = v1[0]; - vf1[j] = v1[1]; - vf2[i] = v2[0]; - vf2[j] = v2[1]; + // arma::vec vf1 = arma::zeros(p_); + // arma::vec vf2 = arma::zeros(p_); + // vf1[i] = v1[0]; + // vf1[j] = v1[1]; + // vf2[i] = v2[0]; + // vf2[j] = v2[1]; - arma::vec u1 = (vf1 + vf2) / sqrt(2); - arma::vec u2 = (vf1 - vf2) / sqrt(2); + // arma::vec u1 = (vf1 + vf2) / sqrt(2); + // arma::vec u2 = (vf1 - vf2) / sqrt(2); - // Update Cholesky factor - cholesky_update(phi_, u1); - cholesky_downdate(phi_, u2); + // // Update Cholesky factor + // cholesky_update(phi_, u1); + // cholesky_downdate(phi_, u2); - // Update inverse - arma::inv(inv_phi_, arma::trimatu(phi_)); - inv_omega_ = inv_phi_ * inv_phi_.t(); + // // Update inverse + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); } } else { @@ -425,27 +489,28 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { edge_indicators_(i, j) = 1; edge_indicators_(j, i) = 1; - // Cholesky update vectors - arma::vec v1 = {0, -1}; - arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2}; + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + // // Cholesky update vectors + // arma::vec v1 = {0, -1}; + // arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2}; - arma::vec vf1 = arma::zeros(p_); - arma::vec vf2 = arma::zeros(p_); - vf1[i] = v1[0]; - vf1[j] = v1[1]; - vf2[i] = v2[0]; - vf2[j] = v2[1]; + // arma::vec vf1 = arma::zeros(p_); + // arma::vec vf2 = arma::zeros(p_); + // vf1[i] = v1[0]; + // vf1[j] = v1[1]; + // vf2[i] = v2[0]; + // vf2[j] = v2[1]; - arma::vec u1 = (vf1 + vf2) / sqrt(2); - arma::vec u2 = (vf1 - vf2) / sqrt(2); + // arma::vec u1 = (vf1 + vf2) / sqrt(2); + // arma::vec u2 = (vf1 - vf2) / sqrt(2); - // Update Cholesky factor - cholesky_update(phi_, u1); - cholesky_downdate(phi_, u2); + // // Update Cholesky factor + // cholesky_update(phi_, u1); + // cholesky_downdate(phi_, u2); - // Update inverse - arma::inv(inv_phi_, arma::trimatu(phi_)); - inv_omega_ = inv_phi_ * inv_phi_.t(); + // // Update inverse + // arma::inv(inv_phi_, arma::trimatu(phi_)); + // inv_omega_ = inv_phi_ * inv_phi_.t(); } } } diff --git a/src/ggm_model.h b/src/ggm_model.h index 48158588..48af799e 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -133,6 +133,13 @@ class GGMModel : public BaseModel { arma::mat omega_prop_; arma::vec constants_; // Phi_q1q, Phi_q1q1, c[1], c[2], c[3], c[4] + arma::vec v1_ = {0, -1}; + arma::vec v2_ = {0, 0}; + arma::vec vf1_ = arma::zeros(p_); + arma::vec vf2_ = arma::zeros(p_); + arma::vec u1_ = arma::zeros(p_); + arma::vec u2_ = arma::zeros(p_); + // Parameter group updates with optimized likelihood evaluations void update_edge_parameter(size_t i, size_t j); void update_diagonal_parameter(size_t i); @@ -147,6 +154,8 @@ class GGMModel : public BaseModel { double log_density_impl_edge(size_t i, size_t j) const; double log_density_impl_diag(size_t j) const; double get_log_det(arma::mat triangular_A) const; + void cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j); + void cholesky_update_after_diag(double omega_ii_old, size_t i); // double find_reasonable_step_size_edge(const arma::mat& omega, size_t i, size_t j); // double find_reasonable_step_size_diag(const arma::mat& omega, size_t i); // double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal); From 7b776f3d9e9b5b3998276e79bcf52ab1e27290ef Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 2 Dec 2025 09:36:01 +0100 Subject: [PATCH 05/23] R interface for raw data and sufficient statistics --- R/RcppExports.R | 4 ++-- src/RcppExports.cpp | 8 ++++---- src/ggm_model.cpp | 32 ++++++++++++++++++++++++++++++++ src/ggm_model.h | 33 +++++++++++++++++++++++++++++++++ src/sample_ggm.cpp | 5 +++-- 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 808269a3..715c8c12 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -29,8 +29,8 @@ run_simulation_parallel <- function(pairwise_samples, main_samples, draw_indices .Call(`_bgms_run_simulation_parallel`, pairwise_samples, main_samples, draw_indices, no_states, no_variables, no_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } -sample_ggm <- function(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) { - .Call(`_bgms_sample_ggm`, X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) } compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 1aebcbfe..aa1b1985 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -194,12 +194,12 @@ BEGIN_RCPP END_RCPP } // sample_ggm -Rcpp::List sample_ggm(const arma::mat& X, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type); -RcppExport SEXP _bgms_sample_ggm(SEXP XSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP); + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); @@ -209,7 +209,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); - rcpp_result_gen = Rcpp::wrap(sample_ggm(X, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)); return rcpp_result_gen; END_RCPP } diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index 9d475176..6b8e0202 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -610,3 +610,35 @@ void GGMModel::do_one_mh_step() { // could also be called in the main MCMC loop proposal_.increment_iteration(); } + + +GGMModel createGGMFromR( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection +) { + + if (inputFromR.containsElementNamed("n") && inputFromR.containsElementNamed("suf_stat")) { + int n = Rcpp::as(inputFromR["n"]); + arma::mat suf_stat = Rcpp::as(inputFromR["suf_stat"]); + return GGMModel( + n, + suf_stat, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection + ); + } else if (inputFromR.containsElementNamed("X")) { + arma::mat X = Rcpp::as(inputFromR["X"]); + return GGMModel( + X, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection + ); + } else { + throw std::invalid_argument("Input list must contain either 'X' or both 'n' and 'suf_stat'."); + } + +} diff --git a/src/ggm_model.h b/src/ggm_model.h index 48af799e..dcf4a35e 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -5,6 +5,7 @@ #include "adaptiveMetropolis.h" #include "rng_utils.h" + class GGMModel : public BaseModel { public: @@ -31,6 +32,30 @@ class GGMModel : public BaseModel { constants_(6) {} + GGMModel( + const int n, + const arma::mat& suf_stat, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) : n_(n), + p_(suf_stat.n_cols), + dim_((p_ * (p_ + 1)) / 2), + suf_stat_(suf_stat), + prior_inclusion_prob_(prior_inclusion_prob), + edge_selection_(edge_selection), + proposal_(AdaptiveProposal(dim_, 500)), + omega_(arma::eye(p_, p_)), + phi_(arma::eye(p_, p_)), + inv_phi_(arma::eye(p_, p_)), + inv_omega_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), + vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), + omega_prop_(arma::mat(p_, p_, arma::fill::none)), + constants_(6) + {} + GGMModel(const GGMModel& other) : BaseModel(other), dim_(other.dim_), @@ -161,3 +186,11 @@ class GGMModel : public BaseModel { // double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal); // double diag_log_ratio(const arma::mat& omega, size_t i, double proposal); }; + + +GGMModel createGGMFromR( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true +); diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 8ca13568..4462e464 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -159,7 +159,7 @@ Rcpp::List convert_sampler_output_to_ggm_result(const std::vector Date: Mon, 2 Feb 2026 15:29:51 +0100 Subject: [PATCH 06/23] fix include of rng_utils.h --- src/ggm_model.cpp | 2 +- src/ggm_model.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index 6b8e0202..dc9f8ec6 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -1,6 +1,6 @@ #include "ggm_model.h" #include "adaptiveMetropolis.h" -#include "rng_utils.h" +#include "rng/rng_utils.h" #include "cholupdate.h" double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { diff --git a/src/ggm_model.h b/src/ggm_model.h index dcf4a35e..0ddf91ad 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -3,7 +3,7 @@ #include #include "base_model.h" #include "adaptiveMetropolis.h" -#include "rng_utils.h" +#include "rng/rng_utils.h" class GGMModel : public BaseModel { From b8d9e7e9b15cea287eb7dabb3400af6057fc68fd Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 3 Feb 2026 12:10:12 +0100 Subject: [PATCH 07/23] fix include of progressmanager --- src/sample_ggm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 4462e464..7efb964e 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -5,7 +5,7 @@ #include #include "ggm_model.h" -#include "progress_manager.h" +#include "utils/progress_manager.h" #include "chainResultNew.h" void run_mcmc_sampler_single_thread( From fbec813497d727b6ef536c7fd046d2a94bc55af6 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 3 Feb 2026 13:39:47 +0100 Subject: [PATCH 08/23] rename some things --- src/ggm_model.cpp | 432 ++++++++++----------------------------------- src/ggm_model.h | 112 ++++++++---- src/sample_ggm.cpp | 6 +- 3 files changed, 173 insertions(+), 377 deletions(-) diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index dc9f8ec6..a263924d 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -3,36 +3,36 @@ #include "rng/rng_utils.h" #include "cholupdate.h" -double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { +double GaussianVariables::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { return(A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i)); } -void GGMModel::get_constants(size_t i, size_t j) { +void GaussianVariables::get_constants(size_t i, size_t j) { // TODO: helper function? - double logdet_omega = get_log_det(phi_); + double logdet_omega = get_log_det(cholesky_of_precision_); - double log_adj_omega_ii = logdet_omega + std::log(std::abs(inv_omega_(i, i))); - double log_adj_omega_ij = logdet_omega + std::log(std::abs(inv_omega_(i, j))); - double log_adj_omega_jj = logdet_omega + std::log(std::abs(inv_omega_(j, j))); + double log_adj_omega_ii = logdet_omega + std::log(std::abs(covariance_matrix_(i, i))); + double log_adj_omega_ij = logdet_omega + std::log(std::abs(covariance_matrix_(i, j))); + double log_adj_omega_jj = logdet_omega + std::log(std::abs(covariance_matrix_(j, j))); - double inv_omega_sub_j1j1 = compute_inv_submatrix_i(inv_omega_, i, j, j); + double inv_omega_sub_j1j1 = compute_inv_submatrix_i(covariance_matrix_, i, j, j); double log_abs_inv_omega_sub_jj = log_adj_omega_ii + std::log(std::abs(inv_omega_sub_j1j1)); - double Phi_q1q = (2 * std::signbit(inv_omega_(i, j)) - 1) * std::exp( + double Phi_q1q = (2 * std::signbit(covariance_matrix_(i, j)) - 1) * std::exp( (log_adj_omega_ij - (log_adj_omega_jj + log_abs_inv_omega_sub_jj) / 2) ); double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); constants_[1] = Phi_q1q; constants_[2] = Phi_q1q1; - constants_[3] = omega_(i, j) - Phi_q1q * Phi_q1q1; + constants_[3] = precision_matrix_(i, j) - Phi_q1q * Phi_q1q1; constants_[4] = Phi_q1q1; - constants_[5] = omega_(j, j) - Phi_q1q * Phi_q1q; + constants_[5] = precision_matrix_(j, j) - Phi_q1q * Phi_q1q; constants_[6] = constants_[5] + constants_[3] * constants_[3] / (constants_[4] * constants_[4]); } -double GGMModel::R(const double x) const { +double GaussianVariables::R(const double x) const { if (x == 0) { return constants_[6]; } else { @@ -40,7 +40,7 @@ double GGMModel::R(const double x) const { } } -double GGMModel::get_log_det(arma::mat triangular_A) const { +double GaussianVariables::get_log_det(arma::mat triangular_A) const { // assume A is an (upper) triangular cholesky factor // returns the log determinant of A'A @@ -49,7 +49,7 @@ double GGMModel::get_log_det(arma::mat triangular_A) const { return 2 * arma::accu(arma::log(triangular_A.diag())); } -double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { +double GaussianVariables::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { double logdet_omega = get_log_det(phi); // TODO: why not just dot(omega, suf_stat_)? @@ -60,33 +60,16 @@ double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) return log_likelihood; } -double GGMModel::log_density_impl_edge(size_t i, size_t j) const { +double GaussianVariables::log_density_impl_edge(size_t i, size_t j) const { - // this is the log likelihood ratio, not the full log likelihood like GGMModel::log_density_impl + // this is the log likelihood ratio, not the full log likelihood like GaussianVariables::log_density_impl - double Ui2 = omega_(i, j) - omega_prop_(i, j); - // only reached from R - // if (omega_(j, j) == omega_prop_(j, j)) { - // k = i; - // i = j; - // j = k; - // } - double Uj2 = (omega_(j, j) - omega_prop_(j, j)) / 2; + double Ui2 = precision_matrix_(i, j) - precision_proposal_(i, j); + double Uj2 = (precision_matrix_(j, j) - precision_proposal_(j, j)) / 2; - - // W <- matrix(c(0, 1, 1, 0), 2, 2) - // U0 <- matrix(c(0, -1, Ui2, Uj2)) - // U <- matrix(0, nrow(aOmega), 2) - // U[c(i, j), 1] <- c(0, -1) - // U[c(i, j), 2] <- c(Ui2, Uj2) - // aOmega_prop - (aOmega + U %*% W %*% t(U)) - // det(aOmega_prop) - det(aOmega + U %*% W %*% t(U)) - // det(aOmega_prop) - det(W + t(U) %*% inv_aOmega %*% U) * det(W) * det(aOmega) - // below computes logdet(W + t(U) %*% inv_aOmega %*% U) directly (this is a 2x2 matrix) - - double cc11 = 0 + inv_omega_(j, j); - double cc12 = 1 - (inv_omega_(i, j) * Ui2 + inv_omega_(j, j) * Uj2); - double cc22 = 0 + Ui2 * Ui2 * inv_omega_(i, i) + 2 * Ui2 * Uj2 * inv_omega_(i, j) + Uj2 * Uj2 * inv_omega_(j, j); + double cc11 = 0 + covariance_matrix_(j, j); + double cc12 = 1 - (covariance_matrix_(i, j) * Ui2 + covariance_matrix_(j, j) * Uj2); + double cc22 = 0 + Ui2 * Ui2 * covariance_matrix_(i, i) + 2 * Ui2 * Uj2 * covariance_matrix_(i, j) + Uj2 * Uj2 * covariance_matrix_(j, j); double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); // logdet - (logdet(aOmega_prop) - logdet(aOmega)) @@ -98,31 +81,23 @@ double GGMModel::log_density_impl_edge(size_t i, size_t j) const { } -double GGMModel::log_density_impl_diag(size_t j) const { +double GaussianVariables::log_density_impl_diag(size_t j) const { // same as above but for i == j, so Ui2 = 0 - double Uj2 = (omega_(j, j) - omega_prop_(j, j)) / 2; + double Uj2 = (precision_matrix_(j, j) - precision_proposal_(j, j)) / 2; - double cc11 = 0 + inv_omega_(j, j); - double cc12 = 1 - inv_omega_(j, j) * Uj2; - double cc22 = 0 + Uj2 * Uj2 * inv_omega_(j, j); + double cc11 = 0 + covariance_matrix_(j, j); + double cc12 = 1 - covariance_matrix_(j, j) * Uj2; + double cc22 = 0 + Uj2 * Uj2 * covariance_matrix_(j, j); double logdet = std::log(std::abs(cc11 * cc22 - cc12 * cc12)); double trace_prod = -2 * suf_stat_(j, j) * Uj2; - // This function uses the fact that the determinant doesn't change during edge updates. - // double trace_prod = 0.0; - // // TODO: we only need one of the two lines below, but it's not entirely clear which one - // trace_prod += suf_stat_(j, j) * (omega_prop(j, j) - omega(j, j)); - // trace_prod += suf_stat_(i, i) * (omega_prop(i, i) - omega(i, i)); - // trace_prod += 2 * suf_stat_(i, j) * (omega_prop(i, j) - omega(i, j)); - // trace_prod - sum((aOmega_prop - aOmega) * SufStat) - double log_likelihood_ratio = (n_ * logdet - trace_prod) / 2; return log_likelihood_ratio; } -void GGMModel::update_edge_parameter(size_t i, size_t j) { +void GaussianVariables::update_edge_parameter(size_t i, size_t j) { if (edge_indicators_(i, j) == 0) { return; // Edge is not included; skip update @@ -140,107 +115,41 @@ void GGMModel::update_edge_parameter(size_t i, size_t j) { double omega_prop_qq = R(omega_prop_q1q); // form full proposal matrix for Omega - omega_prop_ = omega_; // TODO: needs to be a copy! - omega_prop_(i, j) = omega_prop_q1q; - omega_prop_(j, i) = omega_prop_q1q; - omega_prop_(j, j) = omega_prop_qq; - - // Rcpp::Rcout << "i: " << i << ", j: " << j << - // ", proposed phi: " << phi_prop << - // ", proposal_sd omega_ij: " << proposal_sd << - // ", proposed omega_ij: " << omega_prop_q1q << - // ", proposed omega_jj: " << omega_prop_qq << std::endl; - // constants_.print(Rcpp::Rcout, "Constants:"); - // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); - - // arma::vec eigval = eig_sym(omega_prop_); - // if (arma::any(eigval <= 0)) { - // Rcpp::Rcout << "Warning: omega_prop_ is not positive definite for edge (" << i << ", " << j << ")" << std::endl; - - // Rcpp::Rcout << - // ", proposed phi: " << phi_prop << - // ", proposal_sd omega_ij: " << proposal_sd << - // ", proposed omega_ij: " << omega_prop_q1q << - // ", proposed omega_jj: " << omega_prop_qq << std::endl; - // constants_.print(Rcpp::Rcout, "Constants:"); - // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); - // omega_.print(Rcpp::Rcout, "Current omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); - - // } - - // double ln_alpha = log_density(omega_prop_) - log_density(); - double ln_alpha = log_density_impl_edge(i, j); + precision_proposal_ = precision_matrix_; // TODO: needs to be a copy! + precision_proposal_(i, j) = omega_prop_q1q; + precision_proposal_(j, i) = omega_prop_q1q; + precision_proposal_(j, j) = omega_prop_qq; - // { - // double ln_alpha_ref = log_density(omega_prop_) - log_density(); - // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { - // Rcpp::Rcout << "Warning: log density implementations do not match for edge (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Current omega:"); - // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); - // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; - // } - // } + // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); + double ln_alpha = log_density_impl_edge(i, j); - ln_alpha += R::dcauchy(omega_prop_(i, j), 0.0, 2.5, true); - ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); + ln_alpha += R::dcauchy(precision_proposal_(i, j), 0.0, 2.5, true); + ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, 2.5, true); if (std::log(runif(rng_)) < ln_alpha) { // accept proposal proposal_.increment_accepts(e); - double omega_ij_old = omega_(i, j); - double omega_jj_old = omega_(j, j); + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); - omega_(i, j) = omega_prop_q1q; - omega_(j, i) = omega_prop_q1q; - omega_(j, j) = omega_prop_qq; + precision_matrix_(i, j) = omega_prop_q1q; + precision_matrix_(j, i) = omega_prop_q1q; + precision_matrix_(j, j) = omega_prop_qq; cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); - // // TODO: preallocate? - // // find v for low rank update - // arma::vec v1 = {0, -1}; - // arma::vec v2 = {omega_ij - omega_prop_(i, j), (omega_jj - omega_prop_(j, j)) / 2}; - - // arma::vec vf1 = arma::zeros(p_); - // arma::vec vf2 = arma::zeros(p_); - // vf1[i] = v1[0]; - // vf1[j] = v1[1]; - // vf2[i] = v2[0]; - // vf2[j] = v2[1]; - - // // we now have - // // aOmega_prop - (aOmega + vf1 %*% t(vf2) + vf2 %*% t(vf1)) - - // arma::vec u1 = (vf1 + vf2) / sqrt(2); - // arma::vec u2 = (vf1 - vf2) / sqrt(2); - - // // we now have - // // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) - // // and also - // // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) - - // // update phi (2x O(p^2)) - // cholesky_update(phi_, u1); - // cholesky_downdate(phi_, u2); - - // // update inverse (2x O(p^2)) - // arma::inv(inv_phi_, arma::trimatu(phi_)); - // inv_omega_ = inv_phi_ * inv_phi_.t(); - } proposal_.update_proposal_sd(e); } -void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) +void GaussianVariables::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) { - v2_[0] = omega_ij_old - omega_prop_(i, j); - v2_[1] = (omega_jj_old - omega_prop_(j, j)) / 2; + v2_[0] = omega_ij_old - precision_proposal_(i, j); + v2_[1] = (omega_jj_old - precision_proposal_(j, j)) / 2; vf1_[i] = v1_[0]; vf1_[j] = v1_[1]; @@ -253,18 +162,13 @@ void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_o u1_ = (vf1_ + vf2_) / sqrt(2); u2_ = (vf1_ - vf2_) / sqrt(2); - // we now have - // omega_prop_ - (aOmega + u1 %*% t(u1) - u2 %*% t(u2)) - // and also - // aOmega_prop - (aOmega + cbind(vf1, vf2) %*% matrix(c(0, 1, 1, 0), 2, 2) %*% t(cbind(vf1, vf2))) - // update phi (2x O(p^2)) - cholesky_update(phi_, u1_); - cholesky_downdate(phi_, u2_); + cholesky_update(cholesky_of_precision_, u1_); + cholesky_downdate(cholesky_of_precision_, u2_); // update inverse (2x O(p^2)) - arma::inv(inv_phi_, arma::trimatu(phi_)); - inv_omega_ = inv_phi_ * inv_phi_.t(); + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_matrix_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); // reset for next iteration vf1_[i] = 0.0; @@ -274,11 +178,11 @@ void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_o } -void GGMModel::update_diagonal_parameter(size_t i) { +void GaussianVariables::update_diagonal_parameter(size_t i) { // Implementation of diagonal parameter update // 1-3) from before - double logdet_omega = get_log_det(phi_); - double logdet_omega_sub_ii = logdet_omega + std::log(inv_omega_(i, i)); + double logdet_omega = get_log_det(cholesky_of_precision_); + double logdet_omega_sub_ii = logdet_omega + std::log(covariance_matrix_(i, i)); size_t e = i * (i + 1) / 2 + i; // parameter index in vectorized form double proposal_sd = proposal_.get_proposal_sd(e); @@ -287,28 +191,10 @@ void GGMModel::update_diagonal_parameter(size_t i) { double theta_prop = rnorm(rng_, theta_curr, proposal_sd); //4) Replace and rebuild omega - omega_prop_ = omega_; - omega_prop_(i, i) = omega_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); + precision_proposal_ = precision_matrix_; + precision_proposal_(i, i) = precision_matrix_(i, i) - std::exp(theta_curr) * std::exp(theta_curr) + std::exp(theta_prop) * std::exp(theta_prop); - // Rcpp::Rcout << "i: " << i << - // ", current theta: " << theta_curr << - // ", proposed theta: " << theta_prop << - // ", proposal_sd: " << proposal_sd << std::endl; - // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); - - // 5) Acceptance ratio - // double ln_alpha = log_density(omega_prop_) - log_density(); double ln_alpha = log_density_impl_diag(i); - // { - // double ln_alpha_ref = log_density(omega_prop_) - log_density(); - // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { - // Rcpp::Rcout << "Warning: log density implementations do not match for diag (" << i << ", " << i << ")" << std::endl; - // // omega_.print(Rcpp::Rcout, "Current omega:"); - // // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); - // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; - // Rcpp::Rcout << "1e4 * diff: " << 10000 * (ln_alpha - ln_alpha_ref) << std::endl; - // } - // } ln_alpha += R::dgamma(exp(theta_prop), 1.0, 1.0, true); ln_alpha -= R::dgamma(exp(theta_curr), 1.0, 1.0, true); @@ -318,124 +204,88 @@ void GGMModel::update_diagonal_parameter(size_t i) { proposal_.increment_accepts(e); - double omega_ii = omega_(i, i); - omega_(i, i) = omega_prop_(i, i); + double omega_ii = precision_matrix_(i, i); + precision_matrix_(i, i) = precision_proposal_(i, i); cholesky_update_after_diag(omega_ii, i); - // arma::vec u(p_, arma::fill::zeros); - // double delta = omega_ii - omega_prop_(i, i); - // bool s = delta > 0; - // u(i) = std::sqrt(std::abs(delta)); - - - // if (s) - // cholesky_downdate(phi_, u); - // else - // cholesky_update(phi_, u); - - // // update inverse (2x O(p^2)) - // arma::inv(inv_phi_, arma::trimatu(phi_)); - // inv_omega_ = inv_phi_ * inv_phi_.t(); - - } proposal_.update_proposal_sd(e); } -void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i) +void GaussianVariables::cholesky_update_after_diag(double omega_ii_old, size_t i) { - double delta = omega_ii_old - omega_prop_(i, i); + double delta = omega_ii_old - precision_proposal_(i, i); bool s = delta > 0; vf1_(i) = std::sqrt(std::abs(delta)); if (s) - cholesky_downdate(phi_, vf1_); + cholesky_downdate(cholesky_of_precision_, vf1_); else - cholesky_update(phi_, vf1_); + cholesky_update(cholesky_of_precision_, vf1_); // update inverse (2x O(p^2)) - arma::inv(inv_phi_, arma::trimatu(phi_)); - inv_omega_ = inv_phi_ * inv_phi_.t(); + arma::inv(inv_cholesky_of_precision_, arma::trimatu(cholesky_of_precision_)); + covariance_matrix_ = inv_cholesky_of_precision_ * inv_cholesky_of_precision_.t(); // reset for next iteration vf1_(i) = 0.0; } -void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { +void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) { size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form double proposal_sd = proposal_.get_proposal_sd(e); if (edge_indicators_(i, j) == 1) { // Propose to turn OFF the edge - omega_prop_ = omega_; - omega_prop_(i, j) = 0.0; - omega_prop_(j, i) = 0.0; + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = 0.0; + precision_proposal_(j, i) = 0.0; // Update diagonal using R function with omega_ij = 0 get_constants(i, j); - omega_prop_(j, j) = R(0.0); + precision_proposal_(j, j) = R(0.0); - // double ln_alpha = log_density(omega_prop_) - log_density(); + // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); double ln_alpha = log_density_impl_edge(i, j); // { - // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // double ln_alpha_ref = log_likelihood(precision_proposal_) - log_likelihood(); // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Current omega:"); - // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // precision_matrix_.print(Rcpp::Rcout, "Current omega:"); + // precision_proposal_.print(Rcpp::Rcout, "Proposed omega:"); // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; // } // } - ln_alpha += std::log(1.0 - prior_inclusion_prob_(i, j)) - std::log(prior_inclusion_prob_(i, j)); + ln_alpha += std::log(1.0 - inclusion_probability_(i, j)) - std::log(inclusion_probability_(i, j)); - ln_alpha += R::dnorm(omega_(i, j) / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); - ln_alpha -= R::dcauchy(omega_(i, j), 0.0, 2.5, true); + ln_alpha += R::dnorm(precision_matrix_(i, j) / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); + ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, 2.5, true); if (std::log(runif(rng_)) < ln_alpha) { // Store old values for Cholesky update - double omega_ij_old = omega_(i, j); - double omega_jj_old = omega_(j, j); + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); // Update omega - omega_(i, j) = 0.0; - omega_(j, i) = 0.0; - omega_(j, j) = omega_prop_(j, j); + precision_matrix_(i, j) = 0.0; + precision_matrix_(j, i) = 0.0; + precision_matrix_(j, j) = precision_proposal_(j, j); // Update edge indicator edge_indicators_(i, j) = 0; edge_indicators_(j, i) = 0; cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); - // // Cholesky update vectors - // arma::vec v1 = {0, -1}; - // arma::vec v2 = {omega_ij_old - 0.0, (omega_jj_old - omega_(j, j)) / 2}; - - // arma::vec vf1 = arma::zeros(p_); - // arma::vec vf2 = arma::zeros(p_); - // vf1[i] = v1[0]; - // vf1[j] = v1[1]; - // vf2[i] = v2[0]; - // vf2[j] = v2[1]; - - // arma::vec u1 = (vf1 + vf2) / sqrt(2); - // arma::vec u2 = (vf1 - vf2) / sqrt(2); - - // // Update Cholesky factor - // cholesky_update(phi_, u1); - // cholesky_downdate(phi_, u2); - - // // Update inverse - // arma::inv(inv_phi_, arma::trimatu(phi_)); - // inv_omega_ = inv_phi_ * inv_phi_.t(); + } } else { @@ -447,23 +297,23 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { double omega_prop_ij = constants_[4] * epsilon; double omega_prop_jj = R(omega_prop_ij); - omega_prop_ = omega_; - omega_prop_(i, j) = omega_prop_ij; - omega_prop_(j, i) = omega_prop_ij; - omega_prop_(j, j) = omega_prop_jj; + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = omega_prop_ij; + precision_proposal_(j, i) = omega_prop_ij; + precision_proposal_(j, j) = omega_prop_jj; - // double ln_alpha = log_density(omega_prop_) - log_density(); + // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); double ln_alpha = log_density_impl_edge(i, j); // { - // double ln_alpha_ref = log_density(omega_prop_) - log_density(); + // double ln_alpha_ref = log_likelihood(precision_proposal_) - log_likelihood(); // if (std::abs(ln_alpha - ln_alpha_ref) > 1e-6) { // Rcpp::Rcout << "Warning: log density implementations do not match for edge indicator (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Current omega:"); - // omega_prop_.print(Rcpp::Rcout, "Proposed omega:"); + // precision_matrix_.print(Rcpp::Rcout, "Current omega:"); + // precision_proposal_.print(Rcpp::Rcout, "Proposed omega:"); // Rcpp::Rcout << "ln_alpha: " << ln_alpha << ", ln_alpha_ref: " << ln_alpha_ref << std::endl; // } // } - ln_alpha += std::log(prior_inclusion_prob_(i, j)) - std::log(1.0 - prior_inclusion_prob_(i, j)); + ln_alpha += std::log(inclusion_probability_(i, j)) - std::log(1.0 - inclusion_probability_(i, j)); // Prior change: add slab (Cauchy prior) ln_alpha += R::dcauchy(omega_prop_ij, 0.0, 2.5, true); @@ -477,132 +327,42 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { proposal_.increment_accepts(e); // Store old values for Cholesky update - double omega_ij_old = omega_(i, j); - double omega_jj_old = omega_(j, j); + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); // Update omega - omega_(i, j) = omega_prop_ij; - omega_(j, i) = omega_prop_ij; - omega_(j, j) = omega_prop_jj; + precision_matrix_(i, j) = omega_prop_ij; + precision_matrix_(j, i) = omega_prop_ij; + precision_matrix_(j, j) = omega_prop_jj; // Update edge indicator edge_indicators_(i, j) = 1; edge_indicators_(j, i) = 1; cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); - // // Cholesky update vectors - // arma::vec v1 = {0, -1}; - // arma::vec v2 = {omega_ij_old - omega_(i, j), (omega_jj_old - omega_(j, j)) / 2}; - - // arma::vec vf1 = arma::zeros(p_); - // arma::vec vf2 = arma::zeros(p_); - // vf1[i] = v1[0]; - // vf1[j] = v1[1]; - // vf2[i] = v2[0]; - // vf2[j] = v2[1]; - - // arma::vec u1 = (vf1 + vf2) / sqrt(2); - // arma::vec u2 = (vf1 - vf2) / sqrt(2); - - // // Update Cholesky factor - // cholesky_update(phi_, u1); - // cholesky_downdate(phi_, u2); - - // // Update inverse - // arma::inv(inv_phi_, arma::trimatu(phi_)); - // inv_omega_ = inv_phi_ * inv_phi_.t(); + } } } -void GGMModel::do_one_mh_step() { +void GaussianVariables::do_one_mh_step() { // Update off-diagonals (upper triangle) for (size_t i = 0; i < p_ - 1; ++i) { for (size_t j = i + 1; j < p_; ++j) { - // Rcpp::Rcout << "Updating edge parameter (" << i << ", " << j << ")" << std::endl; update_edge_parameter(i, j); - // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; - // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); - // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); - // } - // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating edge (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // } - // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; - // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); - // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); - // } - // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating edge (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); - // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); - // } } } // Update diagonals for (size_t i = 0; i < p_; ++i) { - // Rcpp::Rcout << "Updating diagonal parameter " << i << std::endl; update_diagonal_parameter(i); - - // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating diagonal " << i << std::endl; - // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); - // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); - // } - // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating diagonal " << i << std::endl; - // omega_.print(Rcpp::Rcout, "Omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // } - // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating diagonal " << i << std::endl; - // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); - // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); - // } - // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating diagonal " << i << std::endl; - // omega_.print(Rcpp::Rcout, "Omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); - // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); - // } } if (edge_selection_) { for (size_t i = 0; i < p_ - 1; ++i) { for (size_t j = i + 1; j < p_; ++j) { - // Rcpp::Rcout << "Between model move for edge (" << i << ", " << j << ")" << std::endl; update_edge_indicator_parameter_pair(i, j); - // if (!arma:: approx_equal(omega_ * inv_omega_, arma::eye(p_, p_), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Omega * Inv(Omega) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; - // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); - // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); - // } - // if (!arma:: approx_equal(omega_, phi_.t() * phi_, "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Omega not equal to Phi.t() * Phi after updating edge (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // } - // if (!arma:: approx_equal(phi_ * inv_phi_, arma::eye(p_, p_), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Phi * Inv(Phi) not equal to identity after updating edge (" << i << ", " << j << ")" << std::endl; - // (omega_ * inv_omega_).print(Rcpp::Rcout, "Omega * Inv(Omega):"); - // (phi_ * inv_phi_).print(Rcpp::Rcout, "Phi * Inv(Phi):"); - // } - // if (!arma:: approx_equal(inv_omega_, inv_phi_ * inv_phi_.t(), "absdiff", 1e-6)) { - // Rcpp::Rcout << "Warning: Inv(Omega) not equal to Inv(Phi) * Inv(Phi).t() after updating edge (" << i << ", " << j << ")" << std::endl; - // omega_.print(Rcpp::Rcout, "Omega:"); - // phi_.print(Rcpp::Rcout, "Phi:"); - // inv_omega_.print(Rcpp::Rcout, "Inv(Omega):"); - // inv_phi_.print(Rcpp::Rcout, "Inv(Phi):"); - // } } } } @@ -612,7 +372,7 @@ void GGMModel::do_one_mh_step() { } -GGMModel createGGMFromR( +GaussianVariables createGaussianVariablesFromR( const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, @@ -622,7 +382,7 @@ GGMModel createGGMFromR( if (inputFromR.containsElementNamed("n") && inputFromR.containsElementNamed("suf_stat")) { int n = Rcpp::as(inputFromR["n"]); arma::mat suf_stat = Rcpp::as(inputFromR["suf_stat"]); - return GGMModel( + return GaussianVariables( n, suf_stat, prior_inclusion_prob, @@ -631,7 +391,7 @@ GGMModel createGGMFromR( ); } else if (inputFromR.containsElementNamed("X")) { arma::mat X = Rcpp::as(inputFromR["X"]); - return GGMModel( + return GaussianVariables( X, prior_inclusion_prob, initial_edge_indicators, diff --git a/src/ggm_model.h b/src/ggm_model.h index 0ddf91ad..7434ccf7 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -6,74 +6,82 @@ #include "rng/rng_utils.h" -class GGMModel : public BaseModel { +class GaussianVariables : public BaseModel { public: - GGMModel( - const arma::mat& X, - const arma::mat& prior_inclusion_prob, + // constructor from raw data + GaussianVariables( + const arma::mat& observations, + const arma::mat& inclusion_probability, const arma::imat& initial_edge_indicators, const bool edge_selection = true - ) : n_(X.n_rows), - p_(X.n_cols), + ) : n_(observations.n_rows), + p_(observations.n_cols), + // TODO: need to estimate the means! so + 1 dim_((p_ * (p_ + 1)) / 2), - suf_stat_(X.t() * X), - prior_inclusion_prob_(prior_inclusion_prob), + // TODO: need to store sample means! + suf_stat_(observations.t() * observations), + inclusion_probability_(inclusion_probability), edge_selection_(edge_selection), proposal_(AdaptiveProposal(dim_, 500)), - omega_(arma::eye(p_, p_)), - phi_(arma::eye(p_, p_)), - inv_phi_(arma::eye(p_, p_)), - inv_omega_(arma::eye(p_, p_)), + precision_matrix_(arma::eye(p_, p_)), + cholesky_of_precision_(arma::eye(p_, p_)), + inv_cholesky_of_precision_(arma::eye(p_, p_)), + covariance_matrix_(arma::eye(p_, p_)), + edge_indicators_(initial_edge_indicators), + vectorized_parameters_(dim_), vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), - omega_prop_(arma::mat(p_, p_, arma::fill::none)), + precision_proposal_(arma::mat(p_, p_, arma::fill::none)), constants_(6) {} - GGMModel( + // constructor from sufficient statistics + // TODO: needs to implement same TODOs as above constructor + GaussianVariables( const int n, const arma::mat& suf_stat, - const arma::mat& prior_inclusion_prob, + const arma::mat& inclusion_probability, const arma::imat& initial_edge_indicators, const bool edge_selection = true ) : n_(n), p_(suf_stat.n_cols), dim_((p_ * (p_ + 1)) / 2), suf_stat_(suf_stat), - prior_inclusion_prob_(prior_inclusion_prob), + inclusion_probability_(inclusion_probability), edge_selection_(edge_selection), proposal_(AdaptiveProposal(dim_, 500)), - omega_(arma::eye(p_, p_)), - phi_(arma::eye(p_, p_)), - inv_phi_(arma::eye(p_, p_)), - inv_omega_(arma::eye(p_, p_)), + precision_matrix_(arma::eye(p_, p_)), + cholesky_of_precision_(arma::eye(p_, p_)), + inv_cholesky_of_precision_(arma::eye(p_, p_)), + covariance_matrix_(arma::eye(p_, p_)), edge_indicators_(initial_edge_indicators), vectorized_parameters_(dim_), vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), - omega_prop_(arma::mat(p_, p_, arma::fill::none)), + precision_proposal_(arma::mat(p_, p_, arma::fill::none)), constants_(6) {} - GGMModel(const GGMModel& other) + // copy constructor + GaussianVariables(const GaussianVariables& other) : BaseModel(other), dim_(other.dim_), suf_stat_(other.suf_stat_), n_(other.n_), p_(other.p_), - prior_inclusion_prob_(other.prior_inclusion_prob_), + inclusion_probability_(other.inclusion_probability_), edge_selection_(other.edge_selection_), - omega_(other.omega_), - phi_(other.phi_), - inv_phi_(other.inv_phi_), - inv_omega_(other.inv_omega_), + precision_matrix_(other.precision_matrix_), + cholesky_of_precision_(other.cholesky_of_precision_), + inv_cholesky_of_precision_(other.inv_cholesky_of_precision_), + covariance_matrix_(other.covariance_matrix_), edge_indicators_(other.edge_indicators_), vectorized_parameters_(other.vectorized_parameters_), vectorized_indicator_parameters_(other.vectorized_indicator_parameters_), proposal_(other.proposal_), rng_(other.rng_), - omega_prop_(other.omega_prop_), + precision_proposal_(other.precision_proposal_), constants_(other.constants_) {} @@ -94,8 +102,8 @@ class GGMModel : public BaseModel { } // TODO: this can be done more efficiently, no need for the Cholesky! - double log_density(const arma::mat& omega) const { return log_density_impl(omega, arma::chol(omega)); }; - double log_density() const { return log_density_impl(omega_, phi_); } + double log_likelihood(const arma::mat& omega) const { return log_density_impl(omega, arma::chol(omega)); }; + double log_likelihood() const { return log_density_impl(precision_matrix_, cholesky_of_precision_); } void do_one_mh_step() override; @@ -108,11 +116,11 @@ class GGMModel : public BaseModel { } arma::vec get_vectorized_parameters() override { - // upper triangle of omega_ + // upper triangle of precision_matrix_ size_t e = 0; for (size_t j = 0; j < p_; ++j) { for (size_t i = 0; i <= j; ++i) { - vectorized_parameters_(e) = omega_(i, j); + vectorized_parameters_(e) = precision_matrix_(i, j); ++e; } } @@ -120,7 +128,7 @@ class GGMModel : public BaseModel { } arma::ivec get_vectorized_indicator_parameters() override { - // upper triangle of omega_ + // upper triangle of precision_matrix_ size_t e = 0; for (size_t j = 0; j < p_; ++j) { for (size_t i = 0; i <= j; ++i) { @@ -132,7 +140,7 @@ class GGMModel : public BaseModel { } std::unique_ptr clone() const override { - return std::make_unique(*this); // uses copy constructor + return std::make_unique(*this); // uses copy constructor } private: @@ -141,21 +149,22 @@ class GGMModel : public BaseModel { size_t p_; size_t dim_; arma::mat suf_stat_; - arma::mat prior_inclusion_prob_; + arma::mat inclusion_probability_; bool edge_selection_; // parameters - arma::mat omega_, phi_, inv_phi_, inv_omega_; + arma::mat precision_matrix_, cholesky_of_precision_, inv_cholesky_of_precision_, covariance_matrix_; arma::imat edge_indicators_; arma::vec vectorized_parameters_; arma::ivec vectorized_indicator_parameters_; AdaptiveProposal proposal_; + SafeRNG rng_; // internal helper variables - arma::mat omega_prop_; + arma::mat precision_proposal_; arma::vec constants_; // Phi_q1q, Phi_q1q1, c[1], c[2], c[3], c[4] arma::vec v1_ = {0, -1}; @@ -187,10 +196,35 @@ class GGMModel : public BaseModel { // double diag_log_ratio(const arma::mat& omega, size_t i, double proposal); }; +// class MixedVariableTypes : public BaseModel { +// public: + +// function arma::vec gradient(const arma::vec& parameters) override +// { + +// arma::vec grad = arma::zeros(dim_); + +// size_t from = 0; +// size_t to = 0; +// for (const auto& var_type : variable_types_) { +// to += var_type->parameter_dimension(); +// grad += var_type->gradient(arma::span(parameters, from, to - 1)); +// from = to; +// } +// return grad; +// } + +// private: +// std::vector> variable_types_; +// std::vector interactions_; +// size_t dim_; + +// }; + -GGMModel createGGMFromR( +GaussianVariables createGaussianVariablesFromR( const Rcpp::List& inputFromR, - const arma::mat& prior_inclusion_prob, + const arma::mat& inclusion_probability, const arma::imat& initial_edge_indicators, const bool edge_selection = true ); diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 7efb964e..a9af132a 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -23,6 +23,8 @@ void run_mcmc_sampler_single_thread( model.do_one_mh_step(); + // update hyperparameters (BetaBinomial, SBM, etc.) + if (iter >= no_warmup) { chain_result.store_sample(i, model.get_vectorized_parameters()); @@ -173,8 +175,8 @@ Rcpp::List sample_ggm( // should be done dynamically // also adaptation method should be specified differently - // GGMModel model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection); - GGMModel model = createGGMFromR(inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + // GaussianVariables model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection); + GaussianVariables model = createGaussianVariablesFromR(inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); From 0382d0855fae0637638b0557beedbda9c8caf550 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 3 Feb 2026 13:47:43 +0100 Subject: [PATCH 09/23] add skeleton class --- src/SkeletonVariables.h | 80 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 src/SkeletonVariables.h diff --git a/src/SkeletonVariables.h b/src/SkeletonVariables.h new file mode 100644 index 00000000..51fd9c24 --- /dev/null +++ b/src/SkeletonVariables.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include "base_model.h" +#include "adaptiveMetropolis.h" +#include "rng/rng_utils.h" + + +class SkeletonVariables : public BaseModel { +public: + + // constructor from raw data + SkeletonVariables( + const arma::mat& observations, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) : + {} + + // copy constructor + SkeletonVariables(const SkeletonVariables& other) + : BaseModel(other), + {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); // uses copy constructor + } + + bool has_gradient() const { return false; } + bool has_adaptive_mh() const override { return true; } + + double logp(const arma::vec& parameters) override { + // Implement log probability computation + return 0.0; + } + + void do_one_mh_step() override; + + size_t parameter_dimension() const override { + return dim_; + } + + void set_seed(int seed) override { + rng_ = SafeRNG(seed); + } + + // arma::vec get_vectorized_parameters() override { + // // upper triangle of precision_matrix_ + // size_t e = 0; + // for (size_t j = 0; j < p_; ++j) { + // for (size_t i = 0; i <= j; ++i) { + // vectorized_parameters_(e) = precision_matrix_(i, j); + // ++e; + // } + // } + // return vectorized_parameters_; + // } + + // arma::ivec get_vectorized_indicator_parameters() override { + // // upper triangle of precision_matrix_ + // size_t e = 0; + // for (size_t j = 0; j < p_; ++j) { + // for (size_t i = 0; i <= j; ++i) { + // vectorized_indicator_parameters_(e) = edge_indicators_(i, j); + // ++e; + // } + // } + // return vectorized_indicator_parameters_; + // } + + +private: + // data + size_t n_ = 0; + size_t p_ = 0; + size_t dim_ = 0; + + +}; \ No newline at end of file From 192ed8375b6c4879e54408addd3f5bf1cbecb9ca Mon Sep 17 00:00:00 2001 From: MaartenMarsman <52934067+MaartenMarsman@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:01:46 +0100 Subject: [PATCH 10/23] Checkout --- src/skeleton_model.cpp | 231 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 src/skeleton_model.cpp diff --git a/src/skeleton_model.cpp b/src/skeleton_model.cpp new file mode 100644 index 00000000..78fe901a --- /dev/null +++ b/src/skeleton_model.cpp @@ -0,0 +1,231 @@ +/** + * ================================================== + * Header dependencies for SkeletonVariables + * ================================================== + * + * This file defines a template ("skeleton") for variable-type classes + * in the bgms codebase. The included headers constitute the minimal + * set of dependencies required by any variable model, independent of + * its specific statistical formulation. + * + * Each include serves a specific purpose: + * - : ownership and cloning via std::unique_ptr + * - base_model.h : abstract interface for variable models + * - adaptiveMetropolis.h : Metropolis–Hastings proposal mechanism + * - rng_utils.h : reproducible random number generation + */ +#pragma once + +#include + +#include "base_model.h" +#include "mcmc/adaptiveMetropolis.h" +#include "rng/rng_utils.h" + + +/** + * ================================================== + * SkeletonVariables class + * ================================================== + * + * A class in C++ represents a concrete type that bundles together: + * - internal state, including observed data, model parameters, + * and auxiliary objects that maintain sampling state, + * - and functions (methods) that act on this state to evaluate + * the log posterior, its gradient, and to perform sampling updates. + * + * In the bgms codebase, each statistical variable model is implemented + * as a C++ class that stores the model state and provides the methods + * required for inference. + * + * SkeletonVariables defines a template for such implementation classes. + * It specifies the structure and interface that concrete implementations + * of variable models must follow, without imposing a particular + * statistical formulation. + * + * SkeletonVariables inherits from BaseModel, which defines the common + * interface for all variable-model implementations in bgms: + * + * class SkeletonVariables : public BaseModel + * + * Inheriting from BaseModel means that SkeletonVariables must provide + * a fixed set of functions (such as log_posterior and do_one_mh_step) + * that the rest of the codebase relies on. As a result, code elsewhere + * in bgms can interact with SkeletonVariables through the BaseModel + * interface, without needing to know which specific variable model + * implementation is being used. + */ +class SkeletonVariables : public BaseModel { + + /* + * The 'public:' label below marks the beginning of the part of the class + * that is accessible from outside the class. + * + * Functions and constructors declared under 'public' are intended to be + * called by other components of the bgms codebase, such as samplers, + * model-selection routines, and result containers. + * + * Together, these public members define how a variable-model + * implementation can be created, queried, and updated by external code. + */ + public: + + /* + * Constructors are responsible for establishing the complete internal + * state of the object. + * + * A SkeletonVariables object represents a fully specified variable-model + * implementation at a given point in an inference procedure. Therefore, + * all information required to evaluate the log posterior and to perform + * sampling updates must be stored within the object itself. + * + * After construction, the object is expected to be immediately usable: + * no additional initialization steps are required before it can be + * queried, updated, or copied. + * + * The constructor below uses a constructor initializer list to define + * how base classes and data members are constructed. The initializer + * list is evaluated before the constructor body runs and is used to: + * - construct the BaseModel subobject, + * - initialize data members directly from constructor arguments. + * + * This ensures that all base-class and member invariants are established + * before any additional derived quantities are computed in the + * constructor body. + */ + SkeletonVariables( + const arma::mat& observations, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) + : BaseModel(edge_selection), + observations_(observations), + inclusion_probability_(inclusion_probability), + edge_indicators_(initial_edge_indicators) + { + /* + * The constructor body initializes derived state that depends on + * already-constructed members, such as dimensions, parameter vectors, + * and sampling-related objects. + */ + + n_ = observations_.n_rows; + p_ = observations_.n_cols; + + // Dimension of the parameter vector (model-specific). + // For the skeleton, we assume one parameter per variable. + dim_ = p_; + + // Initialize parameter vector + parameters_.zeros(dim_); + + // Initialize adaptive Metropolis–Hastings sampler + adaptive_mh_ = AdaptiveMetropolis(dim_); + } + + /* + * Copy constructor. + * + * The copy constructor creates a new SkeletonVariables object that is an + * exact copy of an existing one. This includes not only the observed data + * and model parameters, but also all internal state required for inference, + * such as sampler state and random number generator state. + * + * Copying is required in bgms because variable-model objects are duplicated + * during inference, for example when running multiple + * chains, or storing and restoring model states. + * + * The copy is performed using a constructor initializer list to ensure + * that the BaseModel subobject and all data members are constructed + * directly from their counterparts in the source object. + * + * After construction, the new object is independent from the original + * but represents the same model state. + */ + SkeletonVariables(const SkeletonVariables& other) + : BaseModel(other), + observations_(other.observations_), + inclusion_probability_(other.inclusion_probability_), + edge_indicators_(other.edge_indicators_), + n_(other.n_), + p_(other.p_), + dim_(other.dim_), + parameters_(other.parameters_), + adaptive_mh_(other.adaptive_mh_), + rng_(other.rng_) + {} + + // -------------------------------------------------- + // Polymorphic copy + // -------------------------------------------------- + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + // -------------------------------------------------- + // Capabilities + // -------------------------------------------------- + bool has_log_posterior() const override { return true; } + bool has_gradient() const override { return true; } + bool has_adaptive_mh() const override { return true; } + + // -------------------------------------------------- + // Log posterior (LIKELIHOOD + PRIOR) + // -------------------------------------------------- + double log_posterior(const arma::vec& parameters) override { + // Skeleton: flat log-density + // Real models will: + // - unpack parameters + // - compute likelihood + // - add priors + return 0.0; + } + + // -------------------------------------------------- + // Gradient of LOG (LIKELIHOOD * PRIOR) + // -------------------------------------------------- + void gradient(const arma::vec& parameters) override { + // Skeleton: + } + + // -------------------------------------------------- + // One Metropolis–Hastings step + // -------------------------------------------------- + void do_one_mh_step(arma::vec& parameters) override { + // Skeleton: + } + + // -------------------------------------------------- + // Required interface + // -------------------------------------------------- + size_t parameter_dimension() const override { + return dim_; + } + + void set_seed(int seed) override { + rng_ = SafeRNG(seed); + } + +protected: + // -------------------------------------------------- + // Data + // -------------------------------------------------- + arma::mat observations_; + arma::mat inclusion_probability_; + arma::imat edge_indicators_; + + // -------------------------------------------------- + // Dimensions + // -------------------------------------------------- + size_t n_ = 0; // number of observations + size_t p_ = 0; // number of variables + size_t dim_ = 0; // dimension of parameter vector + + // -------------------------------------------------- + // Parameters & MCMC machinery + // -------------------------------------------------- + arma::vec parameters_; + AdaptiveMetropolis adaptive_mh_; + SafeRNG rng_; +}; From 5adc5eb85fc1ddcf876501f708816866ad84b8f5 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 4 Feb 2026 15:27:01 +0100 Subject: [PATCH 11/23] mostly functional --- R/RcppExports.R | 8 + src/RcppExports.cpp | 45 ++ src/base_model.cpp | 3 + src/base_model.h | 46 +- src/chainResultNew.h | 85 ++- src/ggm_model.h | 51 +- src/mcmc/base_sampler.h | 62 ++ src/mcmc/hmc_sampler.h | 122 ++++ src/mcmc/mcmc_runner.h | 285 ++++++++ src/mcmc/mh_sampler.h | 69 ++ src/mcmc/nuts_sampler.h | 269 +++++++ src/mcmc/sampler_config.h | 45 ++ src/mixedVariables.cpp | 43 ++ src/mixedVariables.h | 169 +++++ src/omrf_model.cpp | 1424 +++++++++++++++++++++++++++++++++++++ src/omrf_model.h | 415 +++++++++++ src/sample_ggm.cpp | 179 +---- src/sample_omrf.cpp | 85 +++ 18 files changed, 3208 insertions(+), 197 deletions(-) create mode 100644 src/mcmc/base_sampler.h create mode 100644 src/mcmc/hmc_sampler.h create mode 100644 src/mcmc/mcmc_runner.h create mode 100644 src/mcmc/mh_sampler.h create mode 100644 src/mcmc/nuts_sampler.h create mode 100644 src/mcmc/sampler_config.h create mode 100644 src/mixedVariables.cpp create mode 100644 src/mixedVariables.h create mode 100644 src/omrf_model.cpp create mode 100644 src/omrf_model.h create mode 100644 src/sample_omrf.cpp diff --git a/R/RcppExports.R b/R/RcppExports.R index 715c8c12..b05ee7b4 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -29,10 +29,18 @@ run_simulation_parallel <- function(pairwise_samples, main_samples, draw_indices .Call(`_bgms_run_simulation_parallel`, pairwise_samples, main_samples, draw_indices, no_states, no_variables, no_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } +sample_omrf_classed <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed) { + .Call(`_bgms_sample_omrf_classed`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed) +} + sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) { .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) } +sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, edge_selection_start = -1L) { + .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start) +} + compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { .Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index aa1b1985..2bcac928 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -193,6 +193,24 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// sample_omrf_classed +Rcpp::List sample_omrf_classed(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const bool edge_selection, const std::string& sampler_type, const int seed); +RcppExport SEXP _bgms_sample_omrf_classed(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type sampler_type(sampler_typeSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + rcpp_result_gen = Rcpp::wrap(sample_omrf_classed(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed)); + return rcpp_result_gen; +END_RCPP +} // sample_ggm Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type); RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { @@ -213,6 +231,31 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// sample_omrf +Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const int edge_selection_start); +RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP edge_selection_startSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); + Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); + Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); + Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); + Rcpp::traits::input_parameter< const int >::type no_chains(no_chainsSEXP); + Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); + Rcpp::traits::input_parameter< const std::string& >::type sampler_type(sampler_typeSEXP); + Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); + Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); + Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< const double >::type target_acceptance(target_acceptanceSEXP); + Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); + Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); + Rcpp::traits::input_parameter< const int >::type edge_selection_start(edge_selection_startSEXP); + rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start)); + return rcpp_result_gen; +END_RCPP +} // compute_Vn_mfm_sbm arma::vec compute_Vn_mfm_sbm(arma::uword no_variables, double dirichlet_alpha, arma::uword t_max, double lambda); RcppExport SEXP _bgms_compute_Vn_mfm_sbm(SEXP no_variablesSEXP, SEXP dirichlet_alphaSEXP, SEXP t_maxSEXP, SEXP lambdaSEXP) { @@ -236,7 +279,9 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 7}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, + {"_bgms_sample_omrf_classed", (DL_FUNC) &_bgms_sample_omrf_classed, 8}, {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 10}, + {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 15}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/base_model.cpp b/src/base_model.cpp index e69de29b..7ffcdce5 100644 --- a/src/base_model.cpp +++ b/src/base_model.cpp @@ -0,0 +1,3 @@ +#include "base_model.h" + +// BaseModel is a header-only abstract class with no cpp implementations needed diff --git a/src/base_model.h b/src/base_model.h index 10f443dc..bee092ad 100644 --- a/src/base_model.h +++ b/src/base_model.h @@ -4,6 +4,10 @@ #include #include +// Forward declarations +struct SamplerResult; +struct SafeRNG; + class BaseModel { public: virtual ~BaseModel() = default; @@ -11,6 +15,8 @@ class BaseModel { // Capability queries virtual bool has_gradient() const { return false; } virtual bool has_adaptive_mh() const { return false; } + virtual bool has_nuts() const { return has_gradient(); } + virtual bool has_edge_selection() const { return false; } // Core methods (to be overridden by derived classes) virtual double logp(const arma::vec& parameters) = 0; @@ -35,15 +41,34 @@ class BaseModel { throw std::runtime_error("do_one_mh_step method must be implemented in derived class"); } - virtual arma::vec get_vectorized_parameters() { + // Edge selection (for models with spike-and-slab priors) + virtual void update_edge_indicators() { + throw std::runtime_error("update_edge_indicators not implemented for this model"); + } + + virtual arma::vec get_vectorized_parameters() const { throw std::runtime_error("get_vectorized_parameters method must be implemented in derived class"); } + virtual void set_vectorized_parameters(const arma::vec& parameters) { + throw std::runtime_error("set_vectorized_parameters method must be implemented in derived class"); + } + virtual arma::ivec get_vectorized_indicator_parameters() { throw std::runtime_error("get_vectorized_indicator_parameters method must be implemented in derived class"); } - // Return dimensionality of the parameter space + // Full parameter dimension (for fixed-size output, includes all possible params) + virtual size_t full_parameter_dimension() const { + return parameter_dimension(); // Default: same as active dimension + } + + // Get full vectorized parameters (zeros for inactive, for consistent output) + virtual arma::vec get_full_vectorized_parameters() const { + throw std::runtime_error("get_full_vectorized_parameters must be implemented in derived class"); + } + + // Return dimensionality of the active parameter space virtual size_t parameter_dimension() const = 0; virtual void set_seed(int seed) { @@ -54,7 +79,24 @@ class BaseModel { throw std::runtime_error("clone method must be implemented in derived class"); } + // RNG access for samplers + virtual SafeRNG& get_rng() { + throw std::runtime_error("get_rng method must be implemented in derived class"); + } + + // Step size for gradient-based samplers + virtual void set_step_size(double step_size) { step_size_ = step_size; } + virtual double get_step_size() const { return step_size_; } + + // Inverse mass matrix for HMC/NUTS + virtual void set_inv_mass(const arma::vec& inv_mass) { inv_mass_ = inv_mass; } + virtual const arma::vec& get_inv_mass() const { return inv_mass_; } + + // Get active inverse mass (for models with edge selection, may be subset) + virtual arma::vec get_active_inv_mass() const { return inv_mass_; } protected: BaseModel() = default; + double step_size_ = 0.1; + arma::vec inv_mass_; }; diff --git a/src/chainResultNew.h b/src/chainResultNew.h index fa269a54..3ef3e4df 100644 --- a/src/chainResultNew.h +++ b/src/chainResultNew.h @@ -3,30 +3,97 @@ #include #include +/** + * ChainResultNew - Storage for a single MCMC chain's output + * + * Holds samples, diagnostics, and error state for one chain. + * Designed for use with both MH and NUTS/HMC samplers. + */ class ChainResultNew { public: ChainResultNew() {} - bool error = false, - userInterrupt = false; + // Error handling + bool error = false; + bool userInterrupt = false; std::string error_msg; - int chain_id; + // Chain identifier + int chain_id = 0; + + // Parameter samples (param_dim × n_iter) arma::mat samples; + // Edge indicator samples (n_edges × n_iter), only if edge_selection = true + arma::imat indicator_samples; + bool has_indicators = false; + + // NUTS/HMC diagnostics (n_iter), only if using NUTS/HMC + arma::ivec treedepth_samples; + arma::ivec divergent_samples; + arma::vec energy_samples; + bool has_nuts_diagnostics = false; + + /** + * Reserve storage for samples + * @param param_dim Number of parameters per sample + * @param n_iter Number of sampling iterations + */ void reserve(const size_t param_dim, const size_t n_iter) { samples.set_size(param_dim, n_iter); } + + /** + * Reserve storage for edge indicator samples + * @param n_edges Number of edges (p * (p - 1) / 2) + * @param n_iter Number of sampling iterations + */ + void reserve_indicators(const size_t n_edges, const size_t n_iter) { + indicator_samples.set_size(n_edges, n_iter); + has_indicators = true; + } + + /** + * Reserve storage for NUTS diagnostics + * @param n_iter Number of sampling iterations + */ + void reserve_nuts_diagnostics(const size_t n_iter) { + treedepth_samples.set_size(n_iter); + divergent_samples.set_size(n_iter); + energy_samples.set_size(n_iter); + has_nuts_diagnostics = true; + } + + /** + * Store a parameter sample + * @param iter Iteration index (0-based) + * @param sample Parameter vector + */ void store_sample(const size_t iter, const arma::vec& sample) { samples.col(iter) = sample; } - // arma::imat indicator_samples; + /** + * Store edge indicator sample + * @param iter Iteration index (0-based) + * @param indicators Edge indicator vector + */ + void store_indicators(const size_t iter, const arma::ivec& indicators) { + indicator_samples.col(iter) = indicators; + } - // other samples - // arma::ivec treedepth_samples; - // arma::ivec divergent_samples; - // arma::vec energy_samples; - // arma::imat allocation_samples; + /** + * Store NUTS diagnostics for one iteration + * @param iter Iteration index (0-based) + * @param tree_depth Tree depth from NUTS + * @param divergent Whether a divergence occurred + * @param energy Final Hamiltonian energy + */ + void store_nuts_diagnostics(const size_t iter, int tree_depth, bool divergent, double energy) { + treedepth_samples(iter) = tree_depth; + divergent_samples(iter) = divergent ? 1 : 0; + energy_samples(iter) = energy; + } }; + diff --git a/src/ggm_model.h b/src/ggm_model.h index 7434ccf7..34a67359 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -95,6 +95,7 @@ class GaussianVariables : public BaseModel { bool has_gradient() const { return false; } bool has_adaptive_mh() const override { return true; } + bool has_edge_selection() const override { return edge_selection_; } double logp(const arma::vec& parameters) override { // Implement log probability computation @@ -111,20 +112,39 @@ class GaussianVariables : public BaseModel { return dim_; } + // For GGM, full dimension is the same as parameter dimension (no edge selection filtering) + size_t full_parameter_dimension() const override { + return dim_; + } + void set_seed(int seed) override { rng_ = SafeRNG(seed); } - arma::vec get_vectorized_parameters() override { + arma::vec get_vectorized_parameters() const override { // upper triangle of precision_matrix_ + arma::vec result(dim_); size_t e = 0; for (size_t j = 0; j < p_; ++j) { for (size_t i = 0; i <= j; ++i) { - vectorized_parameters_(e) = precision_matrix_(i, j); + result(e) = precision_matrix_(i, j); ++e; } } - return vectorized_parameters_; + return result; + } + + // For GGM, full and active parameter vectors are the same + arma::vec get_full_vectorized_parameters() const override { + arma::vec result(dim_); + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + result(e) = precision_matrix_(i, j); + ++e; + } + } + return result; } arma::ivec get_vectorized_indicator_parameters() override { @@ -196,31 +216,6 @@ class GaussianVariables : public BaseModel { // double diag_log_ratio(const arma::mat& omega, size_t i, double proposal); }; -// class MixedVariableTypes : public BaseModel { -// public: - -// function arma::vec gradient(const arma::vec& parameters) override -// { - -// arma::vec grad = arma::zeros(dim_); - -// size_t from = 0; -// size_t to = 0; -// for (const auto& var_type : variable_types_) { -// to += var_type->parameter_dimension(); -// grad += var_type->gradient(arma::span(parameters, from, to - 1)); -// from = to; -// } -// return grad; -// } - -// private: -// std::vector> variable_types_; -// std::vector interactions_; -// size_t dim_; - -// }; - GaussianVariables createGaussianVariablesFromR( const Rcpp::List& inputFromR, diff --git a/src/mcmc/base_sampler.h b/src/mcmc/base_sampler.h new file mode 100644 index 00000000..4a23157c --- /dev/null +++ b/src/mcmc/base_sampler.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include "mcmc_utils.h" +#include "sampler_config.h" +#include "../base_model.h" + +/** + * BaseSampler - Abstract base class for MCMC samplers + * + * Provides a unified interface for all MCMC sampling algorithms: + * - Random-walk Metropolis-Hastings + * - Hamiltonian Monte Carlo (HMC) + * - No-U-Turn Sampler (NUTS) + * + * All samplers follow the same workflow: + * 1. warmup_step() during warmup (may adapt parameters) + * 2. finalize_warmup() after warmup completes + * 3. sample_step() during sampling (fixed parameters) + * + * The sampler asks the model for logp/gradient evaluations but owns + * the sampling algorithm logic. + */ +class BaseSampler { +public: + virtual ~BaseSampler() = default; + + /** + * Perform one step during warmup phase + * + * During warmup, samplers may adapt their parameters (step size, + * proposal covariance, mass matrix, etc.) + * + * @param model The model to sample from + * @return SamplerResult with new state and diagnostics + */ + virtual SamplerResult warmup_step(BaseModel& model) = 0; + + /** + * Finalize warmup phase + * + * Called after all warmup iterations complete. Samplers should + * fix their adapted parameters for the sampling phase. + */ + virtual void finalize_warmup() {} + + /** + * Perform one step during sampling phase + * + * Sampling steps use fixed parameters (no adaptation). + * + * @param model The model to sample from + * @return SamplerResult with new state and diagnostics + */ + virtual SamplerResult sample_step(BaseModel& model) = 0; + + /** + * Check if this sampler produces NUTS-style diagnostics + * (tree depth, divergences, energy) + */ + virtual bool has_nuts_diagnostics() const { return false; } +}; diff --git a/src/mcmc/hmc_sampler.h b/src/mcmc/hmc_sampler.h new file mode 100644 index 00000000..a4750409 --- /dev/null +++ b/src/mcmc/hmc_sampler.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include "base_sampler.h" +#include "mcmc_utils.h" +#include "mcmc_hmc.h" +#include "sampler_config.h" +#include "../base_model.h" + +/** + * HMCSampler - Hamiltonian Monte Carlo sampler + * + * Uses fixed-length leapfrog integration with optional step size + * adaptation during warmup via dual averaging. + * + * The sampler fully owns all sampling logic. The model only provides: + * - logp_and_gradient(theta): Compute log posterior and gradient + * - get_vectorized_parameters(): Get current state as vector + * - set_vectorized_parameters(theta): Update model state from vector + * - get_active_inv_mass(): Get inverse mass diagonal + * - get_rng(): Get random number generator + */ +class HMCSampler : public BaseSampler { +public: + /** + * Construct HMC sampler with configuration + * @param config Sampler configuration + */ + explicit HMCSampler(const SamplerConfig& config) + : step_size_(config.initial_step_size), + target_acceptance_(config.target_acceptance), + num_leapfrogs_(config.num_leapfrogs), + no_warmup_(config.no_warmup), + warmup_iteration_(0) + { + // Initialize dual averaging state + dual_avg_state_.set_size(3); + dual_avg_state_(0) = std::log(step_size_); + dual_avg_state_(1) = std::log(step_size_); + dual_avg_state_(2) = 0.0; + } + + /** + * Perform one HMC step during warmup (with step size adaptation) + */ + SamplerResult warmup_step(BaseModel& model) override { + SamplerResult result = do_hmc_step(model); + + // Update step size via dual averaging + warmup_iteration_++; + update_step_size_with_dual_averaging( + step_size_, + result.accept_prob, + warmup_iteration_, + dual_avg_state_, + target_acceptance_ + ); + step_size_ = std::exp(dual_avg_state_(0)); + + return result; + } + + /** + * Finalize warmup phase (fix step size to averaged value) + */ + void finalize_warmup() override { + step_size_ = std::exp(dual_avg_state_(1)); + } + + /** + * Perform one HMC step during sampling (fixed step size) + */ + SamplerResult sample_step(BaseModel& model) override { + return do_hmc_step(model); + } + + double get_step_size() const { return step_size_; } + double get_averaged_step_size() const { return std::exp(dual_avg_state_(1)); } + +private: + /** + * Execute one HMC step using the model's interface + */ + SamplerResult do_hmc_step(BaseModel& model) { + // Get current state + arma::vec theta = model.get_vectorized_parameters(); + arma::vec inv_mass = model.get_active_inv_mass(); + SafeRNG& rng = model.get_rng(); + + // Create log posterior and gradient functions that call the model + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + // Call the HMC free function + SamplerResult result = hmc_sampler( + theta, + step_size_, + log_post, + grad_fn, + num_leapfrogs_, + inv_mass, + rng + ); + + // Update model state with new parameters + model.set_vectorized_parameters(result.state); + + return result; + } + + double step_size_; + double target_acceptance_; + int num_leapfrogs_; + int no_warmup_; + int warmup_iteration_; + arma::vec dual_avg_state_; +}; diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h new file mode 100644 index 00000000..475aac57 --- /dev/null +++ b/src/mcmc/mcmc_runner.h @@ -0,0 +1,285 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../base_model.h" +#include "../chainResultNew.h" +#include "../utils/progress_manager.h" +#include "sampler_config.h" +#include "base_sampler.h" +#include "nuts_sampler.h" +#include "hmc_sampler.h" +#include "mh_sampler.h" +#include "mcmc_utils.h" + + +/** + * Create a sampler based on configuration + * + * Factory function that returns the appropriate sampler type. + * + * @param config Sampler configuration + * @return Unique pointer to the created sampler + */ +inline std::unique_ptr create_sampler(const SamplerConfig& config) { + if (config.sampler_type == "nuts") { + return std::make_unique(config, config.no_warmup); + } else if (config.sampler_type == "hmc") { + return std::make_unique(config); + } else { + return std::make_unique(config); + } +} + + +/** + * Run MCMC sampling for a single chain + * + * Supports MH, NUTS, and HMC samplers with optional edge selection. + * Handles warmup adaptation and diagnostic collection. + * + * @param chain_result Output storage for this chain + * @param model The model to sample from + * @param config Sampler configuration + * @param chain_id Chain identifier (0-based) + * @param pm Progress manager for user feedback + */ +inline void run_mcmc_chain( + ChainResultNew& chain_result, + BaseModel& model, + const SamplerConfig& config, + const int chain_id, + ProgressManager& pm +) { + chain_result.chain_id = chain_id + 1; + + const int edge_start = config.get_edge_selection_start(); + + // Create sampler for this chain + auto sampler = create_sampler(config); + + // ========================================================================= + // Warmup phase + // ========================================================================= + for (int iter = 0; iter < config.no_warmup; ++iter) { + + // model->impute_missing_data(); + + // Edge selection starts after edge_start iterations + if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { + model.update_edge_indicators(); + } + + // Sampler step (unified interface) + sampler->warmup_step(model); + + // Progress and interrupt check + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + return; + } + } + + // Finalize warmup (samplers fix their adapted parameters) + sampler->finalize_warmup(); + + // ========================================================================= + // Sampling phase + // ========================================================================= + for (int iter = 0; iter < config.no_iter; ++iter) { + + // model->impute_missing_data(); + + // Edge selection continues during sampling + if (config.edge_selection && model.has_edge_selection()) { + model.update_edge_indicators(); + } + + // Sampler step (unified interface) + SamplerResult result = sampler->sample_step(model); + + // Store NUTS diagnostics if available + if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { + auto* diag = dynamic_cast(result.diagnostics.get()); + if (diag) { + chain_result.store_nuts_diagnostics(iter, diag->tree_depth, diag->divergent, diag->energy); + } + } + + // Store samples + chain_result.store_sample(iter, model.get_full_vectorized_parameters()); + + // Store edge indicators if applicable + if (chain_result.has_indicators) { + chain_result.store_indicators(iter, model.get_vectorized_indicator_parameters()); + } + + // Progress and interrupt check + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + return; + } + } +} + + +/** + * Worker struct for parallel chain execution + */ +struct MCMCChainRunner : public RcppParallel::Worker { + std::vector& results_; + std::vector>& models_; + const SamplerConfig& config_; + ProgressManager& pm_; + + MCMCChainRunner( + std::vector& results, + std::vector>& models, + const SamplerConfig& config, + ProgressManager& pm + ) : + results_(results), + models_(models), + config_(config), + pm_(pm) + {} + + void operator()(std::size_t begin, std::size_t end) { + for (std::size_t i = begin; i < end; ++i) { + ChainResultNew& chain_result = results_[i]; + BaseModel& model = *models_[i]; + model.set_seed(config_.seed + static_cast(i)); + + try { + run_mcmc_chain(chain_result, model, config_, static_cast(i), pm_); + } catch (std::exception& e) { + chain_result.error = true; + chain_result.error_msg = e.what(); + } catch (...) { + chain_result.error = true; + chain_result.error_msg = "Unknown error"; + } + } + } +}; + + +/** + * Run MCMC sampling with parallel chains + * + * Main entry point for multi-chain MCMC. Handles: + * - Chain allocation and model cloning + * - Parallel or sequential execution based on no_threads + * - Result collection + * + * @param model Template model (will be cloned for each chain) + * @param config Sampler configuration + * @param no_chains Number of chains to run + * @param no_threads Number of threads (1 = sequential) + * @param pm Progress manager + * @return Vector of chain results + */ +inline std::vector run_mcmc_sampler( + BaseModel& model, + const SamplerConfig& config, + const int no_chains, + const int no_threads, + ProgressManager& pm +) { + const bool has_nuts_diag = (config.sampler_type == "nuts"); + + // Allocate result storage + std::vector results(no_chains); + for (int c = 0; c < no_chains; ++c) { + results[c].reserve(model.full_parameter_dimension(), config.no_iter); + + if (config.edge_selection) { + size_t n_edges = model.get_vectorized_indicator_parameters().n_elem; + results[c].reserve_indicators(n_edges, config.no_iter); + } + + if (has_nuts_diag) { + results[c].reserve_nuts_diagnostics(config.no_iter); + } + } + + if (no_threads > 1) { + // Multi-threaded execution + std::vector> models; + models.reserve(no_chains); + for (int c = 0; c < no_chains; ++c) { + models.push_back(model.clone()); + models[c]->set_seed(config.seed + c); + } + + MCMCChainRunner runner(results, models, config, pm); + tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); + RcppParallel::parallelFor(0, static_cast(no_chains), runner); + + } else { + // Single-threaded execution + model.set_seed(config.seed); + for (int c = 0; c < no_chains; ++c) { + auto chain_model = model.clone(); + chain_model->set_seed(config.seed + c); + run_mcmc_chain(results[c], *chain_model, config, c, pm); + } + } + + return results; +} + + +/** + * Convert chain results to Rcpp::List format + * + * Creates a standardized output format for both GGM and OMRF models. + * Each chain is a list with: + * - chain_id: Chain identifier + * - samples: Parameter samples matrix (param_dim × n_iter) + * - indicator_samples: Edge indicators (if edge_selection) + * - treedepth / divergent / energy: NUTS diagnostics (if NUTS/HMC) + * - error / error_msg: Error information (if error occurred) + * + * @param results Vector of chain results + * @return Rcpp::List with per-chain output + */ +inline Rcpp::List convert_results_to_list(const std::vector& results) { + Rcpp::List output(results.size()); + + for (size_t i = 0; i < results.size(); ++i) { + const ChainResultNew& chain = results[i]; + Rcpp::List chain_list; + + chain_list["chain_id"] = chain.chain_id; + + if (chain.error) { + chain_list["error"] = true; + chain_list["error_msg"] = chain.error_msg; + } else { + chain_list["error"] = false; + chain_list["samples"] = chain.samples; + chain_list["userInterrupt"] = chain.userInterrupt; + + if (chain.has_indicators) { + chain_list["indicator_samples"] = chain.indicator_samples; + } + + if (chain.has_nuts_diagnostics) { + chain_list["treedepth"] = chain.treedepth_samples; + chain_list["divergent"] = chain.divergent_samples; + chain_list["energy"] = chain.energy_samples; + } + } + + output[i] = chain_list; + } + + return output; +} diff --git a/src/mcmc/mh_sampler.h b/src/mcmc/mh_sampler.h new file mode 100644 index 00000000..8f0ed8b6 --- /dev/null +++ b/src/mcmc/mh_sampler.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include "base_sampler.h" +#include "mcmc_utils.h" +#include "sampler_config.h" +#include "../base_model.h" + +/** + * MHSampler - Metropolis-Hastings sampler + * + * Delegates to the model's component-wise MH updates. The model + * handles proposal adaptation internally during warmup. + * + * This is a thin wrapper that provides a uniform interface consistent + * with other samplers (NUTS, HMC), but the actual sampling logic + * (component-wise updates, Gibbs sweeps, etc.) is model-specific. + */ +class MHSampler : public BaseSampler { +public: + /** + * Construct MH sampler with configuration + * @param config Sampler configuration + */ + explicit MHSampler(const SamplerConfig& config) + : no_warmup_(config.no_warmup), + warmup_iteration_(0) + {} + + /** + * Perform one MH step during warmup + * + * The model handles proposal adaptation internally. + */ + SamplerResult warmup_step(BaseModel& model) override { + warmup_iteration_++; + model.do_one_mh_step(); + + SamplerResult result; + result.state = model.get_full_vectorized_parameters(); + result.accept_prob = 1.0; // Not tracked for component-wise MH + return result; + } + + /** + * Finalize warmup phase + * + * Nothing to do - model handles adaptation internally. + */ + void finalize_warmup() override { + // Model handles proposal finalization internally + } + + /** + * Perform one MH step during sampling + */ + SamplerResult sample_step(BaseModel& model) override { + model.do_one_mh_step(); + + SamplerResult result; + result.state = model.get_full_vectorized_parameters(); + result.accept_prob = 1.0; + return result; + } + +private: + int no_warmup_; + int warmup_iteration_; +}; diff --git a/src/mcmc/nuts_sampler.h b/src/mcmc/nuts_sampler.h new file mode 100644 index 00000000..e1327579 --- /dev/null +++ b/src/mcmc/nuts_sampler.h @@ -0,0 +1,269 @@ +#pragma once + +#include +#include +#include +#include +#include "base_sampler.h" +#include "mcmc_utils.h" +#include "mcmc_nuts.h" +#include "mcmc_adaptation.h" +#include "sampler_config.h" +#include "../base_model.h" + +/** + * NUTSSampler - No-U-Turn Sampler implementation + * + * Provides a clean interface to the NUTS algorithm for any BaseModel + * with gradient support. Handles: + * - Step size adaptation via dual averaging during warmup + * - Mass matrix adaptation using Welford's algorithm during warmup + * - Trajectory simulation with the no-U-turn criterion + * - Diagnostics collection (tree depth, divergence, energy) + * + * The sampler fully owns all sampling logic. The model only provides: + * - logp_and_gradient(theta): Compute log posterior and gradient + * - get_vectorized_parameters(): Get current state as vector + * - set_vectorized_parameters(theta): Update model state from vector + * - get_active_inv_mass(): Get inverse mass diagonal (used for initialization) + * - get_rng(): Get random number generator + * + * Warmup Schedule (Stan-style): + * - Stage 1 (7.5%): Initial adaptation, step size only + * - Stage 2 (82.5%): Mass matrix learning in doubling windows + * - Stage 3 (10%): Final step size tuning with fixed mass + * + * Usage: + * NUTSSampler nuts(config, n_warmup); + * for (iter in warmup) { + * auto result = nuts.warmup_step(model); + * } + * nuts.finalize_warmup(); + * for (iter in sampling) { + * auto result = nuts.sample_step(model); + * } + */ +class NUTSSampler : public BaseSampler { +public: + /** + * Construct NUTS sampler with configuration + * @param config Sampler configuration (step size, target acceptance, etc.) + * @param n_warmup Number of warmup iterations (for scheduling mass matrix adaptation) + */ + explicit NUTSSampler(const SamplerConfig& config, int n_warmup = 1000) + : step_size_(config.initial_step_size), + target_acceptance_(config.target_acceptance), + max_tree_depth_(config.max_tree_depth), + no_warmup_(config.no_warmup), + n_warmup_(n_warmup), + warmup_iteration_(0), + initialized_(false), + step_adapter_(config.initial_step_size) + { + build_warmup_schedule(n_warmup); + } + + /** + * Perform one NUTS step during warmup (with step size and mass matrix adaptation) + * @param model The model to sample from + * @return SamplerResult with state and diagnostics + */ + SamplerResult warmup_step(BaseModel& model) override { + // Initialize on first warmup iteration + if (!initialized_) { + initialize(model); + initialized_ = true; + } + + SamplerResult result = do_nuts_step(model); + + // Adapt step size during all warmup phases + step_adapter_.update(result.accept_prob, target_acceptance_); + step_size_ = step_adapter_.current(); + + // During Stage 2, accumulate samples for mass matrix estimation + if (in_stage2()) { + mass_accumulator_->update(result.state); + + // Check if we're at the end of a window + if (at_window_end()) { + // Update mass matrix from accumulated samples + inv_mass_ = mass_accumulator_->inverse_mass(); + mass_accumulator_->reset(); + + // Restart step size adaptation with new mass matrix + step_adapter_.restart(step_size_); + } + } + + warmup_iteration_++; + return result; + } + + /** + * Finalize warmup phase (fix step size to averaged value) + */ + void finalize_warmup() override { + step_size_ = step_adapter_.averaged(); + } + + /** + * Perform one NUTS step during sampling (fixed step size and mass matrix) + * @param model The model to sample from + * @return SamplerResult with state and diagnostics + */ + SamplerResult sample_step(BaseModel& model) override { + return do_nuts_step(model); + } + + /** + * NUTS produces tree depth, divergence, and energy diagnostics + */ + bool has_nuts_diagnostics() const override { return true; } + + /** + * Get the current (or final) step size + */ + double get_step_size() const { return step_size_; } + + /** + * Get the averaged step size (for reporting after warmup) + */ + double get_averaged_step_size() const { + return step_adapter_.averaged(); + } + + /** + * Get the current inverse mass matrix diagonal + */ + const arma::vec& get_inv_mass() const { return inv_mass_; } + +private: + /** + * Build Stan-style warmup schedule with doubling windows + */ + void build_warmup_schedule(int n_warmup) { + // Stage 1: 7.5% of warmup + stage1_end_ = static_cast(0.075 * n_warmup); + + // Stage 3 starts at 90% of warmup + stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); + + // Stage 2: build doubling windows between stage1_end and stage3_start + window_ends_.clear(); + int cur = stage1_end_; + int wsize = 25; // Initial window size + + while (cur < stage3_start_) { + int win = std::min(wsize, stage3_start_ - cur); + window_ends_.push_back(cur + win); + cur += win; + wsize = std::min(wsize * 2, stage3_start_ - cur); + } + } + + /** + * Check if we're in Stage 2 (mass matrix learning phase) + */ + bool in_stage2() const { + return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; + } + + /** + * Check if we're at the end of a Stage 2 window + */ + bool at_window_end() const { + for (int end : window_ends_) { + if (warmup_iteration_ + 1 == end) { + return true; + } + } + return false; + } + + /** + * Initialize step size and mass matrix on first iteration + */ + void initialize(BaseModel& model) { + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + // Initialize inverse mass to identity (or from model) + inv_mass_ = model.get_active_inv_mass(); + + // Initialize mass matrix accumulator + mass_accumulator_ = std::make_unique( + static_cast(theta.n_elem)); + + // Create log posterior and gradient functions + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + // Use heuristic to find good initial step size + step_size_ = heuristic_initial_step_size( + theta, log_post, grad_fn, rng, target_acceptance_); + + // Restart dual averaging with the heuristic step size + step_adapter_.restart(step_size_); + } + + /** + * Execute one NUTS step using the sampler's learned mass matrix + */ + SamplerResult do_nuts_step(BaseModel& model) { + // Get current state + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + // Create log posterior and gradient functions that call the model + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + // Call the NUTS free function with our learned inverse mass + SamplerResult result = nuts_sampler( + theta, + step_size_, + log_post, + grad_fn, + inv_mass_, + rng, + max_tree_depth_ + ); + + // Update model state with new parameters + model.set_vectorized_parameters(result.state); + + return result; + } + + // Configuration + double step_size_; + double target_acceptance_; + int max_tree_depth_; + int no_warmup_; + int n_warmup_; + + // State tracking + int warmup_iteration_; + bool initialized_; + + // Step size adaptation + DualAveraging step_adapter_; + + // Mass matrix adaptation + arma::vec inv_mass_; + std::unique_ptr mass_accumulator_; + + // Warmup schedule + int stage1_end_; + int stage3_start_; + std::vector window_ends_; +}; diff --git a/src/mcmc/sampler_config.h b/src/mcmc/sampler_config.h new file mode 100644 index 00000000..7aee7469 --- /dev/null +++ b/src/mcmc/sampler_config.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +/** + * SamplerConfig - Configuration for MCMC sampling + * + * Holds all settings for the generic MCMC runner, including: + * - Sampler type selection (MH, NUTS, HMC) + * - Iteration counts + * - NUTS/HMC specific parameters + * - Edge selection settings + */ +struct SamplerConfig { + // Sampler type: "adaptive_metropolis", "nuts", "hmc" + std::string sampler_type = "adaptive_metropolis"; + + // Iteration counts + int no_iter = 1000; + int no_warmup = 500; + + // NUTS/HMC parameters + int max_tree_depth = 10; + int num_leapfrogs = 10; // For HMC only + double initial_step_size = 0.1; + double target_acceptance = 0.8; + + // Edge selection settings + bool edge_selection = false; + int edge_selection_start = -1; // -1 = no_warmup / 2 (default) + + // Random seed + int seed = 42; + + // Constructor with defaults + SamplerConfig() = default; + + // Get actual edge selection start iteration + int get_edge_selection_start() const { + if (edge_selection_start < 0) { + return no_warmup / 2; // Default: start at half of warmup + } + return edge_selection_start; + } +}; diff --git a/src/mixedVariables.cpp b/src/mixedVariables.cpp new file mode 100644 index 00000000..e4510128 --- /dev/null +++ b/src/mixedVariables.cpp @@ -0,0 +1,43 @@ +#include "ggm_model.h" +#include "mixedVariables.h" +#include "rng/rng_utils.h" + + +void MixedVariableTypes::instantiate_variable_types(Rcpp::List input_from_R) +{ + // instantiate variable_types_ + for (Rcpp::List var_type_list : input_from_R) { + std::string type = Rcpp::as(var_type_list["type"]); + if (type == "Continuous") { + variable_types_.push_back(std::make_unique( + Rcpp::as(var_type_list["observations"]), + Rcpp::as(var_type_list["inclusion_probability"]), + Rcpp::as(var_type_list["initial_edge_indicators"]), + Rcpp::as(var_type_list["edge_selection"]) + )); + // } else if (type == "Ordinal") { + // variable_types_.push_back(std::make_unique( + // var_type_list["observations"], + // var_type_list["inclusion_probability"], + // var_type_list["initial_edge_indicators"], + // var_type_list["edge_selection"] + // )); + // } else if (type == "Blume-Capel") { + // variable_types_.push_back(std::make_unique( + // var_type_list["observations"], + // var_type_list["inclusion_probability"], + // var_type_list["initial_edge_indicators"], + // var_type_list["edge_selection"] + // )); + // } else if (type == "Count") { + // variable_types_.push_back(std::make_unique( + // var_type_list["observations"], + // var_type_list["inclusion_probability"], + // var_type_list["initial_edge_indicators"], + // var_type_list["edge_selection"] + // )); + } else { + throw std::runtime_error("MixedVariableTypes received an unknown variable type in sublist fro input_from_R: " + type); + } + } +} \ No newline at end of file diff --git a/src/mixedVariables.h b/src/mixedVariables.h new file mode 100644 index 00000000..66304b06 --- /dev/null +++ b/src/mixedVariables.h @@ -0,0 +1,169 @@ +#pragma once + +#include +#include +#include +#include "base_model.h" + + +// Forward declaration - this class is work in progress and not yet functional +class MixedVariableTypes : public BaseModel { +public: + + // Constructor + MixedVariableTypes( + Rcpp::List input_from_R, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const bool edge_selection = true + ) + { + + instantiate_variable_types(input_from_R); + // instantiate_variable_interactions(); // TODO: not yet implemented + + dim_ = 0; + for (const auto& var_type : variable_types_) { + dim_ += var_type->parameter_dimension(); + } + // dim_ += interactions_.size(); // TODO: proper dimension calculation + } + + // Capability queries + bool has_gradient() const override + { + for (const auto& var_type : variable_types_) { + if (var_type->has_gradient()) { + return true; + } + } + return false; + } + bool has_adaptive_mh() const override + { + for (const auto& var_type : variable_types_) { + if (var_type->has_adaptive_mh()) { + return true; + } + } + return false; + } + + // Return dimensionality of the parameter space + size_t parameter_dimension() const override { + return dim_; + } + + arma::vec get_vectorized_parameters() const override { + arma::vec result(dim_); + size_t current = 0; + for (size_t i = 0; i < variable_types_.size(); ++i) { + arma::vec var_params = variable_types_[i]->get_vectorized_parameters(); + result.subvec(current, current + var_params.n_elem - 1) = var_params; + current += var_params.n_elem; + } + for (size_t i = 0; i < interactions_.size(); ++i) { + const arma::mat& interactions_mat = interactions_[i]; + for (size_t c = 0; c < interactions_mat.n_cols; ++c) { + for (size_t r = 0; r < interactions_mat.n_rows; ++r) { + result(current) = interactions_mat(r, c); + ++current; + } + } + } + return result; + } + + arma::ivec get_vectorized_indicator_parameters() override { + for (size_t i = 0; i < variable_types_.size(); ++i) { + auto& [from, to] = indicator_parameters_indices_[i]; + vectorized_indicator_parameters_.subvec(from, to) = variable_types_[i]->get_vectorized_indicator_parameters(); + } + size_t current = indicator_parameters_indices_.empty() ? 0 : indicator_parameters_indices_.back().second + 1; + for (size_t i = 0; i < interactions_indicators_.size(); ++i) { + const arma::imat& indicator_mat = interactions_indicators_[i]; + for (size_t c = 0; c < indicator_mat.n_cols; ++c) { + for (size_t r = 0; r < indicator_mat.n_rows; ++r) { + vectorized_indicator_parameters_(current) = indicator_mat(r, c); + ++current; + } + } + } + + return vectorized_indicator_parameters_; + } + + + double logp(const arma::vec& parameters) override + { + double total_logp = 0.0; + for (size_t i = 0; i < variable_types_.size(); ++i) { + auto& [from, to] = parameters_indices_[i]; + // need to do some transformation here! + arma::vec var_params = parameters.subvec(from, to); + total_logp += variable_types_[i]->logp(var_params); + } + // interactions log-probability can be added here if needed + return total_logp; + } + + arma::vec gradient(const arma::vec& parameters) override { + + // TODO: only should call the gradient for variable types that have it + // the rest are assumed to be constant, so have gradient zero + arma::vec total_gradient = arma::zeros(parameters.n_elem); + for (size_t i = 0; i < variable_types_.size(); ++i) + { + if (!variable_types_[i]->has_gradient()) { + continue; + } + auto& [from, to] = parameters_indices_[i]; + arma::vec var_params = parameters.subvec(from, to); + // maybe need to do some transformation here! + arma::vec var_gradient = variable_types_[i]->gradient(var_params); + total_gradient.subvec(from, to) = var_gradient; + } + + return total_gradient; + } + + std::pair logp_and_gradient( + const arma::vec& parameters) override { + if (!has_gradient()) { + throw std::runtime_error("Gradient not implemented for this model"); + } + return {logp(parameters), gradient(parameters)}; + } + + void do_one_mh_step() override { + for (auto& var_type : variable_types_) { + var_type->do_one_mh_step(); + } + } + + void set_seed(int seed) override { + for (auto& var_type : variable_types_) { + var_type->set_seed(seed); + } + } + + std::unique_ptr clone() const override { + throw std::runtime_error("clone method not yet implemented for MixedVariableTypes"); + } + + +private: + std::vector> variable_types_; + std::vector interactions_; + std::vector interactions_indicators_; + size_t dim_; + arma::vec vectorized_parameters_; + arma::ivec vectorized_indicator_parameters_; + arma::ivec indices_from_; + arma::ivec indices_to_; + std::vector> parameters_indices_; + std::vector> indicator_parameters_indices_; + + void instantiate_variable_types(const Rcpp::List input_from_R); + +}; diff --git a/src/omrf_model.cpp b/src/omrf_model.cpp new file mode 100644 index 00000000..7367dba9 --- /dev/null +++ b/src/omrf_model.cpp @@ -0,0 +1,1424 @@ +#include +#include "omrf_model.h" +#include "adaptiveMetropolis.h" +#include "rng/rng_utils.h" +#include "mcmc/mcmc_hmc.h" +#include "mcmc/mcmc_nuts.h" +#include "mcmc/mcmc_rwm.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/mcmc_adaptation.h" +#include "mcmc/mcmc_runner.h" +#include "math/explog_switch.h" +#include "utils/common_helpers.h" +#include "utils/variable_helpers.h" + + +// ============================================================================= +// Constructor +// ============================================================================= + +OMRFModel::OMRFModel( + const arma::imat& observations, + const arma::ivec& num_categories, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const arma::uvec& is_ordinal_variable, + const arma::ivec& baseline_category, + double main_alpha, + double main_beta, + double pairwise_scale, + bool edge_selection +) : + n_(observations.n_rows), + p_(observations.n_cols), + observations_(observations), + num_categories_(num_categories), + is_ordinal_variable_(is_ordinal_variable), + baseline_category_(baseline_category), + inclusion_probability_(inclusion_probability), + main_alpha_(main_alpha), + main_beta_(main_beta), + pairwise_scale_(pairwise_scale), + edge_selection_(edge_selection), + edge_selection_active_(false), + proposal_(AdaptiveProposal(1, 500)), // Will be resized later + step_size_(0.1), + has_missing_(false), + gradient_cache_valid_(false) +{ + // Initialize parameter dimensions + num_main_ = count_num_main_effects_internal(); + num_pairwise_ = (p_ * (p_ - 1)) / 2; + + // Initialize parameters + int max_cats = num_categories_.max(); + main_effects_ = arma::zeros(p_, max_cats); + pairwise_effects_ = arma::zeros(p_, p_); + edge_indicators_ = initial_edge_indicators; + + // Initialize proposal SDs + proposal_sd_main_ = arma::ones(p_, max_cats) * 0.5; + proposal_sd_pairwise_ = arma::ones(p_, p_) * 0.5; + + // Initialize adaptive proposal + proposal_ = AdaptiveProposal(num_main_ + num_pairwise_, 500); + + // Initialize mass matrix + inv_mass_ = arma::ones(num_main_ + num_pairwise_); + + // Pre-compute observations as double (for efficient matrix operations) + observations_double_ = arma::conv_to::from(observations_); + + // Compute sufficient statistics + compute_sufficient_statistics(); + + // Initialize residual matrix + update_residual_matrix(); + + // Build interaction index + build_interaction_index(); +} + + +// ============================================================================= +// Copy constructor +// ============================================================================= + +OMRFModel::OMRFModel(const OMRFModel& other) + : BaseModel(other), + n_(other.n_), + p_(other.p_), + observations_(other.observations_), + observations_double_(other.observations_double_), + num_categories_(other.num_categories_), + is_ordinal_variable_(other.is_ordinal_variable_), + baseline_category_(other.baseline_category_), + counts_per_category_(other.counts_per_category_), + blume_capel_stats_(other.blume_capel_stats_), + pairwise_stats_(other.pairwise_stats_), + residual_matrix_(other.residual_matrix_), + main_effects_(other.main_effects_), + pairwise_effects_(other.pairwise_effects_), + edge_indicators_(other.edge_indicators_), + inclusion_probability_(other.inclusion_probability_), + main_alpha_(other.main_alpha_), + main_beta_(other.main_beta_), + pairwise_scale_(other.pairwise_scale_), + edge_selection_(other.edge_selection_), + edge_selection_active_(other.edge_selection_active_), + num_main_(other.num_main_), + num_pairwise_(other.num_pairwise_), + proposal_(other.proposal_), + proposal_sd_main_(other.proposal_sd_main_), + proposal_sd_pairwise_(other.proposal_sd_pairwise_), + rng_(other.rng_), + step_size_(other.step_size_), + inv_mass_(other.inv_mass_), + has_missing_(other.has_missing_), + missing_index_(other.missing_index_), + grad_obs_cache_(other.grad_obs_cache_), + index_matrix_cache_(other.index_matrix_cache_), + gradient_cache_valid_(other.gradient_cache_valid_), + interaction_index_(other.interaction_index_) +{ +} + + +// ============================================================================= +// Sufficient statistics computation +// ============================================================================= + +void OMRFModel::compute_sufficient_statistics() { + int max_cats = num_categories_.max(); + + // Category counts for ordinal variables + counts_per_category_ = arma::zeros(max_cats + 1, p_); + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + for (size_t i = 0; i < n_; ++i) { + int cat = observations_(i, v); + if (cat >= 0 && cat <= num_categories_(v)) { + counts_per_category_(cat, v)++; + } + } + } + } + + // Blume-Capel statistics (linear and quadratic sums) + blume_capel_stats_ = arma::zeros(2, p_); + for (size_t v = 0; v < p_; ++v) { + if (!is_ordinal_variable_(v)) { + int baseline = baseline_category_(v); + for (size_t i = 0; i < n_; ++i) { + int s = observations_(i, v) - baseline; + blume_capel_stats_(0, v) += s; // linear + blume_capel_stats_(1, v) += s * s; // quadratic + } + } + } + + // Pairwise statistics (X^T X) - use pre-computed transformed observations + arma::mat ps = observations_double_.t() * observations_double_; + pairwise_stats_ = arma::conv_to::from(ps); +} + + +size_t OMRFModel::count_num_main_effects_internal() const { + size_t count = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + count += num_categories_(v); + } else { + count += 2; // linear and quadratic for Blume-Capel + } + } + return count; +} + + +void OMRFModel::build_interaction_index() { + interaction_index_ = arma::zeros(num_pairwise_, 3); + int idx = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + interaction_index_(idx, 0) = idx; + interaction_index_(idx, 1) = v1; + interaction_index_(idx, 2) = v2; + idx++; + } + } +} + + +void OMRFModel::update_residual_matrix() { + // Use pre-computed transformed observations (computed once in constructor) + residual_matrix_ = observations_double_ * pairwise_effects_; +} + + +void OMRFModel::set_pairwise_effects(const arma::mat& pairwise_effects) { + pairwise_effects_ = pairwise_effects; + update_residual_matrix(); + invalidate_gradient_cache(); +} + + +// ============================================================================= +// BaseModel interface implementation +// ============================================================================= + +size_t OMRFModel::parameter_dimension() const { + // Count active parameters: main effects + included pairwise effects + size_t active = num_main_; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + active++; + } + } + } + return active; +} + + +void OMRFModel::set_seed(int seed) { + rng_ = SafeRNG(seed); +} + + +std::unique_ptr OMRFModel::clone() const { + return std::make_unique(*this); +} + + +void OMRFModel::set_adaptive_proposal(AdaptiveProposal proposal) { + proposal_ = proposal; +} + + +// ============================================================================= +// Parameter vectorization +// ============================================================================= + +arma::vec OMRFModel::vectorize_parameters() const { + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + arma::vec param_vec(num_main_ + num_active); + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + param_vec(offset++) = main_effects_(v, c); + } + } else { + param_vec(offset++) = main_effects_(v, 0); // linear + param_vec(offset++) = main_effects_(v, 1); // quadratic + } + } + + // Active pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + param_vec(offset++) = pairwise_effects_(v1, v2); + } + } + } + + return param_vec; +} + + +void OMRFModel::unvectorize_parameters(const arma::vec& param_vec) { + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + main_effects_(v, c) = param_vec(offset++); + } + } else { + main_effects_(v, 0) = param_vec(offset++); // linear + main_effects_(v, 1) = param_vec(offset++); // quadratic + } + } + + // Active pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + double val = param_vec(offset++); + pairwise_effects_(v1, v2) = val; + pairwise_effects_(v2, v1) = val; + } + } + } + + update_residual_matrix(); + invalidate_gradient_cache(); +} + + +arma::vec OMRFModel::get_vectorized_parameters() const { + return vectorize_parameters(); +} + + +void OMRFModel::set_vectorized_parameters(const arma::vec& parameters) { + unvectorize_parameters(parameters); +} + + +arma::vec OMRFModel::get_full_vectorized_parameters() const { + // Fixed-size vector: all main effects + ALL pairwise effects + arma::vec param_vec(num_main_ + num_pairwise_); + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + param_vec(offset++) = main_effects_(v, c); + } + } else { + param_vec(offset++) = main_effects_(v, 0); // linear + param_vec(offset++) = main_effects_(v, 1); // quadratic + } + } + + // ALL pairwise effects (zeros for inactive edges) + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + param_vec(offset++) = pairwise_effects_(v1, v2); + } + } + + return param_vec; +} + + +arma::ivec OMRFModel::get_vectorized_indicator_parameters() { + arma::ivec indicators(num_pairwise_); + int idx = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + indicators(idx++) = edge_indicators_(v1, v2); + } + } + return indicators; +} + + +arma::vec OMRFModel::get_active_inv_mass() const { + if (!edge_selection_active_) { + return inv_mass_; + } + + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + arma::vec active_inv_mass(num_main_ + num_active); + active_inv_mass.head(num_main_) = inv_mass_.head(num_main_); + + int offset_full = num_main_; + int offset_active = num_main_; + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + active_inv_mass(offset_active) = inv_mass_(offset_full); + offset_active++; + } + offset_full++; + } + } + + return active_inv_mass; +} + + +void OMRFModel::vectorize_parameters_into(arma::vec& param_vec) const { + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + // Resize if needed (should rarely happen after first call) + size_t needed_size = num_main_ + num_active; + if (param_vec.n_elem != needed_size) { + param_vec.set_size(needed_size); + } + + int offset = 0; + + // Main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + param_vec(offset++) = main_effects_(v, c); + } + } else { + param_vec(offset++) = main_effects_(v, 0); // linear + param_vec(offset++) = main_effects_(v, 1); // quadratic + } + } + + // Active pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + param_vec(offset++) = pairwise_effects_(v1, v2); + } + } + } +} + + +void OMRFModel::get_active_inv_mass_into(arma::vec& active_inv_mass) const { + if (!edge_selection_active_) { + // No edge selection - just use full inv_mass + if (active_inv_mass.n_elem != inv_mass_.n_elem) { + active_inv_mass.set_size(inv_mass_.n_elem); + } + active_inv_mass = inv_mass_; + return; + } + + // Count active parameters + int num_active = 0; + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + num_active++; + } + } + } + + size_t needed_size = num_main_ + num_active; + if (active_inv_mass.n_elem != needed_size) { + active_inv_mass.set_size(needed_size); + } + + active_inv_mass.head(num_main_) = inv_mass_.head(num_main_); + + int offset_full = num_main_; + int offset_active = num_main_; + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + active_inv_mass(offset_active) = inv_mass_(offset_full); + offset_active++; + } + offset_full++; + } + } +} + + +// ============================================================================= +// Log-pseudoposterior computation +// ============================================================================= + +double OMRFModel::logp(const arma::vec& parameters) { + // Unvectorize into temporary matrices (safe approach) + arma::mat temp_main = main_effects_; + arma::mat temp_pairwise = pairwise_effects_; + + // Unvectorize parameters into temporaries + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + temp_main(v, c) = parameters(offset++); + } + } else { + temp_main(v, 0) = parameters(offset++); + temp_main(v, 1) = parameters(offset++); + } + } + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + temp_pairwise(v1, v2) = parameters(offset++); + temp_pairwise(v2, v1) = temp_pairwise(v1, v2); + } + } + } + + // Compute residual matrix from temp_pairwise + arma::mat temp_residual = arma::conv_to::from(observations_) * temp_pairwise; + + // Compute log-posterior with temporaries + return log_pseudoposterior_with_state(temp_main, temp_pairwise, temp_residual); +} + + +double OMRFModel::log_pseudoposterior_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat +) const { + double log_post = 0.0; + + auto log_beta_prior = [this](double x) { + return x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + }; + + // Main effect contributions (priors and sufficient statistics) + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + + if (is_ordinal_variable_(v)) { + for (int c = 0; c < num_cats; ++c) { + log_post += log_beta_prior(main_eff(v, c)); + log_post += main_eff(v, c) * counts_per_category_(c + 1, v); + } + } else { + log_post += log_beta_prior(main_eff(v, 0)); + log_post += log_beta_prior(main_eff(v, 1)); + log_post += main_eff(v, 0) * blume_capel_stats_(0, v); + log_post += main_eff(v, 1) * blume_capel_stats_(1, v); + } + } + + // Log-denominator contributions using vectorized helpers + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + arma::vec residual_score = residual_mat.col(v); + arma::vec bound = num_cats * residual_score; + + arma::vec denom(n_, arma::fill::zeros); + if (is_ordinal_variable_(v)) { + // Extract main effect parameters for this variable + arma::vec main_effect_param = main_eff.row(v).cols(0, num_cats - 1).t(); + denom = compute_denom_ordinal(residual_score, main_effect_param, bound); + } else { + int ref = baseline_category_(v); + double lin_effect = main_eff(v, 0); + double quad_effect = main_eff(v, 1); + // This updates bound in-place + denom = compute_denom_blume_capel(residual_score, lin_effect, quad_effect, ref, num_cats, bound); + } + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } + + // Pairwise effect contributions: sufficient statistics + Cauchy prior + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + double effect = pairwise_eff(v1, v2); + // Sufficient statistics term (data likelihood contribution) + log_post += 2.0 * pairwise_stats_(v1, v2) * effect; + // Cauchy prior using R's dcauchy for consistency + log_post += R::dcauchy(effect, 0.0, pairwise_scale_, true); + } + } + } + + return log_post; +} + + +double OMRFModel::log_pseudoposterior_internal() const { + return log_pseudoposterior_with_state(main_effects_, pairwise_effects_, residual_matrix_); +} + + +double OMRFModel::log_pseudoposterior_main_component(int variable, int category, int parameter) const { + double log_post = 0.0; + + // Lambda for Beta-prime prior on main effects (matches original implementation) + // log p(theta) = alpha * theta - (alpha + beta) * log(1 + exp(theta)) + auto log_beta_prior = [this](double x) { + return x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + }; + + int num_cats = num_categories_(variable); + arma::vec bound = num_cats * residual_matrix_.col(variable); + + if (is_ordinal_variable_(variable)) { + // Ordinal variable: use category + log_post += log_beta_prior(main_effects_(variable, category)); + log_post += main_effects_(variable, category) * counts_per_category_(category + 1, variable); + + // Log-denominator contribution + for (size_t i = 0; i < n_; ++i) { + double max_val = 0.0; + for (int c = 0; c < num_cats; ++c) { + double val = main_effects_(variable, c) + (c + 1) * residual_matrix_(i, variable); + if (val > max_val) max_val = val; + } + + double denom = std::exp(-max_val); + for (int c = 0; c < num_cats; ++c) { + double val = main_effects_(variable, c) + (c + 1) * residual_matrix_(i, variable); + denom += std::exp(val - max_val); + } + log_post -= (max_val + std::log(denom)); + } + } else { + // Blume-Capel: use parameter (0 = linear, 1 = quadratic) + log_post += log_beta_prior(main_effects_(variable, parameter)); + log_post += main_effects_(variable, parameter) * blume_capel_stats_(parameter, variable); + + int baseline = baseline_category_(variable); + for (size_t i = 0; i < n_; ++i) { + double max_val = -std::numeric_limits::infinity(); + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + double val = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * residual_matrix_(i, variable); + if (val > max_val) max_val = val; + } + + double denom = 0.0; + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + double val = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * residual_matrix_(i, variable); + denom += std::exp(val - max_val); + } + log_post -= (max_val + std::log(denom)); + } + } + + return log_post; +} + + +double OMRFModel::log_pseudoposterior_pairwise_component(int var1, int var2) const { + double log_post = 2.0 * pairwise_effects_(var1, var2) * pairwise_stats_(var1, var2); + + // Contribution from both variables' pseudo-likelihoods + for (int var : {var1, var2}) { + int num_cats = num_categories_(var); + int other_var = (var == var1) ? var2 : var1; + + for (size_t i = 0; i < n_; ++i) { + double max_val = -std::numeric_limits::infinity(); + + if (is_ordinal_variable_(var)) { + max_val = 0.0; + for (int c = 0; c < num_cats; ++c) { + double val = main_effects_(var, c) + (c + 1) * residual_matrix_(i, var); + if (val > max_val) max_val = val; + } + + double denom = std::exp(-max_val); + for (int c = 0; c < num_cats; ++c) { + double val = main_effects_(var, c) + (c + 1) * residual_matrix_(i, var); + denom += std::exp(val - max_val); + } + log_post -= (max_val + std::log(denom)); + } else { + int baseline = baseline_category_(var); + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + double val = main_effects_(var, 0) * s + main_effects_(var, 1) * s * s + s * residual_matrix_(i, var); + if (val > max_val) max_val = val; + } + + double denom = 0.0; + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + double val = main_effects_(var, 0) * s + main_effects_(var, 1) * s * s + s * residual_matrix_(i, var); + denom += std::exp(val - max_val); + } + log_post -= (max_val + std::log(denom)); + } + } + } + + // Cauchy prior if edge is included + if (edge_indicators_(var1, var2) == 1) { + log_post += R::dcauchy(pairwise_effects_(var1, var2), 0.0, pairwise_scale_, true); + } + + return log_post; +} + + +double OMRFModel::compute_log_likelihood_ratio_for_variable( + int variable, + const arma::vec& interacting_score, + double proposed_state, + double current_state +) const { + double log_ratio = 0.0; + int num_cats = num_categories_(variable); + + for (size_t i = 0; i < n_; ++i) { + double rest_minus = residual_matrix_(i, variable) - current_state * interacting_score(i); + double rest_prop = rest_minus + proposed_state * interacting_score(i); + double rest_curr = rest_minus + current_state * interacting_score(i); + + double max_prop = -std::numeric_limits::infinity(); + double max_curr = -std::numeric_limits::infinity(); + + if (is_ordinal_variable_(variable)) { + max_prop = 0.0; + max_curr = 0.0; + for (int c = 0; c < num_cats; ++c) { + double val_prop = main_effects_(variable, c) + (c + 1) * rest_prop; + double val_curr = main_effects_(variable, c) + (c + 1) * rest_curr; + if (val_prop > max_prop) max_prop = val_prop; + if (val_curr > max_curr) max_curr = val_curr; + } + + double denom_prop = std::exp(-max_prop); + double denom_curr = std::exp(-max_curr); + for (int c = 0; c < num_cats; ++c) { + double val_prop = main_effects_(variable, c) + (c + 1) * rest_prop; + double val_curr = main_effects_(variable, c) + (c + 1) * rest_curr; + denom_prop += std::exp(val_prop - max_prop); + denom_curr += std::exp(val_curr - max_curr); + } + + log_ratio += (max_curr + std::log(denom_curr)) - (max_prop + std::log(denom_prop)); + } else { + int baseline = baseline_category_(variable); + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + double val_prop = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_prop; + double val_curr = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_curr; + if (val_prop > max_prop) max_prop = val_prop; + if (val_curr > max_curr) max_curr = val_curr; + } + + double denom_prop = 0.0; + double denom_curr = 0.0; + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + double val_prop = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_prop; + double val_curr = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_curr; + denom_prop += std::exp(val_prop - max_prop); + denom_curr += std::exp(val_curr - max_curr); + } + + log_ratio += (max_curr + std::log(denom_curr)) - (max_prop + std::log(denom_prop)); + } + } + + return log_ratio; +} + + +double OMRFModel::log_pseudolikelihood_ratio_interaction( + int variable1, + int variable2, + double proposed_state, + double current_state +) const { + double delta = proposed_state - current_state; + double log_ratio = 2.0 * delta * pairwise_stats_(variable1, variable2); + + // For Blume-Capel variables, transform observations by subtracting baseline + auto get_transformed_score = [this](int var) -> arma::vec { + arma::vec score = arma::conv_to::from(observations_.col(var)); + if (!is_ordinal_variable_(var)) { + score -= static_cast(baseline_category_(var)); + } + return score; + }; + + // Contribution from variable1 + arma::vec interacting_score = get_transformed_score(variable2); + log_ratio += compute_log_likelihood_ratio_for_variable(variable1, interacting_score, proposed_state, current_state); + + // Contribution from variable2 + interacting_score = get_transformed_score(variable1); + log_ratio += compute_log_likelihood_ratio_for_variable(variable2, interacting_score, proposed_state, current_state); + + return log_ratio; +} + + +// ============================================================================= +// Gradient computation +// ============================================================================= + +void OMRFModel::ensure_gradient_cache() { + if (gradient_cache_valid_) return; + + // Compute observed gradient and index matrix (constant during MCMC) + int num_active = 0; + index_matrix_cache_ = arma::zeros(p_, p_); + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + index_matrix_cache_(v1, v2) = num_main_ + num_active; + num_active++; + } + } + } + + grad_obs_cache_ = arma::zeros(num_main_ + num_active); + + // Observed statistics for main effects + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + grad_obs_cache_(offset++) = counts_per_category_(c + 1, v); + } + } else { + grad_obs_cache_(offset++) = blume_capel_stats_(0, v); + grad_obs_cache_(offset++) = blume_capel_stats_(1, v); + } + } + + // Observed statistics for pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + grad_obs_cache_(offset++) = 2.0 * pairwise_stats_(v1, v2); + } + } + } + + gradient_cache_valid_ = true; +} + + +arma::vec OMRFModel::gradient(const arma::vec& parameters) { + // Unvectorize into temporary matrices (safe approach) + arma::mat temp_main = main_effects_; + arma::mat temp_pairwise = pairwise_effects_; + + // Unvectorize parameters into temporaries + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + temp_main(v, c) = parameters(offset++); + } + } else { + temp_main(v, 0) = parameters(offset++); + temp_main(v, 1) = parameters(offset++); + } + } + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + temp_pairwise(v1, v2) = parameters(offset++); + temp_pairwise(v2, v1) = temp_pairwise(v1, v2); + } + } + } + + // Compute residual matrix from temp_pairwise + arma::mat temp_residual = arma::conv_to::from(observations_) * temp_pairwise; + + return gradient_with_state(temp_main, temp_pairwise, temp_residual); +} + + +arma::vec OMRFModel::gradient_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat +) const { + // Start with cached observed gradient + arma::vec gradient = grad_obs_cache_; + + // Expected statistics for main and pairwise effects + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + arma::vec residual_score = residual_mat.col(v); + arma::vec bound = num_cats * residual_score; + + if (is_ordinal_variable_(v)) { + // Extract main effect parameters for this variable + arma::vec main_param = main_eff.row(v).cols(0, num_cats - 1).t(); + + // Use optimized helper function + arma::mat probs = compute_probs_ordinal(main_param, residual_score, bound, num_cats); + + // Main effects gradient + for (int c = 0; c < num_cats; ++c) { + gradient(offset + c) -= arma::accu(probs.col(c + 1)); + } + + // Pairwise effects gradient + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + + arma::vec expected_value = arma::zeros(n_); + for (int c = 1; c <= num_cats; ++c) { + expected_value += c * probs.col(c) % observations_double_.col(j); + } + + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += num_cats; + } else { + int ref = baseline_category_(v); + double lin_eff = main_eff(v, 0); + double quad_eff = main_eff(v, 1); + + // Use optimized helper function (updates bound in-place) + arma::mat probs = compute_probs_blume_capel(residual_score, lin_eff, quad_eff, ref, num_cats, bound); + + arma::vec score = arma::regspace(0, num_cats) - static_cast(ref); + arma::vec sq_score = arma::square(score); + + // Main effects gradient + gradient(offset) -= arma::accu(probs * score); + gradient(offset + 1) -= arma::accu(probs * sq_score); + + // Pairwise effects gradient + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + + arma::vec expected_value = arma::zeros(n_); + for (int c = 0; c <= num_cats; ++c) { + int s = c - ref; + expected_value += s * probs.col(c) % observations_double_.col(j); + } + + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += 2; + } + } + + // Prior gradients for main effects (Beta-prime prior) + offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_pars = is_ordinal_variable_(v) ? num_categories_(v) : 2; + for (int c = 0; c < num_pars; ++c) { + double x = main_eff(v, c); + double prob = 1.0 / (1.0 + std::exp(-x)); + gradient(offset + c) += main_alpha_ - (main_alpha_ + main_beta_) * prob; + } + offset += num_pars; + } + + // Cauchy prior gradient for pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + int idx = index_matrix_cache_(v1, v2); + double x = pairwise_eff(v1, v2); + gradient(idx) -= 2.0 * x / (pairwise_scale_ * pairwise_scale_ + x * x); + } + } + } + + return gradient; +} + + +arma::vec OMRFModel::gradient_internal() const { + return gradient_with_state(main_effects_, pairwise_effects_, residual_matrix_); +} + + +std::pair OMRFModel::logp_and_gradient(const arma::vec& parameters) { + // Ensure gradient cache is initialized + ensure_gradient_cache(); + + // Use the external-state versions + arma::mat temp_main = main_effects_; + arma::mat temp_pairwise = pairwise_effects_; + + // Unvectorize parameters into temporaries + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + temp_main(v, c) = parameters(offset++); + } + } else { + temp_main(v, 0) = parameters(offset++); + temp_main(v, 1) = parameters(offset++); + } + } + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + temp_pairwise(v1, v2) = parameters(offset++); + temp_pairwise(v2, v1) = temp_pairwise(v1, v2); + } + } + } + + arma::mat temp_residual = observations_double_ * temp_pairwise; + + double lp = log_pseudoposterior_with_state(temp_main, temp_pairwise, temp_residual); + arma::vec grad = gradient_with_state(temp_main, temp_pairwise, temp_residual); + + return {lp, grad}; +} + + +// ============================================================================= +// Metropolis-Hastings updates +// ============================================================================= + +void OMRFModel::update_main_effect_parameter(int variable, int category, int parameter) { + double proposal_sd; + double& current = is_ordinal_variable_(variable) + ? main_effects_(variable, category) + : main_effects_(variable, parameter); + + proposal_sd = is_ordinal_variable_(variable) + ? proposal_sd_main_(variable, category) + : proposal_sd_main_(variable, parameter); + + int cat_for_log = is_ordinal_variable_(variable) ? category : -1; + int par_for_log = is_ordinal_variable_(variable) ? -1 : parameter; + + auto log_post = [&](double theta) { + double old_val = current; + current = theta; + double lp = log_pseudoposterior_main_component(variable, cat_for_log, par_for_log); + current = old_val; + return lp; + }; + + SamplerResult result = rwm_sampler(current, proposal_sd, log_post, rng_); + current = result.state[0]; + invalidate_gradient_cache(); +} + + +void OMRFModel::update_pairwise_effect(int var1, int var2) { + if (edge_indicators_(var1, var2) == 0) return; + + double& value = pairwise_effects_(var1, var2); + double proposal_sd = proposal_sd_pairwise_(var1, var2); + double current = value; + + auto log_post = [&](double theta) { + pairwise_effects_(var1, var2) = theta; + pairwise_effects_(var2, var1) = theta; + update_residual_matrix(); + return log_pseudoposterior_pairwise_component(var1, var2); + }; + + SamplerResult result = rwm_sampler(current, proposal_sd, log_post, rng_); + + value = result.state[0]; + pairwise_effects_(var2, var1) = value; + + if (current != value) { + update_residual_matrix(); + } + invalidate_gradient_cache(); +} + + +void OMRFModel::update_edge_indicator(int var1, int var2) { + double current_state = pairwise_effects_(var1, var2); + double proposal_sd = proposal_sd_pairwise_(var1, var2); + + bool proposing_addition = (edge_indicators_(var1, var2) == 0); + double proposed_state = proposing_addition ? rnorm(rng_, current_state, proposal_sd) : 0.0; + + double log_accept = log_pseudolikelihood_ratio_interaction(var1, var2, proposed_state, current_state); + + double incl_prob = inclusion_probability_(var1, var2); + + if (proposing_addition) { + log_accept += R::dcauchy(proposed_state, 0.0, pairwise_scale_, true); + log_accept -= R::dnorm(proposed_state, current_state, proposal_sd, true); + log_accept += MY_LOG(incl_prob) - MY_LOG(1.0 - incl_prob); + } else { + log_accept -= R::dcauchy(current_state, 0.0, pairwise_scale_, true); + log_accept += R::dnorm(current_state, proposed_state, proposal_sd, true); + log_accept -= MY_LOG(incl_prob) - MY_LOG(1.0 - incl_prob); + } + + if (MY_LOG(runif(rng_)) < log_accept) { + int updated = 1 - edge_indicators_(var1, var2); + edge_indicators_(var1, var2) = updated; + edge_indicators_(var2, var1) = updated; + + pairwise_effects_(var1, var2) = proposed_state; + pairwise_effects_(var2, var1) = proposed_state; + + update_residual_matrix(); + invalidate_gradient_cache(); + } +} + + +// ============================================================================= +// Main update methods +// ============================================================================= + +void OMRFModel::do_one_mh_step() { + // Update main effects + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + update_main_effect_parameter(v, c, -1); + } + } else { + for (int p = 0; p < 2; ++p) { + update_main_effect_parameter(v, -1, p); + } + } + } + + // Update pairwise effects + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + update_pairwise_effect(v1, v2); + } + } + + // Update edge indicators if in selection phase + if (edge_selection_active_) { + update_edge_indicators(); + } + + proposal_.increment_iteration(); +} + + +void OMRFModel::update_edge_indicators() { + for (size_t idx = 0; idx < num_pairwise_; ++idx) { + int var1 = interaction_index_(idx, 1); + int var2 = interaction_index_(idx, 2); + update_edge_indicator(var1, var2); + } +} + + +void OMRFModel::initialize_graph() { + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + double p = inclusion_probability_(v1, v2); + int draw = (runif(rng_) < p) ? 1 : 0; + edge_indicators_(v1, v2) = draw; + edge_indicators_(v2, v1) = draw; + if (!draw) { + pairwise_effects_(v1, v2) = 0.0; + pairwise_effects_(v2, v1) = 0.0; + } + } + } + update_residual_matrix(); + invalidate_gradient_cache(); +} + + + +void OMRFModel::impute_missing() { + if (!has_missing_) return; + + // For each missing value, sample from conditional distribution + for (size_t m = 0; m < missing_index_.n_rows; ++m) { + int person = missing_index_(m, 0); + int variable = missing_index_(m, 1); + int num_cats = num_categories_(variable); + + arma::vec log_probs; + if (is_ordinal_variable_(variable)) { + log_probs.set_size(num_cats + 1); + log_probs(0) = 0.0; + for (int c = 0; c < num_cats; ++c) { + log_probs(c + 1) = main_effects_(variable, c) + (c + 1) * residual_matrix_(person, variable); + } + } else { + int baseline = baseline_category_(variable); + log_probs.set_size(num_cats + 1); + for (int c = 0; c <= num_cats; ++c) { + int s = c - baseline; + log_probs(c) = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * residual_matrix_(person, variable); + } + } + + // Sample from categorical + double max_val = log_probs.max(); + arma::vec probs = arma::exp(log_probs - max_val); + probs /= arma::sum(probs); + + double u = runif(rng_); + double cumsum = 0.0; + int new_value = 0; + for (size_t c = 0; c < probs.n_elem; ++c) { + cumsum += probs(c); + if (u < cumsum) { + new_value = c; + break; + } + } + + int old_value = observations_(person, variable); + if (new_value != old_value) { + observations_(person, variable) = new_value; + // Update sufficient statistics + compute_sufficient_statistics(); + update_residual_matrix(); + } + } +} + + +// ============================================================================= +// Factory function +// ============================================================================= + +OMRFModel createOMRFModelFromR( + const Rcpp::List& inputFromR, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + bool edge_selection +) { + arma::imat observations = Rcpp::as(inputFromR["observations"]); + arma::ivec num_categories = Rcpp::as(inputFromR["num_categories"]); + arma::uvec is_ordinal_variable = Rcpp::as(inputFromR["is_ordinal_variable"]); + arma::ivec baseline_category = Rcpp::as(inputFromR["baseline_category"]); + + double main_alpha = inputFromR.containsElementNamed("main_alpha") + ? Rcpp::as(inputFromR["main_alpha"]) : 1.0; + double main_beta = inputFromR.containsElementNamed("main_beta") + ? Rcpp::as(inputFromR["main_beta"]) : 1.0; + double pairwise_scale = inputFromR.containsElementNamed("pairwise_scale") + ? Rcpp::as(inputFromR["pairwise_scale"]) : 2.5; + + return OMRFModel( + observations, + num_categories, + inclusion_probability, + initial_edge_indicators, + is_ordinal_variable, + baseline_category, + main_alpha, + main_beta, + pairwise_scale, + edge_selection + ); +} + + +// ============================================================================= +// R interface: sample_omrf_classed +// ============================================================================= + +// [[Rcpp::export]] +Rcpp::List sample_omrf_classed( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const bool edge_selection, + const std::string& sampler_type, + const int seed +) { + // Create model from R input + OMRFModel model = createOMRFModelFromR( + inputFromR, + prior_inclusion_prob, + initial_edge_indicators, + edge_selection + ); + + // Set random seed + model.set_seed(seed); + + // Storage for samples - use FIXED size (all parameters) + int full_dim = model.full_parameter_dimension(); + arma::mat samples(no_iter, full_dim); + arma::imat indicator_samples; + if (edge_selection) { + int num_edges = (model.get_p() * (model.get_p() - 1)) / 2; + indicator_samples.set_size(no_iter, num_edges); + } + + // NUTS/HMC diagnostics + arma::ivec treedepth_samples; + arma::ivec divergent_samples; + arma::vec energy_samples; + bool use_nuts = (sampler_type == "nuts"); + bool use_hmc = (sampler_type == "hmc"); + if (use_nuts || use_hmc) { + treedepth_samples.set_size(no_iter); + divergent_samples.set_size(no_iter); + energy_samples.set_size(no_iter); + } + + // Create sampler configuration + SamplerConfig config; + config.sampler_type = sampler_type; + config.initial_step_size = 0.1; + config.target_acceptance = 0.8; + config.max_tree_depth = 10; + config.num_leapfrogs = 10; + config.no_warmup = no_warmup; + + // Create appropriate sampler + std::unique_ptr sampler = create_sampler(config); + + // Warmup phase + Rcpp::Rcout << "Running warmup (" << no_warmup << " iterations)..." << std::endl; + for (int iter = 0; iter < no_warmup; ++iter) { + SamplerResult result = sampler->warmup_step(model); + + // Edge selection only after initial warmup period + if (edge_selection && iter > no_warmup / 2) { + model.update_edge_indicators(); + } + + // Check for user interrupt + if ((iter + 1) % 100 == 0) { + Rcpp::checkUserInterrupt(); + Rcpp::Rcout << " Warmup iteration " << (iter + 1) << "/" << no_warmup + << " (step_size=" << model.get_step_size() << ")" << std::endl; + } + } + + // Use averaged step size for sampling + sampler->finalize_warmup(); + Rcpp::Rcout << "Warmup complete." << std::endl; + + // Sampling phase + Rcpp::Rcout << "Running sampling (" << no_iter << " iterations)..." << std::endl; + for (int iter = 0; iter < no_iter; ++iter) { + SamplerResult result = sampler->sample_step(model); + + // Extract NUTS/HMC diagnostics if available + if (sampler->has_nuts_diagnostics()) { + if (auto nuts_diag = std::dynamic_pointer_cast(result.diagnostics)) { + treedepth_samples(iter) = nuts_diag->tree_depth; + divergent_samples(iter) = nuts_diag->divergent ? 1 : 0; + energy_samples(iter) = nuts_diag->energy; + } + } + + if (edge_selection) { + model.update_edge_indicators(); + } + + // Store samples - use FULL vectorization (fixed size) + samples.row(iter) = model.get_full_vectorized_parameters().t(); + + if (edge_selection) { + arma::imat indicators = model.get_edge_indicators(); + int idx = 0; + for (int i = 0; i < static_cast(model.get_p()) - 1; ++i) { + for (int j = i + 1; j < static_cast(model.get_p()); ++j) { + indicator_samples(iter, idx++) = indicators(i, j); + } + } + } + + // Check for user interrupt + if ((iter + 1) % 100 == 0) { + Rcpp::checkUserInterrupt(); + Rcpp::Rcout << " Sampling iteration " << (iter + 1) << "/" << no_iter << std::endl; + } + } + + // Build output list + Rcpp::List output; + output["samples"] = samples; + + if (edge_selection) { + output["indicator_samples"] = indicator_samples; + // Compute posterior mean of edge indicators + arma::vec posterior_mean_indicator = arma::mean(arma::conv_to::from(indicator_samples), 0).t(); + output["posterior_mean_indicator"] = posterior_mean_indicator; + } + + if (use_nuts || use_hmc) { + output["treedepth"] = treedepth_samples; + output["divergent"] = divergent_samples; + output["energy"] = energy_samples; + // Get final step size from sampler (NUTSSampler and HMCSampler have get_step_size()) + output["final_step_size"] = 0.0; // Could add getter to sampler if needed + } + + output["sampler_type"] = sampler_type; + output["no_iter"] = no_iter; + output["no_warmup"] = no_warmup; + output["edge_selection"] = edge_selection; + output["num_variables"] = model.get_p(); + output["num_observations"] = model.get_n(); + + return output; +} \ No newline at end of file diff --git a/src/omrf_model.h b/src/omrf_model.h new file mode 100644 index 00000000..64f2ebc0 --- /dev/null +++ b/src/omrf_model.h @@ -0,0 +1,415 @@ +#pragma once + +#include +#include +#include "base_model.h" +#include "adaptiveMetropolis.h" +#include "rng/rng_utils.h" +#include "mcmc/mcmc_utils.h" + +/** + * OMRFModel - Ordinal Markov Random Field Model + * + * A class-based implementation of the OMRF model for Bayesian inference on + * ordinal and Blume-Capel variables. This class encapsulates: + * - Parameter storage (main effects, pairwise effects, edge indicators) + * - Sufficient statistics computation + * - Log-pseudoposterior and gradient evaluations + * - Adaptive Metropolis-Hastings updates for individual parameters + * - NUTS/HMC updates for joint parameter sampling + * - Edge selection (spike-and-slab) with asymmetric proposals + * + * Inherits from BaseModel for compatibility with the generic MCMC framework. + */ +class OMRFModel : public BaseModel { +public: + + /** + * Constructor from raw observations + * + * @param observations Integer matrix of categorical observations (persons × variables) + * @param num_categories Number of categories per variable + * @param inclusion_probability Prior inclusion probabilities for edges + * @param initial_edge_indicators Initial edge inclusion matrix + * @param is_ordinal_variable Indicator (1 = ordinal, 0 = Blume-Capel) + * @param baseline_category Reference categories for Blume-Capel variables + * @param main_alpha Beta prior hyperparameter α for main effects + * @param main_beta Beta prior hyperparameter β for main effects + * @param pairwise_scale Scale parameter of Cauchy prior on interactions + * @param edge_selection Enable edge selection (spike-and-slab) + */ + OMRFModel( + const arma::imat& observations, + const arma::ivec& num_categories, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + const arma::uvec& is_ordinal_variable, + const arma::ivec& baseline_category, + double main_alpha = 1.0, + double main_beta = 1.0, + double pairwise_scale = 2.5, + bool edge_selection = true + ); + + /** + * Copy constructor for cloning (required for parallel chains) + */ + OMRFModel(const OMRFModel& other); + + // ========================================================================= + // BaseModel interface implementation + // ========================================================================= + + bool has_gradient() const override { return true; } + bool has_adaptive_mh() const override { return true; } + bool has_edge_selection() const override { return edge_selection_; } + + /** + * Compute log-pseudoposterior for given parameter vector + */ + double logp(const arma::vec& parameters) override; + + /** + * Compute gradient of log-pseudoposterior + */ + arma::vec gradient(const arma::vec& parameters) override; + + /** + * Combined log-posterior and gradient evaluation (more efficient) + */ + std::pair logp_and_gradient(const arma::vec& parameters) override; + + /** + * Perform one adaptive MH step (updates all parameters) + */ + void do_one_mh_step() override; + + /** + * Return dimensionality of active parameter space + */ + size_t parameter_dimension() const override; + + /** + * Set random seed for reproducibility + */ + void set_seed(int seed) override; + + /** + * Get vectorized parameters (main effects + active pairwise effects) + */ + arma::vec get_vectorized_parameters() const override; + + /** + * Set parameters from vectorized form + */ + void set_vectorized_parameters(const arma::vec& parameters) override; + + /** + * Get vectorized edge indicators + */ + arma::ivec get_vectorized_indicator_parameters() override; + + /** + * Clone the model for parallel execution + */ + std::unique_ptr clone() const override; + + /** + * Get RNG for samplers + */ + SafeRNG& get_rng() override { return rng_; } + + // ========================================================================= + // OMRF-specific methods + // ========================================================================= + + /** + * Set the adaptive proposal mechanism + */ + void set_adaptive_proposal(AdaptiveProposal proposal); + + /** + * Update edge indicators via Metropolis-Hastings + */ + void update_edge_indicators() override; + + /** + * Initialize random graph structure (for starting edge selection) + */ + void initialize_graph(); + + /** + * Impute missing values (if any) + */ + void impute_missing(); + + // ========================================================================= + // Accessors + // ========================================================================= + + const arma::mat& get_main_effects() const { return main_effects_; } + const arma::mat& get_pairwise_effects() const { return pairwise_effects_; } + const arma::imat& get_edge_indicators() const { return edge_indicators_; } + const arma::mat& get_residual_matrix() const { return residual_matrix_; } + + void set_main_effects(const arma::mat& main_effects) { main_effects_ = main_effects; } + void set_pairwise_effects(const arma::mat& pairwise_effects); + void set_edge_indicators(const arma::imat& edge_indicators) { edge_indicators_ = edge_indicators; } + + size_t num_variables() const { return p_; } + size_t num_observations() const { return n_; } + size_t num_main_effects() const { return num_main_; } + size_t num_pairwise_effects() const { return num_pairwise_; } + + // Shorthand accessors (for interface compatibility) + size_t get_p() const { return p_; } + size_t get_n() const { return n_; } + + // Adaptation control + void set_step_size(double step_size) { step_size_ = step_size; } + double get_step_size() const { return step_size_; } + void set_inv_mass(const arma::vec& inv_mass) { inv_mass_ = inv_mass; } + const arma::vec& get_inv_mass() const { return inv_mass_; } + + /** + * Get full dimension (main + ALL pairwise, regardless of edge indicators) + * Used for fixed-size sample storage + */ + size_t full_parameter_dimension() const override { return num_main_ + num_pairwise_; } + + /** + * Get all parameters in a fixed-size vector (inactive edges are 0) + * Used for sample storage to avoid dimension changes + */ + arma::vec get_full_vectorized_parameters() const override; + + // Proposal SD access (for external adaptation) + arma::mat& get_proposal_sd_main() { return proposal_sd_main_; } + arma::mat& get_proposal_sd_pairwise() { return proposal_sd_pairwise_; } + + // Control edge selection phase + void set_edge_selection_active(bool active) { edge_selection_active_ = active; } + bool is_edge_selection_active() const { return edge_selection_active_; } + +private: + // ========================================================================= + // Data members + // ========================================================================= + + // Data + size_t n_; // Number of observations + size_t p_; // Number of variables + arma::imat observations_; // Categorical observations (n × p) + arma::mat observations_double_; // Observations as double (for efficient matrix ops) + arma::ivec num_categories_; // Categories per variable + arma::uvec is_ordinal_variable_; // 1 = ordinal, 0 = Blume-Capel + arma::ivec baseline_category_; // Reference category for Blume-Capel + + // Sufficient statistics + arma::imat counts_per_category_; // Category counts (max_cats+1 × p) + arma::imat blume_capel_stats_; // [linear_sum, quadratic_sum] for BC vars (2 × p) + arma::imat pairwise_stats_; // X^T X + arma::mat residual_matrix_; // X * pairwise_effects (n × p) + + // Parameters + arma::mat main_effects_; // Main effect parameters (p × max_cats) + arma::mat pairwise_effects_; // Pairwise interaction strengths (p × p, symmetric) + arma::imat edge_indicators_; // Edge inclusion indicators (p × p, symmetric binary) + + // Priors + arma::mat inclusion_probability_; // Prior inclusion probabilities + double main_alpha_; // Beta prior α + double main_beta_; // Beta prior β + double pairwise_scale_; // Cauchy scale for pairwise effects + + // Model configuration + bool edge_selection_; // Enable edge selection + bool edge_selection_active_; // Currently in edge selection phase + + // Dimension tracking + size_t num_main_; // Total number of main effect parameters + size_t num_pairwise_; // Number of possible pairwise effects + + // Adaptive proposals + AdaptiveProposal proposal_; + arma::mat proposal_sd_main_; + arma::mat proposal_sd_pairwise_; + + // RNG + SafeRNG rng_; + + // NUTS/HMC settings + double step_size_; + arma::vec inv_mass_; + + // Missing data handling + bool has_missing_; + arma::imat missing_index_; + + // Cached gradient components + arma::vec grad_obs_cache_; + arma::imat index_matrix_cache_; + bool gradient_cache_valid_; + + // Interaction indexing (for edge updates) + arma::imat interaction_index_; + + // ========================================================================= + // Private helper methods + // ========================================================================= + + /** + * Compute sufficient statistics from observations + */ + void compute_sufficient_statistics(); + + /** + * Count total number of main effect parameters + */ + size_t count_num_main_effects_internal() const; + + /** + * Build interaction index matrix + */ + void build_interaction_index(); + + /** + * Update residual matrix after pairwise effects change + */ + void update_residual_matrix(); + + /** + * Invalidate gradient cache (call after parameter changes) + */ + void invalidate_gradient_cache() { gradient_cache_valid_ = false; } + + /** + * Ensure gradient cache is valid + */ + void ensure_gradient_cache(); + + // ------------------------------------------------------------------------- + // Log-posterior components + // ------------------------------------------------------------------------- + + /** + * Full log-pseudoposterior (internal, uses current state) + */ + double log_pseudoposterior_internal() const; + + /** + * Full log-pseudoposterior with external state (avoids modifying model) + */ + double log_pseudoposterior_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat + ) const; + + /** + * Log-posterior for single main effect component + */ + double log_pseudoposterior_main_component(int variable, int category, int parameter) const; + + /** + * Log-posterior for single pairwise interaction + */ + double log_pseudoposterior_pairwise_component(int var1, int var2) const; + + /** + * Log-likelihood ratio for variable update + */ + double compute_log_likelihood_ratio_for_variable( + int variable, + const arma::vec& interacting_score, + double proposed_state, + double current_state + ) const; + + /** + * Log-pseudolikelihood ratio for interaction update + */ + double log_pseudolikelihood_ratio_interaction( + int variable1, + int variable2, + double proposed_state, + double current_state + ) const; + + // ------------------------------------------------------------------------- + // Gradient components + // ------------------------------------------------------------------------- + + /** + * Compute gradient with current state + */ + arma::vec gradient_internal() const; + + /** + * Compute gradient with external state (avoids modifying model) + */ + arma::vec gradient_with_state( + const arma::mat& main_eff, + const arma::mat& pairwise_eff, + const arma::mat& residual_mat + ) const; + + // ------------------------------------------------------------------------- + // Parameter vectorization + // ------------------------------------------------------------------------- + + /** + * Flatten parameters to vector + */ + arma::vec vectorize_parameters() const; + + /** + * Flatten parameters into pre-allocated vector (avoids allocation) + */ + void vectorize_parameters_into(arma::vec& param_vec) const; + + /** + * Unflatten vector to parameter matrices + */ + void unvectorize_parameters(const arma::vec& param_vec); + + /** + * Extract active inverse mass (only for included edges) + */ + arma::vec get_active_inv_mass() const; + + /** + * Extract active inverse mass into pre-allocated vector (avoids allocation) + */ + void get_active_inv_mass_into(arma::vec& active_inv_mass) const; + + // ------------------------------------------------------------------------- + // Metropolis updates + // ------------------------------------------------------------------------- + + /** + * Update single main effect parameter via RWM + */ + void update_main_effect_parameter(int variable, int category, int parameter); + + /** + * Update single pairwise effect via RWM + */ + void update_pairwise_effect(int var1, int var2); + + /** + * Update single edge indicator (spike-and-slab) + */ + void update_edge_indicator(int var1, int var2); +}; + + +/** + * Factory function to create OMRFModel from R inputs + */ +OMRFModel createOMRFModelFromR( + const Rcpp::List& inputFromR, + const arma::mat& inclusion_probability, + const arma::imat& initial_edge_indicators, + bool edge_selection = true +); diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index a9af132a..bd635f46 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -7,157 +7,8 @@ #include "ggm_model.h" #include "utils/progress_manager.h" #include "chainResultNew.h" - -void run_mcmc_sampler_single_thread( - ChainResultNew& chain_result, - BaseModel& model, - const int no_iter, - const int no_warmup, - const int chain_id, - ProgressManager& pm -) { - - chain_result.chain_id = chain_id + 1; - size_t i = 0; - for (size_t iter = 0; iter < no_iter + no_warmup; ++iter) { - - model.do_one_mh_step(); - - // update hyperparameters (BetaBinomial, SBM, etc.) - - if (iter >= no_warmup) { - - chain_result.store_sample(i, model.get_vectorized_parameters()); - ++i; - } - - pm.update(chain_id); - if (pm.shouldExit()) { - chain_result.userInterrupt = true; - break; - } - } -} - -struct GGMChainRunner : public RcppParallel::Worker { - std::vector& results_; - std::vector>& models_; - size_t no_iter_; - size_t no_warmup_; - int seed_; - ProgressManager& pm_; - - GGMChainRunner( - std::vector& results, - std::vector>& models, - const size_t no_iter, - const size_t no_warmup, - const int seed, - ProgressManager& pm - ) : - results_(results), - models_(models), - no_iter_(no_iter), - no_warmup_(no_warmup), - seed_(seed), - pm_(pm) - {} - - void operator()(std::size_t begin, std::size_t end) { - for (std::size_t i = begin; i < end; ++i) { - - ChainResultNew& chain_result = results_[i]; - BaseModel& model = *models_[i]; - model.set_seed(seed_ + i); - try { - - run_mcmc_sampler_single_thread(chain_result, model, no_iter_, no_warmup_, i, pm_); - - } catch (std::exception& e) { - chain_result.error = true; - chain_result.error_msg = e.what(); - } catch (...) { - chain_result.error = true; - chain_result.error_msg = "Unknown error"; - } - } - } -}; - -void run_mcmc_sampler_threaded( - std::vector& results, - std::vector>& models, - const int no_iter, - const int no_warmup, - const int seed, - const int no_threads, - ProgressManager& pm -) { - - GGMChainRunner runner(results, models, no_iter, no_warmup, seed, pm); - tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); - RcppParallel::parallelFor(0, results.size(), runner); -} - - -std::vector run_mcmc_sampler( - BaseModel& model, - const int no_iter, - const int no_warmup, - const int no_chains, - const int seed, - const int no_threads, - ProgressManager& pm -) { - - Rcpp::Rcout << "Allocating results objects..." << std::endl; - std::vector results(no_chains); - for (size_t c = 0; c < no_chains; ++c) { - results[c].reserve(model.parameter_dimension(), no_iter); - } - - if (no_threads > 1) { - - Rcpp::Rcout << "Running multi-threaded MCMC sampling..." << std::endl; - std::vector> models; - models.reserve(no_chains); - for (size_t c = 0; c < no_chains; ++c) { - models.push_back(model.clone()); // deep copy via virtual clone - } - run_mcmc_sampler_threaded(results, models, no_iter, no_warmup, seed, no_threads, pm); - - } else { - - model.set_seed(seed); - Rcpp::Rcout << "Running single-threaded MCMC sampling..." << std::endl; - // TODO: this is actually not correct, each chain should have its own model object - // now chain 2 continues from chain 1 state - for (size_t c = 0; c < no_chains; ++c) { - run_mcmc_sampler_single_thread(results[c], model, no_iter, no_warmup, c, pm); - } - - } - return results; -} - -Rcpp::List convert_sampler_output_to_ggm_result(const std::vector& results) { - - Rcpp::List output(results.size()); - for (size_t i = 0; i < results.size(); ++i) { - - Rcpp::List chain_i; - chain_i["chain_id"] = results[i].chain_id; - if (results[i].error) { - chain_i["error"] = results[i].error_msg; - } else { - chain_i["samples"] = results[i].samples; - chain_i["userInterrupt"] = results[i].userInterrupt; - - } - output[i] = chain_i; - } - return output; -} +#include "mcmc/mcmc_runner.h" +#include "mcmc/sampler_config.h" // [[Rcpp::export]] Rcpp::List sample_ggm( @@ -173,18 +24,30 @@ Rcpp::List sample_ggm( const int progress_type ) { - // should be done dynamically - // also adaptation method should be specified differently - // GaussianVariables model(X, prior_inclusion_prob, initial_edge_indicators, edge_selection); - GaussianVariables model = createGaussianVariablesFromR(inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + // Create model from R input + GaussianVariables model = createGaussianVariablesFromR( + inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + // Configure sampler - GGM only supports MH + SamplerConfig config; + config.sampler_type = "mh"; + config.no_iter = no_iter; + config.no_warmup = no_warmup; + config.edge_selection = edge_selection; + config.seed = seed; + // Edge selection starts at no_warmup/2 by default (handled by get_edge_selection_start()) + + // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); - std::vector output = run_mcmc_sampler(model, no_iter, no_warmup, no_chains, seed, no_threads, pm); + // Run MCMC using unified infrastructure + std::vector results = run_mcmc_sampler( + model, config, no_chains, no_threads, pm); - Rcpp::List ggm_result = convert_sampler_output_to_ggm_result(output); + // Convert to R list format + Rcpp::List output = convert_results_to_list(results); pm.finish(); - return ggm_result; + return output; } \ No newline at end of file diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp new file mode 100644 index 00000000..6e6c45a0 --- /dev/null +++ b/src/sample_omrf.cpp @@ -0,0 +1,85 @@ +/** + * sample_omrf.cpp - R interface for OMRF model sampling + * + * Uses the unified MCMC runner infrastructure to sample from OMRF models. + * Supports MH, NUTS, and HMC samplers with optional edge selection. + */ +#include +#include +#include + +#include "omrf_model.h" +#include "utils/progress_manager.h" +#include "chainResultNew.h" +#include "mcmc/mcmc_runner.h" +#include "mcmc/sampler_config.h" + +/** + * R-exported function to sample from an OMRF model + * + * @param inputFromR List with model specification + * @param prior_inclusion_prob Prior inclusion probabilities (p × p matrix) + * @param initial_edge_indicators Initial edge indicators (p × p integer matrix) + * @param no_iter Number of post-warmup iterations + * @param no_warmup Number of warmup iterations + * @param no_chains Number of parallel chains + * @param edge_selection Whether to do edge selection (spike-and-slab) + * @param sampler_type "mh", "nuts", or "hmc" + * @param seed Random seed + * @param no_threads Number of threads for parallel execution + * @param progress_type Progress bar type + * @param target_acceptance Target acceptance rate for NUTS/HMC (default: 0.8) + * @param max_tree_depth Maximum tree depth for NUTS (default: 10) + * @param num_leapfrogs Number of leapfrog steps for HMC (default: 10) + * @param edge_selection_start Iteration to start edge selection (-1 = no_warmup/2) + * + * @return List with per-chain results including samples and diagnostics + */ +// [[Rcpp::export]] +Rcpp::List sample_omrf( + const Rcpp::List& inputFromR, + const arma::mat& prior_inclusion_prob, + const arma::imat& initial_edge_indicators, + const int no_iter, + const int no_warmup, + const int no_chains, + const bool edge_selection, + const std::string& sampler_type, + const int seed, + const int no_threads, + const int progress_type, + const double target_acceptance = 0.8, + const int max_tree_depth = 10, + const int num_leapfrogs = 10, + const int edge_selection_start = -1 +) { + // Create model from R input + OMRFModel model = createOMRFModelFromR( + inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + + // Configure sampler + SamplerConfig config; + config.sampler_type = sampler_type; + config.no_iter = no_iter; + config.no_warmup = no_warmup; + config.edge_selection = edge_selection; + config.edge_selection_start = edge_selection_start; // -1 means use default (no_warmup/2) + config.seed = seed; + config.target_acceptance = target_acceptance; + config.max_tree_depth = max_tree_depth; + config.num_leapfrogs = num_leapfrogs; + + // Set up progress manager + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + + // Run MCMC using unified infrastructure + std::vector results = run_mcmc_sampler( + model, config, no_chains, no_threads, pm); + + // Convert to R list format + Rcpp::List output = convert_results_to_list(results); + + pm.finish(); + + return output; +} From ab32893a6b7a930049a3fbdafb74bc1c7499cd13 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Fri, 6 Feb 2026 09:13:57 +0100 Subject: [PATCH 12/23] algorithms mostly separated from classes --- src/base_model.h | 10 ++++++++++ src/mcmc/mcmc_leapfrog.cpp | 9 +++++++++ src/mcmc/mcmc_runner.h | 8 ++++++++ src/mcmc/nuts_sampler.h | 15 ++++++++++++--- src/mcmc/sampler_config.h | 4 ++-- src/omrf_model.h | 4 ++-- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/src/base_model.h b/src/base_model.h index bee092ad..71a3527f 100644 --- a/src/base_model.h +++ b/src/base_model.h @@ -95,6 +95,16 @@ class BaseModel { // Get active inverse mass (for models with edge selection, may be subset) virtual arma::vec get_active_inv_mass() const { return inv_mass_; } + // Edge selection activation + virtual void set_edge_selection_active(bool active) { + (void)active; // Default: no-op + } + + // Initialize graph structure for edge selection + virtual void initialize_graph() { + // Default: no-op + } + protected: BaseModel() = default; double step_size_ = 0.1; diff --git a/src/mcmc/mcmc_leapfrog.cpp b/src/mcmc/mcmc_leapfrog.cpp index ceadcfe8..0316d3c9 100644 --- a/src/mcmc/mcmc_leapfrog.cpp +++ b/src/mcmc/mcmc_leapfrog.cpp @@ -82,7 +82,16 @@ std::pair leapfrog_memo( arma::vec theta_new = theta; auto grad1 = memo.cached_grad(theta_new); + if (grad1.n_elem != theta_new.n_elem) { + Rcpp::Rcout << "LEAPFROG: grad1.n_elem=" << grad1.n_elem + << " theta.n_elem=" << theta_new.n_elem << std::endl; + } r_half += 0.5 * eps * grad1; + + if (inv_mass_diag.n_elem != r_half.n_elem) { + Rcpp::Rcout << "LEAPFROG: inv_mass_diag.n_elem=" << inv_mass_diag.n_elem + << " r_half.n_elem=" << r_half.n_elem << std::endl; + } theta_new += eps * (inv_mass_diag % r_half); auto grad2 = memo.cached_grad(theta_new); r_half += 0.5 * eps * grad2; diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h index 475aac57..0d80f0cb 100644 --- a/src/mcmc/mcmc_runner.h +++ b/src/mcmc/mcmc_runner.h @@ -88,6 +88,14 @@ inline void run_mcmc_chain( // Finalize warmup (samplers fix their adapted parameters) sampler->finalize_warmup(); + // ========================================================================= + // Activate edge selection mode (if enabled) + // ========================================================================= + if (config.edge_selection && model.has_edge_selection()) { + model.set_edge_selection_active(true); + model.initialize_graph(); // Randomly initialize graph structure + } + // ========================================================================= // Sampling phase // ========================================================================= diff --git a/src/mcmc/nuts_sampler.h b/src/mcmc/nuts_sampler.h index e1327579..3f9e92dc 100644 --- a/src/mcmc/nuts_sampler.h +++ b/src/mcmc/nuts_sampler.h @@ -212,7 +212,7 @@ class NUTSSampler : public BaseSampler { } /** - * Execute one NUTS step using the sampler's learned mass matrix + * Execute one NUTS step using the model's active inverse mass matrix */ SamplerResult do_nuts_step(BaseModel& model) { // Get current state @@ -227,13 +227,22 @@ class NUTSSampler : public BaseSampler { return model.logp_and_gradient(params).second; }; - // Call the NUTS free function with our learned inverse mass + // Get active inverse mass from model (handles dimension changes for edge selection) + arma::vec active_inv_mass = model.get_active_inv_mass(); + + // Debug: check dimension match + if (theta.n_elem != active_inv_mass.n_elem) { + Rcpp::Rcout << "DIMENSION MISMATCH: theta=" << theta.n_elem + << " active_inv_mass=" << active_inv_mass.n_elem << std::endl; + } + + // Call the NUTS free function with the active inverse mass SamplerResult result = nuts_sampler( theta, step_size_, log_post, grad_fn, - inv_mass_, + active_inv_mass, rng, max_tree_depth_ ); diff --git a/src/mcmc/sampler_config.h b/src/mcmc/sampler_config.h index 7aee7469..c985a43d 100644 --- a/src/mcmc/sampler_config.h +++ b/src/mcmc/sampler_config.h @@ -27,7 +27,7 @@ struct SamplerConfig { // Edge selection settings bool edge_selection = false; - int edge_selection_start = -1; // -1 = no_warmup / 2 (default) + int edge_selection_start = -1; // -1 = no_warmup (default, start at sampling) // Random seed int seed = 42; @@ -38,7 +38,7 @@ struct SamplerConfig { // Get actual edge selection start iteration int get_edge_selection_start() const { if (edge_selection_start < 0) { - return no_warmup / 2; // Default: start at half of warmup + return no_warmup; // Default: start at beginning of sampling } return edge_selection_start; } diff --git a/src/omrf_model.h b/src/omrf_model.h index 64f2ebc0..b8bc783e 100644 --- a/src/omrf_model.h +++ b/src/omrf_model.h @@ -136,7 +136,7 @@ class OMRFModel : public BaseModel { /** * Initialize random graph structure (for starting edge selection) */ - void initialize_graph(); + void initialize_graph() override; /** * Impute missing values (if any) @@ -188,7 +188,7 @@ class OMRFModel : public BaseModel { arma::mat& get_proposal_sd_pairwise() { return proposal_sd_pairwise_; } // Control edge selection phase - void set_edge_selection_active(bool active) { edge_selection_active_ = active; } + void set_edge_selection_active(bool active) override { edge_selection_active_ = active; } bool is_edge_selection_active() const { return edge_selection_active_; } private: From 8ab9f584668de38a6a1e4e9902d69765b7273f46 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Fri, 6 Feb 2026 14:08:33 +0100 Subject: [PATCH 13/23] functional and fast but code needs double checking --- src/SkeletonVariables.h | 160 ++++++------- src/mcmc/mcmc_adaptation.h | 10 +- src/mcmc/mcmc_leapfrog.cpp | 8 - src/mcmc/mcmc_utils.cpp | 7 +- src/mcmc/mh_sampler.h | 2 - src/mcmc/nuts_sampler.h | 203 ++++++---------- src/omrf_model.cpp | 268 ++++++++------------- src/omrf_model.h | 5 + src/skeleton_model.cpp | 461 ++++++++++++++++++------------------- 9 files changed, 492 insertions(+), 632 deletions(-) diff --git a/src/SkeletonVariables.h b/src/SkeletonVariables.h index 51fd9c24..0b5d1b43 100644 --- a/src/SkeletonVariables.h +++ b/src/SkeletonVariables.h @@ -1,80 +1,80 @@ -#pragma once - -#include -#include "base_model.h" -#include "adaptiveMetropolis.h" -#include "rng/rng_utils.h" - - -class SkeletonVariables : public BaseModel { -public: - - // constructor from raw data - SkeletonVariables( - const arma::mat& observations, - const arma::mat& inclusion_probability, - const arma::imat& initial_edge_indicators, - const bool edge_selection = true - ) : - {} - - // copy constructor - SkeletonVariables(const SkeletonVariables& other) - : BaseModel(other), - {} - - std::unique_ptr clone() const override { - return std::make_unique(*this); // uses copy constructor - } - - bool has_gradient() const { return false; } - bool has_adaptive_mh() const override { return true; } - - double logp(const arma::vec& parameters) override { - // Implement log probability computation - return 0.0; - } - - void do_one_mh_step() override; - - size_t parameter_dimension() const override { - return dim_; - } - - void set_seed(int seed) override { - rng_ = SafeRNG(seed); - } - - // arma::vec get_vectorized_parameters() override { - // // upper triangle of precision_matrix_ - // size_t e = 0; - // for (size_t j = 0; j < p_; ++j) { - // for (size_t i = 0; i <= j; ++i) { - // vectorized_parameters_(e) = precision_matrix_(i, j); - // ++e; - // } - // } - // return vectorized_parameters_; - // } - - // arma::ivec get_vectorized_indicator_parameters() override { - // // upper triangle of precision_matrix_ - // size_t e = 0; - // for (size_t j = 0; j < p_; ++j) { - // for (size_t i = 0; i <= j; ++i) { - // vectorized_indicator_parameters_(e) = edge_indicators_(i, j); - // ++e; - // } - // } - // return vectorized_indicator_parameters_; - // } - - -private: - // data - size_t n_ = 0; - size_t p_ = 0; - size_t dim_ = 0; - - -}; \ No newline at end of file +// #pragma once + +// #include +// #include "base_model.h" +// #include "adaptiveMetropolis.h" +// #include "rng/rng_utils.h" + + +// class SkeletonVariables : public BaseModel { +// public: + +// // constructor from raw data +// SkeletonVariables( +// const arma::mat& observations, +// const arma::mat& inclusion_probability, +// const arma::imat& initial_edge_indicators, +// const bool edge_selection = true +// ) : +// {} + +// // copy constructor +// SkeletonVariables(const SkeletonVariables& other) +// : BaseModel(other), +// {} + +// std::unique_ptr clone() const override { +// return std::make_unique(*this); // uses copy constructor +// } + +// bool has_gradient() const { return false; } +// bool has_adaptive_mh() const override { return true; } + +// double logp(const arma::vec& parameters) override { +// // Implement log probability computation +// return 0.0; +// } + +// void do_one_mh_step() override; + +// size_t parameter_dimension() const override { +// return dim_; +// } + +// void set_seed(int seed) override { +// rng_ = SafeRNG(seed); +// } + +// // arma::vec get_vectorized_parameters() override { +// // // upper triangle of precision_matrix_ +// // size_t e = 0; +// // for (size_t j = 0; j < p_; ++j) { +// // for (size_t i = 0; i <= j; ++i) { +// // vectorized_parameters_(e) = precision_matrix_(i, j); +// // ++e; +// // } +// // } +// // return vectorized_parameters_; +// // } + +// // arma::ivec get_vectorized_indicator_parameters() override { +// // // upper triangle of precision_matrix_ +// // size_t e = 0; +// // for (size_t j = 0; j < p_; ++j) { +// // for (size_t i = 0; i <= j; ++i) { +// // vectorized_indicator_parameters_(e) = edge_indicators_(i, j); +// // ++e; +// // } +// // } +// // return vectorized_indicator_parameters_; +// // } + + +// private: +// // data +// size_t n_ = 0; +// size_t p_ = 0; +// size_t dim_ = 0; + + +// }; \ No newline at end of file diff --git a/src/mcmc/mcmc_adaptation.h b/src/mcmc/mcmc_adaptation.h index 9cefb21e..b2168ae9 100644 --- a/src/mcmc/mcmc_adaptation.h +++ b/src/mcmc/mcmc_adaptation.h @@ -94,7 +94,7 @@ class DiagMassMatrixAccumulator { // === Dynamic Warmup Schedule with Adaptive Windows === -// +// // For edge_selection = FALSE: // Stage 1 (init), Stage 2 (doubling windows), Stage 3a (terminal) // total_warmup = user-specified warmup @@ -108,7 +108,7 @@ class DiagMassMatrixAccumulator { // Warning types: // 0 = none // 1 = warmup extremely short (< 50) -// 2 = core stages using proportional fallback +// 2 = core stages using proportional fallback // 3 = limited proposal SD tuning (edge_selection && warmup < 300) // 4 = Stage 3b skipped (would have < 20 iterations) // @@ -228,7 +228,7 @@ struct WarmupSchedule { bool in_stage3b(int i) const { return !stage3b_skipped && i >= stage3b_start && i < stage3c_start; } bool in_stage3c(int i) const { return enable_selection && !stage3b_skipped && i >= stage3c_start && i < total_warmup; } bool sampling (int i) const { return i >= total_warmup; } - + bool has_warning() const { return warning_type > 0; } bool warmup_extremely_short() const { return warning_type == 1; } bool using_proportional_fallback() const { return warning_type == 2; } @@ -298,7 +298,7 @@ class HMCAdaptationController { mass_accumulator.update(theta); int w = schedule.current_window(iteration); if (iteration + 1 == schedule.window_ends[w]) { - // inv_mass = variance (not 1/variance) + // STAN convention: inv_mass = variance (not 1/variance!) // Higher variance → higher inverse mass → parameter moves more freely inv_mass_ = mass_accumulator.variance(); mass_accumulator.reset(); @@ -339,7 +339,7 @@ class HMCAdaptationController { void reinit_stepsize(double new_step_size) { step_size_ = new_step_size; step_adapter.restart(new_step_size); - // Set mu to log(10 * epsilon) for dual averaging + // Set mu to log(10 * epsilon) as per STAN's approach step_adapter.mu = MY_LOG(10.0 * new_step_size); mass_matrix_updated_ = false; } diff --git a/src/mcmc/mcmc_leapfrog.cpp b/src/mcmc/mcmc_leapfrog.cpp index 0316d3c9..d5d7c120 100644 --- a/src/mcmc/mcmc_leapfrog.cpp +++ b/src/mcmc/mcmc_leapfrog.cpp @@ -82,16 +82,8 @@ std::pair leapfrog_memo( arma::vec theta_new = theta; auto grad1 = memo.cached_grad(theta_new); - if (grad1.n_elem != theta_new.n_elem) { - Rcpp::Rcout << "LEAPFROG: grad1.n_elem=" << grad1.n_elem - << " theta.n_elem=" << theta_new.n_elem << std::endl; - } r_half += 0.5 * eps * grad1; - if (inv_mass_diag.n_elem != r_half.n_elem) { - Rcpp::Rcout << "LEAPFROG: inv_mass_diag.n_elem=" << inv_mass_diag.n_elem - << " r_half.n_elem=" << r_half.n_elem << std::endl; - } theta_new += eps * (inv_mass_diag % r_half); auto grad2 = memo.cached_grad(theta_new); r_half += 0.5 * eps * grad2; diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 96ca6a09..51daed3c 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -59,7 +59,7 @@ double heuristic_initial_step_size( double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum and evaluate arma::vec r = arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -131,7 +131,7 @@ double heuristic_initial_step_size( ) { double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum from N(0, M) where M = diag(1/inv_mass_diag) arma::vec r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -151,7 +151,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum on each iteration for step size search + // Resample momentum (STAN resamples on each iteration) r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; @@ -159,7 +159,6 @@ double heuristic_initial_step_size( // One leapfrog step from original position with new momentum std::tie(theta_new, r_new) = leapfrog(theta, r, eps, grad, 1, inv_mass_diag); - // Evaluate Hamiltonian logp1 = log_post(theta_new); kin1 = kinetic_energy(r_new, inv_mass_diag); H1 = logp1 - kin1; diff --git a/src/mcmc/mh_sampler.h b/src/mcmc/mh_sampler.h index 8f0ed8b6..fdbd7ab1 100644 --- a/src/mcmc/mh_sampler.h +++ b/src/mcmc/mh_sampler.h @@ -37,7 +37,6 @@ class MHSampler : public BaseSampler { model.do_one_mh_step(); SamplerResult result; - result.state = model.get_full_vectorized_parameters(); result.accept_prob = 1.0; // Not tracked for component-wise MH return result; } @@ -58,7 +57,6 @@ class MHSampler : public BaseSampler { model.do_one_mh_step(); SamplerResult result; - result.state = model.get_full_vectorized_parameters(); result.accept_prob = 1.0; return result; } diff --git a/src/mcmc/nuts_sampler.h b/src/mcmc/nuts_sampler.h index 3f9e92dc..9fa694fb 100644 --- a/src/mcmc/nuts_sampler.h +++ b/src/mcmc/nuts_sampler.h @@ -8,48 +8,12 @@ #include "mcmc_utils.h" #include "mcmc_nuts.h" #include "mcmc_adaptation.h" +#include "mcmc_leapfrog.h" #include "sampler_config.h" #include "../base_model.h" -/** - * NUTSSampler - No-U-Turn Sampler implementation - * - * Provides a clean interface to the NUTS algorithm for any BaseModel - * with gradient support. Handles: - * - Step size adaptation via dual averaging during warmup - * - Mass matrix adaptation using Welford's algorithm during warmup - * - Trajectory simulation with the no-U-turn criterion - * - Diagnostics collection (tree depth, divergence, energy) - * - * The sampler fully owns all sampling logic. The model only provides: - * - logp_and_gradient(theta): Compute log posterior and gradient - * - get_vectorized_parameters(): Get current state as vector - * - set_vectorized_parameters(theta): Update model state from vector - * - get_active_inv_mass(): Get inverse mass diagonal (used for initialization) - * - get_rng(): Get random number generator - * - * Warmup Schedule (Stan-style): - * - Stage 1 (7.5%): Initial adaptation, step size only - * - Stage 2 (82.5%): Mass matrix learning in doubling windows - * - Stage 3 (10%): Final step size tuning with fixed mass - * - * Usage: - * NUTSSampler nuts(config, n_warmup); - * for (iter in warmup) { - * auto result = nuts.warmup_step(model); - * } - * nuts.finalize_warmup(); - * for (iter in sampling) { - * auto result = nuts.sample_step(model); - * } - */ class NUTSSampler : public BaseSampler { public: - /** - * Construct NUTS sampler with configuration - * @param config Sampler configuration (step size, target acceptance, etc.) - * @param n_warmup Number of warmup iterations (for scheduling mass matrix adaptation) - */ explicit NUTSSampler(const SamplerConfig& config, int n_warmup = 1000) : step_size_(config.initial_step_size), target_acceptance_(config.target_acceptance), @@ -63,13 +27,7 @@ class NUTSSampler : public BaseSampler { build_warmup_schedule(n_warmup); } - /** - * Perform one NUTS step during warmup (with step size and mass matrix adaptation) - * @param model The model to sample from - * @return SamplerResult with state and diagnostics - */ SamplerResult warmup_step(BaseModel& model) override { - // Initialize on first warmup iteration if (!initialized_) { initialize(model); initialized_ = true; @@ -83,16 +41,28 @@ class NUTSSampler : public BaseSampler { // During Stage 2, accumulate samples for mass matrix estimation if (in_stage2()) { - mass_accumulator_->update(result.state); + arma::vec full_params = model.get_full_vectorized_parameters(); + mass_accumulator_->update(full_params); - // Check if we're at the end of a window if (at_window_end()) { - // Update mass matrix from accumulated samples - inv_mass_ = mass_accumulator_->inverse_mass(); + // Stan convention: inv_mass = variance (high-variance params move more) + inv_mass_ = mass_accumulator_->variance(); mass_accumulator_->reset(); - // Restart step size adaptation with new mass matrix - step_adapter_.restart(step_size_); + // Push adapted mass matrix to model + model.set_inv_mass(inv_mass_); + + // Re-run heuristic step size with new mass matrix + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + auto [log_post, grad_fn] = make_model_functions(model); + arma::vec active_inv_mass = model.get_active_inv_mass(); + + double new_eps = heuristic_initial_step_size( + theta, log_post, grad_fn, active_inv_mass, rng, + 0.625, step_size_); + step_size_ = new_eps; + step_adapter_.restart(new_eps); } } @@ -100,59 +70,27 @@ class NUTSSampler : public BaseSampler { return result; } - /** - * Finalize warmup phase (fix step size to averaged value) - */ void finalize_warmup() override { step_size_ = step_adapter_.averaged(); } - /** - * Perform one NUTS step during sampling (fixed step size and mass matrix) - * @param model The model to sample from - * @return SamplerResult with state and diagnostics - */ SamplerResult sample_step(BaseModel& model) override { return do_nuts_step(model); } - /** - * NUTS produces tree depth, divergence, and energy diagnostics - */ bool has_nuts_diagnostics() const override { return true; } - - /** - * Get the current (or final) step size - */ double get_step_size() const { return step_size_; } - - /** - * Get the averaged step size (for reporting after warmup) - */ - double get_averaged_step_size() const { - return step_adapter_.averaged(); - } - - /** - * Get the current inverse mass matrix diagonal - */ + double get_averaged_step_size() const { return step_adapter_.averaged(); } const arma::vec& get_inv_mass() const { return inv_mass_; } private: - /** - * Build Stan-style warmup schedule with doubling windows - */ void build_warmup_schedule(int n_warmup) { - // Stage 1: 7.5% of warmup stage1_end_ = static_cast(0.075 * n_warmup); - - // Stage 3 starts at 90% of warmup stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); - // Stage 2: build doubling windows between stage1_end and stage3_start window_ends_.clear(); int cur = stage1_end_; - int wsize = 25; // Initial window size + int wsize = 25; while (cur < stage3_start_) { int win = std::min(wsize, stage3_start_ - cur); @@ -162,94 +100,91 @@ class NUTSSampler : public BaseSampler { } } - /** - * Check if we're in Stage 2 (mass matrix learning phase) - */ bool in_stage2() const { return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; } - /** - * Check if we're at the end of a Stage 2 window - */ bool at_window_end() const { for (int end : window_ends_) { - if (warmup_iteration_ + 1 == end) { - return true; - } + if (warmup_iteration_ + 1 == end) return true; } return false; } /** - * Initialize step size and mass matrix on first iteration + * Create log_post and grad lambdas that share a single logp_and_gradient call. + * Avoids doubling computation when the memoizer requests both at the same point. */ + static std::pair< + std::function, + std::function + > make_model_functions(BaseModel& model) { + struct JointCache { + arma::vec theta; + double logp; + arma::vec grad; + bool valid = false; + }; + auto cache = std::make_shared(); + + auto ensure = [&model, cache](const arma::vec& params) { + if (!cache->valid || + params.n_elem != cache->theta.n_elem || + !arma::approx_equal(params, cache->theta, "absdiff", 1e-14)) { + auto [lp, gr] = model.logp_and_gradient(params); + cache->theta = params; + cache->logp = lp; + cache->grad = std::move(gr); + cache->valid = true; + } + }; + + auto log_post = [ensure, cache](const arma::vec& params) -> double { + ensure(params); + return cache->logp; + }; + auto grad_fn = [ensure, cache](const arma::vec& params) -> arma::vec { + ensure(params); + return cache->grad; + }; + + return {log_post, grad_fn}; + } + void initialize(BaseModel& model) { arma::vec theta = model.get_vectorized_parameters(); SafeRNG& rng = model.get_rng(); - // Initialize inverse mass to identity (or from model) - inv_mass_ = model.get_active_inv_mass(); + // Initialize inverse mass from model (defaults to ones) + inv_mass_ = arma::ones(model.full_parameter_dimension()); + model.set_inv_mass(inv_mass_); - // Initialize mass matrix accumulator + // Mass matrix accumulator uses full dimension mass_accumulator_ = std::make_unique( - static_cast(theta.n_elem)); + static_cast(model.full_parameter_dimension())); - // Create log posterior and gradient functions - auto log_post = [&model](const arma::vec& params) -> double { - return model.logp_and_gradient(params).first; - }; - auto grad_fn = [&model](const arma::vec& params) -> arma::vec { - return model.logp_and_gradient(params).second; - }; + auto [log_post, grad_fn] = make_model_functions(model); - // Use heuristic to find good initial step size step_size_ = heuristic_initial_step_size( theta, log_post, grad_fn, rng, target_acceptance_); - // Restart dual averaging with the heuristic step size step_adapter_.restart(step_size_); } - /** - * Execute one NUTS step using the model's active inverse mass matrix - */ SamplerResult do_nuts_step(BaseModel& model) { - // Get current state arma::vec theta = model.get_vectorized_parameters(); SafeRNG& rng = model.get_rng(); - // Create log posterior and gradient functions that call the model - auto log_post = [&model](const arma::vec& params) -> double { - return model.logp_and_gradient(params).first; - }; - auto grad_fn = [&model](const arma::vec& params) -> arma::vec { - return model.logp_and_gradient(params).second; - }; + auto [log_post, grad_fn] = make_model_functions(model); - // Get active inverse mass from model (handles dimension changes for edge selection) arma::vec active_inv_mass = model.get_active_inv_mass(); - // Debug: check dimension match - if (theta.n_elem != active_inv_mass.n_elem) { - Rcpp::Rcout << "DIMENSION MISMATCH: theta=" << theta.n_elem - << " active_inv_mass=" << active_inv_mass.n_elem << std::endl; - } - - // Call the NUTS free function with the active inverse mass SamplerResult result = nuts_sampler( - theta, - step_size_, - log_post, - grad_fn, - active_inv_mass, - rng, - max_tree_depth_ + theta, step_size_, log_post, grad_fn, + active_inv_mass, rng, max_tree_depth_ ); - // Update model state with new parameters model.set_vectorized_parameters(result.state); - return result; } diff --git a/src/omrf_model.cpp b/src/omrf_model.cpp index 7367dba9..f189e216 100644 --- a/src/omrf_model.cpp +++ b/src/omrf_model.cpp @@ -191,11 +191,16 @@ void OMRFModel::build_interaction_index() { void OMRFModel::update_residual_matrix() { - // Use pre-computed transformed observations (computed once in constructor) residual_matrix_ = observations_double_ * pairwise_effects_; } +void OMRFModel::update_residual_columns(int var1, int var2, double delta) { + residual_matrix_.col(var1) += delta * observations_double_.col(var2); + residual_matrix_.col(var2) += delta * observations_double_.col(var1); +} + + void OMRFModel::set_pairwise_effects(const arma::mat& pairwise_effects) { pairwise_effects_ = pairwise_effects; update_residual_matrix(); @@ -597,57 +602,34 @@ double OMRFModel::log_pseudoposterior_internal() const { double OMRFModel::log_pseudoposterior_main_component(int variable, int category, int parameter) const { double log_post = 0.0; - // Lambda for Beta-prime prior on main effects (matches original implementation) - // log p(theta) = alpha * theta - (alpha + beta) * log(1 + exp(theta)) auto log_beta_prior = [this](double x) { return x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); }; int num_cats = num_categories_(variable); - arma::vec bound = num_cats * residual_matrix_.col(variable); if (is_ordinal_variable_(variable)) { - // Ordinal variable: use category log_post += log_beta_prior(main_effects_(variable, category)); log_post += main_effects_(variable, category) * counts_per_category_(category + 1, variable); - // Log-denominator contribution - for (size_t i = 0; i < n_; ++i) { - double max_val = 0.0; - for (int c = 0; c < num_cats; ++c) { - double val = main_effects_(variable, c) + (c + 1) * residual_matrix_(i, variable); - if (val > max_val) max_val = val; - } + arma::vec residual_score = residual_matrix_.col(variable); + arma::vec main_param = main_effects_.row(variable).cols(0, num_cats - 1).t(); + arma::vec bound = num_cats * residual_score; - double denom = std::exp(-max_val); - for (int c = 0; c < num_cats; ++c) { - double val = main_effects_(variable, c) + (c + 1) * residual_matrix_(i, variable); - denom += std::exp(val - max_val); - } - log_post -= (max_val + std::log(denom)); - } + arma::vec denom = compute_denom_ordinal(residual_score, main_param, bound); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); } else { - // Blume-Capel: use parameter (0 = linear, 1 = quadratic) log_post += log_beta_prior(main_effects_(variable, parameter)); log_post += main_effects_(variable, parameter) * blume_capel_stats_(parameter, variable); - int baseline = baseline_category_(variable); - for (size_t i = 0; i < n_; ++i) { - double max_val = -std::numeric_limits::infinity(); - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - double val = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * residual_matrix_(i, variable); - if (val > max_val) max_val = val; - } + arma::vec residual_score = residual_matrix_.col(variable); + arma::vec bound(n_); - double denom = 0.0; - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - double val = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * residual_matrix_(i, variable); - denom += std::exp(val - max_val); - } - log_post -= (max_val + std::log(denom)); - } + arma::vec denom = compute_denom_blume_capel( + residual_score, main_effects_(variable, 0), main_effects_(variable, 1), + baseline_category_(variable), num_cats, bound + ); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); } return log_post; @@ -657,47 +639,25 @@ double OMRFModel::log_pseudoposterior_main_component(int variable, int category, double OMRFModel::log_pseudoposterior_pairwise_component(int var1, int var2) const { double log_post = 2.0 * pairwise_effects_(var1, var2) * pairwise_stats_(var1, var2); - // Contribution from both variables' pseudo-likelihoods for (int var : {var1, var2}) { int num_cats = num_categories_(var); - int other_var = (var == var1) ? var2 : var1; - - for (size_t i = 0; i < n_; ++i) { - double max_val = -std::numeric_limits::infinity(); - - if (is_ordinal_variable_(var)) { - max_val = 0.0; - for (int c = 0; c < num_cats; ++c) { - double val = main_effects_(var, c) + (c + 1) * residual_matrix_(i, var); - if (val > max_val) max_val = val; - } - - double denom = std::exp(-max_val); - for (int c = 0; c < num_cats; ++c) { - double val = main_effects_(var, c) + (c + 1) * residual_matrix_(i, var); - denom += std::exp(val - max_val); - } - log_post -= (max_val + std::log(denom)); - } else { - int baseline = baseline_category_(var); - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - double val = main_effects_(var, 0) * s + main_effects_(var, 1) * s * s + s * residual_matrix_(i, var); - if (val > max_val) max_val = val; - } + arma::vec residual_score = residual_matrix_.col(var); - double denom = 0.0; - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - double val = main_effects_(var, 0) * s + main_effects_(var, 1) * s * s + s * residual_matrix_(i, var); - denom += std::exp(val - max_val); - } - log_post -= (max_val + std::log(denom)); - } + if (is_ordinal_variable_(var)) { + arma::vec main_param = main_effects_.row(var).cols(0, num_cats - 1).t(); + arma::vec bound = num_cats * residual_score; + arma::vec denom = compute_denom_ordinal(residual_score, main_param, bound); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); + } else { + arma::vec bound(n_); + arma::vec denom = compute_denom_blume_capel( + residual_score, main_effects_(var, 0), main_effects_(var, 1), + baseline_category_(var), num_cats, bound + ); + log_post -= arma::accu(bound + ARMA_MY_LOG(denom)); } } - // Cauchy prior if edge is included if (edge_indicators_(var1, var2) == 1) { log_post += R::dcauchy(pairwise_effects_(var1, var2), 0.0, pairwise_scale_, true); } @@ -712,62 +672,43 @@ double OMRFModel::compute_log_likelihood_ratio_for_variable( double proposed_state, double current_state ) const { - double log_ratio = 0.0; int num_cats = num_categories_(variable); - for (size_t i = 0; i < n_; ++i) { - double rest_minus = residual_matrix_(i, variable) - current_state * interacting_score(i); - double rest_prop = rest_minus + proposed_state * interacting_score(i); - double rest_curr = rest_minus + current_state * interacting_score(i); - - double max_prop = -std::numeric_limits::infinity(); - double max_curr = -std::numeric_limits::infinity(); + // Residual without the current interaction contribution + arma::vec rest_base = residual_matrix_.col(variable) - current_state * interacting_score; - if (is_ordinal_variable_(variable)) { - max_prop = 0.0; - max_curr = 0.0; - for (int c = 0; c < num_cats; ++c) { - double val_prop = main_effects_(variable, c) + (c + 1) * rest_prop; - double val_curr = main_effects_(variable, c) + (c + 1) * rest_curr; - if (val_prop > max_prop) max_prop = val_prop; - if (val_curr > max_curr) max_curr = val_curr; - } - - double denom_prop = std::exp(-max_prop); - double denom_curr = std::exp(-max_curr); - for (int c = 0; c < num_cats; ++c) { - double val_prop = main_effects_(variable, c) + (c + 1) * rest_prop; - double val_curr = main_effects_(variable, c) + (c + 1) * rest_curr; - denom_prop += std::exp(val_prop - max_prop); - denom_curr += std::exp(val_curr - max_curr); - } + if (is_ordinal_variable_(variable)) { + arma::vec main_param = main_effects_.row(variable).cols(0, num_cats - 1).t(); - log_ratio += (max_curr + std::log(denom_curr)) - (max_prop + std::log(denom_prop)); - } else { - int baseline = baseline_category_(variable); - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - double val_prop = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_prop; - double val_curr = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_curr; - if (val_prop > max_prop) max_prop = val_prop; - if (val_curr > max_curr) max_curr = val_curr; - } + arma::vec rest_curr = rest_base + current_state * interacting_score; + arma::vec bound_curr = num_cats * rest_curr; + arma::vec denom_curr = compute_denom_ordinal(rest_curr, main_param, bound_curr); - double denom_prop = 0.0; - double denom_curr = 0.0; - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - double val_prop = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_prop; - double val_curr = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * rest_curr; - denom_prop += std::exp(val_prop - max_prop); - denom_curr += std::exp(val_curr - max_curr); - } + arma::vec rest_prop = rest_base + proposed_state * interacting_score; + arma::vec bound_prop = num_cats * rest_prop; + arma::vec denom_prop = compute_denom_ordinal(rest_prop, main_param, bound_prop); - log_ratio += (max_curr + std::log(denom_curr)) - (max_prop + std::log(denom_prop)); - } + return arma::accu(bound_curr + ARMA_MY_LOG(denom_curr)) + - arma::accu(bound_prop + ARMA_MY_LOG(denom_prop)); + } else { + arma::vec rest_curr = rest_base + current_state * interacting_score; + arma::vec bound_curr(n_); + arma::vec denom_curr = compute_denom_blume_capel( + rest_curr, main_effects_(variable, 0), main_effects_(variable, 1), + baseline_category_(variable), num_cats, bound_curr + ); + double log_ratio = arma::accu(ARMA_MY_LOG(denom_curr) + bound_curr); + + arma::vec rest_prop = rest_base + proposed_state * interacting_score; + arma::vec bound_prop(n_); + arma::vec denom_prop = compute_denom_blume_capel( + rest_prop, main_effects_(variable, 0), main_effects_(variable, 1), + baseline_category_(variable), num_cats, bound_prop + ); + log_ratio -= arma::accu(ARMA_MY_LOG(denom_prop) + bound_prop); + + return log_ratio; } - - return log_ratio; } @@ -780,22 +721,15 @@ double OMRFModel::log_pseudolikelihood_ratio_interaction( double delta = proposed_state - current_state; double log_ratio = 2.0 * delta * pairwise_stats_(variable1, variable2); - // For Blume-Capel variables, transform observations by subtracting baseline - auto get_transformed_score = [this](int var) -> arma::vec { - arma::vec score = arma::conv_to::from(observations_.col(var)); - if (!is_ordinal_variable_(var)) { - score -= static_cast(baseline_category_(var)); - } - return score; - }; - - // Contribution from variable1 - arma::vec interacting_score = get_transformed_score(variable2); - log_ratio += compute_log_likelihood_ratio_for_variable(variable1, interacting_score, proposed_state, current_state); + // Contribution from variable1 (interacting with variable2's observations) + log_ratio += compute_log_likelihood_ratio_for_variable( + variable1, observations_double_.col(variable2), + proposed_state, current_state); - // Contribution from variable2 - interacting_score = get_transformed_score(variable1); - log_ratio += compute_log_likelihood_ratio_for_variable(variable2, interacting_score, proposed_state, current_state); + // Contribution from variable2 (interacting with variable1's observations) + log_ratio += compute_log_likelihood_ratio_for_variable( + variable2, observations_double_.col(variable1), + proposed_state, current_state); return log_ratio; } @@ -1036,55 +970,57 @@ std::pair OMRFModel::logp_and_gradient(const arma::vec& param // ============================================================================= void OMRFModel::update_main_effect_parameter(int variable, int category, int parameter) { - double proposal_sd; double& current = is_ordinal_variable_(variable) ? main_effects_(variable, category) : main_effects_(variable, parameter); - proposal_sd = is_ordinal_variable_(variable) + double proposal_sd = is_ordinal_variable_(variable) ? proposal_sd_main_(variable, category) : proposal_sd_main_(variable, parameter); int cat_for_log = is_ordinal_variable_(variable) ? category : -1; int par_for_log = is_ordinal_variable_(variable) ? -1 : parameter; - auto log_post = [&](double theta) { - double old_val = current; - current = theta; - double lp = log_pseudoposterior_main_component(variable, cat_for_log, par_for_log); - current = old_val; - return lp; - }; + double current_value = current; + double proposed_value = rnorm(rng_, current_value, proposal_sd); - SamplerResult result = rwm_sampler(current, proposal_sd, log_post, rng_); - current = result.state[0]; - invalidate_gradient_cache(); + // Evaluate log-posterior at proposed + current = proposed_value; + double lp_proposed = log_pseudoposterior_main_component(variable, cat_for_log, par_for_log); + + // Evaluate log-posterior at current + current = current_value; + double lp_current = log_pseudoposterior_main_component(variable, cat_for_log, par_for_log); + + double log_accept = lp_proposed - lp_current; + if (MY_LOG(runif(rng_)) < log_accept) { + current = proposed_value; + } } void OMRFModel::update_pairwise_effect(int var1, int var2) { if (edge_indicators_(var1, var2) == 0) return; - double& value = pairwise_effects_(var1, var2); + double current_value = pairwise_effects_(var1, var2); double proposal_sd = proposal_sd_pairwise_(var1, var2); - double current = value; + double proposed_value = rnorm(rng_, current_value, proposal_sd); - auto log_post = [&](double theta) { - pairwise_effects_(var1, var2) = theta; - pairwise_effects_(var2, var1) = theta; - update_residual_matrix(); - return log_pseudoposterior_pairwise_component(var1, var2); - }; + double log_accept = log_pseudolikelihood_ratio_interaction( + var1, var2, proposed_value, current_value); - SamplerResult result = rwm_sampler(current, proposal_sd, log_post, rng_); + // Cauchy prior ratio + log_accept += R::dcauchy(proposed_value, 0.0, pairwise_scale_, true) + - R::dcauchy(current_value, 0.0, pairwise_scale_, true); - value = result.state[0]; - pairwise_effects_(var2, var1) = value; + double accept_prob = std::min(1.0, MY_EXP(log_accept)); - if (current != value) { - update_residual_matrix(); + if (runif(rng_) < accept_prob) { + double delta = proposed_value - current_value; + pairwise_effects_(var1, var2) = proposed_value; + pairwise_effects_(var2, var1) = proposed_value; + update_residual_columns(var1, var2, delta); } - invalidate_gradient_cache(); } @@ -1114,11 +1050,11 @@ void OMRFModel::update_edge_indicator(int var1, int var2) { edge_indicators_(var1, var2) = updated; edge_indicators_(var2, var1) = updated; + double delta = proposed_state - current_state; pairwise_effects_(var1, var2) = proposed_state; pairwise_effects_(var2, var1) = proposed_state; - update_residual_matrix(); - invalidate_gradient_cache(); + update_residual_columns(var1, var2, delta); } } @@ -1149,11 +1085,7 @@ void OMRFModel::do_one_mh_step() { } } - // Update edge indicators if in selection phase - if (edge_selection_active_) { - update_edge_indicators(); - } - + invalidate_gradient_cache(); proposal_.increment_iteration(); } diff --git a/src/omrf_model.h b/src/omrf_model.h index b8bc783e..03a1dc29 100644 --- a/src/omrf_model.h +++ b/src/omrf_model.h @@ -278,6 +278,11 @@ class OMRFModel : public BaseModel { */ void update_residual_matrix(); + /** + * Incrementally update two residual columns after a single pairwise effect change + */ + void update_residual_columns(int var1, int var2, double delta); + /** * Invalidate gradient cache (call after parameter changes) */ diff --git a/src/skeleton_model.cpp b/src/skeleton_model.cpp index 78fe901a..511d3d0e 100644 --- a/src/skeleton_model.cpp +++ b/src/skeleton_model.cpp @@ -1,231 +1,230 @@ -/** - * ================================================== - * Header dependencies for SkeletonVariables - * ================================================== - * - * This file defines a template ("skeleton") for variable-type classes - * in the bgms codebase. The included headers constitute the minimal - * set of dependencies required by any variable model, independent of - * its specific statistical formulation. - * - * Each include serves a specific purpose: - * - : ownership and cloning via std::unique_ptr - * - base_model.h : abstract interface for variable models - * - adaptiveMetropolis.h : Metropolis–Hastings proposal mechanism - * - rng_utils.h : reproducible random number generation - */ -#pragma once - -#include - -#include "base_model.h" -#include "mcmc/adaptiveMetropolis.h" -#include "rng/rng_utils.h" - - -/** - * ================================================== - * SkeletonVariables class - * ================================================== - * - * A class in C++ represents a concrete type that bundles together: - * - internal state, including observed data, model parameters, - * and auxiliary objects that maintain sampling state, - * - and functions (methods) that act on this state to evaluate - * the log posterior, its gradient, and to perform sampling updates. - * - * In the bgms codebase, each statistical variable model is implemented - * as a C++ class that stores the model state and provides the methods - * required for inference. - * - * SkeletonVariables defines a template for such implementation classes. - * It specifies the structure and interface that concrete implementations - * of variable models must follow, without imposing a particular - * statistical formulation. - * - * SkeletonVariables inherits from BaseModel, which defines the common - * interface for all variable-model implementations in bgms: - * - * class SkeletonVariables : public BaseModel - * - * Inheriting from BaseModel means that SkeletonVariables must provide - * a fixed set of functions (such as log_posterior and do_one_mh_step) - * that the rest of the codebase relies on. As a result, code elsewhere - * in bgms can interact with SkeletonVariables through the BaseModel - * interface, without needing to know which specific variable model - * implementation is being used. - */ -class SkeletonVariables : public BaseModel { - - /* - * The 'public:' label below marks the beginning of the part of the class - * that is accessible from outside the class. - * - * Functions and constructors declared under 'public' are intended to be - * called by other components of the bgms codebase, such as samplers, - * model-selection routines, and result containers. - * - * Together, these public members define how a variable-model - * implementation can be created, queried, and updated by external code. - */ - public: - - /* - * Constructors are responsible for establishing the complete internal - * state of the object. - * - * A SkeletonVariables object represents a fully specified variable-model - * implementation at a given point in an inference procedure. Therefore, - * all information required to evaluate the log posterior and to perform - * sampling updates must be stored within the object itself. - * - * After construction, the object is expected to be immediately usable: - * no additional initialization steps are required before it can be - * queried, updated, or copied. - * - * The constructor below uses a constructor initializer list to define - * how base classes and data members are constructed. The initializer - * list is evaluated before the constructor body runs and is used to: - * - construct the BaseModel subobject, - * - initialize data members directly from constructor arguments. - * - * This ensures that all base-class and member invariants are established - * before any additional derived quantities are computed in the - * constructor body. - */ - SkeletonVariables( - const arma::mat& observations, - const arma::mat& inclusion_probability, - const arma::imat& initial_edge_indicators, - const bool edge_selection = true - ) - : BaseModel(edge_selection), - observations_(observations), - inclusion_probability_(inclusion_probability), - edge_indicators_(initial_edge_indicators) - { - /* - * The constructor body initializes derived state that depends on - * already-constructed members, such as dimensions, parameter vectors, - * and sampling-related objects. - */ - - n_ = observations_.n_rows; - p_ = observations_.n_cols; - - // Dimension of the parameter vector (model-specific). - // For the skeleton, we assume one parameter per variable. - dim_ = p_; - - // Initialize parameter vector - parameters_.zeros(dim_); - - // Initialize adaptive Metropolis–Hastings sampler - adaptive_mh_ = AdaptiveMetropolis(dim_); - } - - /* - * Copy constructor. - * - * The copy constructor creates a new SkeletonVariables object that is an - * exact copy of an existing one. This includes not only the observed data - * and model parameters, but also all internal state required for inference, - * such as sampler state and random number generator state. - * - * Copying is required in bgms because variable-model objects are duplicated - * during inference, for example when running multiple - * chains, or storing and restoring model states. - * - * The copy is performed using a constructor initializer list to ensure - * that the BaseModel subobject and all data members are constructed - * directly from their counterparts in the source object. - * - * After construction, the new object is independent from the original - * but represents the same model state. - */ - SkeletonVariables(const SkeletonVariables& other) - : BaseModel(other), - observations_(other.observations_), - inclusion_probability_(other.inclusion_probability_), - edge_indicators_(other.edge_indicators_), - n_(other.n_), - p_(other.p_), - dim_(other.dim_), - parameters_(other.parameters_), - adaptive_mh_(other.adaptive_mh_), - rng_(other.rng_) - {} - - // -------------------------------------------------- - // Polymorphic copy - // -------------------------------------------------- - std::unique_ptr clone() const override { - return std::make_unique(*this); - } - - // -------------------------------------------------- - // Capabilities - // -------------------------------------------------- - bool has_log_posterior() const override { return true; } - bool has_gradient() const override { return true; } - bool has_adaptive_mh() const override { return true; } - - // -------------------------------------------------- - // Log posterior (LIKELIHOOD + PRIOR) - // -------------------------------------------------- - double log_posterior(const arma::vec& parameters) override { - // Skeleton: flat log-density - // Real models will: - // - unpack parameters - // - compute likelihood - // - add priors - return 0.0; - } - - // -------------------------------------------------- - // Gradient of LOG (LIKELIHOOD * PRIOR) - // -------------------------------------------------- - void gradient(const arma::vec& parameters) override { - // Skeleton: - } - - // -------------------------------------------------- - // One Metropolis–Hastings step - // -------------------------------------------------- - void do_one_mh_step(arma::vec& parameters) override { - // Skeleton: - } - - // -------------------------------------------------- - // Required interface - // -------------------------------------------------- - size_t parameter_dimension() const override { - return dim_; - } - - void set_seed(int seed) override { - rng_ = SafeRNG(seed); - } - -protected: - // -------------------------------------------------- - // Data - // -------------------------------------------------- - arma::mat observations_; - arma::mat inclusion_probability_; - arma::imat edge_indicators_; - - // -------------------------------------------------- - // Dimensions - // -------------------------------------------------- - size_t n_ = 0; // number of observations - size_t p_ = 0; // number of variables - size_t dim_ = 0; // dimension of parameter vector - - // -------------------------------------------------- - // Parameters & MCMC machinery - // -------------------------------------------------- - arma::vec parameters_; - AdaptiveMetropolis adaptive_mh_; - SafeRNG rng_; -}; +// /** +// * ================================================== +// * Header dependencies for SkeletonVariables +// * ================================================== +// * +// * This file defines a template ("skeleton") for variable-type classes +// * in the bgms codebase. The included headers constitute the minimal +// * set of dependencies required by any variable model, independent of +// * its specific statistical formulation. +// * +// * Each include serves a specific purpose: +// * - : ownership and cloning via std::unique_ptr +// * - base_model.h : abstract interface for variable models +// * - adaptiveMetropolis.h : Metropolis–Hastings proposal mechanism +// * - rng_utils.h : reproducible random number generation +// */ + +// #include + +// #include "base_model.h" +// #include "adaptiveMetropolis.h" +// #include "rng/rng_utils.h" + + +// /** +// * ================================================== +// * SkeletonVariables class +// * ================================================== +// * +// * A class in C++ represents a concrete type that bundles together: +// * - internal state, including observed data, model parameters, +// * and auxiliary objects that maintain sampling state, +// * - and functions (methods) that act on this state to evaluate +// * the log posterior, its gradient, and to perform sampling updates. +// * +// * In the bgms codebase, each statistical variable model is implemented +// * as a C++ class that stores the model state and provides the methods +// * required for inference. +// * +// * SkeletonVariables defines a template for such implementation classes. +// * It specifies the structure and interface that concrete implementations +// * of variable models must follow, without imposing a particular +// * statistical formulation. +// * +// * SkeletonVariables inherits from BaseModel, which defines the common +// * interface for all variable-model implementations in bgms: +// * +// * class SkeletonVariables : public BaseModel +// * +// * Inheriting from BaseModel means that SkeletonVariables must provide +// * a fixed set of functions (such as log_posterior and do_one_mh_step) +// * that the rest of the codebase relies on. As a result, code elsewhere +// * in bgms can interact with SkeletonVariables through the BaseModel +// * interface, without needing to know which specific variable model +// * implementation is being used. +// */ +// class SkeletonVariables : public BaseModel { + +// /* +// * The 'public:' label below marks the beginning of the part of the class +// * that is accessible from outside the class. +// * +// * Functions and constructors declared under 'public' are intended to be +// * called by other components of the bgms codebase, such as samplers, +// * model-selection routines, and result containers. +// * +// * Together, these public members define how a variable-model +// * implementation can be created, queried, and updated by external code. +// */ +// public: + +// /* +// * Constructors are responsible for establishing the complete internal +// * state of the object. +// * +// * A SkeletonVariables object represents a fully specified variable-model +// * implementation at a given point in an inference procedure. Therefore, +// * all information required to evaluate the log posterior and to perform +// * sampling updates must be stored within the object itself. +// * +// * After construction, the object is expected to be immediately usable: +// * no additional initialization steps are required before it can be +// * queried, updated, or copied. +// * +// * The constructor below uses a constructor initializer list to define +// * how base classes and data members are constructed. The initializer +// * list is evaluated before the constructor body runs and is used to: +// * - construct the BaseModel subobject, +// * - initialize data members directly from constructor arguments. +// * +// * This ensures that all base-class and member invariants are established +// * before any additional derived quantities are computed in the +// * constructor body. +// */ +// SkeletonVariables( +// const arma::mat& observations, +// const arma::mat& inclusion_probability, +// const arma::imat& initial_edge_indicators, +// const bool edge_selection = true +// ) +// : BaseModel(edge_selection), +// observations_(observations), +// inclusion_probability_(inclusion_probability), +// edge_indicators_(initial_edge_indicators) +// { +// /* +// * The constructor body initializes derived state that depends on +// * already-constructed members, such as dimensions, parameter vectors, +// * and sampling-related objects. +// */ + +// n_ = observations_.n_rows; +// p_ = observations_.n_cols; + +// // Dimension of the parameter vector (model-specific). +// // For the skeleton, we assume one parameter per variable. +// dim_ = p_; + +// // Initialize parameter vector +// parameters_.zeros(dim_); + +// // Initialize adaptive Metropolis–Hastings sampler +// adaptive_mh_ = AdaptiveMetropolis(dim_); +// } + +// /* +// * Copy constructor. +// * +// * The copy constructor creates a new SkeletonVariables object that is an +// * exact copy of an existing one. This includes not only the observed data +// * and model parameters, but also all internal state required for inference, +// * such as sampler state and random number generator state. +// * +// * Copying is required in bgms because variable-model objects are duplicated +// * during inference, for example when running multiple +// * chains, or storing and restoring model states. +// * +// * The copy is performed using a constructor initializer list to ensure +// * that the BaseModel subobject and all data members are constructed +// * directly from their counterparts in the source object. +// * +// * After construction, the new object is independent from the original +// * but represents the same model state. +// */ +// SkeletonVariables(const SkeletonVariables& other) +// : BaseModel(other), +// observations_(other.observations_), +// inclusion_probability_(other.inclusion_probability_), +// edge_indicators_(other.edge_indicators_), +// n_(other.n_), +// p_(other.p_), +// dim_(other.dim_), +// parameters_(other.parameters_), +// adaptive_mh_(other.adaptive_mh_), +// rng_(other.rng_) +// {} + +// // -------------------------------------------------- +// // Polymorphic copy +// // -------------------------------------------------- +// std::unique_ptr clone() const override { +// return std::make_unique(*this); +// } + +// // -------------------------------------------------- +// // Capabilities +// // -------------------------------------------------- +// bool has_log_posterior() const override { return true; } +// bool has_gradient() const override { return true; } +// bool has_adaptive_mh() const override { return true; } + +// // -------------------------------------------------- +// // Log posterior (LIKELIHOOD + PRIOR) +// // -------------------------------------------------- +// double log_posterior(const arma::vec& parameters) override { +// // Skeleton: flat log-density +// // Real models will: +// // - unpack parameters +// // - compute likelihood +// // - add priors +// return 0.0; +// } + +// // -------------------------------------------------- +// // Gradient of LOG (LIKELIHOOD * PRIOR) +// // -------------------------------------------------- +// void gradient(const arma::vec& parameters) override { +// // Skeleton: +// } + +// // -------------------------------------------------- +// // One Metropolis–Hastings step +// // -------------------------------------------------- +// void do_one_mh_step(arma::vec& parameters) override { +// // Skeleton: +// } + +// // -------------------------------------------------- +// // Required interface +// // -------------------------------------------------- +// size_t parameter_dimension() const override { +// return dim_; +// } + +// void set_seed(int seed) override { +// rng_ = SafeRNG(seed); +// } + +// protected: +// // -------------------------------------------------- +// // Data +// // -------------------------------------------------- +// arma::mat observations_; +// arma::mat inclusion_probability_; +// arma::imat edge_indicators_; + +// // -------------------------------------------------- +// // Dimensions +// // -------------------------------------------------- +// size_t n_ = 0; // number of observations +// size_t p_ = 0; // number of variables +// size_t dim_ = 0; // dimension of parameter vector + +// // -------------------------------------------------- +// // Parameters & MCMC machinery +// // -------------------------------------------------- +// arma::vec parameters_; +// AdaptiveMetropolis adaptive_mh_; +// SafeRNG rng_; +// }; From 63406f862f97d3d2951d249b1e5d5b4a6f382132 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Fri, 6 Feb 2026 16:57:31 +0100 Subject: [PATCH 14/23] almost everything functional, next up are the priors --- src/mcmc/hmc_sampler.h | 177 ++++++++++++++++++++++++------------ src/mcmc/mcmc_hmc.cpp | 11 ++- src/mcmc/mcmc_leapfrog.cpp | 4 +- src/mcmc/mcmc_memoization.h | 170 +++++++++++----------------------- src/mcmc/mcmc_nuts.cpp | 78 ++++++++++++++++ src/mcmc/mcmc_nuts.h | 10 +- src/mcmc/mcmc_runner.h | 6 +- src/mcmc/mcmc_utils.cpp | 25 ++++- src/mcmc/nuts_sampler.h | 67 ++++---------- src/omrf_model.cpp | 120 ++++++++++++++++++++++-- 10 files changed, 428 insertions(+), 240 deletions(-) diff --git a/src/mcmc/hmc_sampler.h b/src/mcmc/hmc_sampler.h index a4750409..d4ba605e 100644 --- a/src/mcmc/hmc_sampler.h +++ b/src/mcmc/hmc_sampler.h @@ -2,93 +2,150 @@ #include #include +#include #include "base_sampler.h" #include "mcmc_utils.h" #include "mcmc_hmc.h" +#include "mcmc_adaptation.h" #include "sampler_config.h" #include "../base_model.h" /** * HMCSampler - Hamiltonian Monte Carlo sampler * - * Uses fixed-length leapfrog integration with optional step size - * adaptation during warmup via dual averaging. + * Uses fixed-length leapfrog integration with full warmup adaptation: + * - Heuristic initial step size + * - Diagonal mass matrix estimation (windowed) + * - Dual averaging step size adaptation * - * The sampler fully owns all sampling logic. The model only provides: - * - logp_and_gradient(theta): Compute log posterior and gradient - * - get_vectorized_parameters(): Get current state as vector - * - set_vectorized_parameters(theta): Update model state from vector - * - get_active_inv_mass(): Get inverse mass diagonal - * - get_rng(): Get random number generator + * Warmup schedule mirrors NUTSSampler (3 stages). */ class HMCSampler : public BaseSampler { public: - /** - * Construct HMC sampler with configuration - * @param config Sampler configuration - */ explicit HMCSampler(const SamplerConfig& config) : step_size_(config.initial_step_size), target_acceptance_(config.target_acceptance), num_leapfrogs_(config.num_leapfrogs), no_warmup_(config.no_warmup), - warmup_iteration_(0) + warmup_iteration_(0), + initialized_(false), + step_adapter_(config.initial_step_size) { - // Initialize dual averaging state - dual_avg_state_.set_size(3); - dual_avg_state_(0) = std::log(step_size_); - dual_avg_state_(1) = std::log(step_size_); - dual_avg_state_(2) = 0.0; + build_warmup_schedule(config.no_warmup); } - /** - * Perform one HMC step during warmup (with step size adaptation) - */ SamplerResult warmup_step(BaseModel& model) override { + if (!initialized_) { + initialize(model); + initialized_ = true; + } + SamplerResult result = do_hmc_step(model); - // Update step size via dual averaging - warmup_iteration_++; - update_step_size_with_dual_averaging( - step_size_, - result.accept_prob, - warmup_iteration_, - dual_avg_state_, - target_acceptance_ - ); - step_size_ = std::exp(dual_avg_state_(0)); + // Adapt step size during all warmup phases + step_adapter_.update(result.accept_prob, target_acceptance_); + step_size_ = step_adapter_.current(); + + // During Stage 2, accumulate samples for mass matrix estimation + if (in_stage2()) { + arma::vec full_params = model.get_full_vectorized_parameters(); + mass_accumulator_->update(full_params); + + if (at_window_end()) { + inv_mass_ = mass_accumulator_->variance(); + mass_accumulator_->reset(); + + model.set_inv_mass(inv_mass_); + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + arma::vec active_inv_mass = model.get_active_inv_mass(); + + double new_eps = heuristic_initial_step_size( + theta, log_post, grad_fn, active_inv_mass, rng, + 0.625, step_size_); + step_size_ = new_eps; + step_adapter_.restart(new_eps); + } + } + + warmup_iteration_++; return result; } - /** - * Finalize warmup phase (fix step size to averaged value) - */ void finalize_warmup() override { - step_size_ = std::exp(dual_avg_state_(1)); + step_size_ = step_adapter_.averaged(); } - /** - * Perform one HMC step during sampling (fixed step size) - */ SamplerResult sample_step(BaseModel& model) override { return do_hmc_step(model); } double get_step_size() const { return step_size_; } - double get_averaged_step_size() const { return std::exp(dual_avg_state_(1)); } + double get_averaged_step_size() const { return step_adapter_.averaged(); } private: - /** - * Execute one HMC step using the model's interface - */ + void build_warmup_schedule(int n_warmup) { + stage1_end_ = static_cast(0.075 * n_warmup); + stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); + + window_ends_.clear(); + int cur = stage1_end_; + int wsize = 25; + + while (cur < stage3_start_) { + int win = std::min(wsize, stage3_start_ - cur); + window_ends_.push_back(cur + win); + cur += win; + wsize = std::min(wsize * 2, stage3_start_ - cur); + } + } + + bool in_stage2() const { + return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; + } + + bool at_window_end() const { + for (int end : window_ends_) { + if (warmup_iteration_ + 1 == end) return true; + } + return false; + } + + void initialize(BaseModel& model) { + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + inv_mass_ = arma::ones(model.full_parameter_dimension()); + model.set_inv_mass(inv_mass_); + + mass_accumulator_ = std::make_unique( + static_cast(model.full_parameter_dimension())); + + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + step_size_ = heuristic_initial_step_size( + theta, log_post, grad_fn, rng, target_acceptance_); + + step_adapter_.restart(step_size_); + } + SamplerResult do_hmc_step(BaseModel& model) { - // Get current state arma::vec theta = model.get_vectorized_parameters(); arma::vec inv_mass = model.get_active_inv_mass(); SafeRNG& rng = model.get_rng(); - // Create log posterior and gradient functions that call the model auto log_post = [&model](const arma::vec& params) -> double { return model.logp_and_gradient(params).first; }; @@ -96,27 +153,33 @@ class HMCSampler : public BaseSampler { return model.logp_and_gradient(params).second; }; - // Call the HMC free function SamplerResult result = hmc_sampler( - theta, - step_size_, - log_post, - grad_fn, - num_leapfrogs_, - inv_mass, - rng - ); - - // Update model state with new parameters - model.set_vectorized_parameters(result.state); + theta, step_size_, log_post, grad_fn, + num_leapfrogs_, inv_mass, rng); + model.set_vectorized_parameters(result.state); return result; } + // Configuration double step_size_; double target_acceptance_; int num_leapfrogs_; int no_warmup_; + + // State tracking int warmup_iteration_; - arma::vec dual_avg_state_; + bool initialized_; + + // Step size adaptation + DualAveraging step_adapter_; + + // Mass matrix adaptation + arma::vec inv_mass_; + std::unique_ptr mass_accumulator_; + + // Warmup schedule + int stage1_end_; + int stage3_start_; + std::vector window_ends_; }; diff --git a/src/mcmc/mcmc_hmc.cpp b/src/mcmc/mcmc_hmc.cpp index adb3c022..dde56bed 100644 --- a/src/mcmc/mcmc_hmc.cpp +++ b/src/mcmc/mcmc_hmc.cpp @@ -24,13 +24,22 @@ SamplerResult hmc_sampler( theta, r, step_size, grad, num_leapfrogs, inv_mass_diag ); + // If leapfrog produced NaN/Inf, reject immediately + if (theta.has_nan() || theta.has_inf() || r.has_nan() || r.has_inf()) { + return {init_theta, 0.0}; + } + // Hamiltonians double current_H = -log_post(init_theta) + kinetic_energy(init_r, inv_mass_diag); double proposed_H = -log_post(theta) + kinetic_energy(r, inv_mass_diag); double log_accept_prob = current_H - proposed_H; - arma::vec state = (MY_LOG(runif(rng)) < log_accept_prob) ? theta : init_theta; + // NaN guard: treat non-finite Hamiltonian as rejection + if (!std::isfinite(log_accept_prob)) { + return {init_theta, 0.0}; + } + arma::vec state = (MY_LOG(runif(rng)) < log_accept_prob) ? theta : init_theta; double accept_prob = std::min(1.0, MY_EXP(log_accept_prob)); return {state, accept_prob}; diff --git a/src/mcmc/mcmc_leapfrog.cpp b/src/mcmc/mcmc_leapfrog.cpp index d5d7c120..cb7a8895 100644 --- a/src/mcmc/mcmc_leapfrog.cpp +++ b/src/mcmc/mcmc_leapfrog.cpp @@ -81,11 +81,11 @@ std::pair leapfrog_memo( arma::vec r_half = r; arma::vec theta_new = theta; - auto grad1 = memo.cached_grad(theta_new); + const arma::vec& grad1 = memo.cached_grad(theta_new); r_half += 0.5 * eps * grad1; theta_new += eps * (inv_mass_diag % r_half); - auto grad2 = memo.cached_grad(theta_new); + const arma::vec& grad2 = memo.cached_grad(theta_new); r_half += 0.5 * eps * grad2; return {theta_new, r_half}; diff --git a/src/mcmc/mcmc_memoization.h b/src/mcmc/mcmc_memoization.h index c28574ed..a5ad6947 100644 --- a/src/mcmc/mcmc_memoization.h +++ b/src/mcmc/mcmc_memoization.h @@ -1,143 +1,77 @@ #pragma once #include -#include // for std::unordered_map -#include // for std::function - - - -/** - * Struct: VecHash - * - * A hash function is a way of converting complex input data (like a vector of numbers) - * into a single fixed-size integer value. This hash acts like a digital fingerprint — - * it helps uniquely identify the input and allows fast lookup in data structures - * like hash tables. - * - * In C++, hash tables are implemented using containers like std::unordered_map, - * which require a hash function and an equality comparator for the key type. - * Since arma::vec (a vector of doubles) isn't natively supported as a key, - * this struct defines how to compute a hash for it. - * - * Why use hashing here? - * - We cache (memoize) values like log-posterior or gradient at specific points. - * - To retrieve them quickly later, we need a fast way to index by vector. - * - Hashing gives constant-time access (on average) by turning the vector into a hash code. - * - * This implementation: - * - Initializes a seed based on the vector size. - * - Iteratively incorporates each element's hash into the seed using bitwise operations - * and a constant based on the golden ratio to reduce collisions. - * - * The result is a unique and repeatable hash value that lets us efficiently - * store and retrieve cached evaluations. - * - * Provides a custom hash function for arma::vec so that it can be used as a key - * in standard hash-based containers like std::unordered_map. - * - * Since arma::vec is not hashable by default in C++, this functor allows you - * to compute a combined hash from the individual elements of the vector. - * - * Hashing Strategy: - * - Starts with an initial seed equal to the vector's length. - * - Iteratively combines the hash of each element into the seed using - * a standard hashing recipe (XOR, bit-shift mix, and golden ratio). - * - * This ensures that vectors with similar but not identical elements - * produce distinct hash codes, reducing the risk of collisions. - */ -struct VecHash { - std::size_t operator()(const arma::vec& v) const { - std::size_t seed = v.n_elem; - for (size_t i = 0; i < v.n_elem; i++) { - seed ^= std::hash{}(v[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - - - -/** - * Struct: VecEqual - * - * Provides a custom equality comparator for arma::vec, to be used in conjunction - * with VecHash in unordered_map and similar containers. - * - * By default, arma::vec does not implement equality suitable for use as a key - * in hashed containers. This comparator solves that by relying on Armadillo's - * approx_equal function. - * - * Comparison Strategy: - * - Uses "absdiff" mode to check if the absolute difference between vectors - * is below a specified tolerance. - * - The tolerance is hardcoded as 1e-12, assuming numerical precision errors - * should not cause false inequality. - * - * This is crucial for numerical stability, especially when the same vector - * might be computed multiple times with slight rounding differences. - */ -struct VecEqual { - bool operator()(const arma::vec& a, const arma::vec& b) const { - return arma::approx_equal(a, b, "absdiff", 1e-12); - } -}; - +#include /** * Class: Memoizer * - * A utility class that caches (memoizes) evaluations of the log-posterior and its gradient - * at specific parameter values (theta) to avoid redundant and costly recomputation. + * Single-entry cache for joint log-posterior and gradient evaluations. * - * Usage: - * - Constructed with a log-posterior function and its gradient. - * - Uses hash-based caching (via unordered_map) keyed by arma::vec inputs. - * That means: - * - Each parameter vector (theta) is used as a key. - * - Results are stored in a lookup table so that repeated evaluations can be skipped. - * - A custom VecHash function generates a hash code for arma::vec. - * - A custom VecEqual comparator defines when two vectors are considered equal (with a small tolerance). - * This enables fast retrieval (average constant time) of previously computed values for log_post and grad. + * In NUTS, the typical access pattern within a leapfrog step is: + * 1. cached_grad(theta) — compute gradient (and cache logp as side-effect) + * 2. cached_log_post(theta) — retrieve the already-cached logp * - * Methods: - * - cached_log_post(const arma::vec& theta): - * Returns the cached log-posterior if available, otherwise computes, stores, and returns it. + * A single-entry cache is optimal here because each leapfrog step produces + * a new unique theta: hash-map lookups would almost never hit, and hashing + * an arma::vec element-by-element is expensive. * - * - cached_grad(const arma::vec& theta): - * Returns the cached gradient if available, otherwise computes, stores, and returns it. - * - * Internal Details: - * - Relies on VecHash and VecEqual to use arma::vec as map keys. - * - Assumes high-precision comparison for theta keys (absdiff tolerance 1e-12). + * The joint evaluation function computes both logp and gradient together + * (since models often share most of the computation between the two). */ class Memoizer { public: - std::function log_post; - std::function grad; + using JointFn = std::function(const arma::vec&)>; + + JointFn joint_fn; - std::unordered_map logp_cache; - std::unordered_map grad_cache; + // Single-entry cache + arma::vec cached_theta; + double cached_logp_val; + arma::vec cached_grad_val; + bool has_cache = false; + /** + * Construct from separate log_post and grad functions. + * Calls them independently (backward-compatible). + */ Memoizer( const std::function& lp, const std::function& gr - ) : log_post(lp), grad(gr) {} + ) : joint_fn([lp, gr](const arma::vec& theta) -> std::pair { + arma::vec g = gr(theta); + double v = lp(theta); + return {v, std::move(g)}; + }) {} + + /** + * Construct from a joint function that computes both at once. + */ + explicit Memoizer(JointFn jf) : joint_fn(std::move(jf)) {} double cached_log_post(const arma::vec& theta) { - auto it = logp_cache.find(theta); - if (it != logp_cache.end()) return it->second; - double val = log_post(theta); - logp_cache[theta] = val; - return val; + ensure_cached(theta); + return cached_logp_val; + } + + const arma::vec& cached_grad(const arma::vec& theta) { + ensure_cached(theta); + return cached_grad_val; } - arma::vec cached_grad(const arma::vec& theta) { - auto it = grad_cache.find(theta); - if (it != grad_cache.end()) return it->second; - arma::vec val = grad(theta); - grad_cache[theta] = val; - return val; +private: + void ensure_cached(const arma::vec& theta) { + if (has_cache && + theta.n_elem == cached_theta.n_elem && + std::memcmp(theta.memptr(), cached_theta.memptr(), + theta.n_elem * sizeof(double)) == 0) { + return; + } + auto [lp, gr] = joint_fn(theta); + cached_theta = theta; + cached_logp_val = lp; + cached_grad_val = std::move(gr); + has_cache = true; } }; \ No newline at end of file diff --git a/src/mcmc/mcmc_nuts.cpp b/src/mcmc/mcmc_nuts.cpp index e47c7703..070b6de7 100644 --- a/src/mcmc/mcmc_nuts.cpp +++ b/src/mcmc/mcmc_nuts.cpp @@ -260,5 +260,83 @@ SamplerResult nuts_sampler( diag->divergent = any_divergence; diag->energy = energy; + return {theta, accept_prob, diag}; +} + + +SamplerResult nuts_sampler_joint( + const arma::vec& init_theta, + double step_size, + const std::function(const arma::vec&)>& joint_fn, + const arma::vec& inv_mass_diag, + SafeRNG& rng, + int max_depth +) { + Memoizer memo(joint_fn); + bool any_divergence = false; + + arma::vec r0 = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, init_theta.n_elem); + auto logp0 = memo.cached_log_post(init_theta); + double kin0 = kinetic_energy(r0, inv_mass_diag); + double joint0 = logp0 - kin0; + double log_u = log(runif(rng)) + joint0; + arma::vec theta_min = init_theta, r_min = r0; + arma::vec theta_plus = init_theta, r_plus = r0; + arma::vec theta = init_theta; + arma::vec r = r0; + int j = 0; + int n = 1, s = 1; + + double alpha = 0.5; + int n_alpha = 1; + + while (s == 1 && j < max_depth) { + int v = runif(rng) < 0.5 ? -1 : 1; + + BuildTreeResult result; + if (v == -1) { + result = build_tree( + theta_min, r_min, log_u, v, j, step_size, init_theta, r0, logp0, kin0, + memo, inv_mass_diag, rng + ); + theta_min = result.theta_min; + r_min = result.r_min; + } else { + result = build_tree( + theta_plus, r_plus, log_u, v, j, step_size, init_theta, r0, logp0, kin0, + memo, inv_mass_diag, rng + ); + theta_plus = result.theta_plus; + r_plus = result.r_plus; + } + + any_divergence = any_divergence || result.divergent; + alpha = result.alpha; + n_alpha = result.n_alpha; + + if (result.s_prime == 1) { + double prob = static_cast(result.n_prime) / static_cast(n); + if (runif(rng) < prob) { + theta = result.theta_prime; + r = result.r_prime; + } + } + bool no_uturn = !is_uturn(theta_min, theta_plus, r_min, r_plus, inv_mass_diag); + s = result.s_prime * no_uturn; + n += result.n_prime; + j++; + } + + double accept_prob = alpha / static_cast(n_alpha); + + auto logp_final = memo.cached_log_post(theta); + double kin_final = kinetic_energy(r, inv_mass_diag); + double energy = -logp_final + kin_final; + + auto diag = std::make_shared(); + diag->tree_depth = j; + diag->divergent = any_divergence; + diag->energy = energy; + return {theta, accept_prob, diag}; } \ No newline at end of file diff --git a/src/mcmc/mcmc_nuts.h b/src/mcmc/mcmc_nuts.h index 099e1c4d..a730684f 100644 --- a/src/mcmc/mcmc_nuts.h +++ b/src/mcmc/mcmc_nuts.h @@ -56,4 +56,12 @@ SamplerResult nuts_sampler(const arma::vec& init_theta, const std::function& grad, const arma::vec& inv_mass_diag, SafeRNG& rng, - int max_depth = 10); \ No newline at end of file + int max_depth = 10); + +SamplerResult nuts_sampler_joint( + const arma::vec& init_theta, + double step_size, + const std::function(const arma::vec&)>& joint_fn, + const arma::vec& inv_mass_diag, + SafeRNG& rng, + int max_depth = 10); \ No newline at end of file diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h index 0d80f0cb..4187453a 100644 --- a/src/mcmc/mcmc_runner.h +++ b/src/mcmc/mcmc_runner.h @@ -28,10 +28,12 @@ inline std::unique_ptr create_sampler(const SamplerConfig& config) { if (config.sampler_type == "nuts") { return std::make_unique(config, config.no_warmup); - } else if (config.sampler_type == "hmc") { + } else if (config.sampler_type == "hmc" || config.sampler_type == "hamiltonian-mc") { return std::make_unique(config); - } else { + } else if (config.sampler_type == "mh" || config.sampler_type == "adaptive-metropolis") { return std::make_unique(config); + } else { + Rcpp::stop("Unknown sampler_type: '%s'", config.sampler_type.c_str()); } } diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 51daed3c..883b371f 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "mcmc/mcmc_leapfrog.h" #include "mcmc/mcmc_utils.h" #include "rng/rng_utils.h" @@ -73,10 +74,22 @@ double heuristic_initial_step_size( double kin1 = kinetic_energy(r_new, inv_mass_diag); double H1 = logp1 - kin1; - int direction = 2 * (H1 - H0 > MY_LOG(0.5)) - 1; // +1 or -1 + // NaN guard: treat non-finite H as bad step (force halving) + auto safe_delta_H = [](double H1, double H0) -> double { + double delta = H1 - H0; + return std::isfinite(delta) ? delta : -std::numeric_limits::infinity(); + }; + + int direction = 2 * (safe_delta_H(H1, H0) > MY_LOG(0.5)) - 1; // +1 or -1 int attempts = 0; - while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { + while (attempts < max_attempts) { + double delta = safe_delta_H(H1, H0); + bool keep_going = (direction == 1) + ? (delta > -MY_LOG(2.0)) + : (delta < MY_LOG(2.0)); + if (!keep_going) break; + eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; // Resample momentum on each iteration for step size search @@ -148,7 +161,13 @@ double heuristic_initial_step_size( int direction = 2 * (H1 - H0 > MY_LOG(0.5)) - 1; // +1 or -1 int attempts = 0; - while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { + while (attempts < max_attempts) { + double delta = safe_delta_H(H1, H0); + bool keep_going = (direction == 1) + ? (delta > -MY_LOG(2.0)) + : (delta < MY_LOG(2.0)); + if (!keep_going) break; + eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; // Resample momentum (STAN resamples on each iteration) diff --git a/src/mcmc/nuts_sampler.h b/src/mcmc/nuts_sampler.h index 9fa694fb..6aa2bbca 100644 --- a/src/mcmc/nuts_sampler.h +++ b/src/mcmc/nuts_sampler.h @@ -8,7 +8,6 @@ #include "mcmc_utils.h" #include "mcmc_nuts.h" #include "mcmc_adaptation.h" -#include "mcmc_leapfrog.h" #include "sampler_config.h" #include "../base_model.h" @@ -52,10 +51,14 @@ class NUTSSampler : public BaseSampler { // Push adapted mass matrix to model model.set_inv_mass(inv_mass_); - // Re-run heuristic step size with new mass matrix arma::vec theta = model.get_vectorized_parameters(); SafeRNG& rng = model.get_rng(); - auto [log_post, grad_fn] = make_model_functions(model); + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; arma::vec active_inv_mass = model.get_active_inv_mass(); double new_eps = heuristic_initial_step_size( @@ -111,59 +114,22 @@ class NUTSSampler : public BaseSampler { return false; } - /** - * Create log_post and grad lambdas that share a single logp_and_gradient call. - * Avoids doubling computation when the memoizer requests both at the same point. - */ - static std::pair< - std::function, - std::function - > make_model_functions(BaseModel& model) { - struct JointCache { - arma::vec theta; - double logp; - arma::vec grad; - bool valid = false; - }; - auto cache = std::make_shared(); - - auto ensure = [&model, cache](const arma::vec& params) { - if (!cache->valid || - params.n_elem != cache->theta.n_elem || - !arma::approx_equal(params, cache->theta, "absdiff", 1e-14)) { - auto [lp, gr] = model.logp_and_gradient(params); - cache->theta = params; - cache->logp = lp; - cache->grad = std::move(gr); - cache->valid = true; - } - }; - - auto log_post = [ensure, cache](const arma::vec& params) -> double { - ensure(params); - return cache->logp; - }; - auto grad_fn = [ensure, cache](const arma::vec& params) -> arma::vec { - ensure(params); - return cache->grad; - }; - - return {log_post, grad_fn}; - } - void initialize(BaseModel& model) { arma::vec theta = model.get_vectorized_parameters(); SafeRNG& rng = model.get_rng(); - // Initialize inverse mass from model (defaults to ones) inv_mass_ = arma::ones(model.full_parameter_dimension()); model.set_inv_mass(inv_mass_); - // Mass matrix accumulator uses full dimension mass_accumulator_ = std::make_unique( static_cast(model.full_parameter_dimension())); - auto [log_post, grad_fn] = make_model_functions(model); + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; step_size_ = heuristic_initial_step_size( theta, log_post, grad_fn, rng, target_acceptance_); @@ -175,12 +141,15 @@ class NUTSSampler : public BaseSampler { arma::vec theta = model.get_vectorized_parameters(); SafeRNG& rng = model.get_rng(); - auto [log_post, grad_fn] = make_model_functions(model); + auto joint_fn = [&model](const arma::vec& params) + -> std::pair { + return model.logp_and_gradient(params); + }; arma::vec active_inv_mass = model.get_active_inv_mass(); - SamplerResult result = nuts_sampler( - theta, step_size_, log_post, grad_fn, + SamplerResult result = nuts_sampler_joint( + theta, step_size_, joint_fn, active_inv_mass, rng, max_tree_depth_ ); diff --git a/src/omrf_model.cpp b/src/omrf_model.cpp index f189e216..64cb4e63 100644 --- a/src/omrf_model.cpp +++ b/src/omrf_model.cpp @@ -926,12 +926,10 @@ arma::vec OMRFModel::gradient_internal() const { std::pair OMRFModel::logp_and_gradient(const arma::vec& parameters) { - // Ensure gradient cache is initialized ensure_gradient_cache(); - // Use the external-state versions - arma::mat temp_main = main_effects_; - arma::mat temp_pairwise = pairwise_effects_; + arma::mat temp_main(main_effects_.n_rows, main_effects_.n_cols, arma::fill::none); + arma::mat temp_pairwise(p_, p_, arma::fill::zeros); // Unvectorize parameters into temporaries int offset = 0; @@ -958,10 +956,118 @@ std::pair OMRFModel::logp_and_gradient(const arma::vec& param arma::mat temp_residual = observations_double_ * temp_pairwise; - double lp = log_pseudoposterior_with_state(temp_main, temp_pairwise, temp_residual); - arma::vec grad = gradient_with_state(temp_main, temp_pairwise, temp_residual); + // Initialize gradient from cached observed statistics + arma::vec gradient = grad_obs_cache_; + double log_post = 0.0; + + // Merged per-variable loop: compute probability table ONCE per variable + // and derive both logp and gradient contributions from it. + offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_cats = num_categories_(v); + arma::vec residual_score = temp_residual.col(v); + arma::vec bound = num_cats * residual_score; + + if (is_ordinal_variable_(v)) { + arma::vec main_param = temp_main.row(v).cols(0, num_cats - 1).t(); + + // Prior + sufficient statistics for logp + for (int c = 0; c < num_cats; ++c) { + double x = temp_main(v, c); + log_post += x * main_alpha_ - std::log1p(std::exp(x)) * (main_alpha_ + main_beta_); + log_post += x * counts_per_category_(c + 1, v); + } + + // Compute probability table ONCE (replaces separate denom + probs) + arma::mat probs = compute_probs_ordinal(main_param, residual_score, bound, num_cats); + + // Log-partition: bound + log(denom) = -log(probs(:, 0)) + log_post += arma::accu(ARMA_MY_LOG(probs.col(0))); + + // Gradient: main effects + for (int c = 0; c < num_cats; ++c) { + gradient(offset + c) -= arma::accu(probs.col(c + 1)); + } + + // Gradient: pairwise effects + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + arma::vec expected_value = arma::zeros(n_); + for (int c = 1; c <= num_cats; ++c) { + expected_value += c * probs.col(c) % observations_double_.col(j); + } + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += num_cats; + } else { + int ref = baseline_category_(v); + double lin_eff = temp_main(v, 0); + double quad_eff = temp_main(v, 1); + + // Prior + sufficient statistics for logp + log_post += lin_eff * main_alpha_ - std::log1p(std::exp(lin_eff)) * (main_alpha_ + main_beta_); + log_post += quad_eff * main_alpha_ - std::log1p(std::exp(quad_eff)) * (main_alpha_ + main_beta_); + log_post += lin_eff * blume_capel_stats_(0, v); + log_post += quad_eff * blume_capel_stats_(1, v); + + // Compute probability table ONCE (bound is updated in-place) + arma::mat probs = compute_probs_blume_capel(residual_score, lin_eff, quad_eff, ref, num_cats, bound); + + // Log-partition: bound + log(denom) = ref * r - log(probs(:, ref)) + log_post -= static_cast(ref) * arma::accu(residual_score) + - arma::accu(ARMA_MY_LOG(probs.col(ref))); + + // Gradient: main effects + arma::vec score = arma::regspace(0, num_cats) - static_cast(ref); + arma::vec sq_score = arma::square(score); + gradient(offset) -= arma::accu(probs * score); + gradient(offset + 1) -= arma::accu(probs * sq_score); + + // Gradient: pairwise effects + for (size_t j = 0; j < p_; ++j) { + if (edge_indicators_(v, j) == 0 || v == j) continue; + arma::vec expected_value = arma::zeros(n_); + for (int c = 0; c <= num_cats; ++c) { + int s = c - ref; + expected_value += s * probs.col(c) % observations_double_.col(j); + } + int location = (v < j) ? index_matrix_cache_(v, j) : index_matrix_cache_(j, v); + gradient(location) -= arma::accu(expected_value); + } + + offset += 2; + } + } + + // Gradient: main effect prior (Beta-prime) + offset = 0; + for (size_t v = 0; v < p_; ++v) { + int num_pars = is_ordinal_variable_(v) ? num_categories_(v) : 2; + for (int c = 0; c < num_pars; ++c) { + double x = temp_main(v, c); + double prob = 1.0 / (1.0 + std::exp(-x)); + gradient(offset + c) += main_alpha_ - (main_alpha_ + main_beta_) * prob; + } + offset += num_pars; + } + + // Pairwise: sufficient statistics + Cauchy prior (logp) and gradient + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + double effect = temp_pairwise(v1, v2); + log_post += 2.0 * pairwise_stats_(v1, v2) * effect; + log_post += R::dcauchy(effect, 0.0, pairwise_scale_, true); + + int idx = index_matrix_cache_(v1, v2); + gradient(idx) -= 2.0 * effect / (pairwise_scale_ * pairwise_scale_ + effect * effect); + } + } + } - return {lp, grad}; + return {log_post, std::move(gradient)}; } From a0a08a584034be24f6a50eba8f54f9878bea208f Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 9 Feb 2026 11:27:54 +0100 Subject: [PATCH 15/23] changes for imputation and edge priors --- R/RcppExports.R | 4 +- R/bgm.R | 50 ++++++++ R/output_utils.R | 33 +++++ src/RcppExports.cpp | 17 ++- src/base_model.h | 21 ++++ src/ggm_model.h | 18 +++ src/mcmc/mcmc_runner.h | 49 +++++++- src/mcmc/sampler_config.h | 3 + src/omrf_model.cpp | 111 +++++++++++------ src/omrf_model.h | 14 ++- src/priors/edge_prior.h | 249 ++++++++++++++++++++++++++++++++++++++ src/sample_ggm.cpp | 6 +- src/sample_omrf.cpp | 45 ++++++- 13 files changed, 567 insertions(+), 53 deletions(-) create mode 100644 src/priors/edge_prior.h diff --git a/R/RcppExports.R b/R/RcppExports.R index b05ee7b4..fef77cd4 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -37,8 +37,8 @@ sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) } -sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, edge_selection_start = -1L) { - .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start) +sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, edge_selection_start = -1L) { + .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start) } compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) { diff --git a/R/bgm.R b/R/bgm.R index e0cca531..b7f5a9c1 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -400,6 +400,7 @@ bgm = function( chains = 4, cores = parallel::detectCores(), display_progress = c("per-chain", "total", "none"), + backend = c("legacy", "new"), seed = NULL, standardize = FALSE, interaction_scale, @@ -551,6 +552,9 @@ bgm = function( )) } + # Check backend --------------------------------------------------------------- + backend = match.arg(backend) + # Check display_progress ------------------------------------------------------ progress_type = progress_type_from_display_progress(display_progress) @@ -592,6 +596,9 @@ bgm = function( } } + # Save data before Blume-Capel centering (needed by the new backend) + x_raw = x + # Precompute the sufficient statistics for the two Blume-Capel parameters ----- blume_capel_stats = matrix(0, nrow = 2, ncol = num_variables) if(any(!variable_bool)) { @@ -691,6 +698,47 @@ bgm = function( seed <- as.integer(seed) + if (backend == "new") { + input_list = list( + observations = x_raw, + num_categories = num_categories, + is_ordinal_variable = variable_bool, + baseline_category = baseline_category, + main_alpha = main_alpha, + main_beta = main_beta, + pairwise_scale = pairwise_scale + ) + + out_raw = sample_omrf( + inputFromR = input_list, + prior_inclusion_prob = matrix(inclusion_probability, + nrow = num_variables, ncol = num_variables), + initial_edge_indicators = indicator, + no_iter = iter, + no_warmup = warmup, + no_chains = chains, + no_threads = cores, + progress_type = progress_type, + edge_selection = edge_selection, + sampler_type = update_method, + seed = seed, + edge_prior = edge_prior, + na_impute = na_impute, + missing_index = missing_index, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda, + target_acceptance = target_accept, + max_tree_depth = nuts_max_depth, + num_leapfrogs = hmc_num_leapfrogs + ) + + out = transform_new_backend_output(out_raw, num_thresholds) + } else { + out = run_bgm_parallel( observations = x, num_categories = num_categories, pairwise_scale = pairwise_scale, edge_prior = edge_prior, @@ -716,6 +764,8 @@ bgm = function( pairwise_scaling_factors = pairwise_scaling_factors ) + } # end backend branch + userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) if(userInterrupt) { warning("Stopped sampling after user interrupt, results are likely uninterpretable.") diff --git a/R/output_utils.R b/R/output_utils.R index 1ae67cfc..bb21ce5e 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -178,6 +178,39 @@ prepare_output_bgm = function( } +# Transform sample_omrf output to match the old backend format. +# +# The new backend returns a flat `samples` matrix (params x iters) containing +# all main + pairwise parameters concatenated. The old backend stores separate +# `main_samples` and `pairwise_samples` matrices (iters x params). NUTS fields +# also differ in naming. This function bridges the gap so that +# `prepare_output_bgm()` can process both backends identically. +transform_new_backend_output = function(out, num_thresholds) { + lapply(out, function(chain) { + samples_t = t(chain$samples) # (params x iters) -> (iters x params) + n_params = ncol(samples_t) + + res = list( + main_samples = samples_t[, seq_len(num_thresholds), drop = FALSE], + pairwise_samples = samples_t[, seq(num_thresholds + 1, n_params), drop = FALSE], + userInterrupt = isTRUE(chain$userInterrupt), + chain_id = chain$chain_id + ) + + if (!is.null(chain$indicator_samples)) { + res$indicator_samples = t(chain$indicator_samples) + } + + # Rename NUTS diagnostics to match old backend convention (trailing __) + if (!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth + if (!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent + if (!is.null(chain$energy)) res[["energy__"]] = chain$energy + + res + }) +} + + # Generate names for bgmCompare parameters generate_param_names_bgmCompare = function( data_columnnames, diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 2bcac928..49b600ba 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -232,8 +232,8 @@ BEGIN_RCPP END_RCPP } // sample_omrf -Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const int edge_selection_start); -RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP edge_selection_startSEXP) { +Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const int edge_selection_start); +RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP edge_selection_startSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -248,11 +248,20 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); + Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); + Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_nullable(missing_index_nullableSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type dirichlet_alpha(dirichlet_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); Rcpp::traits::input_parameter< const double >::type target_acceptance(target_acceptanceSEXP); Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); Rcpp::traits::input_parameter< const int >::type edge_selection_start(edge_selection_startSEXP); - rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start)); + rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start)); return rcpp_result_gen; END_RCPP } @@ -281,7 +290,7 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, {"_bgms_sample_omrf_classed", (DL_FUNC) &_bgms_sample_omrf_classed, 8}, {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 10}, - {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 15}, + {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/base_model.h b/src/base_model.h index 71a3527f..53ff6e3b 100644 --- a/src/base_model.h +++ b/src/base_model.h @@ -105,6 +105,27 @@ class BaseModel { // Default: no-op } + // Missing data imputation + virtual bool has_missing_data() const { return false; } + virtual void impute_missing() { + // Default: no-op + } + + // Edge prior support: models expose these so the external edge prior + // class can read current indicators and update inclusion probabilities. + virtual const arma::imat& get_edge_indicators() const { + throw std::runtime_error("get_edge_indicators not implemented for this model"); + } + virtual arma::mat& get_inclusion_probability() { + throw std::runtime_error("get_inclusion_probability not implemented for this model"); + } + virtual int get_num_variables() const { + throw std::runtime_error("get_num_variables not implemented for this model"); + } + virtual int get_num_pairwise() const { + throw std::runtime_error("get_num_pairwise not implemented for this model"); + } + protected: BaseModel() = default; double step_size_ = 0.1; diff --git a/src/ggm_model.h b/src/ggm_model.h index 34a67359..5e4cf1b3 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -159,6 +159,24 @@ class GaussianVariables : public BaseModel { return vectorized_indicator_parameters_; } + SafeRNG& get_rng() override { return rng_; } + + const arma::imat& get_edge_indicators() const override { + return edge_indicators_; + } + + arma::mat& get_inclusion_probability() override { + return inclusion_probability_; + } + + int get_num_variables() const override { + return static_cast(p_); + } + + int get_num_pairwise() const override { + return static_cast(p_ * (p_ - 1) / 2); + } + std::unique_ptr clone() const override { return std::make_unique(*this); // uses copy constructor } diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h index 4187453a..c3bf3472 100644 --- a/src/mcmc/mcmc_runner.h +++ b/src/mcmc/mcmc_runner.h @@ -8,6 +8,7 @@ #include "../base_model.h" #include "../chainResultNew.h" +#include "../priors/edge_prior.h" #include "../utils/progress_manager.h" #include "sampler_config.h" #include "base_sampler.h" @@ -53,6 +54,7 @@ inline std::unique_ptr create_sampler(const SamplerConfig& config) inline void run_mcmc_chain( ChainResultNew& chain_result, BaseModel& model, + BaseEdgePrior& edge_prior, const SamplerConfig& config, const int chain_id, ProgressManager& pm @@ -69,7 +71,10 @@ inline void run_mcmc_chain( // ========================================================================= for (int iter = 0; iter < config.no_warmup; ++iter) { - // model->impute_missing_data(); + // Impute missing data if applicable + if (config.na_impute && model.has_missing_data()) { + model.impute_missing(); + } // Edge selection starts after edge_start iterations if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { @@ -79,6 +84,17 @@ inline void run_mcmc_chain( // Sampler step (unified interface) sampler->warmup_step(model); + // Update edge prior parameters (Beta-Bernoulli, SBM, etc.) + if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { + edge_prior.update( + model.get_edge_indicators(), + model.get_inclusion_probability(), + model.get_num_variables(), + model.get_num_pairwise(), + model.get_rng() + ); + } + // Progress and interrupt check pm.update(chain_id); if (pm.shouldExit()) { @@ -103,7 +119,10 @@ inline void run_mcmc_chain( // ========================================================================= for (int iter = 0; iter < config.no_iter; ++iter) { - // model->impute_missing_data(); + // Impute missing data if applicable + if (config.na_impute && model.has_missing_data()) { + model.impute_missing(); + } // Edge selection continues during sampling if (config.edge_selection && model.has_edge_selection()) { @@ -113,6 +132,17 @@ inline void run_mcmc_chain( // Sampler step (unified interface) SamplerResult result = sampler->sample_step(model); + // Update edge prior parameters (Beta-Bernoulli, SBM, etc.) + if (config.edge_selection && model.has_edge_selection()) { + edge_prior.update( + model.get_edge_indicators(), + model.get_inclusion_probability(), + model.get_num_variables(), + model.get_num_pairwise(), + model.get_rng() + ); + } + // Store NUTS diagnostics if available if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { auto* diag = dynamic_cast(result.diagnostics.get()); @@ -145,17 +175,20 @@ inline void run_mcmc_chain( struct MCMCChainRunner : public RcppParallel::Worker { std::vector& results_; std::vector>& models_; + std::vector>& edge_priors_; const SamplerConfig& config_; ProgressManager& pm_; MCMCChainRunner( std::vector& results, std::vector>& models, + std::vector>& edge_priors, const SamplerConfig& config, ProgressManager& pm ) : results_(results), models_(models), + edge_priors_(edge_priors), config_(config), pm_(pm) {} @@ -164,10 +197,11 @@ struct MCMCChainRunner : public RcppParallel::Worker { for (std::size_t i = begin; i < end; ++i) { ChainResultNew& chain_result = results_[i]; BaseModel& model = *models_[i]; + BaseEdgePrior& edge_prior = *edge_priors_[i]; model.set_seed(config_.seed + static_cast(i)); try { - run_mcmc_chain(chain_result, model, config_, static_cast(i), pm_); + run_mcmc_chain(chain_result, model, edge_prior, config_, static_cast(i), pm_); } catch (std::exception& e) { chain_result.error = true; chain_result.error_msg = e.what(); @@ -197,6 +231,7 @@ struct MCMCChainRunner : public RcppParallel::Worker { */ inline std::vector run_mcmc_sampler( BaseModel& model, + BaseEdgePrior& edge_prior, const SamplerConfig& config, const int no_chains, const int no_threads, @@ -222,13 +257,16 @@ inline std::vector run_mcmc_sampler( if (no_threads > 1) { // Multi-threaded execution std::vector> models; + std::vector> edge_priors; models.reserve(no_chains); + edge_priors.reserve(no_chains); for (int c = 0; c < no_chains; ++c) { models.push_back(model.clone()); models[c]->set_seed(config.seed + c); + edge_priors.push_back(edge_prior.clone()); } - MCMCChainRunner runner(results, models, config, pm); + MCMCChainRunner runner(results, models, edge_priors, config, pm); tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); RcppParallel::parallelFor(0, static_cast(no_chains), runner); @@ -238,7 +276,8 @@ inline std::vector run_mcmc_sampler( for (int c = 0; c < no_chains; ++c) { auto chain_model = model.clone(); chain_model->set_seed(config.seed + c); - run_mcmc_chain(results[c], *chain_model, config, c, pm); + auto chain_edge_prior = edge_prior.clone(); + run_mcmc_chain(results[c], *chain_model, *chain_edge_prior, config, c, pm); } } diff --git a/src/mcmc/sampler_config.h b/src/mcmc/sampler_config.h index c985a43d..55a2e62a 100644 --- a/src/mcmc/sampler_config.h +++ b/src/mcmc/sampler_config.h @@ -29,6 +29,9 @@ struct SamplerConfig { bool edge_selection = false; int edge_selection_start = -1; // -1 = no_warmup (default, start at sampling) + // Missing data imputation + bool na_impute = false; + // Random seed int seed = 42; diff --git a/src/omrf_model.cpp b/src/omrf_model.cpp index 64cb4e63..10d95662 100644 --- a/src/omrf_model.cpp +++ b/src/omrf_model.cpp @@ -1227,52 +1227,93 @@ void OMRFModel::initialize_graph() { void OMRFModel::impute_missing() { if (!has_missing_) return; - // For each missing value, sample from conditional distribution - for (size_t m = 0; m < missing_index_.n_rows; ++m) { - int person = missing_index_(m, 0); - int variable = missing_index_(m, 1); - int num_cats = num_categories_(variable); - - arma::vec log_probs; - if (is_ordinal_variable_(variable)) { - log_probs.set_size(num_cats + 1); - log_probs(0) = 0.0; - for (int c = 0; c < num_cats; ++c) { - log_probs(c + 1) = main_effects_(variable, c) + (c + 1) * residual_matrix_(person, variable); + const int num_variables = p_; + const int num_missings = missing_index_.n_rows; + const int max_num_categories = num_categories_.max(); + + arma::vec category_probabilities(max_num_categories + 1); + + for (int miss = 0; miss < num_missings; miss++) { + const int person = missing_index_(miss, 0); + const int variable = missing_index_(miss, 1); + + const double residual_score = residual_matrix_(person, variable); + const int num_cats = num_categories_(variable); + const bool is_ordinal = is_ordinal_variable_(variable); + + double cumsum = 0.0; + + if (is_ordinal) { + cumsum = 1.0; + category_probabilities[0] = cumsum; + for (int cat = 0; cat < num_cats; cat++) { + const int score = cat + 1; + const double exponent = main_effects_(variable, cat) + score * residual_score; + cumsum += MY_EXP(exponent); + category_probabilities[score] = cumsum; } } else { - int baseline = baseline_category_(variable); - log_probs.set_size(num_cats + 1); - for (int c = 0; c <= num_cats; ++c) { - int s = c - baseline; - log_probs(c) = main_effects_(variable, 0) * s + main_effects_(variable, 1) * s * s + s * residual_matrix_(person, variable); + const int ref = baseline_category_(variable); + + cumsum = MY_EXP( + main_effects_(variable, 0) * ref + main_effects_(variable, 1) * ref * ref + ); + category_probabilities[0] = cumsum; + + for (int cat = 0; cat <= num_cats; cat++) { + const int score = cat - ref; + const double exponent = + main_effects_(variable, 0) * score + + main_effects_(variable, 1) * score * score + + score * residual_score; + cumsum += MY_EXP(exponent); + category_probabilities[cat] = cumsum; } } - // Sample from categorical - double max_val = log_probs.max(); - arma::vec probs = arma::exp(log_probs - max_val); - probs /= arma::sum(probs); - - double u = runif(rng_); - double cumsum = 0.0; - int new_value = 0; - for (size_t c = 0; c < probs.n_elem; ++c) { - cumsum += probs(c); - if (u < cumsum) { - new_value = c; - break; - } + // Sample from categorical distribution via inverse transform + const double u = runif(rng_) * cumsum; + int sampled_score = 0; + while (u > category_probabilities[sampled_score]) { + sampled_score++; } - int old_value = observations_(person, variable); + int new_value = sampled_score; + if (!is_ordinal) + new_value -= baseline_category_(variable); + const int old_value = observations_(person, variable); + if (new_value != old_value) { observations_(person, variable) = new_value; - // Update sufficient statistics - compute_sufficient_statistics(); - update_residual_matrix(); + observations_double_(person, variable) = static_cast(new_value); + + if (is_ordinal) { + counts_per_category_(old_value, variable)--; + counts_per_category_(new_value, variable)++; + } else { + const int delta = new_value - old_value; + const int delta_sq = new_value * new_value - old_value * old_value; + blume_capel_stats_(0, variable) += delta; + blume_capel_stats_(1, variable) += delta_sq; + } + + // Incrementally update residuals across all variables + for (int var = 0; var < num_variables; var++) { + const double delta_score = (new_value - old_value) * pairwise_effects_(var, variable); + residual_matrix_(person, var) += delta_score; + } } } + + // Recompute pairwise sufficient statistics + arma::mat ps = observations_double_.t() * observations_double_; + pairwise_stats_ = arma::conv_to::from(ps); +} + + +void OMRFModel::set_missing_data(const arma::imat& missing_index) { + missing_index_ = missing_index; + has_missing_ = (missing_index.n_rows > 0 && missing_index.n_cols == 2); } diff --git a/src/omrf_model.h b/src/omrf_model.h index 03a1dc29..32dc09af 100644 --- a/src/omrf_model.h +++ b/src/omrf_model.h @@ -6,6 +6,7 @@ #include "adaptiveMetropolis.h" #include "rng/rng_utils.h" #include "mcmc/mcmc_utils.h" +#include "utils/common_helpers.h" /** * OMRFModel - Ordinal Markov Random Field Model @@ -63,6 +64,7 @@ class OMRFModel : public BaseModel { bool has_gradient() const override { return true; } bool has_adaptive_mh() const override { return true; } bool has_edge_selection() const override { return edge_selection_; } + bool has_missing_data() const override { return has_missing_; } /** * Compute log-pseudoposterior for given parameter vector @@ -141,7 +143,12 @@ class OMRFModel : public BaseModel { /** * Impute missing values (if any) */ - void impute_missing(); + void impute_missing() override; + + /** + * Set missing data information + */ + void set_missing_data(const arma::imat& missing_index); // ========================================================================= // Accessors @@ -149,13 +156,16 @@ class OMRFModel : public BaseModel { const arma::mat& get_main_effects() const { return main_effects_; } const arma::mat& get_pairwise_effects() const { return pairwise_effects_; } - const arma::imat& get_edge_indicators() const { return edge_indicators_; } + const arma::imat& get_edge_indicators() const override { return edge_indicators_; } + arma::mat& get_inclusion_probability() override { return inclusion_probability_; } const arma::mat& get_residual_matrix() const { return residual_matrix_; } void set_main_effects(const arma::mat& main_effects) { main_effects_ = main_effects; } void set_pairwise_effects(const arma::mat& pairwise_effects); void set_edge_indicators(const arma::imat& edge_indicators) { edge_indicators_ = edge_indicators; } + int get_num_variables() const override { return static_cast(p_); } + int get_num_pairwise() const override { return static_cast(num_pairwise_); } size_t num_variables() const { return p_; } size_t num_observations() const { return n_; } size_t num_main_effects() const { return num_main_; } diff --git a/src/priors/edge_prior.h b/src/priors/edge_prior.h new file mode 100644 index 00000000..02f5c716 --- /dev/null +++ b/src/priors/edge_prior.h @@ -0,0 +1,249 @@ +#pragma once + +#include +#include +#include "../rng/rng_utils.h" +#include "../utils/common_helpers.h" +#include "sbm_edge_prior.h" +#include "../sbm_edge_prior_interface.h" + + +/** + * Abstract base class for edge inclusion priors. + * + * The edge prior updates the inclusion probability matrix based on the + * current edge indicators. This is independent of the model type (GGM, OMRF, + * etc.), so it is implemented as a separate class hierarchy. + * + * The MCMC runner calls update() after each edge indicator update, passing + * the current edge indicators and inclusion probability matrix. The edge + * prior modifies inclusion_probability in place. + */ +class BaseEdgePrior { +public: + virtual ~BaseEdgePrior() = default; + + virtual void update( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + int num_pairwise, + SafeRNG& rng + ) = 0; + + virtual std::unique_ptr clone() const = 0; +}; + + +/** + * Bernoulli edge prior (fixed inclusion probabilities, no update needed). + */ +class BernoulliEdgePrior : public BaseEdgePrior { +public: + void update( + const arma::imat& /*edge_indicators*/, + arma::mat& /*inclusion_probability*/, + int /*num_variables*/, + int /*num_pairwise*/, + SafeRNG& /*rng*/ + ) override { + // No-op: inclusion probabilities are fixed + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } +}; + + +/** + * Beta-Bernoulli edge prior. + * + * Draws a shared inclusion probability from Beta(alpha + #included, + * beta + #excluded) and assigns it to all edges. + */ +class BetaBernoulliEdgePrior : public BaseEdgePrior { +public: + BetaBernoulliEdgePrior(double alpha = 1.0, double beta = 1.0) + : alpha_(alpha), beta_(beta) {} + + void update( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + int num_pairwise, + SafeRNG& rng + ) override { + int num_edges_included = 0; + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + num_edges_included += edge_indicators(i, j); + } + } + + double prob = rbeta(rng, + alpha_ + num_edges_included, + beta_ + num_pairwise - num_edges_included + ); + + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + inclusion_probability(i, j) = prob; + inclusion_probability(j, i) = prob; + } + } + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + double alpha_; + double beta_; +}; + + +/** + * Stochastic Block Model (MFM-SBM) edge prior. + * + * Maintains cluster allocations and block-level inclusion probabilities. + * Each edge's inclusion probability depends on its endpoints' cluster + * assignments. + */ +class StochasticBlockEdgePrior : public BaseEdgePrior { +public: + StochasticBlockEdgePrior( + double beta_bernoulli_alpha, + double beta_bernoulli_beta, + double beta_bernoulli_alpha_between, + double beta_bernoulli_beta_between, + double dirichlet_alpha, + double lambda + ) : beta_bernoulli_alpha_(beta_bernoulli_alpha), + beta_bernoulli_beta_(beta_bernoulli_beta), + beta_bernoulli_alpha_between_(beta_bernoulli_alpha_between), + beta_bernoulli_beta_between_(beta_bernoulli_beta_between), + dirichlet_alpha_(dirichlet_alpha), + lambda_(lambda), + initialized_(false) + {} + + /** + * Initialize SBM state from the current edge indicators. Called + * automatically on first update(). + */ + void initialize( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + SafeRNG& rng + ) { + cluster_allocations_.set_size(num_variables); + cluster_allocations_[0] = 0; + cluster_allocations_[1] = 1; + for (int i = 2; i < num_variables; i++) { + cluster_allocations_[i] = (runif(rng) > 0.5) ? 1 : 0; + } + + cluster_prob_ = block_probs_mfm_sbm( + cluster_allocations_, + arma::conv_to::from(edge_indicators), + num_variables, + beta_bernoulli_alpha_, beta_bernoulli_beta_, + beta_bernoulli_alpha_between_, beta_bernoulli_beta_between_, + rng + ); + + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + inclusion_probability(i, j) = cluster_prob_(cluster_allocations_[i], cluster_allocations_[j]); + inclusion_probability(j, i) = inclusion_probability(i, j); + } + } + + log_Vn_ = compute_Vn_mfm_sbm( + num_variables, dirichlet_alpha_, num_variables + 10, lambda_); + + initialized_ = true; + } + + void update( + const arma::imat& edge_indicators, + arma::mat& inclusion_probability, + int num_variables, + int /*num_pairwise*/, + SafeRNG& rng + ) override { + if (!initialized_) { + initialize(edge_indicators, inclusion_probability, num_variables, rng); + } + + cluster_allocations_ = block_allocations_mfm_sbm( + cluster_allocations_, num_variables, log_Vn_, cluster_prob_, + arma::conv_to::from(edge_indicators), dirichlet_alpha_, + beta_bernoulli_alpha_, beta_bernoulli_beta_, + beta_bernoulli_alpha_between_, beta_bernoulli_beta_between_, rng + ); + + cluster_prob_ = block_probs_mfm_sbm( + cluster_allocations_, + arma::conv_to::from(edge_indicators), num_variables, + beta_bernoulli_alpha_, beta_bernoulli_beta_, + beta_bernoulli_alpha_between_, beta_bernoulli_beta_between_, rng + ); + + for (int i = 0; i < num_variables - 1; i++) { + for (int j = i + 1; j < num_variables; j++) { + inclusion_probability(i, j) = cluster_prob_(cluster_allocations_[i], cluster_allocations_[j]); + inclusion_probability(j, i) = inclusion_probability(i, j); + } + } + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + double beta_bernoulli_alpha_; + double beta_bernoulli_beta_; + double beta_bernoulli_alpha_between_; + double beta_bernoulli_beta_between_; + double dirichlet_alpha_; + double lambda_; + + bool initialized_; + arma::uvec cluster_allocations_; + arma::mat cluster_prob_; + arma::vec log_Vn_; +}; + + +/** + * Factory: create an edge prior from an EdgePrior enum and hyperparameters. + */ +inline std::unique_ptr create_edge_prior( + EdgePrior type, + double beta_bernoulli_alpha = 1.0, + double beta_bernoulli_beta = 1.0, + double beta_bernoulli_alpha_between = 1.0, + double beta_bernoulli_beta_between = 1.0, + double dirichlet_alpha = 1.0, + double lambda = 1.0 +) { + switch (type) { + case Beta_Bernoulli: + return std::make_unique( + beta_bernoulli_alpha, beta_bernoulli_beta); + case Stochastic_Block: + return std::make_unique( + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda); + case Bernoulli: + case Not_Applicable: + default: + return std::make_unique(); + } +} diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index bd635f46..253ea35a 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -6,6 +6,7 @@ #include "ggm_model.h" #include "utils/progress_manager.h" +#include "priors/edge_prior.h" #include "chainResultNew.h" #include "mcmc/mcmc_runner.h" #include "mcmc/sampler_config.h" @@ -40,9 +41,12 @@ Rcpp::List sample_ggm( // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + // Create default edge prior (Bernoulli = no-op) + BernoulliEdgePrior edge_prior; + // Run MCMC using unified infrastructure std::vector results = run_mcmc_sampler( - model, config, no_chains, no_threads, pm); + model, edge_prior, config, no_chains, no_threads, pm); // Convert to R list format Rcpp::List output = convert_results_to_list(results); diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index 6e6c45a0..a57fb790 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -10,6 +10,8 @@ #include "omrf_model.h" #include "utils/progress_manager.h" +#include "utils/common_helpers.h" +#include "priors/edge_prior.h" #include "chainResultNew.h" #include "mcmc/mcmc_runner.h" #include "mcmc/sampler_config.h" @@ -18,8 +20,8 @@ * R-exported function to sample from an OMRF model * * @param inputFromR List with model specification - * @param prior_inclusion_prob Prior inclusion probabilities (p × p matrix) - * @param initial_edge_indicators Initial edge indicators (p × p integer matrix) + * @param prior_inclusion_prob Prior inclusion probabilities (p x p matrix) + * @param initial_edge_indicators Initial edge indicators (p x p integer matrix) * @param no_iter Number of post-warmup iterations * @param no_warmup Number of warmup iterations * @param no_chains Number of parallel chains @@ -28,6 +30,15 @@ * @param seed Random seed * @param no_threads Number of threads for parallel execution * @param progress_type Progress bar type + * @param edge_prior Edge prior type: "Bernoulli", "Beta-Bernoulli", "Stochastic-Block" + * @param na_impute Whether to impute missing data + * @param missing_index Matrix of missing data indices (n_missing x 2, 0-based) + * @param beta_bernoulli_alpha Beta-Bernoulli alpha hyperparameter + * @param beta_bernoulli_beta Beta-Bernoulli beta hyperparameter + * @param beta_bernoulli_alpha_between SBM between-cluster alpha + * @param beta_bernoulli_beta_between SBM between-cluster beta + * @param dirichlet_alpha Dirichlet alpha for SBM + * @param lambda Lambda for SBM * @param target_acceptance Target acceptance rate for NUTS/HMC (default: 0.8) * @param max_tree_depth Maximum tree depth for NUTS (default: 10) * @param num_leapfrogs Number of leapfrog steps for HMC (default: 10) @@ -48,6 +59,15 @@ Rcpp::List sample_omrf( const int seed, const int no_threads, const int progress_type, + const std::string& edge_prior = "Bernoulli", + const bool na_impute = false, + const Rcpp::Nullable missing_index_nullable = R_NilValue, + const double beta_bernoulli_alpha = 1.0, + const double beta_bernoulli_beta = 1.0, + const double beta_bernoulli_alpha_between = 1.0, + const double beta_bernoulli_beta_between = 1.0, + const double dirichlet_alpha = 1.0, + const double lambda = 1.0, const double target_acceptance = 0.8, const int max_tree_depth = 10, const int num_leapfrogs = 10, @@ -57,24 +77,41 @@ Rcpp::List sample_omrf( OMRFModel model = createOMRFModelFromR( inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); + // Set up missing data imputation + if (na_impute && missing_index_nullable.isNotNull()) { + arma::imat missing_index = Rcpp::as( + Rcpp::IntegerMatrix(missing_index_nullable.get())); + model.set_missing_data(missing_index); + } + + // Create edge prior + EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); + auto edge_prior_obj = create_edge_prior( + edge_prior_enum, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda + ); + // Configure sampler SamplerConfig config; config.sampler_type = sampler_type; config.no_iter = no_iter; config.no_warmup = no_warmup; config.edge_selection = edge_selection; - config.edge_selection_start = edge_selection_start; // -1 means use default (no_warmup/2) + config.edge_selection_start = edge_selection_start; config.seed = seed; config.target_acceptance = target_acceptance; config.max_tree_depth = max_tree_depth; config.num_leapfrogs = num_leapfrogs; + config.na_impute = na_impute; // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); // Run MCMC using unified infrastructure std::vector results = run_mcmc_sampler( - model, config, no_chains, no_threads, pm); + model, *edge_prior_obj, config, no_chains, no_threads, pm); // Convert to R list format Rcpp::List output = convert_results_to_list(results); From 82ba4fb1ae84d93f0fec3707253a2af36f571b64 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 9 Feb 2026 14:49:24 +0100 Subject: [PATCH 16/23] ggm now also available with support for edge priors --- R/RcppExports.R | 4 +- R/bgm.R | 114 +++++++++++++++++++++++- R/function_input_utils.R | 90 +++++++++++-------- R/output_utils.R | 187 +++++++++++++++++++++++++++++++++++++++ src/RcppExports.cpp | 15 +++- src/chainResultNew.h | 23 +++++ src/ggm_model.cpp | 33 ++++++- src/ggm_model.h | 11 +++ src/mcmc/mcmc_runner.h | 15 ++++ src/priors/edge_prior.h | 9 ++ src/sample_ggm.cpp | 23 +++-- 11 files changed, 472 insertions(+), 52 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index fef77cd4..44e0f165 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -33,8 +33,8 @@ sample_omrf_classed <- function(inputFromR, prior_inclusion_prob, initial_edge_i .Call(`_bgms_sample_omrf_classed`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed) } -sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) { - .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type) +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda) } sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, edge_selection_start = -1L) { diff --git a/R/bgm.R b/R/bgm.R index b7f5a9c1..c89f3fb9 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -509,6 +509,23 @@ bgm = function( edge_selection = model$edge_selection edge_prior = model$edge_prior inclusion_probability = model$inclusion_probability + is_continuous = model$is_continuous + + # Block NUTS/HMC for the Gaussian model -------------------------------------- + if(is_continuous) { + user_chose_method = length(update_method_input) == 1 + if(user_chose_method && update_method %in% c("nuts", "hamiltonian-mc")) { + stop(paste0( + "The Gaussian model (variable_type = 'continuous') only supports ", + "update_method = 'adaptive-metropolis'. ", + "Got '", update_method, "'." + )) + } + update_method = "adaptive-metropolis" + if(!hasArg(target_accept)) { + target_accept = 0.44 + } + } # Check Gibbs input ----------------------------------------------------------- check_positive_integer(iter, "iter") @@ -558,6 +575,99 @@ bgm = function( # Check display_progress ------------------------------------------------------ progress_type = progress_type_from_display_progress(display_progress) +# Setting the seed + if(missing(seed) || is.null(seed)) { + seed = sample.int(.Machine$integer.max, 1) + } + + if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) { + stop("Argument 'seed' must be a single non-negative integer.") + } + + seed <- as.integer(seed) + + data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x) + + # ========================================================================== + # Gaussian (continuous) path + # ========================================================================== + if (is_continuous) { + num_variables = ncol(x) + + # Handle missing data for continuous variables + if (na_action == "listwise") { + missing_rows = apply(x, 1, anyNA) + if (all(missing_rows)) { + stop("All rows in x contain at least one missing response.\n", + "You could try option na_action = 'impute'.") + } + if (sum(missing_rows) > 0) { + warning(sum(missing_rows), " row(s) with missing observations removed (na_action = 'listwise').", + call. = FALSE) + } + x = x[!missing_rows, , drop = FALSE] + na_impute = FALSE + } else { + stop("Imputation is not yet supported for the Gaussian model. ", + "Use na_action = 'listwise'.") + } + + indicator = matrix(1L, nrow = num_variables, ncol = num_variables) + + out_raw = sample_ggm( + inputFromR = list(X = x), + prior_inclusion_prob = matrix(inclusion_probability, + nrow = num_variables, ncol = num_variables), + initial_edge_indicators = indicator, + no_iter = iter, + no_warmup = warmup, + no_chains = chains, + edge_selection = edge_selection, + seed = seed, + no_threads = cores, + progress_type = progress_type, + edge_prior = edge_prior, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda + ) + + out = transform_ggm_backend_output(out_raw, num_variables) + + userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt")) + if (userInterrupt) { + warning("Stopped sampling after user interrupt, results are likely uninterpretable.") + } + + output = prepare_output_ggm( + out = out, x = x, iter = iter, + data_columnnames = data_columnnames, + warmup = warmup, + edge_selection = edge_selection, edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda, + na_action = na_action, na_impute = na_impute, + variable_type = variable_type, + update_method = update_method, + target_accept = target_accept, + num_chains = chains + ) + + return(output) + } + + # ========================================================================== + # Ordinal / Blume-Capel path + # ========================================================================== + # Format the data input ------------------------------------------------------- data = reformat_data( x = x, @@ -773,7 +883,7 @@ bgm = function( output <- tryCatch( prepare_output_bgm( out = out, x = x, num_categories = num_categories, iter = iter, - data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = data_columnnames, is_ordinal_variable = variable_bool, warmup = warmup, pairwise_scale = pairwise_scale, standardize = standardize, pairwise_scaling_factors = pairwise_scaling_factors, @@ -806,7 +916,7 @@ bgm = function( # Main output handler in the wrapper function output = prepare_output_bgm( out = out, x = x, num_categories = num_categories, iter = iter, - data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x), + data_columnnames = data_columnnames, is_ordinal_variable = variable_bool, warmup = warmup, pairwise_scale = pairwise_scale, standardize = standardize, pairwise_scaling_factors = pairwise_scaling_factors, diff --git a/R/function_input_utils.R b/R/function_input_utils.R index 4dde1a1c..4c6aa3a8 100644 --- a/R/function_input_utils.R +++ b/R/function_input_utils.R @@ -36,24 +36,30 @@ check_model = function(x, dirichlet_alpha = dirichlet_alpha, lambda = lambda) { # Check variable type input --------------------------------------------------- + is_continuous = FALSE if(length(variable_type) == 1) { variable_input = variable_type variable_type = try( match.arg( arg = variable_type, - choices = c("ordinal", "blume-capel") + choices = c("ordinal", "blume-capel", "continuous") ), silent = TRUE ) if(inherits(variable_type, what = "try-error")) { stop(paste0( - "The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", + "The bgm function supports variables of type ordinal, blume-capel, ", + "and continuous, but not of type ", variable_input, "." )) } - variable_bool = (variable_type == "ordinal") - variable_bool = rep(variable_bool, ncol(x)) + if(variable_type == "continuous") { + is_continuous = TRUE + variable_bool = rep(TRUE, ncol(x)) + } else { + variable_bool = (variable_type == "ordinal") + variable_bool = rep(variable_bool, ncol(x)) + } } else { if(length(variable_type) != ncol(x)) { stop(paste0( @@ -62,41 +68,54 @@ check_model = function(x, )) } - variable_input = unique(variable_type) - variable_type = try(match.arg( - arg = variable_type, - choices = c("ordinal", "blume-capel"), - several.ok = TRUE - ), silent = TRUE) - - if(inherits(variable_type, what = "try-error")) { + has_continuous = any(variable_type == "continuous") + if(has_continuous && !all(variable_type == "continuous")) { stop(paste0( - "The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input, collapse = ", "), "." + "When using continuous variables, all variables must be of type ", + "'continuous'. Mixtures of continuous and ordinal/blume-capel ", + "variables are not supported." )) } + if(has_continuous) { + is_continuous = TRUE + variable_bool = rep(TRUE, ncol(x)) + } else { + variable_input = unique(variable_type) + variable_type = try(match.arg( + arg = variable_type, + choices = c("ordinal", "blume-capel"), + several.ok = TRUE + ), silent = TRUE) - num_types = sapply(variable_input, function(type) { - tmp = try( - match.arg( - arg = type, - choices = c("ordinal", "blume-capel") - ), - silent = TRUE - ) - inherits(tmp, what = "try-error") - }) + if(inherits(variable_type, what = "try-error")) { + stop(paste0( + "The bgm function supports variables of type ordinal, blume-capel, ", + "and continuous, but not of type ", + paste0(variable_input, collapse = ", "), "." + )) + } - if(length(variable_type) != ncol(x)) { - stop(paste0( - "The bgm function supports variables of type ordinal and blume-capel, \n", - "but not of type ", - paste0(variable_input[num_types], collapse = ", "), "." - )) - } + num_types = sapply(variable_input, function(type) { + tmp = try( + match.arg( + arg = type, + choices = c("ordinal", "blume-capel") + ), + silent = TRUE + ) + inherits(tmp, what = "try-error") + }) - variable_bool = (variable_type == "ordinal") + if(length(variable_type) != ncol(x)) { + stop(paste0( + "The bgm function supports variables of type ordinal, blume-capel, ", + "and continuous, but not of type ", + paste0(variable_input[num_types], collapse = ", "), "." + )) + } + + variable_bool = (variable_type == "ordinal") + } } # Check Blume-Capel variable input -------------------------------------------- @@ -316,7 +335,8 @@ check_model = function(x, baseline_category = baseline_category, edge_selection = edge_selection, edge_prior = edge_prior, - inclusion_probability = theta + inclusion_probability = theta, + is_continuous = is_continuous )) } diff --git a/R/output_utils.R b/R/output_utils.R index bb21ce5e..43e32c57 100644 --- a/R/output_utils.R +++ b/R/output_utils.R @@ -178,6 +178,193 @@ prepare_output_bgm = function( } +prepare_output_ggm = function( + out, x, iter, data_columnnames, + warmup, edge_selection, edge_prior, inclusion_probability, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda, + na_action, na_impute, + variable_type, update_method, target_accept, + num_chains +) { + num_variables = ncol(x) + + arguments = list( + num_variables = num_variables, + num_cases = nrow(x), + na_impute = na_impute, + variable_type = variable_type, + iter = iter, + warmup = warmup, + edge_selection = edge_selection, + edge_prior = edge_prior, + inclusion_probability = inclusion_probability, + beta_bernoulli_alpha = beta_bernoulli_alpha, + beta_bernoulli_beta = beta_bernoulli_beta, + beta_bernoulli_alpha_between = beta_bernoulli_alpha_between, + beta_bernoulli_beta_between = beta_bernoulli_beta_between, + dirichlet_alpha = dirichlet_alpha, + lambda = lambda, + na_action = na_action, + version = packageVersion("bgms"), + update_method = update_method, + target_accept = target_accept, + num_chains = num_chains, + data_columnnames = data_columnnames, + no_variables = num_variables, + is_continuous = TRUE + ) + + results = list() + + # Parameter names: diagonal = "Var (precision)", off-diagonal = "Var1-Var2" + diag_names = paste0(data_columnnames, " (precision)") + + edge_names = character() + for (i in 1:(num_variables - 1)) { + for (j in (i + 1):num_variables) { + edge_names = c(edge_names, paste0(data_columnnames[i], "-", data_columnnames[j])) + } + } + + # Summarize MCMC chains + summary_list = summarize_fit(out, edge_selection = edge_selection) + main_summary = summary_list$main[, -1] + pairwise_summary = summary_list$pairwise[, -1] + + rownames(main_summary) = diag_names + rownames(pairwise_summary) = edge_names + + results$posterior_summary_main = main_summary + results$posterior_summary_pairwise = pairwise_summary + + if (edge_selection) { + indicator_summary = summarize_indicator(out, param_names = edge_names)[, -1] + rownames(indicator_summary) = edge_names + results$posterior_summary_indicator = indicator_summary + + has_sbm = identical(edge_prior, "Stochastic-Block") && + "allocations" %in% names(out[[1]]) + + if (has_sbm) { + sbm_convergence = summarize_alloc_pairs( + allocations = lapply(out, `[[`, "allocations"), + node_names = data_columnnames + ) + results$posterior_summary_pairwise_allocations = sbm_convergence$sbm_summary + } + } + + # Posterior mean matrices + results$posterior_mean_main = matrix(main_summary$mean, + nrow = num_variables, ncol = 1, + dimnames = list(data_columnnames, "precision_diag")) + + results$posterior_mean_pairwise = matrix(0, + nrow = num_variables, ncol = num_variables, + dimnames = list(data_columnnames, data_columnnames)) + results$posterior_mean_pairwise[lower.tri(results$posterior_mean_pairwise)] = pairwise_summary$mean + results$posterior_mean_pairwise = results$posterior_mean_pairwise + t(results$posterior_mean_pairwise) + + if (edge_selection) { + indicator_means = indicator_summary$mean + results$posterior_mean_indicator = matrix(0, + nrow = num_variables, ncol = num_variables, + dimnames = list(data_columnnames, data_columnnames)) + results$posterior_mean_indicator[upper.tri(results$posterior_mean_indicator)] = indicator_means + results$posterior_mean_indicator[lower.tri(results$posterior_mean_indicator)] = + t(results$posterior_mean_indicator)[lower.tri(results$posterior_mean_indicator)] + + if (has_sbm) { + sbm_convergence = summarize_alloc_pairs( + allocations = lapply(out, `[[`, "allocations"), + node_names = data_columnnames + ) + results$posterior_mean_coclustering_matrix = sbm_convergence$co_occur_matrix + + sbm_summary = posterior_summary_SBM( + allocations = lapply(out, `[[`, "allocations"), + arguments = arguments + ) + results$posterior_mean_allocations = sbm_summary$allocations_mean + results$posterior_mode_allocations = sbm_summary$allocations_mode + results$posterior_num_blocks = sbm_summary$blocks + } + } + + results$arguments = arguments + class(results) = "bgms" + + results$raw_samples = list( + main = lapply(out, function(chain) chain$main_samples), + pairwise = lapply(out, function(chain) chain$pairwise_samples), + indicator = if (edge_selection) lapply(out, function(chain) chain$indicator_samples) else NULL, + allocations = if (edge_selection && identical(edge_prior, "Stochastic-Block") && "allocations" %in% names(out[[1]])) lapply(out, `[[`, "allocations") else NULL, + nchains = length(out), + niter = nrow(out[[1]]$main_samples), + parameter_names = list( + main = diag_names, + pairwise = edge_names, + indicator = if (edge_selection) edge_names else NULL, + allocations = if (identical(edge_prior, "Stochastic-Block")) data_columnnames else NULL + ) + ) + + return(results) +} + + +# Transform sample_ggm output to match the old backend format. +# +# The GGM backend returns a `samples` matrix (params x iters) where params +# are the upper triangle of the precision matrix stored column-by-column: +# (0,0), (0,1), (1,1), (0,2), (1,2), (2,2), ... +# We split these into diagonal elements ("main") and off-diagonal ("pairwise"). +transform_ggm_backend_output = function(out, p) { + # Build index maps for upper triangle (column-major) + diag_idx = integer(p) + offdiag_idx = integer(p * (p - 1) / 2) + pos = 0L + d = 0L + od = 0L + for (j in seq_len(p)) { + for (i in seq_len(j)) { + pos = pos + 1L + if (i == j) { + d = d + 1L + diag_idx[d] = pos + } else { + od = od + 1L + offdiag_idx[od] = pos + } + } + } + + lapply(out, function(chain) { + samples_t = t(chain$samples) # (params x iters) -> (iters x params) + + res = list( + main_samples = samples_t[, diag_idx, drop = FALSE], + pairwise_samples = samples_t[, offdiag_idx, drop = FALSE], + userInterrupt = isTRUE(chain$userInterrupt), + chain_id = chain$chain_id + ) + + if (!is.null(chain$indicator_samples)) { + indic_t = t(chain$indicator_samples) # (params x iters) -> (iters x params) + res$indicator_samples = indic_t[, offdiag_idx, drop = FALSE] + } + + if (!is.null(chain$allocation_samples)) { + res$allocations = t(chain$allocation_samples) # (variables x iters) -> (iters x variables) + } + + res + }) +} + + # Transform sample_omrf output to match the old backend format. # # The new backend returns a flat `samples` matrix (params x iters) containing diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 49b600ba..5273dd1c 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -212,8 +212,8 @@ BEGIN_RCPP END_RCPP } // sample_ggm -Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type); -RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP) { +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -227,7 +227,14 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); - rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type)); + Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha_between(beta_bernoulli_alpha_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta_between(beta_bernoulli_beta_betweenSEXP); + Rcpp::traits::input_parameter< const double >::type dirichlet_alpha(dirichlet_alphaSEXP); + Rcpp::traits::input_parameter< const double >::type lambda(lambdaSEXP); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda)); return rcpp_result_gen; END_RCPP } @@ -289,7 +296,7 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, {"_bgms_sample_omrf_classed", (DL_FUNC) &_bgms_sample_omrf_classed, 8}, - {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 10}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 17}, {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} diff --git a/src/chainResultNew.h b/src/chainResultNew.h index 3ef3e4df..295e16c9 100644 --- a/src/chainResultNew.h +++ b/src/chainResultNew.h @@ -29,6 +29,10 @@ class ChainResultNew { arma::imat indicator_samples; bool has_indicators = false; + // SBM allocation samples (n_variables × n_iter), only if SBM edge prior + arma::imat allocation_samples; + bool has_allocations = false; + // NUTS/HMC diagnostics (n_iter), only if using NUTS/HMC arma::ivec treedepth_samples; arma::ivec divergent_samples; @@ -54,6 +58,16 @@ class ChainResultNew { has_indicators = true; } + /** + * Reserve storage for SBM allocation samples + * @param n_variables Number of variables + * @param n_iter Number of sampling iterations + */ + void reserve_allocations(const size_t n_variables, const size_t n_iter) { + allocation_samples.set_size(n_variables, n_iter); + has_allocations = true; + } + /** * Reserve storage for NUTS diagnostics * @param n_iter Number of sampling iterations @@ -83,6 +97,15 @@ class ChainResultNew { indicator_samples.col(iter) = indicators; } + /** + * Store SBM allocation sample + * @param iter Iteration index (0-based) + * @param allocations Allocation vector (1-based cluster labels) + */ + void store_allocations(const size_t iter, const arma::ivec& allocations) { + allocation_samples.col(iter) = allocations; + } + /** * Store NUTS diagnostics for one iteration * @param iter Iteration index (0-based) diff --git a/src/ggm_model.cpp b/src/ggm_model.cpp index a263924d..ef1dc0ae 100644 --- a/src/ggm_model.cpp +++ b/src/ggm_model.cpp @@ -107,7 +107,7 @@ void GaussianVariables::update_edge_parameter(size_t i, size_t j) { double Phi_q1q = constants_[1]; double Phi_q1q1 = constants_[2]; - size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form + size_t e = j * (j + 1) / 2 + i; // parameter index in vectorized form (column-major upper triangle) double proposal_sd = proposal_.get_proposal_sd(e); double phi_prop = rnorm(rng_, Phi_q1q, proposal_sd); @@ -184,7 +184,7 @@ void GaussianVariables::update_diagonal_parameter(size_t i) { double logdet_omega = get_log_det(cholesky_of_precision_); double logdet_omega_sub_ii = logdet_omega + std::log(covariance_matrix_(i, i)); - size_t e = i * (i + 1) / 2 + i; // parameter index in vectorized form + size_t e = i * (i + 3) / 2; // parameter index in vectorized form (column-major upper triangle, i==j) double proposal_sd = proposal_.get_proposal_sd(e); double theta_curr = (logdet_omega - logdet_omega_sub_ii) / 2; @@ -238,7 +238,7 @@ void GaussianVariables::cholesky_update_after_diag(double omega_ii_old, size_t i void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) { - size_t e = i * (i + 1) / 2 + j; // parameter index in vectorized form + size_t e = j * (j + 1) / 2 + i; // parameter index in vectorized form (column-major upper triangle) double proposal_sd = proposal_.get_proposal_sd(e); if (edge_indicators_(i, j) == 1) { @@ -359,7 +359,7 @@ void GaussianVariables::do_one_mh_step() { update_diagonal_parameter(i); } - if (edge_selection_) { + if (edge_selection_active_) { for (size_t i = 0; i < p_ - 1; ++i) { for (size_t j = i + 1; j < p_; ++j) { update_edge_indicator_parameter_pair(i, j); @@ -371,6 +371,31 @@ void GaussianVariables::do_one_mh_step() { proposal_.increment_iteration(); } +void GaussianVariables::initialize_graph() { + for (size_t i = 0; i < p_ - 1; ++i) { + for (size_t j = i + 1; j < p_; ++j) { + double p = inclusion_probability_(i, j); + int draw = (runif(rng_) < p) ? 1 : 0; + edge_indicators_(i, j) = draw; + edge_indicators_(j, i) = draw; + if (!draw) { + precision_proposal_ = precision_matrix_; + precision_proposal_(i, j) = 0.0; + precision_proposal_(j, i) = 0.0; + get_constants(i, j); + precision_proposal_(j, j) = R(0.0); + + double omega_ij_old = precision_matrix_(i, j); + double omega_jj_old = precision_matrix_(j, j); + precision_matrix_(j, j) = precision_proposal_(j, j); + precision_matrix_(i, j) = 0.0; + precision_matrix_(j, i) = 0.0; + cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); + } + } + } +} + GaussianVariables createGaussianVariablesFromR( const Rcpp::List& inputFromR, diff --git a/src/ggm_model.h b/src/ggm_model.h index 5e4cf1b3..84d91849 100644 --- a/src/ggm_model.h +++ b/src/ggm_model.h @@ -97,6 +97,16 @@ class GaussianVariables : public BaseModel { bool has_adaptive_mh() const override { return true; } bool has_edge_selection() const override { return edge_selection_; } + void set_edge_selection_active(bool active) override { + edge_selection_active_ = active; + } + + void initialize_graph() override; + + // GGM handles edge indicator updates inside do_one_mh_step(), so the + // external call from the MCMC runner is a no-op. + void update_edge_indicators() override {} + double logp(const arma::vec& parameters) override { // Implement log probability computation return 0.0; @@ -189,6 +199,7 @@ class GaussianVariables : public BaseModel { arma::mat suf_stat_; arma::mat inclusion_probability_; bool edge_selection_; + bool edge_selection_active_ = false; // parameters arma::mat precision_matrix_, cholesky_of_precision_, inv_cholesky_of_precision_, covariance_matrix_; diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h index c3bf3472..b4068e9e 100644 --- a/src/mcmc/mcmc_runner.h +++ b/src/mcmc/mcmc_runner.h @@ -159,6 +159,11 @@ inline void run_mcmc_chain( chain_result.store_indicators(iter, model.get_vectorized_indicator_parameters()); } + // Store SBM allocations if applicable + if (chain_result.has_allocations && edge_prior.has_allocations()) { + chain_result.store_allocations(iter, edge_prior.get_allocations()); + } + // Progress and interrupt check pm.update(chain_id); if (pm.shouldExit()) { @@ -238,6 +243,8 @@ inline std::vector run_mcmc_sampler( ProgressManager& pm ) { const bool has_nuts_diag = (config.sampler_type == "nuts"); + const bool has_sbm_alloc = edge_prior.has_allocations() || + (config.edge_selection && dynamic_cast(&edge_prior) != nullptr); // Allocate result storage std::vector results(no_chains); @@ -249,6 +256,10 @@ inline std::vector run_mcmc_sampler( results[c].reserve_indicators(n_edges, config.no_iter); } + if (has_sbm_alloc) { + results[c].reserve_allocations(model.get_num_variables(), config.no_iter); + } + if (has_nuts_diag) { results[c].reserve_nuts_diagnostics(config.no_iter); } @@ -320,6 +331,10 @@ inline Rcpp::List convert_results_to_list(const std::vector& res chain_list["indicator_samples"] = chain.indicator_samples; } + if (chain.has_allocations) { + chain_list["allocation_samples"] = chain.allocation_samples; + } + if (chain.has_nuts_diagnostics) { chain_list["treedepth"] = chain.treedepth_samples; chain_list["divergent"] = chain.divergent_samples; diff --git a/src/priors/edge_prior.h b/src/priors/edge_prior.h index 02f5c716..c17c838c 100644 --- a/src/priors/edge_prior.h +++ b/src/priors/edge_prior.h @@ -32,6 +32,9 @@ class BaseEdgePrior { ) = 0; virtual std::unique_ptr clone() const = 0; + + virtual bool has_allocations() const { return false; } + virtual arma::ivec get_allocations() const { return arma::ivec(); } }; @@ -205,6 +208,12 @@ class StochasticBlockEdgePrior : public BaseEdgePrior { return std::make_unique(*this); } + bool has_allocations() const override { return initialized_; } + + arma::ivec get_allocations() const override { + return arma::conv_to::from(cluster_allocations_) + 1; // 1-based + } + private: double beta_bernoulli_alpha_; double beta_bernoulli_beta_; diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 253ea35a..aab3fef0 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -6,6 +6,7 @@ #include "ggm_model.h" #include "utils/progress_manager.h" +#include "utils/common_helpers.h" #include "priors/edge_prior.h" #include "chainResultNew.h" #include "mcmc/mcmc_runner.h" @@ -22,7 +23,14 @@ Rcpp::List sample_ggm( const bool edge_selection, const int seed, const int no_threads, - const int progress_type + const int progress_type, + const std::string& edge_prior = "Bernoulli", + const double beta_bernoulli_alpha = 1.0, + const double beta_bernoulli_beta = 1.0, + const double beta_bernoulli_alpha_between = 1.0, + const double beta_bernoulli_beta_between = 1.0, + const double dirichlet_alpha = 1.0, + const double lambda = 1.0 ) { // Create model from R input @@ -36,17 +44,22 @@ Rcpp::List sample_ggm( config.no_warmup = no_warmup; config.edge_selection = edge_selection; config.seed = seed; - // Edge selection starts at no_warmup/2 by default (handled by get_edge_selection_start()) // Set up progress manager ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); - // Create default edge prior (Bernoulli = no-op) - BernoulliEdgePrior edge_prior; + // Create edge prior + EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); + auto edge_prior_obj = create_edge_prior( + edge_prior_enum, + beta_bernoulli_alpha, beta_bernoulli_beta, + beta_bernoulli_alpha_between, beta_bernoulli_beta_between, + dirichlet_alpha, lambda + ); // Run MCMC using unified infrastructure std::vector results = run_mcmc_sampler( - model, edge_prior, config, no_chains, no_threads, pm); + model, *edge_prior_obj, config, no_chains, no_threads, pm); // Convert to R list format Rcpp::List output = convert_results_to_list(results); From 88e6545c76a02dcd070f7b1d4a398d4cf2bf02d0 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Mon, 9 Feb 2026 15:40:17 +0100 Subject: [PATCH 17/23] cleanup --- R/RcppExports.R | 4 -- src/RcppExports.cpp | 19 ------ src/SkeletonVariables.h | 80 ---------------------- src/base_model.cpp | 3 - src/omrf_model.cpp | 147 ---------------------------------------- 5 files changed, 253 deletions(-) delete mode 100644 src/SkeletonVariables.h delete mode 100644 src/base_model.cpp diff --git a/R/RcppExports.R b/R/RcppExports.R index 44e0f165..528716e0 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -29,10 +29,6 @@ run_simulation_parallel <- function(pairwise_samples, main_samples, draw_indices .Call(`_bgms_run_simulation_parallel`, pairwise_samples, main_samples, draw_indices, no_states, no_variables, no_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } -sample_omrf_classed <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed) { - .Call(`_bgms_sample_omrf_classed`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed) -} - sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0) { .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 5273dd1c..b395f4be 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -193,24 +193,6 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// sample_omrf_classed -Rcpp::List sample_omrf_classed(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const bool edge_selection, const std::string& sampler_type, const int seed); -RcppExport SEXP _bgms_sample_omrf_classed(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< const Rcpp::List& >::type inputFromR(inputFromRSEXP); - Rcpp::traits::input_parameter< const arma::mat& >::type prior_inclusion_prob(prior_inclusion_probSEXP); - Rcpp::traits::input_parameter< const arma::imat& >::type initial_edge_indicators(initial_edge_indicatorsSEXP); - Rcpp::traits::input_parameter< const int >::type no_iter(no_iterSEXP); - Rcpp::traits::input_parameter< const int >::type no_warmup(no_warmupSEXP); - Rcpp::traits::input_parameter< const bool >::type edge_selection(edge_selectionSEXP); - Rcpp::traits::input_parameter< const std::string& >::type sampler_type(sampler_typeSEXP); - Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); - rcpp_result_gen = Rcpp::wrap(sample_omrf_classed(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, edge_selection, sampler_type, seed)); - return rcpp_result_gen; -END_RCPP -} // sample_ggm Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda); RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP) { @@ -295,7 +277,6 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 7}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, {"_bgms_run_simulation_parallel", (DL_FUNC) &_bgms_run_simulation_parallel, 12}, - {"_bgms_sample_omrf_classed", (DL_FUNC) &_bgms_sample_omrf_classed, 8}, {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 17}, {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, diff --git a/src/SkeletonVariables.h b/src/SkeletonVariables.h deleted file mode 100644 index 0b5d1b43..00000000 --- a/src/SkeletonVariables.h +++ /dev/null @@ -1,80 +0,0 @@ -// #pragma once - -// #include -// #include "base_model.h" -// #include "adaptiveMetropolis.h" -// #include "rng/rng_utils.h" - - -// class SkeletonVariables : public BaseModel { -// public: - -// // constructor from raw data -// SkeletonVariables( -// const arma::mat& observations, -// const arma::mat& inclusion_probability, -// const arma::imat& initial_edge_indicators, -// const bool edge_selection = true -// ) : -// {} - -// // copy constructor -// SkeletonVariables(const SkeletonVariables& other) -// : BaseModel(other), -// {} - -// std::unique_ptr clone() const override { -// return std::make_unique(*this); // uses copy constructor -// } - -// bool has_gradient() const { return false; } -// bool has_adaptive_mh() const override { return true; } - -// double logp(const arma::vec& parameters) override { -// // Implement log probability computation -// return 0.0; -// } - -// void do_one_mh_step() override; - -// size_t parameter_dimension() const override { -// return dim_; -// } - -// void set_seed(int seed) override { -// rng_ = SafeRNG(seed); -// } - -// // arma::vec get_vectorized_parameters() override { -// // // upper triangle of precision_matrix_ -// // size_t e = 0; -// // for (size_t j = 0; j < p_; ++j) { -// // for (size_t i = 0; i <= j; ++i) { -// // vectorized_parameters_(e) = precision_matrix_(i, j); -// // ++e; -// // } -// // } -// // return vectorized_parameters_; -// // } - -// // arma::ivec get_vectorized_indicator_parameters() override { -// // // upper triangle of precision_matrix_ -// // size_t e = 0; -// // for (size_t j = 0; j < p_; ++j) { -// // for (size_t i = 0; i <= j; ++i) { -// // vectorized_indicator_parameters_(e) = edge_indicators_(i, j); -// // ++e; -// // } -// // } -// // return vectorized_indicator_parameters_; -// // } - - -// private: -// // data -// size_t n_ = 0; -// size_t p_ = 0; -// size_t dim_ = 0; - - -// }; \ No newline at end of file diff --git a/src/base_model.cpp b/src/base_model.cpp deleted file mode 100644 index 7ffcdce5..00000000 --- a/src/base_model.cpp +++ /dev/null @@ -1,3 +0,0 @@ -#include "base_model.h" - -// BaseModel is a header-only abstract class with no cpp implementations needed diff --git a/src/omrf_model.cpp b/src/omrf_model.cpp index 10d95662..832e15dd 100644 --- a/src/omrf_model.cpp +++ b/src/omrf_model.cpp @@ -1354,150 +1354,3 @@ OMRFModel createOMRFModelFromR( } -// ============================================================================= -// R interface: sample_omrf_classed -// ============================================================================= - -// [[Rcpp::export]] -Rcpp::List sample_omrf_classed( - const Rcpp::List& inputFromR, - const arma::mat& prior_inclusion_prob, - const arma::imat& initial_edge_indicators, - const int no_iter, - const int no_warmup, - const bool edge_selection, - const std::string& sampler_type, - const int seed -) { - // Create model from R input - OMRFModel model = createOMRFModelFromR( - inputFromR, - prior_inclusion_prob, - initial_edge_indicators, - edge_selection - ); - - // Set random seed - model.set_seed(seed); - - // Storage for samples - use FIXED size (all parameters) - int full_dim = model.full_parameter_dimension(); - arma::mat samples(no_iter, full_dim); - arma::imat indicator_samples; - if (edge_selection) { - int num_edges = (model.get_p() * (model.get_p() - 1)) / 2; - indicator_samples.set_size(no_iter, num_edges); - } - - // NUTS/HMC diagnostics - arma::ivec treedepth_samples; - arma::ivec divergent_samples; - arma::vec energy_samples; - bool use_nuts = (sampler_type == "nuts"); - bool use_hmc = (sampler_type == "hmc"); - if (use_nuts || use_hmc) { - treedepth_samples.set_size(no_iter); - divergent_samples.set_size(no_iter); - energy_samples.set_size(no_iter); - } - - // Create sampler configuration - SamplerConfig config; - config.sampler_type = sampler_type; - config.initial_step_size = 0.1; - config.target_acceptance = 0.8; - config.max_tree_depth = 10; - config.num_leapfrogs = 10; - config.no_warmup = no_warmup; - - // Create appropriate sampler - std::unique_ptr sampler = create_sampler(config); - - // Warmup phase - Rcpp::Rcout << "Running warmup (" << no_warmup << " iterations)..." << std::endl; - for (int iter = 0; iter < no_warmup; ++iter) { - SamplerResult result = sampler->warmup_step(model); - - // Edge selection only after initial warmup period - if (edge_selection && iter > no_warmup / 2) { - model.update_edge_indicators(); - } - - // Check for user interrupt - if ((iter + 1) % 100 == 0) { - Rcpp::checkUserInterrupt(); - Rcpp::Rcout << " Warmup iteration " << (iter + 1) << "/" << no_warmup - << " (step_size=" << model.get_step_size() << ")" << std::endl; - } - } - - // Use averaged step size for sampling - sampler->finalize_warmup(); - Rcpp::Rcout << "Warmup complete." << std::endl; - - // Sampling phase - Rcpp::Rcout << "Running sampling (" << no_iter << " iterations)..." << std::endl; - for (int iter = 0; iter < no_iter; ++iter) { - SamplerResult result = sampler->sample_step(model); - - // Extract NUTS/HMC diagnostics if available - if (sampler->has_nuts_diagnostics()) { - if (auto nuts_diag = std::dynamic_pointer_cast(result.diagnostics)) { - treedepth_samples(iter) = nuts_diag->tree_depth; - divergent_samples(iter) = nuts_diag->divergent ? 1 : 0; - energy_samples(iter) = nuts_diag->energy; - } - } - - if (edge_selection) { - model.update_edge_indicators(); - } - - // Store samples - use FULL vectorization (fixed size) - samples.row(iter) = model.get_full_vectorized_parameters().t(); - - if (edge_selection) { - arma::imat indicators = model.get_edge_indicators(); - int idx = 0; - for (int i = 0; i < static_cast(model.get_p()) - 1; ++i) { - for (int j = i + 1; j < static_cast(model.get_p()); ++j) { - indicator_samples(iter, idx++) = indicators(i, j); - } - } - } - - // Check for user interrupt - if ((iter + 1) % 100 == 0) { - Rcpp::checkUserInterrupt(); - Rcpp::Rcout << " Sampling iteration " << (iter + 1) << "/" << no_iter << std::endl; - } - } - - // Build output list - Rcpp::List output; - output["samples"] = samples; - - if (edge_selection) { - output["indicator_samples"] = indicator_samples; - // Compute posterior mean of edge indicators - arma::vec posterior_mean_indicator = arma::mean(arma::conv_to::from(indicator_samples), 0).t(); - output["posterior_mean_indicator"] = posterior_mean_indicator; - } - - if (use_nuts || use_hmc) { - output["treedepth"] = treedepth_samples; - output["divergent"] = divergent_samples; - output["energy"] = energy_samples; - // Get final step size from sampler (NUTSSampler and HMCSampler have get_step_size()) - output["final_step_size"] = 0.0; // Could add getter to sampler if needed - } - - output["sampler_type"] = sampler_type; - output["no_iter"] = no_iter; - output["no_warmup"] = no_warmup; - output["edge_selection"] = edge_selection; - output["num_variables"] = model.get_p(); - output["num_observations"] = model.get_n(); - - return output; -} \ No newline at end of file From 5aca4ea00cc533024721b7cf7a70ed4832696dd8 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 10 Feb 2026 08:46:28 +0100 Subject: [PATCH 18/23] push test for ggms --- tests/testthat/test-ggm.R | 229 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/testthat/test-ggm.R diff --git a/tests/testthat/test-ggm.R b/tests/testthat/test-ggm.R new file mode 100644 index 00000000..a4973429 --- /dev/null +++ b/tests/testthat/test-ggm.R @@ -0,0 +1,229 @@ +test_that("GGM runs and returns bgms object", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + colnames(x) <- paste0("V", 1:4) + + fit <- bgm( + x = x, + variable_type = "continuous", + iter = 200, + warmup = 1000, + chains = 1, + display_progress = "none", + seed = 123 + ) + + expect_s3_class(fit, "bgms") + expect_true(fit$arguments$is_continuous) + expect_equal(fit$arguments$num_variables, 4) + expect_equal(fit$arguments$iter, 200) + expect_equal(fit$raw_samples$niter, 200) + expect_equal(fit$raw_samples$nchains, 1) +}) + +test_that("GGM output has correct dimensions", { + testthat::skip_on_cran() + + p <- 5 + set.seed(42) + x <- matrix(rnorm(100 * p), nrow = 100, ncol = p) + colnames(x) <- paste0("V", 1:p) + + fit <- bgm( + x = x, + variable_type = "continuous", + edge_selection = TRUE, + iter = 100, + warmup = 1000, + chains = 1, + display_progress = "none", + seed = 42 + ) + + # main: p diagonal precision elements + expect_equal(nrow(fit$posterior_summary_main), p) + expect_equal(nrow(fit$posterior_mean_main), p) + + # pairwise: p*(p-1)/2 off-diagonal elements + n_edges <- p * (p - 1) / 2 + expect_equal(nrow(fit$posterior_summary_pairwise), n_edges) + expect_equal(nrow(fit$posterior_mean_pairwise), p) + expect_equal(ncol(fit$posterior_mean_pairwise), p) + + # indicators + expect_equal(nrow(fit$posterior_summary_indicator), n_edges) + expect_equal(nrow(fit$posterior_mean_indicator), p) + expect_equal(ncol(fit$posterior_mean_indicator), p) + + # raw samples + expect_equal(ncol(fit$raw_samples$main[[1]]), p) + expect_equal(ncol(fit$raw_samples$pairwise[[1]]), n_edges) + expect_equal(nrow(fit$raw_samples$main[[1]]), 100) +}) + +test_that("GGM without edge selection works", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + colnames(x) <- paste0("V", 1:4) + + fit <- bgm( + x = x, + variable_type = "continuous", + edge_selection = FALSE, + iter = 100, + warmup = 1000, + chains = 1, + display_progress = "none", + seed = 99 + ) + + expect_s3_class(fit, "bgms") + expect_null(fit$posterior_summary_indicator) + expect_null(fit$posterior_mean_indicator) +}) + +test_that("GGM rejects NUTS and HMC", { + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + + expect_error( + bgm(x = x, variable_type = "continuous", update_method = "nuts"), + "only supports.*adaptive-metropolis" + ) + expect_error( + bgm(x = x, variable_type = "continuous", update_method = "hamiltonian-mc"), + "only supports.*adaptive-metropolis" + ) +}) + +test_that("Mixed continuous and ordinal is rejected", { + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + + expect_error( + bgm(x = x, variable_type = c("continuous", "ordinal", "ordinal", "ordinal")), + "all variables must be of type" + ) +}) + +test_that("GGM is reproducible", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(200), nrow = 50, ncol = 4) + colnames(x) <- paste0("V", 1:4) + + fit1 <- bgm( + x = x, variable_type = "continuous", + iter = 100, warmup = 1000, chains = 1, + display_progress = "none", seed = 321 + ) + fit2 <- bgm( + x = x, variable_type = "continuous", + iter = 100, warmup = 1000, chains = 1, + display_progress = "none", seed = 321 + ) + + expect_equal(fit1$raw_samples$main, fit2$raw_samples$main) + expect_equal(fit1$raw_samples$pairwise, fit2$raw_samples$pairwise) +}) + +test_that("GGM posterior precision diagonal means are positive", { + testthat::skip_on_cran() + + set.seed(42) + x <- matrix(rnorm(500), nrow = 100, ncol = 5) + colnames(x) <- paste0("V", 1:5) + + fit <- bgm( + x = x, variable_type = "continuous", + edge_selection = FALSE, + iter = 500, warmup = 1000, chains = 1, + display_progress = "none", seed = 42 + ) + + # Diagonal precision elements should be positive + expect_true(all(fit$posterior_summary_main$mean > 0)) +}) + + +test_that("GGM posterior compare against simulated data", { + testthat::skip_on_cran() + + n <- 1000 + p <- 10 + ne <- p * (p - 1) / 2 + # avoid a test dependency on BDgraph and a random graph structure by using a fixed precision matrix + # set.seed(42) + # adj <- matrix(0, p, p) + # adj[lower.tri(adj)] <- runif(ne) < .5 + # adj <- adj + t(adj) + # omega <- zapsmall(BDgraph::rgwish(1, adj)) + omega <- structure(c(6.240119, 0, 0, -0.370239, 0, 0, 0, 0, -1.622902, + 0, 0, 1.905013, 0, -0.194995, 0, 0, -2.468628, -0.557277, 0, + 0, 0, 0, 5.509142, -7.942389, 1.40081, 0, 0, -0.76775, 0, 0, + -0.370239, -0.194995, -7.942389, 15.521405, -3.537489, 0, 4.60785, + 0, 3.278511, 0, 0, 0, 1.40081, -3.537489, 2.78257, 0, 0, 1.374641, + 0, -1.198092, 0, 0, 0, 0, 0, 1.350879, 0, 0.230677, -1.357952, + 0, 0, -2.468628, 0, 4.60785, 0, 0, 15.88698, 0, 1.20017, -1.973919, + 0, -0.557277, -0.76775, 0, 1.374641, 0.230677, 0, 7.007312, 1.597035, + 0, -1.622902, 0, 0, 3.278511, 0, -1.357952, 1.20017, 1.597035, + 13.378039, -4.769958, 0, 0, 0, 0, -1.198092, 0, -1.973919, 0, + -4.769958, 5.536877), dim = c(10L, 10L)) + adj <- omega != 0 + diag(adj) <- 0 + covmat <- solve(omega) + chol <- chol(covmat) + + set.seed(43) + x <- matrix(rnorm(n * p), nrow = n, ncol = p) %*% chol + # cov(x) - covmat + + fit_no_vs <- bgm( + x = x, variable_type = "continuous", + edge_selection = FALSE, + iter = 5000, warmup = 1000, chains = 2, + display_progress = "none", seed = 42 + ) + + expect_true(cor(fit_no_vs$posterior_summary_main$mean, diag(omega)) > 0.9) + expect_true(cor(fit_no_vs$posterior_summary_pairwise$mean, omega[upper.tri(omega)]) > 0.9) + + fit_vs <- bgm( + x = x, variable_type = "continuous", + edge_selection = TRUE, + iter = 10000, warmup = 2000, chains = 2, + display_progress = "none", seed = 42 + ) + + expect_true(cor(fit_vs$posterior_summary_main$mean, diag(omega)) > 0.9) + expect_true(cor(fit_vs$posterior_summary_pairwise$mean, omega[upper.tri(omega)]) > 0.9) + expect_true(cor(fit_vs$posterior_summary_indicator$mean, adj[upper.tri(adj)]) > 0.85) + + # This test somehow fails? Can you fix this? + fit_vs_sbm <- bgm( + x = x, variable_type = "continuous", + edge_selection = TRUE, + edge_prior = "Stochastic-Block", + iter = 5000, warmup = 1000, chains = 2, + display_progress = "none", seed = 42 + ) + + expect_true(cor(fit_vs_sbm$posterior_summary_main$mean, diag(omega)) > 0.9) + expect_true(cor(fit_vs_sbm$posterior_summary_pairwise$mean, omega[upper.tri(omega)]) > 0.9) + expect_true(cor(fit_vs_sbm$posterior_summary_indicator$mean, adj[upper.tri(adj)]) > 0.85) + + # SBM-specific output + expect_false(is.null(fit_vs_sbm$posterior_mean_coclustering_matrix)) + expect_equal(nrow(fit_vs_sbm$posterior_mean_coclustering_matrix), p) + expect_equal(ncol(fit_vs_sbm$posterior_mean_coclustering_matrix), p) + expect_false(is.null(fit_vs_sbm$posterior_num_blocks)) + expect_false(is.null(fit_vs_sbm$posterior_mode_allocations)) + expect_false(is.null(fit_vs_sbm$raw_samples$allocations)) + + +}) From cab6251a43e9f6177e2f94d413bf417c3373286f Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Tue, 10 Feb 2026 09:49:17 +0100 Subject: [PATCH 19/23] revert changes to mcmc_utils.cpp --- src/mcmc/mcmc_utils.cpp | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 883b371f..1e37577c 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include "mcmc/mcmc_leapfrog.h" #include "mcmc/mcmc_utils.h" #include "rng/rng_utils.h" @@ -74,22 +73,10 @@ double heuristic_initial_step_size( double kin1 = kinetic_energy(r_new, inv_mass_diag); double H1 = logp1 - kin1; - // NaN guard: treat non-finite H as bad step (force halving) - auto safe_delta_H = [](double H1, double H0) -> double { - double delta = H1 - H0; - return std::isfinite(delta) ? delta : -std::numeric_limits::infinity(); - }; - - int direction = 2 * (safe_delta_H(H1, H0) > MY_LOG(0.5)) - 1; // +1 or -1 + int direction = 2 * (H1 - H0 > MY_LOG(0.5)) - 1; // +1 or -1 int attempts = 0; - while (attempts < max_attempts) { - double delta = safe_delta_H(H1, H0); - bool keep_going = (direction == 1) - ? (delta > -MY_LOG(2.0)) - : (delta < MY_LOG(2.0)); - if (!keep_going) break; - + while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; // Resample momentum on each iteration for step size search @@ -161,16 +148,10 @@ double heuristic_initial_step_size( int direction = 2 * (H1 - H0 > MY_LOG(0.5)) - 1; // +1 or -1 int attempts = 0; - while (attempts < max_attempts) { - double delta = safe_delta_H(H1, H0); - bool keep_going = (direction == 1) - ? (delta > -MY_LOG(2.0)) - : (delta < MY_LOG(2.0)); - if (!keep_going) break; - + while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum (STAN resamples on each iteration) + // Resample momentum on each iteration for step size search r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; @@ -178,6 +159,7 @@ double heuristic_initial_step_size( // One leapfrog step from original position with new momentum std::tie(theta_new, r_new) = leapfrog(theta, r, eps, grad, 1, inv_mass_diag); + // Evaluate Hamiltonian logp1 = log_post(theta_new); kin1 = kinetic_energy(r_new, inv_mass_diag); H1 = logp1 - kin1; @@ -186,4 +168,4 @@ double heuristic_initial_step_size( } return eps; -} +} \ No newline at end of file From 75098272059d88eb0b8ea84c8b612b53cf258307 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 11 Feb 2026 10:56:22 +0100 Subject: [PATCH 20/23] cleanup --- R/RcppExports.R | 4 - src/RcppExports.cpp | 15 - src/bgm/bgm_sampler.cpp | 6 +- src/bgmCompare/bgmCompare_sampler.cpp | 6 +- src/mcmc/adaptive_gradient_sampler.h | 156 +++++++++ src/mcmc/base_sampler.h | 6 +- src/{chainResultNew.h => mcmc/chain_result.h} | 2 +- src/mcmc/hmc_sampler.h | 167 +-------- src/mcmc/mcmc_adaptation.h | 3 +- src/mcmc/mcmc_hmc.cpp | 11 +- src/mcmc/mcmc_runner.cpp | 237 +++++++++++++ src/mcmc/mcmc_runner.h | 324 ++---------------- src/mcmc/mcmc_utils.cpp | 10 +- src/mcmc/mh_sampler.h | 8 +- src/mcmc/nuts_sampler.h | 170 +-------- src/mixedVariables.cpp | 43 --- src/mixedVariables.h | 169 --------- .../adaptive_metropolis.h} | 4 +- src/{ => models}/base_model.h | 0 src/{ => models/ggm}/cholupdate.cpp | 6 +- src/{ => models/ggm}/cholupdate.h | 0 src/{ => models/ggm}/ggm_model.cpp | 110 +++--- src/{ => models/ggm}/ggm_model.h | 156 ++++----- src/models/mixed/mixed_variables.cpp | 43 +++ src/models/mixed/mixed_variables.h | 169 +++++++++ src/{ => models/omrf}/omrf_model.cpp | 126 +++---- src/{ => models/omrf}/omrf_model.h | 15 +- src/{ => models}/skeleton_model.cpp | 4 +- src/sample_ggm.cpp | 6 +- src/sample_omrf.cpp | 4 +- 30 files changed, 872 insertions(+), 1108 deletions(-) create mode 100644 src/mcmc/adaptive_gradient_sampler.h rename src/{chainResultNew.h => mcmc/chain_result.h} (99%) create mode 100644 src/mcmc/mcmc_runner.cpp delete mode 100644 src/mixedVariables.cpp delete mode 100644 src/mixedVariables.h rename src/{adaptiveMetropolis.h => models/adaptive_metropolis.h} (92%) rename src/{ => models}/base_model.h (100%) rename src/{ => models/ggm}/cholupdate.cpp (96%) rename src/{ => models/ggm}/cholupdate.h (100%) rename src/{ => models/ggm}/ggm_model.cpp (79%) rename src/{ => models/ggm}/ggm_model.h (70%) create mode 100644 src/models/mixed/mixed_variables.cpp create mode 100644 src/models/mixed/mixed_variables.h rename src/{ => models/omrf}/omrf_model.cpp (95%) rename src/{ => models/omrf}/omrf_model.h (97%) rename src/{ => models}/skeleton_model.cpp (99%) diff --git a/R/RcppExports.R b/R/RcppExports.R index 528716e0..806fcb8d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -9,10 +9,6 @@ run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_ .Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type, pairwise_scaling_factors) } -chol_update_arma <- function(R, u, downdate = FALSE, eps = 1e-12) { - .Call(`_bgms_chol_update_arma`, R, u, downdate, eps) -} - compute_conditional_probs <- function(observations, predict_vars, interactions, thresholds, no_categories, variable_type, baseline_category) { .Call(`_bgms_compute_conditional_probs`, observations, predict_vars, interactions, thresholds, no_categories, variable_type, baseline_category) } diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index b395f4be..b82898ca 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -104,20 +104,6 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } -// chol_update_arma -arma::mat chol_update_arma(arma::mat& R, arma::vec& u, bool downdate, double eps); -RcppExport SEXP _bgms_chol_update_arma(SEXP RSEXP, SEXP uSEXP, SEXP downdateSEXP, SEXP epsSEXP) { -BEGIN_RCPP - Rcpp::RObject rcpp_result_gen; - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< arma::mat& >::type R(RSEXP); - Rcpp::traits::input_parameter< arma::vec& >::type u(uSEXP); - Rcpp::traits::input_parameter< bool >::type downdate(downdateSEXP); - Rcpp::traits::input_parameter< double >::type eps(epsSEXP); - rcpp_result_gen = Rcpp::wrap(chol_update_arma(R, u, downdate, eps)); - return rcpp_result_gen; -END_RCPP -} // compute_conditional_probs Rcpp::List compute_conditional_probs(arma::imat observations, arma::ivec predict_vars, arma::mat interactions, arma::mat thresholds, arma::ivec no_categories, Rcpp::StringVector variable_type, arma::ivec baseline_category); RcppExport SEXP _bgms_compute_conditional_probs(SEXP observationsSEXP, SEXP predict_varsSEXP, SEXP interactionsSEXP, SEXP thresholdsSEXP, SEXP no_categoriesSEXP, SEXP variable_typeSEXP, SEXP baseline_categorySEXP) { @@ -272,7 +258,6 @@ END_RCPP static const R_CallMethodDef CallEntries[] = { {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 38}, {"_bgms_run_bgm_parallel", (DL_FUNC) &_bgms_run_bgm_parallel, 35}, - {"_bgms_chol_update_arma", (DL_FUNC) &_bgms_chol_update_arma, 4}, {"_bgms_compute_conditional_probs", (DL_FUNC) &_bgms_compute_conditional_probs, 7}, {"_bgms_sample_omrf_gibbs", (DL_FUNC) &_bgms_sample_omrf_gibbs, 7}, {"_bgms_sample_bcomrf_gibbs", (DL_FUNC) &_bgms_sample_bcomrf_gibbs, 9}, diff --git a/src/bgm/bgm_sampler.cpp b/src/bgm/bgm_sampler.cpp index bbb70543..e68c7cd8 100644 --- a/src/bgm/bgm_sampler.cpp +++ b/src/bgm/bgm_sampler.cpp @@ -588,7 +588,8 @@ void update_hmc_bgm( adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix. Use current step size as starting point. + // step size for the new mass matrix (following STAN's approach). + // STAN uses the current step size as the starting point for the heuristic. if (adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( adapt.inv_mass_diag(), inclusion_indicator, num_categories, @@ -733,7 +734,8 @@ SamplerResult update_nuts_bgm( adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix. Use current step size as starting point. + // step size for the new mass matrix (following STAN's approach). + // STAN uses the current step size as the starting point for the heuristic. if (adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( adapt.inv_mass_diag(), inclusion_indicator, num_categories, diff --git a/src/bgmCompare/bgmCompare_sampler.cpp b/src/bgmCompare/bgmCompare_sampler.cpp index 1ea890cb..567705bd 100644 --- a/src/bgmCompare/bgmCompare_sampler.cpp +++ b/src/bgmCompare/bgmCompare_sampler.cpp @@ -735,7 +735,8 @@ void update_hmc_bgmcompare( hmc_adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix. Use current step size as starting point. + // step size for the new mass matrix (following STAN's approach). + // STAN uses the current step size as the starting point for the heuristic. if (hmc_adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( hmc_adapt.inv_mass_diag(), inclusion_indicator, num_groups, num_categories, @@ -910,7 +911,8 @@ SamplerResult update_nuts_bgmcompare( hmc_adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix. Use current step size as starting point. + // step size for the new mass matrix (following STAN's approach). + // STAN uses the current step size as the starting point for the heuristic. if (hmc_adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( hmc_adapt.inv_mass_diag(), inclusion_indicator, num_groups, num_categories, diff --git a/src/mcmc/adaptive_gradient_sampler.h b/src/mcmc/adaptive_gradient_sampler.h new file mode 100644 index 00000000..02a05ef2 --- /dev/null +++ b/src/mcmc/adaptive_gradient_sampler.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include +#include +#include +#include "mcmc/base_sampler.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/mcmc_adaptation.h" +#include "mcmc/sampler_config.h" +#include "models/base_model.h" + +/** + * AdaptiveGradientSampler - Base for gradient-based MCMC with warmup adaptation + * + * Shared warmup logic for NUTS and HMC: + * Stage 1: step-size adaptation only + * Stage 2: mass matrix estimation in doubling windows + step-size re-tuning + * Stage 3: final step-size averaging + */ +class AdaptiveGradientSampler : public BaseSampler { +public: + AdaptiveGradientSampler(double step_size, double target_acceptance, int n_warmup) + : step_size_(step_size), + target_acceptance_(target_acceptance), + n_warmup_(n_warmup), + warmup_iteration_(0), + initialized_(false), + step_adapter_(step_size) + { + build_warmup_schedule(n_warmup); + } + + SamplerResult warmup_step(BaseModel& model) override { + if (!initialized_) { + initialize(model); + initialized_ = true; + } + + SamplerResult result = do_gradient_step(model); + + step_adapter_.update(result.accept_prob, target_acceptance_); + step_size_ = step_adapter_.current(); + + if (in_stage2()) { + arma::vec full_params = model.get_full_vectorized_parameters(); + mass_accumulator_->update(full_params); + + if (at_window_end()) { + inv_mass_ = mass_accumulator_->variance(); + mass_accumulator_->reset(); + model.set_inv_mass(inv_mass_); + + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + arma::vec active_inv_mass = model.get_active_inv_mass(); + + double new_eps = heuristic_initial_step_size( + theta, log_post, grad_fn, active_inv_mass, rng, + 0.625, step_size_); + step_size_ = new_eps; + step_adapter_.restart(new_eps); + } + } + + warmup_iteration_++; + return result; + } + + void finalize_warmup() override { + step_size_ = step_adapter_.averaged(); + } + + SamplerResult sample_step(BaseModel& model) override { + return do_gradient_step(model); + } + + double get_step_size() const { return step_size_; } + double get_averaged_step_size() const { return step_adapter_.averaged(); } + const arma::vec& get_inv_mass() const { return inv_mass_; } + +protected: + virtual SamplerResult do_gradient_step(BaseModel& model) = 0; + + double step_size_; + double target_acceptance_; + +private: + void build_warmup_schedule(int n_warmup) { + stage1_end_ = static_cast(0.075 * n_warmup); + stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); + + window_ends_.clear(); + int cur = stage1_end_; + int wsize = 25; + + while (cur < stage3_start_) { + int win = std::min(wsize, stage3_start_ - cur); + window_ends_.push_back(cur + win); + cur += win; + wsize = std::min(wsize * 2, stage3_start_ - cur); + } + } + + bool in_stage2() const { + return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; + } + + bool at_window_end() const { + for (int end : window_ends_) { + if (warmup_iteration_ + 1 == end) return true; + } + return false; + } + + void initialize(BaseModel& model) { + arma::vec theta = model.get_vectorized_parameters(); + SafeRNG& rng = model.get_rng(); + + inv_mass_ = arma::ones(model.full_parameter_dimension()); + model.set_inv_mass(inv_mass_); + + mass_accumulator_ = std::make_unique( + static_cast(model.full_parameter_dimension())); + + auto log_post = [&model](const arma::vec& params) -> double { + return model.logp_and_gradient(params).first; + }; + auto grad_fn = [&model](const arma::vec& params) -> arma::vec { + return model.logp_and_gradient(params).second; + }; + + step_size_ = heuristic_initial_step_size( + theta, log_post, grad_fn, rng, target_acceptance_); + + step_adapter_.restart(step_size_); + } + + int n_warmup_; + int warmup_iteration_; + bool initialized_; + + DualAveraging step_adapter_; + arma::vec inv_mass_; + std::unique_ptr mass_accumulator_; + + int stage1_end_; + int stage3_start_; + std::vector window_ends_; +}; diff --git a/src/mcmc/base_sampler.h b/src/mcmc/base_sampler.h index 4a23157c..576840c8 100644 --- a/src/mcmc/base_sampler.h +++ b/src/mcmc/base_sampler.h @@ -1,9 +1,9 @@ #pragma once #include -#include "mcmc_utils.h" -#include "sampler_config.h" -#include "../base_model.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/sampler_config.h" +#include "models/base_model.h" /** * BaseSampler - Abstract base class for MCMC samplers diff --git a/src/chainResultNew.h b/src/mcmc/chain_result.h similarity index 99% rename from src/chainResultNew.h rename to src/mcmc/chain_result.h index 295e16c9..d43f7564 100644 --- a/src/chainResultNew.h +++ b/src/mcmc/chain_result.h @@ -12,7 +12,7 @@ class ChainResultNew { public: - ChainResultNew() {} + ChainResultNew() = default; // Error handling bool error = false; diff --git a/src/mcmc/hmc_sampler.h b/src/mcmc/hmc_sampler.h index d4ba605e..458dcb18 100644 --- a/src/mcmc/hmc_sampler.h +++ b/src/mcmc/hmc_sampler.h @@ -1,147 +1,23 @@ #pragma once -#include -#include -#include -#include "base_sampler.h" -#include "mcmc_utils.h" -#include "mcmc_hmc.h" -#include "mcmc_adaptation.h" -#include "sampler_config.h" -#include "../base_model.h" +#include "mcmc/adaptive_gradient_sampler.h" +#include "mcmc/mcmc_hmc.h" /** - * HMCSampler - Hamiltonian Monte Carlo sampler + * HMCSampler - Hamiltonian Monte Carlo * - * Uses fixed-length leapfrog integration with full warmup adaptation: - * - Heuristic initial step size - * - Diagonal mass matrix estimation (windowed) - * - Dual averaging step size adaptation - * - * Warmup schedule mirrors NUTSSampler (3 stages). + * Fixed-length leapfrog integration. Inherits warmup adaptation + * (step size + diagonal mass matrix) from AdaptiveGradientSampler. */ -class HMCSampler : public BaseSampler { +class HMCSampler : public AdaptiveGradientSampler { public: explicit HMCSampler(const SamplerConfig& config) - : step_size_(config.initial_step_size), - target_acceptance_(config.target_acceptance), - num_leapfrogs_(config.num_leapfrogs), - no_warmup_(config.no_warmup), - warmup_iteration_(0), - initialized_(false), - step_adapter_(config.initial_step_size) - { - build_warmup_schedule(config.no_warmup); - } - - SamplerResult warmup_step(BaseModel& model) override { - if (!initialized_) { - initialize(model); - initialized_ = true; - } - - SamplerResult result = do_hmc_step(model); - - // Adapt step size during all warmup phases - step_adapter_.update(result.accept_prob, target_acceptance_); - step_size_ = step_adapter_.current(); - - // During Stage 2, accumulate samples for mass matrix estimation - if (in_stage2()) { - arma::vec full_params = model.get_full_vectorized_parameters(); - mass_accumulator_->update(full_params); - - if (at_window_end()) { - inv_mass_ = mass_accumulator_->variance(); - mass_accumulator_->reset(); - - model.set_inv_mass(inv_mass_); - - arma::vec theta = model.get_vectorized_parameters(); - SafeRNG& rng = model.get_rng(); - auto log_post = [&model](const arma::vec& params) -> double { - return model.logp_and_gradient(params).first; - }; - auto grad_fn = [&model](const arma::vec& params) -> arma::vec { - return model.logp_and_gradient(params).second; - }; - arma::vec active_inv_mass = model.get_active_inv_mass(); - - double new_eps = heuristic_initial_step_size( - theta, log_post, grad_fn, active_inv_mass, rng, - 0.625, step_size_); - step_size_ = new_eps; - step_adapter_.restart(new_eps); - } - } - - warmup_iteration_++; - return result; - } - - void finalize_warmup() override { - step_size_ = step_adapter_.averaged(); - } - - SamplerResult sample_step(BaseModel& model) override { - return do_hmc_step(model); - } - - double get_step_size() const { return step_size_; } - double get_averaged_step_size() const { return step_adapter_.averaged(); } + : AdaptiveGradientSampler(config.initial_step_size, config.target_acceptance, config.no_warmup), + num_leapfrogs_(config.num_leapfrogs) + {} -private: - void build_warmup_schedule(int n_warmup) { - stage1_end_ = static_cast(0.075 * n_warmup); - stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); - - window_ends_.clear(); - int cur = stage1_end_; - int wsize = 25; - - while (cur < stage3_start_) { - int win = std::min(wsize, stage3_start_ - cur); - window_ends_.push_back(cur + win); - cur += win; - wsize = std::min(wsize * 2, stage3_start_ - cur); - } - } - - bool in_stage2() const { - return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; - } - - bool at_window_end() const { - for (int end : window_ends_) { - if (warmup_iteration_ + 1 == end) return true; - } - return false; - } - - void initialize(BaseModel& model) { - arma::vec theta = model.get_vectorized_parameters(); - SafeRNG& rng = model.get_rng(); - - inv_mass_ = arma::ones(model.full_parameter_dimension()); - model.set_inv_mass(inv_mass_); - - mass_accumulator_ = std::make_unique( - static_cast(model.full_parameter_dimension())); - - auto log_post = [&model](const arma::vec& params) -> double { - return model.logp_and_gradient(params).first; - }; - auto grad_fn = [&model](const arma::vec& params) -> arma::vec { - return model.logp_and_gradient(params).second; - }; - - step_size_ = heuristic_initial_step_size( - theta, log_post, grad_fn, rng, target_acceptance_); - - step_adapter_.restart(step_size_); - } - - SamplerResult do_hmc_step(BaseModel& model) { +protected: + SamplerResult do_gradient_step(BaseModel& model) override { arma::vec theta = model.get_vectorized_parameters(); arma::vec inv_mass = model.get_active_inv_mass(); SafeRNG& rng = model.get_rng(); @@ -161,25 +37,6 @@ class HMCSampler : public BaseSampler { return result; } - // Configuration - double step_size_; - double target_acceptance_; +private: int num_leapfrogs_; - int no_warmup_; - - // State tracking - int warmup_iteration_; - bool initialized_; - - // Step size adaptation - DualAveraging step_adapter_; - - // Mass matrix adaptation - arma::vec inv_mass_; - std::unique_ptr mass_accumulator_; - - // Warmup schedule - int stage1_end_; - int stage3_start_; - std::vector window_ends_; }; diff --git a/src/mcmc/mcmc_adaptation.h b/src/mcmc/mcmc_adaptation.h index b2168ae9..4fdf4fcf 100644 --- a/src/mcmc/mcmc_adaptation.h +++ b/src/mcmc/mcmc_adaptation.h @@ -93,7 +93,7 @@ class DiagMassMatrixAccumulator { }; -// === Dynamic Warmup Schedule with Adaptive Windows === +// === Stan-style Dynamic Warmup Schedule with Adaptive Windows === // // For edge_selection = FALSE: // Stage 1 (init), Stage 2 (doubling windows), Stage 3a (terminal) @@ -332,6 +332,7 @@ class HMCAdaptationController { * This should be called after running heuristic_initial_step_size() with * the new mass matrix to find an appropriate starting step size. * + * Following STAN's approach: * - Set the new step size * - Set mu = log(10 * new_step_size) as the adaptation target * - Restart the dual averaging counters diff --git a/src/mcmc/mcmc_hmc.cpp b/src/mcmc/mcmc_hmc.cpp index dde56bed..adb3c022 100644 --- a/src/mcmc/mcmc_hmc.cpp +++ b/src/mcmc/mcmc_hmc.cpp @@ -24,22 +24,13 @@ SamplerResult hmc_sampler( theta, r, step_size, grad, num_leapfrogs, inv_mass_diag ); - // If leapfrog produced NaN/Inf, reject immediately - if (theta.has_nan() || theta.has_inf() || r.has_nan() || r.has_inf()) { - return {init_theta, 0.0}; - } - // Hamiltonians double current_H = -log_post(init_theta) + kinetic_energy(init_r, inv_mass_diag); double proposed_H = -log_post(theta) + kinetic_energy(r, inv_mass_diag); double log_accept_prob = current_H - proposed_H; - // NaN guard: treat non-finite Hamiltonian as rejection - if (!std::isfinite(log_accept_prob)) { - return {init_theta, 0.0}; - } - arma::vec state = (MY_LOG(runif(rng)) < log_accept_prob) ? theta : init_theta; + double accept_prob = std::min(1.0, MY_EXP(log_accept_prob)); return {state, accept_prob}; diff --git a/src/mcmc/mcmc_runner.cpp b/src/mcmc/mcmc_runner.cpp new file mode 100644 index 00000000..a4a4c149 --- /dev/null +++ b/src/mcmc/mcmc_runner.cpp @@ -0,0 +1,237 @@ +#include "mcmc/mcmc_runner.h" + +#include "mcmc/nuts_sampler.h" +#include "mcmc/hmc_sampler.h" +#include "mcmc/mh_sampler.h" + + +std::unique_ptr create_sampler(const SamplerConfig& config) { + if (config.sampler_type == "nuts") { + return std::make_unique(config, config.no_warmup); + } else if (config.sampler_type == "hmc" || config.sampler_type == "hamiltonian-mc") { + return std::make_unique(config); + } else if (config.sampler_type == "mh" || config.sampler_type == "adaptive-metropolis") { + return std::make_unique(config); + } else { + Rcpp::stop("Unknown sampler_type: '%s'", config.sampler_type.c_str()); + } +} + + +void run_mcmc_chain( + ChainResultNew& chain_result, + BaseModel& model, + BaseEdgePrior& edge_prior, + const SamplerConfig& config, + const int chain_id, + ProgressManager& pm +) { + chain_result.chain_id = chain_id + 1; + + const int edge_start = config.get_edge_selection_start(); + + auto sampler = create_sampler(config); + + // Warmup phase + for (int iter = 0; iter < config.no_warmup; ++iter) { + + if (config.na_impute && model.has_missing_data()) { + model.impute_missing(); + } + + if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { + model.update_edge_indicators(); + } + + sampler->warmup_step(model); + + if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { + edge_prior.update( + model.get_edge_indicators(), + model.get_inclusion_probability(), + model.get_num_variables(), + model.get_num_pairwise(), + model.get_rng() + ); + } + + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + return; + } + } + + sampler->finalize_warmup(); + + // Activate edge selection mode + if (config.edge_selection && model.has_edge_selection()) { + model.set_edge_selection_active(true); + model.initialize_graph(); + } + + // Sampling phase + for (int iter = 0; iter < config.no_iter; ++iter) { + + if (config.na_impute && model.has_missing_data()) { + model.impute_missing(); + } + + if (config.edge_selection && model.has_edge_selection()) { + model.update_edge_indicators(); + } + + SamplerResult result = sampler->sample_step(model); + + if (config.edge_selection && model.has_edge_selection()) { + edge_prior.update( + model.get_edge_indicators(), + model.get_inclusion_probability(), + model.get_num_variables(), + model.get_num_pairwise(), + model.get_rng() + ); + } + + if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { + auto* diag = dynamic_cast(result.diagnostics.get()); + if (diag) { + chain_result.store_nuts_diagnostics(iter, diag->tree_depth, diag->divergent, diag->energy); + } + } + + chain_result.store_sample(iter, model.get_full_vectorized_parameters()); + + if (chain_result.has_indicators) { + chain_result.store_indicators(iter, model.get_vectorized_indicator_parameters()); + } + + if (chain_result.has_allocations && edge_prior.has_allocations()) { + chain_result.store_allocations(iter, edge_prior.get_allocations()); + } + + pm.update(chain_id); + if (pm.shouldExit()) { + chain_result.userInterrupt = true; + return; + } + } +} + + +void MCMCChainRunner::operator()(std::size_t begin, std::size_t end) { + for (std::size_t i = begin; i < end; ++i) { + ChainResultNew& chain_result = results_[i]; + BaseModel& model = *models_[i]; + BaseEdgePrior& edge_prior = *edge_priors_[i]; + model.set_seed(config_.seed + static_cast(i)); + + try { + run_mcmc_chain(chain_result, model, edge_prior, config_, static_cast(i), pm_); + } catch (std::exception& e) { + chain_result.error = true; + chain_result.error_msg = e.what(); + } catch (...) { + chain_result.error = true; + chain_result.error_msg = "Unknown error"; + } + } +} + + +std::vector run_mcmc_sampler( + BaseModel& model, + BaseEdgePrior& edge_prior, + const SamplerConfig& config, + const int no_chains, + const int no_threads, + ProgressManager& pm +) { + const bool has_nuts_diag = (config.sampler_type == "nuts"); + const bool has_sbm_alloc = edge_prior.has_allocations() || + (config.edge_selection && dynamic_cast(&edge_prior) != nullptr); + + std::vector results(no_chains); + for (int c = 0; c < no_chains; ++c) { + results[c].reserve(model.full_parameter_dimension(), config.no_iter); + + if (config.edge_selection) { + size_t n_edges = model.get_vectorized_indicator_parameters().n_elem; + results[c].reserve_indicators(n_edges, config.no_iter); + } + + if (has_sbm_alloc) { + results[c].reserve_allocations(model.get_num_variables(), config.no_iter); + } + + if (has_nuts_diag) { + results[c].reserve_nuts_diagnostics(config.no_iter); + } + } + + if (no_threads > 1) { + std::vector> models; + std::vector> edge_priors; + models.reserve(no_chains); + edge_priors.reserve(no_chains); + for (int c = 0; c < no_chains; ++c) { + models.push_back(model.clone()); + models[c]->set_seed(config.seed + c); + edge_priors.push_back(edge_prior.clone()); + } + + MCMCChainRunner runner(results, models, edge_priors, config, pm); + tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); + RcppParallel::parallelFor(0, static_cast(no_chains), runner); + + } else { + model.set_seed(config.seed); + for (int c = 0; c < no_chains; ++c) { + auto chain_model = model.clone(); + chain_model->set_seed(config.seed + c); + auto chain_edge_prior = edge_prior.clone(); + run_mcmc_chain(results[c], *chain_model, *chain_edge_prior, config, c, pm); + } + } + + return results; +} + + +Rcpp::List convert_results_to_list(const std::vector& results) { + Rcpp::List output(results.size()); + + for (size_t i = 0; i < results.size(); ++i) { + const ChainResultNew& chain = results[i]; + Rcpp::List chain_list; + + chain_list["chain_id"] = chain.chain_id; + + if (chain.error) { + chain_list["error"] = true; + chain_list["error_msg"] = chain.error_msg; + } else { + chain_list["error"] = false; + chain_list["samples"] = chain.samples; + chain_list["userInterrupt"] = chain.userInterrupt; + + if (chain.has_indicators) { + chain_list["indicator_samples"] = chain.indicator_samples; + } + + if (chain.has_allocations) { + chain_list["allocation_samples"] = chain.allocation_samples; + } + + if (chain.has_nuts_diagnostics) { + chain_list["treedepth"] = chain.treedepth_samples; + chain_list["divergent"] = chain.divergent_samples; + chain_list["energy"] = chain.energy_samples; + } + } + + output[i] = chain_list; + } + + return output; +} diff --git a/src/mcmc/mcmc_runner.h b/src/mcmc/mcmc_runner.h index b4068e9e..df7e78f5 100644 --- a/src/mcmc/mcmc_runner.h +++ b/src/mcmc/mcmc_runner.h @@ -6,177 +6,30 @@ #include #include -#include "../base_model.h" -#include "../chainResultNew.h" -#include "../priors/edge_prior.h" -#include "../utils/progress_manager.h" -#include "sampler_config.h" -#include "base_sampler.h" -#include "nuts_sampler.h" -#include "hmc_sampler.h" -#include "mh_sampler.h" -#include "mcmc_utils.h" +#include "models/base_model.h" +#include "mcmc/chain_result.h" +#include "priors/edge_prior.h" +#include "utils/progress_manager.h" +#include "mcmc/sampler_config.h" +#include "mcmc/base_sampler.h" +#include "mcmc/mcmc_utils.h" -/** - * Create a sampler based on configuration - * - * Factory function that returns the appropriate sampler type. - * - * @param config Sampler configuration - * @return Unique pointer to the created sampler - */ -inline std::unique_ptr create_sampler(const SamplerConfig& config) { - if (config.sampler_type == "nuts") { - return std::make_unique(config, config.no_warmup); - } else if (config.sampler_type == "hmc" || config.sampler_type == "hamiltonian-mc") { - return std::make_unique(config); - } else if (config.sampler_type == "mh" || config.sampler_type == "adaptive-metropolis") { - return std::make_unique(config); - } else { - Rcpp::stop("Unknown sampler_type: '%s'", config.sampler_type.c_str()); - } -} +/// Create a sampler matching config.sampler_type. +std::unique_ptr create_sampler(const SamplerConfig& config); - -/** - * Run MCMC sampling for a single chain - * - * Supports MH, NUTS, and HMC samplers with optional edge selection. - * Handles warmup adaptation and diagnostic collection. - * - * @param chain_result Output storage for this chain - * @param model The model to sample from - * @param config Sampler configuration - * @param chain_id Chain identifier (0-based) - * @param pm Progress manager for user feedback - */ -inline void run_mcmc_chain( +/// Run a single MCMC chain (warmup + sampling) writing into chain_result. +void run_mcmc_chain( ChainResultNew& chain_result, BaseModel& model, BaseEdgePrior& edge_prior, const SamplerConfig& config, - const int chain_id, + int chain_id, ProgressManager& pm -) { - chain_result.chain_id = chain_id + 1; - - const int edge_start = config.get_edge_selection_start(); - - // Create sampler for this chain - auto sampler = create_sampler(config); - - // ========================================================================= - // Warmup phase - // ========================================================================= - for (int iter = 0; iter < config.no_warmup; ++iter) { - - // Impute missing data if applicable - if (config.na_impute && model.has_missing_data()) { - model.impute_missing(); - } - - // Edge selection starts after edge_start iterations - if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { - model.update_edge_indicators(); - } - - // Sampler step (unified interface) - sampler->warmup_step(model); - - // Update edge prior parameters (Beta-Bernoulli, SBM, etc.) - if (config.edge_selection && iter >= edge_start && model.has_edge_selection()) { - edge_prior.update( - model.get_edge_indicators(), - model.get_inclusion_probability(), - model.get_num_variables(), - model.get_num_pairwise(), - model.get_rng() - ); - } - - // Progress and interrupt check - pm.update(chain_id); - if (pm.shouldExit()) { - chain_result.userInterrupt = true; - return; - } - } - - // Finalize warmup (samplers fix their adapted parameters) - sampler->finalize_warmup(); - - // ========================================================================= - // Activate edge selection mode (if enabled) - // ========================================================================= - if (config.edge_selection && model.has_edge_selection()) { - model.set_edge_selection_active(true); - model.initialize_graph(); // Randomly initialize graph structure - } - - // ========================================================================= - // Sampling phase - // ========================================================================= - for (int iter = 0; iter < config.no_iter; ++iter) { - - // Impute missing data if applicable - if (config.na_impute && model.has_missing_data()) { - model.impute_missing(); - } - - // Edge selection continues during sampling - if (config.edge_selection && model.has_edge_selection()) { - model.update_edge_indicators(); - } - - // Sampler step (unified interface) - SamplerResult result = sampler->sample_step(model); - - // Update edge prior parameters (Beta-Bernoulli, SBM, etc.) - if (config.edge_selection && model.has_edge_selection()) { - edge_prior.update( - model.get_edge_indicators(), - model.get_inclusion_probability(), - model.get_num_variables(), - model.get_num_pairwise(), - model.get_rng() - ); - } - - // Store NUTS diagnostics if available - if (chain_result.has_nuts_diagnostics && sampler->has_nuts_diagnostics()) { - auto* diag = dynamic_cast(result.diagnostics.get()); - if (diag) { - chain_result.store_nuts_diagnostics(iter, diag->tree_depth, diag->divergent, diag->energy); - } - } - - // Store samples - chain_result.store_sample(iter, model.get_full_vectorized_parameters()); - - // Store edge indicators if applicable - if (chain_result.has_indicators) { - chain_result.store_indicators(iter, model.get_vectorized_indicator_parameters()); - } - - // Store SBM allocations if applicable - if (chain_result.has_allocations && edge_prior.has_allocations()) { - chain_result.store_allocations(iter, edge_prior.get_allocations()); - } +); - // Progress and interrupt check - pm.update(chain_id); - if (pm.shouldExit()) { - chain_result.userInterrupt = true; - return; - } - } -} - -/** - * Worker struct for parallel chain execution - */ +/// Worker struct for TBB parallel chain execution. struct MCMCChainRunner : public RcppParallel::Worker { std::vector& results_; std::vector>& models_; @@ -198,152 +51,19 @@ struct MCMCChainRunner : public RcppParallel::Worker { pm_(pm) {} - void operator()(std::size_t begin, std::size_t end) { - for (std::size_t i = begin; i < end; ++i) { - ChainResultNew& chain_result = results_[i]; - BaseModel& model = *models_[i]; - BaseEdgePrior& edge_prior = *edge_priors_[i]; - model.set_seed(config_.seed + static_cast(i)); - - try { - run_mcmc_chain(chain_result, model, edge_prior, config_, static_cast(i), pm_); - } catch (std::exception& e) { - chain_result.error = true; - chain_result.error_msg = e.what(); - } catch (...) { - chain_result.error = true; - chain_result.error_msg = "Unknown error"; - } - } - } + void operator()(std::size_t begin, std::size_t end); }; -/** - * Run MCMC sampling with parallel chains - * - * Main entry point for multi-chain MCMC. Handles: - * - Chain allocation and model cloning - * - Parallel or sequential execution based on no_threads - * - Result collection - * - * @param model Template model (will be cloned for each chain) - * @param config Sampler configuration - * @param no_chains Number of chains to run - * @param no_threads Number of threads (1 = sequential) - * @param pm Progress manager - * @return Vector of chain results - */ -inline std::vector run_mcmc_sampler( +/// Run multi-chain MCMC (parallel or sequential based on no_threads). +std::vector run_mcmc_sampler( BaseModel& model, BaseEdgePrior& edge_prior, const SamplerConfig& config, - const int no_chains, - const int no_threads, + int no_chains, + int no_threads, ProgressManager& pm -) { - const bool has_nuts_diag = (config.sampler_type == "nuts"); - const bool has_sbm_alloc = edge_prior.has_allocations() || - (config.edge_selection && dynamic_cast(&edge_prior) != nullptr); - - // Allocate result storage - std::vector results(no_chains); - for (int c = 0; c < no_chains; ++c) { - results[c].reserve(model.full_parameter_dimension(), config.no_iter); - - if (config.edge_selection) { - size_t n_edges = model.get_vectorized_indicator_parameters().n_elem; - results[c].reserve_indicators(n_edges, config.no_iter); - } - - if (has_sbm_alloc) { - results[c].reserve_allocations(model.get_num_variables(), config.no_iter); - } - - if (has_nuts_diag) { - results[c].reserve_nuts_diagnostics(config.no_iter); - } - } - - if (no_threads > 1) { - // Multi-threaded execution - std::vector> models; - std::vector> edge_priors; - models.reserve(no_chains); - edge_priors.reserve(no_chains); - for (int c = 0; c < no_chains; ++c) { - models.push_back(model.clone()); - models[c]->set_seed(config.seed + c); - edge_priors.push_back(edge_prior.clone()); - } - - MCMCChainRunner runner(results, models, edge_priors, config, pm); - tbb::global_control control(tbb::global_control::max_allowed_parallelism, no_threads); - RcppParallel::parallelFor(0, static_cast(no_chains), runner); - - } else { - // Single-threaded execution - model.set_seed(config.seed); - for (int c = 0; c < no_chains; ++c) { - auto chain_model = model.clone(); - chain_model->set_seed(config.seed + c); - auto chain_edge_prior = edge_prior.clone(); - run_mcmc_chain(results[c], *chain_model, *chain_edge_prior, config, c, pm); - } - } - - return results; -} - - -/** - * Convert chain results to Rcpp::List format - * - * Creates a standardized output format for both GGM and OMRF models. - * Each chain is a list with: - * - chain_id: Chain identifier - * - samples: Parameter samples matrix (param_dim × n_iter) - * - indicator_samples: Edge indicators (if edge_selection) - * - treedepth / divergent / energy: NUTS diagnostics (if NUTS/HMC) - * - error / error_msg: Error information (if error occurred) - * - * @param results Vector of chain results - * @return Rcpp::List with per-chain output - */ -inline Rcpp::List convert_results_to_list(const std::vector& results) { - Rcpp::List output(results.size()); - - for (size_t i = 0; i < results.size(); ++i) { - const ChainResultNew& chain = results[i]; - Rcpp::List chain_list; - - chain_list["chain_id"] = chain.chain_id; - - if (chain.error) { - chain_list["error"] = true; - chain_list["error_msg"] = chain.error_msg; - } else { - chain_list["error"] = false; - chain_list["samples"] = chain.samples; - chain_list["userInterrupt"] = chain.userInterrupt; - - if (chain.has_indicators) { - chain_list["indicator_samples"] = chain.indicator_samples; - } - - if (chain.has_allocations) { - chain_list["allocation_samples"] = chain.allocation_samples; - } - - if (chain.has_nuts_diagnostics) { - chain_list["treedepth"] = chain.treedepth_samples; - chain_list["divergent"] = chain.divergent_samples; - chain_list["energy"] = chain.energy_samples; - } - } - - output[i] = chain_list; - } +); - return output; -} +/// Convert chain results to Rcpp::List for return to R. +Rcpp::List convert_results_to_list(const std::vector& results); diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 1e37577c..48ace92a 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -59,7 +59,7 @@ double heuristic_initial_step_size( double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum and evaluate arma::vec r = arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -79,7 +79,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum on each iteration for step size search + // Resample momentum (STAN resamples on each iteration) r = arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; @@ -131,7 +131,7 @@ double heuristic_initial_step_size( ) { double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum from N(0, M) where M = diag(1/inv_mass_diag) arma::vec r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -151,7 +151,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum on each iteration for step size search + // Resample momentum (STAN resamples on each iteration) r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; @@ -168,4 +168,4 @@ double heuristic_initial_step_size( } return eps; -} \ No newline at end of file +} diff --git a/src/mcmc/mh_sampler.h b/src/mcmc/mh_sampler.h index fdbd7ab1..3902f732 100644 --- a/src/mcmc/mh_sampler.h +++ b/src/mcmc/mh_sampler.h @@ -1,10 +1,10 @@ #pragma once #include -#include "base_sampler.h" -#include "mcmc_utils.h" -#include "sampler_config.h" -#include "../base_model.h" +#include "mcmc/base_sampler.h" +#include "mcmc/mcmc_utils.h" +#include "mcmc/sampler_config.h" +#include "models/base_model.h" /** * MHSampler - Metropolis-Hastings sampler diff --git a/src/mcmc/nuts_sampler.h b/src/mcmc/nuts_sampler.h index 6aa2bbca..d27a895e 100644 --- a/src/mcmc/nuts_sampler.h +++ b/src/mcmc/nuts_sampler.h @@ -1,143 +1,25 @@ #pragma once -#include -#include -#include -#include -#include "base_sampler.h" -#include "mcmc_utils.h" -#include "mcmc_nuts.h" -#include "mcmc_adaptation.h" -#include "sampler_config.h" -#include "../base_model.h" - -class NUTSSampler : public BaseSampler { +#include "mcmc/adaptive_gradient_sampler.h" +#include "mcmc/mcmc_nuts.h" + +/** + * NUTSSampler - No-U-Turn Sampler + * + * Adaptive tree-depth leapfrog integration. Inherits warmup adaptation + * (step size + diagonal mass matrix) from AdaptiveGradientSampler. + */ +class NUTSSampler : public AdaptiveGradientSampler { public: explicit NUTSSampler(const SamplerConfig& config, int n_warmup = 1000) - : step_size_(config.initial_step_size), - target_acceptance_(config.target_acceptance), - max_tree_depth_(config.max_tree_depth), - no_warmup_(config.no_warmup), - n_warmup_(n_warmup), - warmup_iteration_(0), - initialized_(false), - step_adapter_(config.initial_step_size) - { - build_warmup_schedule(n_warmup); - } - - SamplerResult warmup_step(BaseModel& model) override { - if (!initialized_) { - initialize(model); - initialized_ = true; - } - - SamplerResult result = do_nuts_step(model); - - // Adapt step size during all warmup phases - step_adapter_.update(result.accept_prob, target_acceptance_); - step_size_ = step_adapter_.current(); - - // During Stage 2, accumulate samples for mass matrix estimation - if (in_stage2()) { - arma::vec full_params = model.get_full_vectorized_parameters(); - mass_accumulator_->update(full_params); - - if (at_window_end()) { - // Stan convention: inv_mass = variance (high-variance params move more) - inv_mass_ = mass_accumulator_->variance(); - mass_accumulator_->reset(); - - // Push adapted mass matrix to model - model.set_inv_mass(inv_mass_); - - arma::vec theta = model.get_vectorized_parameters(); - SafeRNG& rng = model.get_rng(); - auto log_post = [&model](const arma::vec& params) -> double { - return model.logp_and_gradient(params).first; - }; - auto grad_fn = [&model](const arma::vec& params) -> arma::vec { - return model.logp_and_gradient(params).second; - }; - arma::vec active_inv_mass = model.get_active_inv_mass(); - - double new_eps = heuristic_initial_step_size( - theta, log_post, grad_fn, active_inv_mass, rng, - 0.625, step_size_); - step_size_ = new_eps; - step_adapter_.restart(new_eps); - } - } - - warmup_iteration_++; - return result; - } - - void finalize_warmup() override { - step_size_ = step_adapter_.averaged(); - } - - SamplerResult sample_step(BaseModel& model) override { - return do_nuts_step(model); - } + : AdaptiveGradientSampler(config.initial_step_size, config.target_acceptance, n_warmup), + max_tree_depth_(config.max_tree_depth) + {} bool has_nuts_diagnostics() const override { return true; } - double get_step_size() const { return step_size_; } - double get_averaged_step_size() const { return step_adapter_.averaged(); } - const arma::vec& get_inv_mass() const { return inv_mass_; } - -private: - void build_warmup_schedule(int n_warmup) { - stage1_end_ = static_cast(0.075 * n_warmup); - stage3_start_ = n_warmup - static_cast(0.10 * n_warmup); - - window_ends_.clear(); - int cur = stage1_end_; - int wsize = 25; - - while (cur < stage3_start_) { - int win = std::min(wsize, stage3_start_ - cur); - window_ends_.push_back(cur + win); - cur += win; - wsize = std::min(wsize * 2, stage3_start_ - cur); - } - } - - bool in_stage2() const { - return warmup_iteration_ >= stage1_end_ && warmup_iteration_ < stage3_start_; - } - bool at_window_end() const { - for (int end : window_ends_) { - if (warmup_iteration_ + 1 == end) return true; - } - return false; - } - - void initialize(BaseModel& model) { - arma::vec theta = model.get_vectorized_parameters(); - SafeRNG& rng = model.get_rng(); - - inv_mass_ = arma::ones(model.full_parameter_dimension()); - model.set_inv_mass(inv_mass_); - - mass_accumulator_ = std::make_unique( - static_cast(model.full_parameter_dimension())); - - auto log_post = [&model](const arma::vec& params) -> double { - return model.logp_and_gradient(params).first; - }; - auto grad_fn = [&model](const arma::vec& params) -> arma::vec { - return model.logp_and_gradient(params).second; - }; - - step_size_ = heuristic_initial_step_size( - theta, log_post, grad_fn, rng, target_acceptance_); - - step_adapter_.restart(step_size_); - } - - SamplerResult do_nuts_step(BaseModel& model) { +protected: + SamplerResult do_gradient_step(BaseModel& model) override { arma::vec theta = model.get_vectorized_parameters(); SafeRNG& rng = model.get_rng(); @@ -157,26 +39,6 @@ class NUTSSampler : public BaseSampler { return result; } - // Configuration - double step_size_; - double target_acceptance_; +private: int max_tree_depth_; - int no_warmup_; - int n_warmup_; - - // State tracking - int warmup_iteration_; - bool initialized_; - - // Step size adaptation - DualAveraging step_adapter_; - - // Mass matrix adaptation - arma::vec inv_mass_; - std::unique_ptr mass_accumulator_; - - // Warmup schedule - int stage1_end_; - int stage3_start_; - std::vector window_ends_; }; diff --git a/src/mixedVariables.cpp b/src/mixedVariables.cpp deleted file mode 100644 index e4510128..00000000 --- a/src/mixedVariables.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "ggm_model.h" -#include "mixedVariables.h" -#include "rng/rng_utils.h" - - -void MixedVariableTypes::instantiate_variable_types(Rcpp::List input_from_R) -{ - // instantiate variable_types_ - for (Rcpp::List var_type_list : input_from_R) { - std::string type = Rcpp::as(var_type_list["type"]); - if (type == "Continuous") { - variable_types_.push_back(std::make_unique( - Rcpp::as(var_type_list["observations"]), - Rcpp::as(var_type_list["inclusion_probability"]), - Rcpp::as(var_type_list["initial_edge_indicators"]), - Rcpp::as(var_type_list["edge_selection"]) - )); - // } else if (type == "Ordinal") { - // variable_types_.push_back(std::make_unique( - // var_type_list["observations"], - // var_type_list["inclusion_probability"], - // var_type_list["initial_edge_indicators"], - // var_type_list["edge_selection"] - // )); - // } else if (type == "Blume-Capel") { - // variable_types_.push_back(std::make_unique( - // var_type_list["observations"], - // var_type_list["inclusion_probability"], - // var_type_list["initial_edge_indicators"], - // var_type_list["edge_selection"] - // )); - // } else if (type == "Count") { - // variable_types_.push_back(std::make_unique( - // var_type_list["observations"], - // var_type_list["inclusion_probability"], - // var_type_list["initial_edge_indicators"], - // var_type_list["edge_selection"] - // )); - } else { - throw std::runtime_error("MixedVariableTypes received an unknown variable type in sublist fro input_from_R: " + type); - } - } -} \ No newline at end of file diff --git a/src/mixedVariables.h b/src/mixedVariables.h deleted file mode 100644 index 66304b06..00000000 --- a/src/mixedVariables.h +++ /dev/null @@ -1,169 +0,0 @@ -#pragma once - -#include -#include -#include -#include "base_model.h" - - -// Forward declaration - this class is work in progress and not yet functional -class MixedVariableTypes : public BaseModel { -public: - - // Constructor - MixedVariableTypes( - Rcpp::List input_from_R, - const arma::mat& inclusion_probability, - const arma::imat& initial_edge_indicators, - const bool edge_selection = true - ) - { - - instantiate_variable_types(input_from_R); - // instantiate_variable_interactions(); // TODO: not yet implemented - - dim_ = 0; - for (const auto& var_type : variable_types_) { - dim_ += var_type->parameter_dimension(); - } - // dim_ += interactions_.size(); // TODO: proper dimension calculation - } - - // Capability queries - bool has_gradient() const override - { - for (const auto& var_type : variable_types_) { - if (var_type->has_gradient()) { - return true; - } - } - return false; - } - bool has_adaptive_mh() const override - { - for (const auto& var_type : variable_types_) { - if (var_type->has_adaptive_mh()) { - return true; - } - } - return false; - } - - // Return dimensionality of the parameter space - size_t parameter_dimension() const override { - return dim_; - } - - arma::vec get_vectorized_parameters() const override { - arma::vec result(dim_); - size_t current = 0; - for (size_t i = 0; i < variable_types_.size(); ++i) { - arma::vec var_params = variable_types_[i]->get_vectorized_parameters(); - result.subvec(current, current + var_params.n_elem - 1) = var_params; - current += var_params.n_elem; - } - for (size_t i = 0; i < interactions_.size(); ++i) { - const arma::mat& interactions_mat = interactions_[i]; - for (size_t c = 0; c < interactions_mat.n_cols; ++c) { - for (size_t r = 0; r < interactions_mat.n_rows; ++r) { - result(current) = interactions_mat(r, c); - ++current; - } - } - } - return result; - } - - arma::ivec get_vectorized_indicator_parameters() override { - for (size_t i = 0; i < variable_types_.size(); ++i) { - auto& [from, to] = indicator_parameters_indices_[i]; - vectorized_indicator_parameters_.subvec(from, to) = variable_types_[i]->get_vectorized_indicator_parameters(); - } - size_t current = indicator_parameters_indices_.empty() ? 0 : indicator_parameters_indices_.back().second + 1; - for (size_t i = 0; i < interactions_indicators_.size(); ++i) { - const arma::imat& indicator_mat = interactions_indicators_[i]; - for (size_t c = 0; c < indicator_mat.n_cols; ++c) { - for (size_t r = 0; r < indicator_mat.n_rows; ++r) { - vectorized_indicator_parameters_(current) = indicator_mat(r, c); - ++current; - } - } - } - - return vectorized_indicator_parameters_; - } - - - double logp(const arma::vec& parameters) override - { - double total_logp = 0.0; - for (size_t i = 0; i < variable_types_.size(); ++i) { - auto& [from, to] = parameters_indices_[i]; - // need to do some transformation here! - arma::vec var_params = parameters.subvec(from, to); - total_logp += variable_types_[i]->logp(var_params); - } - // interactions log-probability can be added here if needed - return total_logp; - } - - arma::vec gradient(const arma::vec& parameters) override { - - // TODO: only should call the gradient for variable types that have it - // the rest are assumed to be constant, so have gradient zero - arma::vec total_gradient = arma::zeros(parameters.n_elem); - for (size_t i = 0; i < variable_types_.size(); ++i) - { - if (!variable_types_[i]->has_gradient()) { - continue; - } - auto& [from, to] = parameters_indices_[i]; - arma::vec var_params = parameters.subvec(from, to); - // maybe need to do some transformation here! - arma::vec var_gradient = variable_types_[i]->gradient(var_params); - total_gradient.subvec(from, to) = var_gradient; - } - - return total_gradient; - } - - std::pair logp_and_gradient( - const arma::vec& parameters) override { - if (!has_gradient()) { - throw std::runtime_error("Gradient not implemented for this model"); - } - return {logp(parameters), gradient(parameters)}; - } - - void do_one_mh_step() override { - for (auto& var_type : variable_types_) { - var_type->do_one_mh_step(); - } - } - - void set_seed(int seed) override { - for (auto& var_type : variable_types_) { - var_type->set_seed(seed); - } - } - - std::unique_ptr clone() const override { - throw std::runtime_error("clone method not yet implemented for MixedVariableTypes"); - } - - -private: - std::vector> variable_types_; - std::vector interactions_; - std::vector interactions_indicators_; - size_t dim_; - arma::vec vectorized_parameters_; - arma::ivec vectorized_indicator_parameters_; - arma::ivec indices_from_; - arma::ivec indices_to_; - std::vector> parameters_indices_; - std::vector> indicator_parameters_indices_; - - void instantiate_variable_types(const Rcpp::List input_from_R); - -}; diff --git a/src/adaptiveMetropolis.h b/src/models/adaptive_metropolis.h similarity index 92% rename from src/adaptiveMetropolis.h rename to src/models/adaptive_metropolis.h index d8d9cb6c..578f5d15 100644 --- a/src/adaptiveMetropolis.h +++ b/src/models/adaptive_metropolis.h @@ -7,10 +7,10 @@ class AdaptiveProposal { public: - AdaptiveProposal(size_t num_params, size_t adaption_window = 50, double target_accept = 0.44) { + AdaptiveProposal(size_t num_params, size_t adaptation_window = 50, double target_accept = 0.44) { proposal_sds_ = arma::vec(num_params, arma::fill::ones) * 0.25; // Initial SD, need to tweak this somehow? acceptance_counts_ = arma::ivec(num_params, arma::fill::zeros); - adaptation_window_ = adaption_window; + adaptation_window_ = adaptation_window; target_accept_ = target_accept; } diff --git a/src/base_model.h b/src/models/base_model.h similarity index 100% rename from src/base_model.h rename to src/models/base_model.h diff --git a/src/cholupdate.cpp b/src/models/ggm/cholupdate.cpp similarity index 96% rename from src/cholupdate.cpp rename to src/models/ggm/cholupdate.cpp index 5ab8f6eb..81331be9 100644 --- a/src/cholupdate.cpp +++ b/src/models/ggm/cholupdate.cpp @@ -1,4 +1,4 @@ -#include "cholupdate.h" +#include "models/ggm/cholupdate.h" extern "C" { @@ -122,8 +122,4 @@ arma::mat chol_update_arma(arma::mat& R, arma::vec& u, bool downdate = false, do cholesky_update(R, u, eps); return R; - int n = R.n_cols; - int up = downdate ? 0 : 1; - chol_up(R.memptr(), u.memptr(), &n, &up, &eps); - return R; } diff --git a/src/cholupdate.h b/src/models/ggm/cholupdate.h similarity index 100% rename from src/cholupdate.h rename to src/models/ggm/cholupdate.h diff --git a/src/ggm_model.cpp b/src/models/ggm/ggm_model.cpp similarity index 79% rename from src/ggm_model.cpp rename to src/models/ggm/ggm_model.cpp index ef1dc0ae..3d7068c5 100644 --- a/src/ggm_model.cpp +++ b/src/models/ggm/ggm_model.cpp @@ -1,15 +1,14 @@ -#include "ggm_model.h" -#include "adaptiveMetropolis.h" +#include "models/ggm/ggm_model.h" +#include "models/adaptive_metropolis.h" #include "rng/rng_utils.h" -#include "cholupdate.h" +#include "models/ggm/cholupdate.h" -double GaussianVariables::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { +double GGMModel::compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const { return(A(ii, jj) - A(ii, i) * A(jj, i) / A(i, i)); } -void GaussianVariables::get_constants(size_t i, size_t j) { +void GGMModel::get_constants(size_t i, size_t j) { - // TODO: helper function? double logdet_omega = get_log_det(cholesky_of_precision_); double log_adj_omega_ii = logdet_omega + std::log(std::abs(covariance_matrix_(i, i))); @@ -23,33 +22,29 @@ void GaussianVariables::get_constants(size_t i, size_t j) { ); double Phi_q1q1 = std::exp((log_adj_omega_jj - log_abs_inv_omega_sub_jj) / 2); - constants_[1] = Phi_q1q; - constants_[2] = Phi_q1q1; - constants_[3] = precision_matrix_(i, j) - Phi_q1q * Phi_q1q1; - constants_[4] = Phi_q1q1; - constants_[5] = precision_matrix_(j, j) - Phi_q1q * Phi_q1q; - constants_[6] = constants_[5] + constants_[3] * constants_[3] / (constants_[4] * constants_[4]); + constants_[0] = Phi_q1q; + constants_[1] = Phi_q1q1; + constants_[2] = precision_matrix_(i, j) - Phi_q1q * Phi_q1q1; + constants_[3] = Phi_q1q1; + constants_[4] = precision_matrix_(j, j) - Phi_q1q * Phi_q1q; + constants_[5] = constants_[4] + constants_[2] * constants_[2] / (constants_[3] * constants_[3]); } -double GaussianVariables::R(const double x) const { - if (x == 0) { - return constants_[6]; +double GGMModel::constrained_diagonal(const double x) const { + if (x == 0) { + return constants_[5]; } else { - return constants_[5] + std::pow((x - constants_[3]) / constants_[4], 2); + return constants_[4] + std::pow((x - constants_[2]) / constants_[3], 2); } } -double GaussianVariables::get_log_det(arma::mat triangular_A) const { - // assume A is an (upper) triangular cholesky factor - // returns the log determinant of A'A - - // TODO: should we just do - // log_det(val, sign, trimatu(A))? +double GGMModel::get_log_det(arma::mat triangular_A) const { + // log-determinant of A'A where A is upper-triangular Cholesky factor return 2 * arma::accu(arma::log(triangular_A.diag())); } -double GaussianVariables::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { +double GGMModel::log_density_impl(const arma::mat& omega, const arma::mat& phi) const { double logdet_omega = get_log_det(phi); // TODO: why not just dot(omega, suf_stat_)? @@ -60,9 +55,9 @@ double GaussianVariables::log_density_impl(const arma::mat& omega, const arma::m return log_likelihood; } -double GaussianVariables::log_density_impl_edge(size_t i, size_t j) const { +double GGMModel::log_density_impl_edge(size_t i, size_t j) const { - // this is the log likelihood ratio, not the full log likelihood like GaussianVariables::log_density_impl + // Log-likelihood ratio (not the full log-likelihood) double Ui2 = precision_matrix_(i, j) - precision_proposal_(i, j); double Uj2 = (precision_matrix_(j, j) - precision_proposal_(j, j)) / 2; @@ -81,7 +76,7 @@ double GaussianVariables::log_density_impl_edge(size_t i, size_t j) const { } -double GaussianVariables::log_density_impl_diag(size_t j) const { +double GGMModel::log_density_impl_diag(size_t j) const { // same as above but for i == j, so Ui2 = 0 double Uj2 = (precision_matrix_(j, j) - precision_proposal_(j, j)) / 2; @@ -97,34 +92,33 @@ double GaussianVariables::log_density_impl_diag(size_t j) const { } -void GaussianVariables::update_edge_parameter(size_t i, size_t j) { +void GGMModel::update_edge_parameter(size_t i, size_t j) { if (edge_indicators_(i, j) == 0) { return; // Edge is not included; skip update } get_constants(i, j); - double Phi_q1q = constants_[1]; - double Phi_q1q1 = constants_[2]; + double Phi_q1q = constants_[0]; + double Phi_q1q1 = constants_[1]; size_t e = j * (j + 1) / 2 + i; // parameter index in vectorized form (column-major upper triangle) double proposal_sd = proposal_.get_proposal_sd(e); double phi_prop = rnorm(rng_, Phi_q1q, proposal_sd); - double omega_prop_q1q = constants_[3] + constants_[4] * phi_prop; - double omega_prop_qq = R(omega_prop_q1q); + double omega_prop_q1q = constants_[2] + constants_[3] * phi_prop; + double omega_prop_qq = constrained_diagonal(omega_prop_q1q); // form full proposal matrix for Omega - precision_proposal_ = precision_matrix_; // TODO: needs to be a copy! + precision_proposal_ = precision_matrix_; precision_proposal_(i, j) = omega_prop_q1q; precision_proposal_(j, i) = omega_prop_q1q; precision_proposal_(j, j) = omega_prop_qq; - // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); double ln_alpha = log_density_impl_edge(i, j); - ln_alpha += R::dcauchy(precision_proposal_(i, j), 0.0, 2.5, true); - ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, 2.5, true); + ln_alpha += R::dcauchy(precision_proposal_(i, j), 0.0, pairwise_scale_, true); + ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, pairwise_scale_, true); if (std::log(runif(rng_)) < ln_alpha) { // accept proposal @@ -145,7 +139,7 @@ void GaussianVariables::update_edge_parameter(size_t i, size_t j) { proposal_.update_proposal_sd(e); } -void GaussianVariables::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) +void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) { v2_[0] = omega_ij_old - precision_proposal_(i, j); @@ -178,7 +172,7 @@ void GaussianVariables::cholesky_update_after_edge(double omega_ij_old, double o } -void GaussianVariables::update_diagonal_parameter(size_t i) { +void GGMModel::update_diagonal_parameter(size_t i) { // Implementation of diagonal parameter update // 1-3) from before double logdet_omega = get_log_det(cholesky_of_precision_); @@ -214,7 +208,7 @@ void GaussianVariables::update_diagonal_parameter(size_t i) { proposal_.update_proposal_sd(e); } -void GaussianVariables::cholesky_update_after_diag(double omega_ii_old, size_t i) +void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i) { double delta = omega_ii_old - precision_proposal_(i, i); @@ -236,7 +230,7 @@ void GaussianVariables::cholesky_update_after_diag(double omega_ii_old, size_t i } -void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) { +void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { size_t e = j * (j + 1) / 2 + i; // parameter index in vectorized form (column-major upper triangle) double proposal_sd = proposal_.get_proposal_sd(e); @@ -247,9 +241,9 @@ void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) precision_proposal_(i, j) = 0.0; precision_proposal_(j, i) = 0.0; - // Update diagonal using R function with omega_ij = 0 + // Update diagonal to preserve positive-definiteness get_constants(i, j); - precision_proposal_(j, j) = R(0.0); + precision_proposal_(j, j) = constrained_diagonal(0.0); // double ln_alpha = log_likelihood(precision_proposal_) - log_likelihood(); double ln_alpha = log_density_impl_edge(i, j); @@ -266,8 +260,8 @@ void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) ln_alpha += std::log(1.0 - inclusion_probability_(i, j)) - std::log(inclusion_probability_(i, j)); - ln_alpha += R::dnorm(precision_matrix_(i, j) / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); - ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, 2.5, true); + ln_alpha += R::dnorm(precision_matrix_(i, j) / constants_[3], 0.0, proposal_sd, true) - std::log(constants_[3]); + ln_alpha -= R::dcauchy(precision_matrix_(i, j), 0.0, pairwise_scale_, true); if (std::log(runif(rng_)) < ln_alpha) { @@ -294,8 +288,8 @@ void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) // Get constants for current state (with edge OFF) get_constants(i, j); - double omega_prop_ij = constants_[4] * epsilon; - double omega_prop_jj = R(omega_prop_ij); + double omega_prop_ij = constants_[3] * epsilon; + double omega_prop_jj = constrained_diagonal(omega_prop_ij); precision_proposal_ = precision_matrix_; precision_proposal_(i, j) = omega_prop_ij; @@ -316,12 +310,11 @@ void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) ln_alpha += std::log(inclusion_probability_(i, j)) - std::log(1.0 - inclusion_probability_(i, j)); // Prior change: add slab (Cauchy prior) - ln_alpha += R::dcauchy(omega_prop_ij, 0.0, 2.5, true); + ln_alpha += R::dcauchy(omega_prop_ij, 0.0, pairwise_scale_, true); // Proposal term: proposed edge value given it was generated from truncated normal - ln_alpha -= R::dnorm(omega_prop_ij / constants_[4], 0.0, proposal_sd, true) - std::log(constants_[4]); + ln_alpha -= R::dnorm(omega_prop_ij / constants_[3], 0.0, proposal_sd, true) - std::log(constants_[3]); - // TODO: this can be factored out? if (std::log(runif(rng_)) < ln_alpha) { // Accept: turn ON the edge proposal_.increment_accepts(e); @@ -345,7 +338,7 @@ void GaussianVariables::update_edge_indicator_parameter_pair(size_t i, size_t j) } } -void GaussianVariables::do_one_mh_step() { +void GGMModel::do_one_mh_step() { // Update off-diagonals (upper triangle) for (size_t i = 0; i < p_ - 1; ++i) { @@ -371,7 +364,7 @@ void GaussianVariables::do_one_mh_step() { proposal_.increment_iteration(); } -void GaussianVariables::initialize_graph() { +void GGMModel::initialize_graph() { for (size_t i = 0; i < p_ - 1; ++i) { for (size_t j = i + 1; j < p_; ++j) { double p = inclusion_probability_(i, j); @@ -383,7 +376,7 @@ void GaussianVariables::initialize_graph() { precision_proposal_(i, j) = 0.0; precision_proposal_(j, i) = 0.0; get_constants(i, j); - precision_proposal_(j, j) = R(0.0); + precision_proposal_(j, j) = constrained_diagonal(0.0); double omega_ij_old = precision_matrix_(i, j); double omega_jj_old = precision_matrix_(j, j); @@ -397,30 +390,33 @@ void GaussianVariables::initialize_graph() { } -GaussianVariables createGaussianVariablesFromR( +GGMModel createGGMModelFromR( const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, - const bool edge_selection + const bool edge_selection, + const double pairwise_scale ) { if (inputFromR.containsElementNamed("n") && inputFromR.containsElementNamed("suf_stat")) { int n = Rcpp::as(inputFromR["n"]); arma::mat suf_stat = Rcpp::as(inputFromR["suf_stat"]); - return GaussianVariables( + return GGMModel( n, suf_stat, prior_inclusion_prob, initial_edge_indicators, - edge_selection + edge_selection, + pairwise_scale ); } else if (inputFromR.containsElementNamed("X")) { arma::mat X = Rcpp::as(inputFromR["X"]); - return GaussianVariables( + return GGMModel( X, prior_inclusion_prob, initial_edge_indicators, - edge_selection + edge_selection, + pairwise_scale ); } else { throw std::invalid_argument("Input list must contain either 'X' or both 'n' and 'suf_stat'."); diff --git a/src/ggm_model.h b/src/models/ggm/ggm_model.h similarity index 70% rename from src/ggm_model.h rename to src/models/ggm/ggm_model.h index 84d91849..7946e770 100644 --- a/src/ggm_model.h +++ b/src/models/ggm/ggm_model.h @@ -1,56 +1,66 @@ #pragma once +#include #include -#include "base_model.h" -#include "adaptiveMetropolis.h" +#include "models/base_model.h" +#include "models/adaptive_metropolis.h" #include "rng/rng_utils.h" -class GaussianVariables : public BaseModel { +/** + * GGMModel - Gaussian Graphical Model + * + * Bayesian inference on the precision matrix (inverse covariance) of a + * multivariate Gaussian via element-wise Metropolis-Hastings. Edge + * selection uses a spike-and-slab prior with Cauchy slab. + * + * The Cholesky factor of the precision matrix is maintained incrementally + * through rank-1 updates/downdates after each element change. + */ +class GGMModel : public BaseModel { public: - // constructor from raw data - GaussianVariables( + // Construct from raw observations + GGMModel( const arma::mat& observations, const arma::mat& inclusion_probability, const arma::imat& initial_edge_indicators, - const bool edge_selection = true + const bool edge_selection = true, + const double pairwise_scale = 2.5 ) : n_(observations.n_rows), p_(observations.n_cols), - // TODO: need to estimate the means! so + 1 + // TODO: we need to adjust the algorithm to also sample the means! dim_((p_ * (p_ + 1)) / 2), - // TODO: need to store sample means! suf_stat_(observations.t() * observations), inclusion_probability_(inclusion_probability), edge_selection_(edge_selection), + pairwise_scale_(pairwise_scale), proposal_(AdaptiveProposal(dim_, 500)), precision_matrix_(arma::eye(p_, p_)), cholesky_of_precision_(arma::eye(p_, p_)), inv_cholesky_of_precision_(arma::eye(p_, p_)), covariance_matrix_(arma::eye(p_, p_)), - edge_indicators_(initial_edge_indicators), - vectorized_parameters_(dim_), vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), - precision_proposal_(arma::mat(p_, p_, arma::fill::none)), - constants_(6) + precision_proposal_(arma::mat(p_, p_, arma::fill::none)) {} - // constructor from sufficient statistics - // TODO: needs to implement same TODOs as above constructor - GaussianVariables( + // Construct from sufficient statistics + GGMModel( const int n, const arma::mat& suf_stat, const arma::mat& inclusion_probability, const arma::imat& initial_edge_indicators, - const bool edge_selection = true + const bool edge_selection = true, + const double pairwise_scale = 2.5 ) : n_(n), p_(suf_stat.n_cols), dim_((p_ * (p_ + 1)) / 2), suf_stat_(suf_stat), inclusion_probability_(inclusion_probability), edge_selection_(edge_selection), + pairwise_scale_(pairwise_scale), proposal_(AdaptiveProposal(dim_, 500)), precision_matrix_(arma::eye(p_, p_)), cholesky_of_precision_(arma::eye(p_, p_)), @@ -59,12 +69,10 @@ class GaussianVariables : public BaseModel { edge_indicators_(initial_edge_indicators), vectorized_parameters_(dim_), vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), - precision_proposal_(arma::mat(p_, p_, arma::fill::none)), - constants_(6) + precision_proposal_(arma::mat(p_, p_, arma::fill::none)) {} - // copy constructor - GaussianVariables(const GaussianVariables& other) + GGMModel(const GGMModel& other) : BaseModel(other), dim_(other.dim_), suf_stat_(other.suf_stat_), @@ -72,6 +80,7 @@ class GaussianVariables : public BaseModel { p_(other.p_), inclusion_probability_(other.inclusion_probability_), edge_selection_(other.edge_selection_), + pairwise_scale_(other.pairwise_scale_), precision_matrix_(other.precision_matrix_), cholesky_of_precision_(other.cholesky_of_precision_), inv_cholesky_of_precision_(other.inv_cholesky_of_precision_), @@ -81,21 +90,17 @@ class GaussianVariables : public BaseModel { vectorized_indicator_parameters_(other.vectorized_indicator_parameters_), proposal_(other.proposal_), rng_(other.rng_), - precision_proposal_(other.precision_proposal_), - constants_(other.constants_) + precision_proposal_(other.precision_proposal_) {} - // // rng_ = SafeRNG(123); - - // } void set_adaptive_proposal(AdaptiveProposal proposal) { proposal_ = proposal; } - bool has_gradient() const { return false; } - bool has_adaptive_mh() const override { return true; } - bool has_edge_selection() const override { return edge_selection_; } + bool has_gradient() const override { return false; } + bool has_adaptive_mh() const override { return true; } + bool has_edge_selection() const override { return edge_selection_; } void set_edge_selection_active(bool active) override { edge_selection_active_ = active; @@ -103,62 +108,33 @@ class GaussianVariables : public BaseModel { void initialize_graph() override; - // GGM handles edge indicator updates inside do_one_mh_step(), so the - // external call from the MCMC runner is a no-op. + // GGM handles edge indicator updates inside do_one_mh_step() void update_edge_indicators() override {} - double logp(const arma::vec& parameters) override { - // Implement log probability computation - return 0.0; - } + // GGM uses component-wise MH; logp is unused. + double logp(const arma::vec& parameters) override { return 0.0; } - // TODO: this can be done more efficiently, no need for the Cholesky! double log_likelihood(const arma::mat& omega) const { return log_density_impl(omega, arma::chol(omega)); }; double log_likelihood() const { return log_density_impl(precision_matrix_, cholesky_of_precision_); } void do_one_mh_step() override; - size_t parameter_dimension() const override { - return dim_; - } - - // For GGM, full dimension is the same as parameter dimension (no edge selection filtering) - size_t full_parameter_dimension() const override { - return dim_; - } + size_t parameter_dimension() const override { return dim_; } + size_t full_parameter_dimension() const override { return dim_; } void set_seed(int seed) override { rng_ = SafeRNG(seed); } arma::vec get_vectorized_parameters() const override { - // upper triangle of precision_matrix_ - arma::vec result(dim_); - size_t e = 0; - for (size_t j = 0; j < p_; ++j) { - for (size_t i = 0; i <= j; ++i) { - result(e) = precision_matrix_(i, j); - ++e; - } - } - return result; + return extract_upper_triangle(); } - // For GGM, full and active parameter vectors are the same arma::vec get_full_vectorized_parameters() const override { - arma::vec result(dim_); - size_t e = 0; - for (size_t j = 0; j < p_; ++j) { - for (size_t i = 0; i <= j; ++i) { - result(e) = precision_matrix_(i, j); - ++e; - } - } - return result; + return extract_upper_triangle(); } arma::ivec get_vectorized_indicator_parameters() override { - // upper triangle of precision_matrix_ size_t e = 0; for (size_t j = 0; j < p_; ++j) { for (size_t i = 0; i <= j; ++i) { @@ -188,11 +164,24 @@ class GaussianVariables : public BaseModel { } std::unique_ptr clone() const override { - return std::make_unique(*this); // uses copy constructor + return std::make_unique(*this); } private: - // data + + arma::vec extract_upper_triangle() const { + arma::vec result(dim_); + size_t e = 0; + for (size_t j = 0; j < p_; ++j) { + for (size_t i = 0; i <= j; ++i) { + result(e) = precision_matrix_(i, j); + ++e; + } + } + return result; + } + + // Data size_t n_; size_t p_; size_t dim_; @@ -200,22 +189,28 @@ class GaussianVariables : public BaseModel { arma::mat inclusion_probability_; bool edge_selection_; bool edge_selection_active_ = false; + double pairwise_scale_; - // parameters + // Parameters arma::mat precision_matrix_, cholesky_of_precision_, inv_cholesky_of_precision_, covariance_matrix_; arma::imat edge_indicators_; arma::vec vectorized_parameters_; arma::ivec vectorized_indicator_parameters_; - AdaptiveProposal proposal_; - SafeRNG rng_; - // internal helper variables + // Scratch space arma::mat precision_proposal_; - arma::vec constants_; // Phi_q1q, Phi_q1q1, c[1], c[2], c[3], c[4] + // Workspace for conditional precision reparametrization. + // [0] Phi_q1q, [1] Phi_q1q1, [2] omega_ij - Phi_q1q*Phi_q1q1, + // [3] Phi_q1q1, [4] omega_jj - Phi_q1q^2, [5] constrained diagonal at x=0. + std::array constants_{}; + + // Work vectors for rank-2 Cholesky update. + // A symmetric rank-2 update A + vf1*vf2' + vf2*vf1' is decomposed into + // two rank-1 updates via u1 = (vf1+vf2)/sqrt(2), u2 = (vf1-vf2)/sqrt(2). arma::vec v1_ = {0, -1}; arma::vec v2_ = {0, 0}; arma::vec vf1_ = arma::zeros(p_); @@ -223,15 +218,19 @@ class GaussianVariables : public BaseModel { arma::vec u1_ = arma::zeros(p_); arma::vec u2_ = arma::zeros(p_); - // Parameter group updates with optimized likelihood evaluations + // MH updates void update_edge_parameter(size_t i, size_t j); void update_diagonal_parameter(size_t i); void update_edge_indicator_parameter_pair(size_t i, size_t j); - // Helper methods + // Helpers void get_constants(size_t i, size_t j); double compute_inv_submatrix_i(const arma::mat& A, const size_t i, const size_t ii, const size_t jj) const; - double R(const double x) const; + + // Conditional precision constraint: returns the required diagonal + // value omega_jj that keeps the precision matrix positive definite + // after changing the off-diagonal element to x. + double constrained_diagonal(const double x) const; double log_density_impl(const arma::mat& omega, const arma::mat& phi) const; double log_density_impl_edge(size_t i, size_t j) const; @@ -239,16 +238,13 @@ class GaussianVariables : public BaseModel { double get_log_det(arma::mat triangular_A) const; void cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j); void cholesky_update_after_diag(double omega_ii_old, size_t i); - // double find_reasonable_step_size_edge(const arma::mat& omega, size_t i, size_t j); - // double find_reasonable_step_size_diag(const arma::mat& omega, size_t i); - // double edge_log_ratio(const arma::mat& omega, size_t i, size_t j, double proposal); - // double diag_log_ratio(const arma::mat& omega, size_t i, double proposal); }; -GaussianVariables createGaussianVariablesFromR( +GGMModel createGGMModelFromR( const Rcpp::List& inputFromR, const arma::mat& inclusion_probability, const arma::imat& initial_edge_indicators, - const bool edge_selection = true + const bool edge_selection = true, + const double pairwise_scale = 2.5 ); diff --git a/src/models/mixed/mixed_variables.cpp b/src/models/mixed/mixed_variables.cpp new file mode 100644 index 00000000..5646f8f3 --- /dev/null +++ b/src/models/mixed/mixed_variables.cpp @@ -0,0 +1,43 @@ +// #include "models/ggm/ggm_model.h" +// #include "models/mixed/mixed_variables.h" +// #include "rng/rng_utils.h" +// +// +// void MixedVariableTypes::instantiate_variable_types(Rcpp::List input_from_R) +// { +// // instantiate variable_types_ +// for (Rcpp::List var_type_list : input_from_R) { +// std::string type = Rcpp::as(var_type_list["type"]); +// if (type == "Continuous") { +// variable_types_.push_back(std::make_unique( +// Rcpp::as(var_type_list["observations"]), +// Rcpp::as(var_type_list["inclusion_probability"]), +// Rcpp::as(var_type_list["initial_edge_indicators"]), +// Rcpp::as(var_type_list["edge_selection"]) +// )); +// // } else if (type == "Ordinal") { +// // variable_types_.push_back(std::make_unique( +// // var_type_list["observations"], +// // var_type_list["inclusion_probability"], +// // var_type_list["initial_edge_indicators"], +// // var_type_list["edge_selection"] +// // )); +// // } else if (type == "Blume-Capel") { +// // variable_types_.push_back(std::make_unique( +// // var_type_list["observations"], +// // var_type_list["inclusion_probability"], +// // var_type_list["initial_edge_indicators"], +// // var_type_list["edge_selection"] +// // )); +// // } else if (type == "Count") { +// // variable_types_.push_back(std::make_unique( +// // var_type_list["observations"], +// // var_type_list["inclusion_probability"], +// // var_type_list["initial_edge_indicators"], +// // var_type_list["edge_selection"] +// // )); +// } else { +// throw std::runtime_error("MixedVariableTypes received an unknown variable type in sublist fro input_from_R: " + type); +// } +// } +// } \ No newline at end of file diff --git a/src/models/mixed/mixed_variables.h b/src/models/mixed/mixed_variables.h new file mode 100644 index 00000000..905d1779 --- /dev/null +++ b/src/models/mixed/mixed_variables.h @@ -0,0 +1,169 @@ +// #pragma once +// +// #include +// #include +// #include +// #include "models/base_model.h" +// +// +// // Forward declaration - this class is work in progress and not yet functional +// class MixedVariableTypes : public BaseModel { +// public: +// +// // Constructor +// MixedVariableTypes( +// Rcpp::List input_from_R, +// const arma::mat& inclusion_probability, +// const arma::imat& initial_edge_indicators, +// const bool edge_selection = true +// ) +// { +// +// instantiate_variable_types(input_from_R); +// // instantiate_variable_interactions(); // TODO: not yet implemented +// +// dim_ = 0; +// for (const auto& var_type : variable_types_) { +// dim_ += var_type->parameter_dimension(); +// } +// // dim_ += interactions_.size(); // TODO: proper dimension calculation +// } +// +// // Capability queries +// bool has_gradient() const override +// { +// for (const auto& var_type : variable_types_) { +// if (var_type->has_gradient()) { +// return true; +// } +// } +// return false; +// } +// bool has_adaptive_mh() const override +// { +// for (const auto& var_type : variable_types_) { +// if (var_type->has_adaptive_mh()) { +// return true; +// } +// } +// return false; +// } +// +// // Return dimensionality of the parameter space +// size_t parameter_dimension() const override { +// return dim_; +// } +// +// arma::vec get_vectorized_parameters() const override { +// arma::vec result(dim_); +// size_t current = 0; +// for (size_t i = 0; i < variable_types_.size(); ++i) { +// arma::vec var_params = variable_types_[i]->get_vectorized_parameters(); +// result.subvec(current, current + var_params.n_elem - 1) = var_params; +// current += var_params.n_elem; +// } +// for (size_t i = 0; i < interactions_.size(); ++i) { +// const arma::mat& interactions_mat = interactions_[i]; +// for (size_t c = 0; c < interactions_mat.n_cols; ++c) { +// for (size_t r = 0; r < interactions_mat.n_rows; ++r) { +// result(current) = interactions_mat(r, c); +// ++current; +// } +// } +// } +// return result; +// } +// +// arma::ivec get_vectorized_indicator_parameters() override { +// for (size_t i = 0; i < variable_types_.size(); ++i) { +// auto& [from, to] = indicator_parameters_indices_[i]; +// vectorized_indicator_parameters_.subvec(from, to) = variable_types_[i]->get_vectorized_indicator_parameters(); +// } +// size_t current = indicator_parameters_indices_.empty() ? 0 : indicator_parameters_indices_.back().second + 1; +// for (size_t i = 0; i < interactions_indicators_.size(); ++i) { +// const arma::imat& indicator_mat = interactions_indicators_[i]; +// for (size_t c = 0; c < indicator_mat.n_cols; ++c) { +// for (size_t r = 0; r < indicator_mat.n_rows; ++r) { +// vectorized_indicator_parameters_(current) = indicator_mat(r, c); +// ++current; +// } +// } +// } +// +// return vectorized_indicator_parameters_; +// } +// +// +// double logp(const arma::vec& parameters) override +// { +// double total_logp = 0.0; +// for (size_t i = 0; i < variable_types_.size(); ++i) { +// auto& [from, to] = parameters_indices_[i]; +// // need to do some transformation here! +// arma::vec var_params = parameters.subvec(from, to); +// total_logp += variable_types_[i]->logp(var_params); +// } +// // interactions log-probability can be added here if needed +// return total_logp; +// } +// +// arma::vec gradient(const arma::vec& parameters) override { +// +// // TODO: only should call the gradient for variable types that have it +// // the rest are assumed to be constant, so have gradient zero +// arma::vec total_gradient = arma::zeros(parameters.n_elem); +// for (size_t i = 0; i < variable_types_.size(); ++i) +// { +// if (!variable_types_[i]->has_gradient()) { +// continue; +// } +// auto& [from, to] = parameters_indices_[i]; +// arma::vec var_params = parameters.subvec(from, to); +// // maybe need to do some transformation here! +// arma::vec var_gradient = variable_types_[i]->gradient(var_params); +// total_gradient.subvec(from, to) = var_gradient; +// } +// +// return total_gradient; +// } +// +// std::pair logp_and_gradient( +// const arma::vec& parameters) override { +// if (!has_gradient()) { +// throw std::runtime_error("Gradient not implemented for this model"); +// } +// return {logp(parameters), gradient(parameters)}; +// } +// +// void do_one_mh_step() override { +// for (auto& var_type : variable_types_) { +// var_type->do_one_mh_step(); +// } +// } +// +// void set_seed(int seed) override { +// for (auto& var_type : variable_types_) { +// var_type->set_seed(seed); +// } +// } +// +// std::unique_ptr clone() const override { +// throw std::runtime_error("clone method not yet implemented for MixedVariableTypes"); +// } +// +// +// private: +// std::vector> variable_types_; +// std::vector interactions_; +// std::vector interactions_indicators_; +// size_t dim_; +// arma::vec vectorized_parameters_; +// arma::ivec vectorized_indicator_parameters_; +// arma::ivec indices_from_; +// arma::ivec indices_to_; +// std::vector> parameters_indices_; +// std::vector> indicator_parameters_indices_; +// +// void instantiate_variable_types(const Rcpp::List input_from_R); +// +// }; diff --git a/src/omrf_model.cpp b/src/models/omrf/omrf_model.cpp similarity index 95% rename from src/omrf_model.cpp rename to src/models/omrf/omrf_model.cpp index 832e15dd..7c724dc3 100644 --- a/src/omrf_model.cpp +++ b/src/models/omrf/omrf_model.cpp @@ -1,6 +1,6 @@ #include -#include "omrf_model.h" -#include "adaptiveMetropolis.h" +#include "models/omrf/omrf_model.h" +#include "models/adaptive_metropolis.h" #include "rng/rng_utils.h" #include "mcmc/mcmc_hmc.h" #include "mcmc/mcmc_nuts.h" @@ -317,6 +317,38 @@ void OMRFModel::unvectorize_parameters(const arma::vec& param_vec) { } +void OMRFModel::unvectorize_to_temps( + const arma::vec& parameters, + arma::mat& temp_main, + arma::mat& temp_pairwise, + arma::mat& temp_residual +) const { + int offset = 0; + for (size_t v = 0; v < p_; ++v) { + if (is_ordinal_variable_(v)) { + int num_cats = num_categories_(v); + for (int c = 0; c < num_cats; ++c) { + temp_main(v, c) = parameters(offset++); + } + } else { + temp_main(v, 0) = parameters(offset++); + temp_main(v, 1) = parameters(offset++); + } + } + + for (size_t v1 = 0; v1 < p_ - 1; ++v1) { + for (size_t v2 = v1 + 1; v2 < p_; ++v2) { + if (edge_indicators_(v1, v2) == 1) { + temp_pairwise(v1, v2) = parameters(offset++); + temp_pairwise(v2, v1) = temp_pairwise(v1, v2); + } + } + } + + temp_residual = observations_double_ * temp_pairwise; +} + + arma::vec OMRFModel::get_vectorized_parameters() const { return vectorize_parameters(); } @@ -493,37 +525,10 @@ void OMRFModel::get_active_inv_mass_into(arma::vec& active_inv_mass) const { // ============================================================================= double OMRFModel::logp(const arma::vec& parameters) { - // Unvectorize into temporary matrices (safe approach) arma::mat temp_main = main_effects_; arma::mat temp_pairwise = pairwise_effects_; - - // Unvectorize parameters into temporaries - int offset = 0; - for (size_t v = 0; v < p_; ++v) { - if (is_ordinal_variable_(v)) { - int num_cats = num_categories_(v); - for (int c = 0; c < num_cats; ++c) { - temp_main(v, c) = parameters(offset++); - } - } else { - temp_main(v, 0) = parameters(offset++); - temp_main(v, 1) = parameters(offset++); - } - } - - for (size_t v1 = 0; v1 < p_ - 1; ++v1) { - for (size_t v2 = v1 + 1; v2 < p_; ++v2) { - if (edge_indicators_(v1, v2) == 1) { - temp_pairwise(v1, v2) = parameters(offset++); - temp_pairwise(v2, v1) = temp_pairwise(v1, v2); - } - } - } - - // Compute residual matrix from temp_pairwise - arma::mat temp_residual = arma::conv_to::from(observations_) * temp_pairwise; - - // Compute log-posterior with temporaries + arma::mat temp_residual; + unvectorize_to_temps(parameters, temp_main, temp_pairwise, temp_residual); return log_pseudoposterior_with_state(temp_main, temp_pairwise, temp_residual); } @@ -785,36 +790,10 @@ void OMRFModel::ensure_gradient_cache() { arma::vec OMRFModel::gradient(const arma::vec& parameters) { - // Unvectorize into temporary matrices (safe approach) arma::mat temp_main = main_effects_; arma::mat temp_pairwise = pairwise_effects_; - - // Unvectorize parameters into temporaries - int offset = 0; - for (size_t v = 0; v < p_; ++v) { - if (is_ordinal_variable_(v)) { - int num_cats = num_categories_(v); - for (int c = 0; c < num_cats; ++c) { - temp_main(v, c) = parameters(offset++); - } - } else { - temp_main(v, 0) = parameters(offset++); - temp_main(v, 1) = parameters(offset++); - } - } - - for (size_t v1 = 0; v1 < p_ - 1; ++v1) { - for (size_t v2 = v1 + 1; v2 < p_; ++v2) { - if (edge_indicators_(v1, v2) == 1) { - temp_pairwise(v1, v2) = parameters(offset++); - temp_pairwise(v2, v1) = temp_pairwise(v1, v2); - } - } - } - - // Compute residual matrix from temp_pairwise - arma::mat temp_residual = arma::conv_to::from(observations_) * temp_pairwise; - + arma::mat temp_residual; + unvectorize_to_temps(parameters, temp_main, temp_pairwise, temp_residual); return gradient_with_state(temp_main, temp_pairwise, temp_residual); } @@ -930,31 +909,8 @@ std::pair OMRFModel::logp_and_gradient(const arma::vec& param arma::mat temp_main(main_effects_.n_rows, main_effects_.n_cols, arma::fill::none); arma::mat temp_pairwise(p_, p_, arma::fill::zeros); - - // Unvectorize parameters into temporaries - int offset = 0; - for (size_t v = 0; v < p_; ++v) { - if (is_ordinal_variable_(v)) { - int num_cats = num_categories_(v); - for (int c = 0; c < num_cats; ++c) { - temp_main(v, c) = parameters(offset++); - } - } else { - temp_main(v, 0) = parameters(offset++); - temp_main(v, 1) = parameters(offset++); - } - } - - for (size_t v1 = 0; v1 < p_ - 1; ++v1) { - for (size_t v2 = v1 + 1; v2 < p_; ++v2) { - if (edge_indicators_(v1, v2) == 1) { - temp_pairwise(v1, v2) = parameters(offset++); - temp_pairwise(v2, v1) = temp_pairwise(v1, v2); - } - } - } - - arma::mat temp_residual = observations_double_ * temp_pairwise; + arma::mat temp_residual; + unvectorize_to_temps(parameters, temp_main, temp_pairwise, temp_residual); // Initialize gradient from cached observed statistics arma::vec gradient = grad_obs_cache_; @@ -962,7 +918,7 @@ std::pair OMRFModel::logp_and_gradient(const arma::vec& param // Merged per-variable loop: compute probability table ONCE per variable // and derive both logp and gradient contributions from it. - offset = 0; + int offset = 0; for (size_t v = 0; v < p_; ++v) { int num_cats = num_categories_(v); arma::vec residual_score = temp_residual.col(v); diff --git a/src/omrf_model.h b/src/models/omrf/omrf_model.h similarity index 97% rename from src/omrf_model.h rename to src/models/omrf/omrf_model.h index 32dc09af..d8a0ab2a 100644 --- a/src/omrf_model.h +++ b/src/models/omrf/omrf_model.h @@ -2,8 +2,8 @@ #include #include -#include "base_model.h" -#include "adaptiveMetropolis.h" +#include "models/base_model.h" +#include "models/adaptive_metropolis.h" #include "rng/rng_utils.h" #include "mcmc/mcmc_utils.h" #include "utils/common_helpers.h" @@ -388,6 +388,17 @@ class OMRFModel : public BaseModel { */ void unvectorize_parameters(const arma::vec& param_vec); + /** + * Unvectorize a parameter vector into temporary main/pairwise matrices, + * then compute the corresponding residual matrix. + */ + void unvectorize_to_temps( + const arma::vec& parameters, + arma::mat& temp_main, + arma::mat& temp_pairwise, + arma::mat& temp_residual + ) const; + /** * Extract active inverse mass (only for included edges) */ diff --git a/src/skeleton_model.cpp b/src/models/skeleton_model.cpp similarity index 99% rename from src/skeleton_model.cpp rename to src/models/skeleton_model.cpp index 511d3d0e..07207453 100644 --- a/src/skeleton_model.cpp +++ b/src/models/skeleton_model.cpp @@ -17,8 +17,8 @@ // #include -// #include "base_model.h" -// #include "adaptiveMetropolis.h" +// #include "models/base_model.h" +// #include "models/adaptive_metropolis.h" // #include "rng/rng_utils.h" diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index aab3fef0..481e92b9 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -4,11 +4,11 @@ #include #include -#include "ggm_model.h" +#include "models/ggm/ggm_model.h" #include "utils/progress_manager.h" #include "utils/common_helpers.h" #include "priors/edge_prior.h" -#include "chainResultNew.h" +#include "mcmc/chain_result.h" #include "mcmc/mcmc_runner.h" #include "mcmc/sampler_config.h" @@ -34,7 +34,7 @@ Rcpp::List sample_ggm( ) { // Create model from R input - GaussianVariables model = createGaussianVariablesFromR( + GGMModel model = createGGMModelFromR( inputFromR, prior_inclusion_prob, initial_edge_indicators, edge_selection); // Configure sampler - GGM only supports MH diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index a57fb790..9525f99b 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -8,11 +8,11 @@ #include #include -#include "omrf_model.h" +#include "models/omrf/omrf_model.h" #include "utils/progress_manager.h" #include "utils/common_helpers.h" #include "priors/edge_prior.h" -#include "chainResultNew.h" +#include "mcmc/chain_result.h" #include "mcmc/mcmc_runner.h" #include "mcmc/sampler_config.h" From 84d086b30fdb14ea251c088f51dda60aa60fc93d Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 11 Feb 2026 10:58:09 +0100 Subject: [PATCH 21/23] cleanup --- src/bgm/bgm_sampler.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/bgm/bgm_sampler.cpp b/src/bgm/bgm_sampler.cpp index e68c7cd8..bbb70543 100644 --- a/src/bgm/bgm_sampler.cpp +++ b/src/bgm/bgm_sampler.cpp @@ -588,8 +588,7 @@ void update_hmc_bgm( adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix (following STAN's approach). - // STAN uses the current step size as the starting point for the heuristic. + // step size for the new mass matrix. Use current step size as starting point. if (adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( adapt.inv_mass_diag(), inclusion_indicator, num_categories, @@ -734,8 +733,7 @@ SamplerResult update_nuts_bgm( adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix (following STAN's approach). - // STAN uses the current step size as the starting point for the heuristic. + // step size for the new mass matrix. Use current step size as starting point. if (adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( adapt.inv_mass_diag(), inclusion_indicator, num_categories, From f8ae7f4f1b266354d0bb4adaa2a0bdd18cf1a37a Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 11 Feb 2026 11:01:44 +0100 Subject: [PATCH 22/23] cleanup --- src/bgmCompare/bgmCompare_sampler.cpp | 6 ++---- src/mcmc/mcmc_adaptation.h | 7 +++---- src/mcmc/mcmc_utils.cpp | 8 ++++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/bgmCompare/bgmCompare_sampler.cpp b/src/bgmCompare/bgmCompare_sampler.cpp index 567705bd..1ea890cb 100644 --- a/src/bgmCompare/bgmCompare_sampler.cpp +++ b/src/bgmCompare/bgmCompare_sampler.cpp @@ -735,8 +735,7 @@ void update_hmc_bgmcompare( hmc_adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix (following STAN's approach). - // STAN uses the current step size as the starting point for the heuristic. + // step size for the new mass matrix. Use current step size as starting point. if (hmc_adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( hmc_adapt.inv_mass_diag(), inclusion_indicator, num_groups, num_categories, @@ -911,8 +910,7 @@ SamplerResult update_nuts_bgmcompare( hmc_adapt.update(current_state, result.accept_prob, iteration); // If mass matrix was just updated, re-run the heuristic to find a good - // step size for the new mass matrix (following STAN's approach). - // STAN uses the current step size as the starting point for the heuristic. + // step size for the new mass matrix. Use current step size as starting point. if (hmc_adapt.mass_matrix_just_updated()) { arma::vec new_inv_mass = inv_mass_active( hmc_adapt.inv_mass_diag(), inclusion_indicator, num_groups, num_categories, diff --git a/src/mcmc/mcmc_adaptation.h b/src/mcmc/mcmc_adaptation.h index 4fdf4fcf..16e51649 100644 --- a/src/mcmc/mcmc_adaptation.h +++ b/src/mcmc/mcmc_adaptation.h @@ -93,7 +93,7 @@ class DiagMassMatrixAccumulator { }; -// === Stan-style Dynamic Warmup Schedule with Adaptive Windows === +// === Dynamic Warmup Schedule with Adaptive Windows === // // For edge_selection = FALSE: // Stage 1 (init), Stage 2 (doubling windows), Stage 3a (terminal) @@ -298,7 +298,7 @@ class HMCAdaptationController { mass_accumulator.update(theta); int w = schedule.current_window(iteration); if (iteration + 1 == schedule.window_ends[w]) { - // STAN convention: inv_mass = variance (not 1/variance!) + // inv_mass = variance (not 1/variance!) // Higher variance → higher inverse mass → parameter moves more freely inv_mass_ = mass_accumulator.variance(); mass_accumulator.reset(); @@ -332,7 +332,6 @@ class HMCAdaptationController { * This should be called after running heuristic_initial_step_size() with * the new mass matrix to find an appropriate starting step size. * - * Following STAN's approach: * - Set the new step size * - Set mu = log(10 * new_step_size) as the adaptation target * - Restart the dual averaging counters @@ -340,7 +339,7 @@ class HMCAdaptationController { void reinit_stepsize(double new_step_size) { step_size_ = new_step_size; step_adapter.restart(new_step_size); - // Set mu to log(10 * epsilon) as per STAN's approach + // Set mu to log(10 * epsilon) for dual averaging step_adapter.mu = MY_LOG(10.0 * new_step_size); mass_matrix_updated_ = false; } diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 48ace92a..0703d428 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -59,7 +59,7 @@ double heuristic_initial_step_size( double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum and evaluate arma::vec r = arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -79,7 +79,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum (STAN resamples on each iteration) + // Resample momentum for step size search r = arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; @@ -131,7 +131,7 @@ double heuristic_initial_step_size( ) { double eps = init_step; double logp0 = log_post(theta); // Only compute once - position doesn't change - + // Sample initial momentum from N(0, M) where M = diag(1/inv_mass_diag) arma::vec r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); double kin0 = kinetic_energy(r, inv_mass_diag); @@ -151,7 +151,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum (STAN resamples on each iteration) + // Resample momentum for step size search r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; From 99b3c1871a4b339fa1a6ddb02e9feb22392aaf05 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 11 Feb 2026 11:04:41 +0100 Subject: [PATCH 23/23] cleanup --- src/mcmc/mcmc_adaptation.h | 2 +- src/mcmc/mcmc_utils.cpp | 4 +- src/models/mixed/mixed_variables.cpp | 43 ------- src/models/mixed/mixed_variables.h | 169 --------------------------- test_ggm.R | 113 ------------------ 5 files changed, 3 insertions(+), 328 deletions(-) delete mode 100644 src/models/mixed/mixed_variables.cpp delete mode 100644 src/models/mixed/mixed_variables.h delete mode 100644 test_ggm.R diff --git a/src/mcmc/mcmc_adaptation.h b/src/mcmc/mcmc_adaptation.h index 16e51649..fec3441b 100644 --- a/src/mcmc/mcmc_adaptation.h +++ b/src/mcmc/mcmc_adaptation.h @@ -298,7 +298,7 @@ class HMCAdaptationController { mass_accumulator.update(theta); int w = schedule.current_window(iteration); if (iteration + 1 == schedule.window_ends[w]) { - // inv_mass = variance (not 1/variance!) + // inv_mass = variance (not 1/variance) // Higher variance → higher inverse mass → parameter moves more freely inv_mass_ = mass_accumulator.variance(); mass_accumulator.reset(); diff --git a/src/mcmc/mcmc_utils.cpp b/src/mcmc/mcmc_utils.cpp index 0703d428..d500c588 100644 --- a/src/mcmc/mcmc_utils.cpp +++ b/src/mcmc/mcmc_utils.cpp @@ -79,7 +79,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum for step size search + // Resample momentum on each iteration for step size search r = arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; @@ -151,7 +151,7 @@ double heuristic_initial_step_size( while (direction * (H1 - H0) > -direction * MY_LOG(2.0) && attempts < max_attempts) { eps = (direction == 1) ? 2.0 * eps : 0.5 * eps; - // Resample momentum for step size search + // Resample momentum on each iteration for step size search r = arma::sqrt(1.0 / inv_mass_diag) % arma_rnorm_vec(rng, theta.n_elem); kin0 = kinetic_energy(r, inv_mass_diag); H0 = logp0 - kin0; diff --git a/src/models/mixed/mixed_variables.cpp b/src/models/mixed/mixed_variables.cpp deleted file mode 100644 index 5646f8f3..00000000 --- a/src/models/mixed/mixed_variables.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// #include "models/ggm/ggm_model.h" -// #include "models/mixed/mixed_variables.h" -// #include "rng/rng_utils.h" -// -// -// void MixedVariableTypes::instantiate_variable_types(Rcpp::List input_from_R) -// { -// // instantiate variable_types_ -// for (Rcpp::List var_type_list : input_from_R) { -// std::string type = Rcpp::as(var_type_list["type"]); -// if (type == "Continuous") { -// variable_types_.push_back(std::make_unique( -// Rcpp::as(var_type_list["observations"]), -// Rcpp::as(var_type_list["inclusion_probability"]), -// Rcpp::as(var_type_list["initial_edge_indicators"]), -// Rcpp::as(var_type_list["edge_selection"]) -// )); -// // } else if (type == "Ordinal") { -// // variable_types_.push_back(std::make_unique( -// // var_type_list["observations"], -// // var_type_list["inclusion_probability"], -// // var_type_list["initial_edge_indicators"], -// // var_type_list["edge_selection"] -// // )); -// // } else if (type == "Blume-Capel") { -// // variable_types_.push_back(std::make_unique( -// // var_type_list["observations"], -// // var_type_list["inclusion_probability"], -// // var_type_list["initial_edge_indicators"], -// // var_type_list["edge_selection"] -// // )); -// // } else if (type == "Count") { -// // variable_types_.push_back(std::make_unique( -// // var_type_list["observations"], -// // var_type_list["inclusion_probability"], -// // var_type_list["initial_edge_indicators"], -// // var_type_list["edge_selection"] -// // )); -// } else { -// throw std::runtime_error("MixedVariableTypes received an unknown variable type in sublist fro input_from_R: " + type); -// } -// } -// } \ No newline at end of file diff --git a/src/models/mixed/mixed_variables.h b/src/models/mixed/mixed_variables.h deleted file mode 100644 index 905d1779..00000000 --- a/src/models/mixed/mixed_variables.h +++ /dev/null @@ -1,169 +0,0 @@ -// #pragma once -// -// #include -// #include -// #include -// #include "models/base_model.h" -// -// -// // Forward declaration - this class is work in progress and not yet functional -// class MixedVariableTypes : public BaseModel { -// public: -// -// // Constructor -// MixedVariableTypes( -// Rcpp::List input_from_R, -// const arma::mat& inclusion_probability, -// const arma::imat& initial_edge_indicators, -// const bool edge_selection = true -// ) -// { -// -// instantiate_variable_types(input_from_R); -// // instantiate_variable_interactions(); // TODO: not yet implemented -// -// dim_ = 0; -// for (const auto& var_type : variable_types_) { -// dim_ += var_type->parameter_dimension(); -// } -// // dim_ += interactions_.size(); // TODO: proper dimension calculation -// } -// -// // Capability queries -// bool has_gradient() const override -// { -// for (const auto& var_type : variable_types_) { -// if (var_type->has_gradient()) { -// return true; -// } -// } -// return false; -// } -// bool has_adaptive_mh() const override -// { -// for (const auto& var_type : variable_types_) { -// if (var_type->has_adaptive_mh()) { -// return true; -// } -// } -// return false; -// } -// -// // Return dimensionality of the parameter space -// size_t parameter_dimension() const override { -// return dim_; -// } -// -// arma::vec get_vectorized_parameters() const override { -// arma::vec result(dim_); -// size_t current = 0; -// for (size_t i = 0; i < variable_types_.size(); ++i) { -// arma::vec var_params = variable_types_[i]->get_vectorized_parameters(); -// result.subvec(current, current + var_params.n_elem - 1) = var_params; -// current += var_params.n_elem; -// } -// for (size_t i = 0; i < interactions_.size(); ++i) { -// const arma::mat& interactions_mat = interactions_[i]; -// for (size_t c = 0; c < interactions_mat.n_cols; ++c) { -// for (size_t r = 0; r < interactions_mat.n_rows; ++r) { -// result(current) = interactions_mat(r, c); -// ++current; -// } -// } -// } -// return result; -// } -// -// arma::ivec get_vectorized_indicator_parameters() override { -// for (size_t i = 0; i < variable_types_.size(); ++i) { -// auto& [from, to] = indicator_parameters_indices_[i]; -// vectorized_indicator_parameters_.subvec(from, to) = variable_types_[i]->get_vectorized_indicator_parameters(); -// } -// size_t current = indicator_parameters_indices_.empty() ? 0 : indicator_parameters_indices_.back().second + 1; -// for (size_t i = 0; i < interactions_indicators_.size(); ++i) { -// const arma::imat& indicator_mat = interactions_indicators_[i]; -// for (size_t c = 0; c < indicator_mat.n_cols; ++c) { -// for (size_t r = 0; r < indicator_mat.n_rows; ++r) { -// vectorized_indicator_parameters_(current) = indicator_mat(r, c); -// ++current; -// } -// } -// } -// -// return vectorized_indicator_parameters_; -// } -// -// -// double logp(const arma::vec& parameters) override -// { -// double total_logp = 0.0; -// for (size_t i = 0; i < variable_types_.size(); ++i) { -// auto& [from, to] = parameters_indices_[i]; -// // need to do some transformation here! -// arma::vec var_params = parameters.subvec(from, to); -// total_logp += variable_types_[i]->logp(var_params); -// } -// // interactions log-probability can be added here if needed -// return total_logp; -// } -// -// arma::vec gradient(const arma::vec& parameters) override { -// -// // TODO: only should call the gradient for variable types that have it -// // the rest are assumed to be constant, so have gradient zero -// arma::vec total_gradient = arma::zeros(parameters.n_elem); -// for (size_t i = 0; i < variable_types_.size(); ++i) -// { -// if (!variable_types_[i]->has_gradient()) { -// continue; -// } -// auto& [from, to] = parameters_indices_[i]; -// arma::vec var_params = parameters.subvec(from, to); -// // maybe need to do some transformation here! -// arma::vec var_gradient = variable_types_[i]->gradient(var_params); -// total_gradient.subvec(from, to) = var_gradient; -// } -// -// return total_gradient; -// } -// -// std::pair logp_and_gradient( -// const arma::vec& parameters) override { -// if (!has_gradient()) { -// throw std::runtime_error("Gradient not implemented for this model"); -// } -// return {logp(parameters), gradient(parameters)}; -// } -// -// void do_one_mh_step() override { -// for (auto& var_type : variable_types_) { -// var_type->do_one_mh_step(); -// } -// } -// -// void set_seed(int seed) override { -// for (auto& var_type : variable_types_) { -// var_type->set_seed(seed); -// } -// } -// -// std::unique_ptr clone() const override { -// throw std::runtime_error("clone method not yet implemented for MixedVariableTypes"); -// } -// -// -// private: -// std::vector> variable_types_; -// std::vector interactions_; -// std::vector interactions_indicators_; -// size_t dim_; -// arma::vec vectorized_parameters_; -// arma::ivec vectorized_indicator_parameters_; -// arma::ivec indices_from_; -// arma::ivec indices_to_; -// std::vector> parameters_indices_; -// std::vector> indicator_parameters_indices_; -// -// void instantiate_variable_types(const Rcpp::List input_from_R); -// -// }; diff --git a/test_ggm.R b/test_ggm.R deleted file mode 100644 index 5f1995ca..00000000 --- a/test_ggm.R +++ /dev/null @@ -1,113 +0,0 @@ -library(bgms) - -# Dimension and true precision -p <- 10 - -adj <- matrix(0, nrow = p, ncol = p) -adj[lower.tri(adj)] <- rbinom(p * (p - 1) / 2, size = 1, prob = 0.3) -adj <- adj + t(adj) -# qgraph::qgraph(adj) -Omega <- BDgraph::rgwish(1, adj = adj, b = p + sample(0:p, 1), D = diag(p)) -Sigma <- solve(Omega) -zapsmall(Omega) - -# Data -n <- 1e3 -x <- mvtnorm::rmvnorm(n = n, mean = rep(0, p), sigma = Sigma) - - -# ---- Run MCMC with warmup and sampling ------------------------------------ - -# debugonce(mbgms:::bgm_gaussian) -sampling_results <- bgms:::sample_ggm( - X = x, - prior_inclusion_prob = matrix(.5, p, p), - initial_edge_indicators = adj, - no_iter = 500, - no_warmup = 500, - no_chains = 3, - edge_selection = FALSE, - no_threads = 1, - seed = 123, - progress_type = 1 -) - -true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) -posterior_means <- rowMeans(sampling_results[[2]]$samples) -cbind(true_values, posterior_means) - -plot(true_values, posterior_means) -abline(0, 1) - -sampling_results2 <- bgms:::sample_ggm( - X = x, - prior_inclusion_prob = matrix(.5, p, p), - initial_edge_indicators = adj, - no_iter = 500, - no_warmup = 500, - no_chains = 3, - edge_selection = TRUE, - no_threads = 1, - seed = 123, - progress_type = 1 -) - -true_values <- zapsmall(Omega[upper.tri(Omega, TRUE)]) -posterior_means <- rowMeans(sampling_results2[[2]]$samples) - -plot(true_values, posterior_means) -abline(0, 1) - -plot(posterior_means, rowMeans(sampling_results2[[2]]$samples != 0)) - - -mmm <- matrix(c( - 1.6735, 0, 0, 0, 0, - 0, 1.0000, 0, 0, -3.4524, - 0, 0, 1.0000, 0, 0, - 0, 0, 0, 1.0000, 0, - 0, -3.4524, 0, 0, 9.6674 -), p, p) -mmm -chol(mmm) -base::isSymmetric(mmm) -eigen(mmm) - -profvis::profvis({ - sampling_results <- bgm_gaussian( - x = x, - n = n, - n_iter = 400, - n_warmup = 400, - n_phases = 10 - ) -}) - -# Extract results -aveOmega <- sampling_results$aveOmega -aveGamma <- sampling_results$aveGamma -aOmega <- sampling_results$aOmega -aGamma <- sampling_results$aGamma -prob <- sampling_results$prob -proposal_sd <- sampling_results$proposal_sd - -library(patchwork) -library(ggplot2) -df <- data.frame( - true = aveOmega[lower.tri(aveOmega)], - Omega[lower.tri(Omega)], - estimated = aveOmega[lower.tri(aveOmega)], - p_inclusion = aveGamma[lower.tri(aveGamma)] -) -p1 <- ggplot(df, aes(x = true, y = estimated)) + - geom_point(size = 5, alpha = 0.8, shape = 21, fill = "grey") + - geom_abline(slope = 1, intercept = 0, color = "grey") + - labs(x = "True Values Omega", y = "Estimated Values Omega (Posterior Mean)") -p2 <- ggplot(df, aes(x = estimated, y = p_inclusion)) + - geom_point(size = 5, alpha = 0.8, shape = 21, fill = "grey") + - labs( - x = "Estimated Values Omega (Posterior Mean)", - y = "Estimated Inclusion Probabilities" - ) -(p1 + p2) + plot_layout(ncol = 1) & theme_bw(base_size = 20) -