Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ run_simulation_parallel <- function(pairwise_samples, main_samples, draw_indices
.Call(`_bgms_run_simulation_parallel`, pairwise_samples, main_samples, draw_indices, no_states, no_variables, no_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type)
}

sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda)
}

sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, edge_selection_start = -1L) {
.Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, edge_selection_start)
}

compute_Vn_mfm_sbm <- function(no_variables, dirichlet_alpha, t_max, lambda) {
.Call(`_bgms_compute_Vn_mfm_sbm`, no_variables, dirichlet_alpha, t_max, lambda)
}
Expand Down
164 changes: 162 additions & 2 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ bgm = function(
chains = 4,
cores = parallel::detectCores(),
display_progress = c("per-chain", "total", "none"),
backend = c("legacy", "new"),
seed = NULL,
standardize = FALSE,
interaction_scale,
Expand Down Expand Up @@ -508,6 +509,23 @@ bgm = function(
edge_selection = model$edge_selection
edge_prior = model$edge_prior
inclusion_probability = model$inclusion_probability
is_continuous = model$is_continuous

# Block NUTS/HMC for the Gaussian model --------------------------------------
if(is_continuous) {
user_chose_method = length(update_method_input) == 1
if(user_chose_method && update_method %in% c("nuts", "hamiltonian-mc")) {
stop(paste0(
"The Gaussian model (variable_type = 'continuous') only supports ",
"update_method = 'adaptive-metropolis'. ",
"Got '", update_method, "'."
))
}
update_method = "adaptive-metropolis"
if(!hasArg(target_accept)) {
target_accept = 0.44
}
}

# Check Gibbs input -----------------------------------------------------------
check_positive_integer(iter, "iter")
Expand Down Expand Up @@ -551,9 +569,105 @@ bgm = function(
))
}

# Check backend ---------------------------------------------------------------
backend = match.arg(backend)

# Check display_progress ------------------------------------------------------
progress_type = progress_type_from_display_progress(display_progress)

# Setting the seed
if(missing(seed) || is.null(seed)) {
seed = sample.int(.Machine$integer.max, 1)
}

if(!is.numeric(seed) || length(seed) != 1 || is.na(seed) || seed < 0) {
stop("Argument 'seed' must be a single non-negative integer.")
}

seed <- as.integer(seed)

data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x)

# ==========================================================================
# Gaussian (continuous) path
# ==========================================================================
if (is_continuous) {
num_variables = ncol(x)

# Handle missing data for continuous variables
if (na_action == "listwise") {
missing_rows = apply(x, 1, anyNA)
if (all(missing_rows)) {
stop("All rows in x contain at least one missing response.\n",
"You could try option na_action = 'impute'.")
}
if (sum(missing_rows) > 0) {
warning(sum(missing_rows), " row(s) with missing observations removed (na_action = 'listwise').",
call. = FALSE)
}
x = x[!missing_rows, , drop = FALSE]
na_impute = FALSE
} else {
stop("Imputation is not yet supported for the Gaussian model. ",
"Use na_action = 'listwise'.")
}

indicator = matrix(1L, nrow = num_variables, ncol = num_variables)

out_raw = sample_ggm(
inputFromR = list(X = x),
prior_inclusion_prob = matrix(inclusion_probability,
nrow = num_variables, ncol = num_variables),
initial_edge_indicators = indicator,
no_iter = iter,
no_warmup = warmup,
no_chains = chains,
edge_selection = edge_selection,
seed = seed,
no_threads = cores,
progress_type = progress_type,
edge_prior = edge_prior,
beta_bernoulli_alpha = beta_bernoulli_alpha,
beta_bernoulli_beta = beta_bernoulli_beta,
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
dirichlet_alpha = dirichlet_alpha,
lambda = lambda
)

out = transform_ggm_backend_output(out_raw, num_variables)

userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
if (userInterrupt) {
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
}

output = prepare_output_ggm(
out = out, x = x, iter = iter,
data_columnnames = data_columnnames,
warmup = warmup,
edge_selection = edge_selection, edge_prior = edge_prior,
inclusion_probability = inclusion_probability,
beta_bernoulli_alpha = beta_bernoulli_alpha,
beta_bernoulli_beta = beta_bernoulli_beta,
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
dirichlet_alpha = dirichlet_alpha,
lambda = lambda,
na_action = na_action, na_impute = na_impute,
variable_type = variable_type,
update_method = update_method,
target_accept = target_accept,
num_chains = chains
)

return(output)
}

# ==========================================================================
# Ordinal / Blume-Capel path
# ==========================================================================

# Format the data input -------------------------------------------------------
data = reformat_data(
x = x,
Expand Down Expand Up @@ -592,6 +706,9 @@ bgm = function(
}
}

# Save data before Blume-Capel centering (needed by the new backend)
x_raw = x

# Precompute the sufficient statistics for the two Blume-Capel parameters -----
blume_capel_stats = matrix(0, nrow = 2, ncol = num_variables)
if(any(!variable_bool)) {
Expand Down Expand Up @@ -691,6 +808,47 @@ bgm = function(

seed <- as.integer(seed)

if (backend == "new") {
input_list = list(
observations = x_raw,
num_categories = num_categories,
is_ordinal_variable = variable_bool,
baseline_category = baseline_category,
main_alpha = main_alpha,
main_beta = main_beta,
pairwise_scale = pairwise_scale
)

out_raw = sample_omrf(
inputFromR = input_list,
prior_inclusion_prob = matrix(inclusion_probability,
nrow = num_variables, ncol = num_variables),
initial_edge_indicators = indicator,
no_iter = iter,
no_warmup = warmup,
no_chains = chains,
no_threads = cores,
progress_type = progress_type,
edge_selection = edge_selection,
sampler_type = update_method,
seed = seed,
edge_prior = edge_prior,
na_impute = na_impute,
missing_index = missing_index,
beta_bernoulli_alpha = beta_bernoulli_alpha,
beta_bernoulli_beta = beta_bernoulli_beta,
beta_bernoulli_alpha_between = beta_bernoulli_alpha_between,
beta_bernoulli_beta_between = beta_bernoulli_beta_between,
dirichlet_alpha = dirichlet_alpha,
lambda = lambda,
target_acceptance = target_accept,
max_tree_depth = nuts_max_depth,
num_leapfrogs = hmc_num_leapfrogs
)

out = transform_new_backend_output(out_raw, num_thresholds)
} else {

out = run_bgm_parallel(
observations = x, num_categories = num_categories,
pairwise_scale = pairwise_scale, edge_prior = edge_prior,
Expand All @@ -716,14 +874,16 @@ bgm = function(
pairwise_scaling_factors = pairwise_scaling_factors
)

} # end backend branch

userInterrupt = any(vapply(out, FUN = `[[`, FUN.VALUE = logical(1L), "userInterrupt"))
if(userInterrupt) {
warning("Stopped sampling after user interrupt, results are likely uninterpretable.")
# Try to prepare output, but catch any errors
output <- tryCatch(
prepare_output_bgm(
out = out, x = x, num_categories = num_categories, iter = iter,
data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
data_columnnames = data_columnnames,
is_ordinal_variable = variable_bool,
warmup = warmup, pairwise_scale = pairwise_scale, standardize = standardize,
pairwise_scaling_factors = pairwise_scaling_factors,
Expand Down Expand Up @@ -756,7 +916,7 @@ bgm = function(
# Main output handler in the wrapper function
output = prepare_output_bgm(
out = out, x = x, num_categories = num_categories, iter = iter,
data_columnnames = if(is.null(colnames(x))) paste0("Variable ", seq_len(ncol(x))) else colnames(x),
data_columnnames = data_columnnames,
is_ordinal_variable = variable_bool,
warmup = warmup, pairwise_scale = pairwise_scale, standardize = standardize,
pairwise_scaling_factors = pairwise_scaling_factors,
Expand Down
90 changes: 55 additions & 35 deletions R/function_input_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,30 @@ check_model = function(x,
dirichlet_alpha = dirichlet_alpha,
lambda = lambda) {
# Check variable type input ---------------------------------------------------
is_continuous = FALSE
if(length(variable_type) == 1) {
variable_input = variable_type
variable_type = try(
match.arg(
arg = variable_type,
choices = c("ordinal", "blume-capel")
choices = c("ordinal", "blume-capel", "continuous")
),
silent = TRUE
)
if(inherits(variable_type, what = "try-error")) {
stop(paste0(
"The bgm function supports variables of type ordinal and blume-capel, \n",
"but not of type ",
"The bgm function supports variables of type ordinal, blume-capel, ",
"and continuous, but not of type ",
variable_input, "."
))
}
variable_bool = (variable_type == "ordinal")
variable_bool = rep(variable_bool, ncol(x))
if(variable_type == "continuous") {
is_continuous = TRUE
variable_bool = rep(TRUE, ncol(x))
} else {
variable_bool = (variable_type == "ordinal")
variable_bool = rep(variable_bool, ncol(x))
}
} else {
if(length(variable_type) != ncol(x)) {
stop(paste0(
Expand All @@ -62,41 +68,54 @@ check_model = function(x,
))
}

variable_input = unique(variable_type)
variable_type = try(match.arg(
arg = variable_type,
choices = c("ordinal", "blume-capel"),
several.ok = TRUE
), silent = TRUE)

if(inherits(variable_type, what = "try-error")) {
has_continuous = any(variable_type == "continuous")
if(has_continuous && !all(variable_type == "continuous")) {
stop(paste0(
"The bgm function supports variables of type ordinal and blume-capel, \n",
"but not of type ",
paste0(variable_input, collapse = ", "), "."
"When using continuous variables, all variables must be of type ",
"'continuous'. Mixtures of continuous and ordinal/blume-capel ",
"variables are not supported."
))
}
if(has_continuous) {
is_continuous = TRUE
variable_bool = rep(TRUE, ncol(x))
} else {
variable_input = unique(variable_type)
variable_type = try(match.arg(
arg = variable_type,
choices = c("ordinal", "blume-capel"),
several.ok = TRUE
), silent = TRUE)

num_types = sapply(variable_input, function(type) {
tmp = try(
match.arg(
arg = type,
choices = c("ordinal", "blume-capel")
),
silent = TRUE
)
inherits(tmp, what = "try-error")
})
if(inherits(variable_type, what = "try-error")) {
stop(paste0(
"The bgm function supports variables of type ordinal, blume-capel, ",
"and continuous, but not of type ",
paste0(variable_input, collapse = ", "), "."
))
}

if(length(variable_type) != ncol(x)) {
stop(paste0(
"The bgm function supports variables of type ordinal and blume-capel, \n",
"but not of type ",
paste0(variable_input[num_types], collapse = ", "), "."
))
}
num_types = sapply(variable_input, function(type) {
tmp = try(
match.arg(
arg = type,
choices = c("ordinal", "blume-capel")
),
silent = TRUE
)
inherits(tmp, what = "try-error")
})

variable_bool = (variable_type == "ordinal")
if(length(variable_type) != ncol(x)) {
stop(paste0(
"The bgm function supports variables of type ordinal, blume-capel, ",
"and continuous, but not of type ",
paste0(variable_input[num_types], collapse = ", "), "."
))
}

variable_bool = (variable_type == "ordinal")
}
}

# Check Blume-Capel variable input --------------------------------------------
Expand Down Expand Up @@ -316,7 +335,8 @@ check_model = function(x,
baseline_category = baseline_category,
edge_selection = edge_selection,
edge_prior = edge_prior,
inclusion_probability = theta
inclusion_probability = theta,
is_continuous = is_continuous
))
}

Expand Down
Loading
Loading