## ----setup, message = FALSE, warning = FALSE, comment = NA--------------------
knitr::opts_chunk$set(message = FALSE, warning = FALSE, comment = NA, 
                      fig.width = 6.25, fig.height = 5)
library(ANCOMBC)
library(tidyverse)
library(DT)
options(DT.options = list(
  initComplete = JS("function(settings, json) {",
  "$(this.api().table().header()).css({'background-color': 
  '#000', 'color': '#fff'});","}")))

# It appears to be a package compatibility issue between the release version of 
# phyloseq and lme4, a fresh installation of phyloseq might be needed
# See this post: https://github.com/lme4/lme4/issues/743
# remotes::install_github("joey711/phyloseq", force = TRUE)

## ----getPackage, eval=FALSE---------------------------------------------------
# if (!requireNamespace("BiocManager", quietly = TRUE))
#     install.packages("BiocManager")
# BiocManager::install("ANCOMBC")

## ----load, eval=FALSE---------------------------------------------------------
# library(ANCOMBC)

## -----------------------------------------------------------------------------
data(QMP, package = "ANCOMBC")
set.seed(123)
n = 150
d = ncol(QMP)
diff_prop = 0.1
lfc_cont = 1
lfc_cat2_vs_1 = -2
lfc_cat3_vs_1 = 1

# Generate the true abundances
abn_data = sim_plnm(abn_table = QMP, taxa_are_rows = FALSE, prv_cut = 0.05, 
                    n = n, lib_mean = 1e8, disp = 0.5)
log_abn_data = log(abn_data + 1e-5)
rownames(log_abn_data) = paste0("T", seq_len(d))
colnames(log_abn_data) = paste0("S", seq_len(n))

# Generate the sample and feature meta data
# Sampling fractions are set to differ by batches
smd = data.frame(samp_frac = log(c(runif(n/3, min = 1e-4, max = 1e-3),
                                   runif(n/3, min = 1e-3, max = 1e-2),
                                   runif(n/3, min = 1e-2, max = 1e-1))),
                 cont_cov = rnorm(n),
                 cat_cov = as.factor(rep(seq_len(3), each = n/3)))
rownames(smd) = paste0("S", seq_len(n))
                      
fmd = data.frame(taxon = paste0("T", seq_len(d)),
                 seq_eff = log(runif(d, min = 0.1, max = 1)),
                 lfc_cont = sample(c(0, lfc_cont), 
                                   size = d,
                                   replace = TRUE,
                                   prob = c(1 - diff_prop, diff_prop)),
                 lfc_cat2_vs_1 = sample(c(0, lfc_cat2_vs_1), 
                                        size = d,
                                        replace = TRUE,
                                        prob = c(1 - diff_prop, diff_prop)),
                 lfc_cat3_vs_1 = sample(c(0, lfc_cat3_vs_1), 
                                        size = d,
                                        replace = TRUE,
                                        prob = c(1 - diff_prop, diff_prop))) %>%
    mutate(lfc_cat3_vs_2 = lfc_cat3_vs_1 - lfc_cat2_vs_1)

# Add effect sizes of covariates to the true abundances
smd_dmy = model.matrix(~ 0 + cont_cov + cat_cov, data = smd)
log_abn_data = log_abn_data + outer(fmd$lfc_cont, smd_dmy[, "cont_cov"] )
log_abn_data = log_abn_data + outer(fmd$lfc_cat2_vs_1, smd_dmy[, "cat_cov2"])
log_abn_data = log_abn_data + outer(fmd$lfc_cat3_vs_1, smd_dmy[, "cat_cov3"])

# Add sample- and taxon-specific biases
log_otu_data = t(t(log_abn_data) + smd$samp_frac)
log_otu_data = log_otu_data + fmd$seq_eff
otu_data = round(exp(log_otu_data))

## -----------------------------------------------------------------------------
set.seed(123)
output = ancombc2(data = otu_data, meta_data = smd, 
                  fix_formula = "cont_cov + cat_cov", rand_formula = NULL,
                  p_adj_method = "holm", pseudo_sens = TRUE,
                  prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                  group = "cat_cov", struc_zero = FALSE, neg_lb = FALSE,
                  alpha = 0.05, n_cl = 2, verbose = TRUE,
                  global = FALSE, pairwise = TRUE, 
                  dunnet = FALSE, trend = FALSE,
                  iter_control = list(tol = 1e-5, max_iter = 20, 
                                      verbose = FALSE),
                  em_control = list(tol = 1e-5, max_iter = 100),
                  lme_control = NULL, 
                  mdfdr_control = list(fwer_ctrl_method = "holm", B = 100), 
                  trend_control = NULL)

res_prim = output$res
res_pair = output$res_pair

## -----------------------------------------------------------------------------
res_merge1 = res_pair %>%
  dplyr::transmute(taxon, 
                   lfc_est1 = lfc_cat_cov2 * diff_cat_cov2,
                   lfc_est2 = lfc_cat_cov3 * diff_cat_cov3,
                   lfc_est3 = lfc_cat_cov3_cat_cov2 * diff_cat_cov3_cat_cov2) %>%
  dplyr::left_join(fmd %>%
                     dplyr::transmute(taxon, 
                                      lfc_true1 = lfc_cat2_vs_1,
                                      lfc_true2 = lfc_cat3_vs_1,
                                      lfc_true3 = lfc_cat3_vs_2),
                   by = "taxon") %>%
  dplyr::transmute(taxon, 
                   lfc_est1 = case_when(lfc_est1 > 0 ~ 1,
                                        lfc_est1 < 0 ~ -1,
                                        TRUE ~ 0),
                   lfc_est2 = case_when(lfc_est2 > 0 ~ 1,
                                        lfc_est2 < 0 ~ -1,
                                        TRUE ~ 0),
                   lfc_est3 = case_when(lfc_est3 > 0 ~ 1,
                                        lfc_est3 < 0 ~ -1,
                                        TRUE ~ 0),
                   lfc_true1 = case_when(lfc_true1 > 0 ~ 1,
                                         lfc_true1 < 0 ~ -1,
                                         TRUE ~ 0),
                   lfc_true2 = case_when(lfc_true2 > 0 ~ 1,
                                         lfc_true2 < 0 ~ -1,
                                         TRUE ~ 0),
                   lfc_true3 = case_when(lfc_true3 > 0 ~ 1,
                                         lfc_true3 < 0 ~ -1,
                                         TRUE ~ 0))
lfc_est1 = res_merge1$lfc_est1
lfc_true1 = res_merge1$lfc_true1
lfc_est2 = res_merge1$lfc_est2
lfc_true2 = res_merge1$lfc_true2
lfc_est3 = res_merge1$lfc_est3
lfc_true3 = res_merge1$lfc_true3

tp1 = sum(lfc_true1 == 1 & lfc_est1 == 1) +
  sum(lfc_true1 == -1 & lfc_est1 == -1)
fp1 = sum(lfc_true1 == 0 & lfc_est1 != 0) +
  sum(lfc_true1 == 1 & lfc_est1 == -1) +
  sum(lfc_true1 == -1 & lfc_est1 == 1)
fn1 = sum(lfc_true1 != 0 & lfc_est1 == 0)

tp2 = sum(lfc_true2 == 1 & lfc_est2 == 1) +
  sum(lfc_true2 == -1 & lfc_est2 == -1)
fp2 = sum(lfc_true2 == 0 & lfc_est2 != 0) +
  sum(lfc_true2 == 1 & lfc_est2 == -1) +
  sum(lfc_true2 == -1 & lfc_est2 == 1)
fn2 = sum(lfc_true2 != 0 & lfc_est2 == 0)

tp3 = sum(lfc_true3 == 1 & lfc_est3 == 1) +
  sum(lfc_true3 == -1 & lfc_est3 == -1)
fp3 = sum(lfc_true3 == 0 & lfc_est3 != 0) +
  sum(lfc_true3 == 1 & lfc_est3 == -1) +
  sum(lfc_true3 == -1 & lfc_est3 == 1)
fn3 = sum(lfc_true3 != 0 & lfc_est3 == 0)

tp = tp1 + tp2 + tp3
fp = fp1 + fp2 + fp3
fn = fn1 + fn2 + fn3

power1 = tp/(tp + fn)
fdr1 = fp/(tp + fp)

res_merge2 = res_pair %>%
  dplyr::transmute(taxon, 
                   lfc_est1 = lfc_cat_cov2 * diff_robust_cat_cov2,
                   lfc_est2 = lfc_cat_cov3 * diff_robust_cat_cov3,
                   lfc_est3 = lfc_cat_cov3_cat_cov2 * diff_robust_cat_cov3_cat_cov2) %>%
  dplyr::left_join(fmd %>%
                     dplyr::transmute(taxon, 
                                      lfc_true1 = lfc_cat2_vs_1,
                                      lfc_true2 = lfc_cat3_vs_1,
                                      lfc_true3 = lfc_cat3_vs_2),
                   by = "taxon") %>%
  dplyr::transmute(taxon, 
                   lfc_est1 = case_when(lfc_est1 > 0 ~ 1,
                                        lfc_est1 < 0 ~ -1,
                                        TRUE ~ 0),
                   lfc_est2 = case_when(lfc_est2 > 0 ~ 1,
                                        lfc_est2 < 0 ~ -1,
                                        TRUE ~ 0),
                   lfc_est3 = case_when(lfc_est3 > 0 ~ 1,
                                        lfc_est3 < 0 ~ -1,
                                        TRUE ~ 0),
                   lfc_true1 = case_when(lfc_true1 > 0 ~ 1,
                                         lfc_true1 < 0 ~ -1,
                                         TRUE ~ 0),
                   lfc_true2 = case_when(lfc_true2 > 0 ~ 1,
                                         lfc_true2 < 0 ~ -1,
                                         TRUE ~ 0),
                   lfc_true3 = case_when(lfc_true3 > 0 ~ 1,
                                         lfc_true3 < 0 ~ -1,
                                         TRUE ~ 0))
lfc_est1 = res_merge2$lfc_est1
lfc_true1 = res_merge2$lfc_true1
lfc_est2 = res_merge2$lfc_est2
lfc_true2 = res_merge2$lfc_true2
lfc_est3 = res_merge2$lfc_est3
lfc_true3 = res_merge2$lfc_true3

tp1 = sum(lfc_true1 == 1 & lfc_est1 == 1) +
  sum(lfc_true1 == -1 & lfc_est1 == -1)
fp1 = sum(lfc_true1 == 0 & lfc_est1 != 0) +
  sum(lfc_true1 == 1 & lfc_est1 == -1) +
  sum(lfc_true1 == -1 & lfc_est1 == 1)
fn1 = sum(lfc_true1 != 0 & lfc_est1 == 0)

tp2 = sum(lfc_true2 == 1 & lfc_est2 == 1) +
  sum(lfc_true2 == -1 & lfc_est2 == -1)
fp2 = sum(lfc_true2 == 0 & lfc_est2 != 0) +
  sum(lfc_true2 == 1 & lfc_est2 == -1) +
  sum(lfc_true2 == -1 & lfc_est2 == 1)
fn2 = sum(lfc_true2 != 0 & lfc_est2 == 0)

tp3 = sum(lfc_true3 == 1 & lfc_est3 == 1) +
  sum(lfc_true3 == -1 & lfc_est3 == -1)
fp3 = sum(lfc_true3 == 0 & lfc_est3 != 0) +
  sum(lfc_true3 == 1 & lfc_est3 == -1) +
  sum(lfc_true3 == -1 & lfc_est3 == 1)
fn3 = sum(lfc_true3 != 0 & lfc_est3 == 0)

tp = tp1 + tp2 + tp3
fp = fp1 + fp2 + fp3
fn = fn1 + fn2 + fn3

power2 = tp/(tp + fn)
fdr2 = fp/(tp + fp)

tab_summ = data.frame(Comparison = c("Without sensitivity score filter", 
                                     "With sensitivity score filter"),
                      Power = round(c(power1, power2), 2),
                      FDR = round(c(fdr1, fdr2), 2))
tab_summ %>%
    datatable(caption = "Power/FDR Comparison")

## -----------------------------------------------------------------------------
data(atlas1006, package = "microbiome")

# Subset to baseline
pseq = phyloseq::subset_samples(atlas1006, time == 0)

# Re-code the bmi group
meta_data = microbiome::meta(pseq)
meta_data$bmi = recode(meta_data$bmi_group,
                       obese = "obese",
                       severeobese = "obese",
                       morbidobese = "obese")

# Note that by default, levels of a categorical variable in R are sorted 
# alphabetically. In this case, the reference level for `bmi` will be 
# `lean`. To manually change the reference level, for instance, setting `obese`
# as the reference level, use:
meta_data$bmi = factor(meta_data$bmi, levels = c("obese", "overweight", "lean"))
# You can verify the change by checking:
# levels(meta_data$bmi)

# Create the region variable
meta_data$region = recode(as.character(meta_data$nationality),
                          Scandinavia = "NE", UKIE = "NE", SouthEurope = "SE", 
                          CentralEurope = "CE", EasternEurope = "EE",
                          .missing = "unknown")

phyloseq::sample_data(pseq) = meta_data

# Subset to lean, overweight, and obese subjects
pseq = phyloseq::subset_samples(pseq, bmi %in% c("lean", "overweight", "obese"))
# Discard "EE" as it contains only 1 subject
# Discard subjects with missing values of region
pseq = phyloseq::subset_samples(pseq, ! region %in% c("EE", "unknown"))

print(pseq)

## -----------------------------------------------------------------------------
set.seed(123)

pseq_perm = pseq
meta_data_perm = microbiome::meta(pseq_perm)
meta_data_perm$bmi = sample(meta_data_perm$bmi)

phyloseq::sample_data(pseq_perm) = meta_data_perm

## -----------------------------------------------------------------------------
set.seed(123)
output = ancombc2(data = pseq_perm, tax_level = "Genus",
                  fix_formula = "bmi", rand_formula = NULL,
                  p_adj_method = "holm", pseudo_sens = TRUE,
                  prv_cut = 0, lib_cut = 1000, s0_perc = 0.05,
                  group = "bmi", struc_zero = TRUE, neg_lb = TRUE)

## -----------------------------------------------------------------------------
res_perm = output$res %>%
    dplyr::transmute(taxon = taxon,
                     p = p_bmioverweight,
                     ss_pass = passed_ss_bmioverweight) %>%
  rowwise() %>%
  mutate(
      p_update = case_when(
          # For taxa with significant p-values that failed the sensitivity analysis,
          # we assign a random value uniformly drawn from the range [0.05, 1].
          p < 0.05 & ss_pass == FALSE ~ runif(1, min = 0.05, max = 1),
    TRUE ~ p 
  )) %>%
  ungroup()

df_p = res_perm %>%
  dplyr::select(p, p_update) %>%
  pivot_longer(cols = p:p_update, names_to = "type", values_to = "value") 
df_p$type = recode(df_p$type, 
                   `p` = "ANCOM-BC2 (Without Sensitivity Analysis)", 
                   `p_update` = "ANCOM-BC2 (With Sensitivity Analysis)")

num_bins = 30
fig_p = df_p %>%
  ggplot(aes(x = value, fill = type)) +
  geom_histogram(position = "identity", alpha = 0.5, bins = num_bins) +
  geom_hline(yintercept = phyloseq::ntaxa(pseq_perm)/num_bins, linetype = "dashed") +
  theme_minimal() +
  labs(title = "Distribution of P-values",
       x = "P-value",
       y = "Frequency",
       fill = NULL,
       color = NULL) +
    theme_bw() + 
    theme(plot.title = element_text(hjust = 0.5),
          panel.grid.minor.y = element_blank(),
          legend.position = "bottom")
fig_p

## -----------------------------------------------------------------------------
set.seed(123)
# It should be noted that we have set the number of bootstrap samples (B) equal 
# to 10 in the 'trend_control' function for computational expediency. 
# However, it is recommended that users utilize the default value of B, 
# which is 100, or larger values for optimal performance.
output = ancombc2(data = pseq, tax_level = "Family",
                  fix_formula = "age + region + bmi", rand_formula = NULL,
                  p_adj_method = "holm", pseudo_sens = TRUE,
                  prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                  group = "bmi", struc_zero = TRUE, neg_lb = TRUE,
                  alpha = 0.05, n_cl = 2, verbose = TRUE,
                  global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
                  iter_control = list(tol = 1e-2, max_iter = 20, 
                                      verbose = TRUE),
                  em_control = list(tol = 1e-5, max_iter = 100),
                  lme_control = lme4::lmerControl(),
                  mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
                  trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
                                                              nrow = 2, 
                                                              byrow = TRUE),
                                                       matrix(c(-1, 0, 1, -1),
                                                              nrow = 2, 
                                                              byrow = TRUE),
                                                       matrix(c(1, 0, 1, -1),
                                                              nrow = 2, 
                                                              byrow = TRUE)),
                                       node = list(2, 2, 1),
                                       solver = "ECOS",
                                       B = 10))

## -----------------------------------------------------------------------------
set.seed(123)
output_interact = ancombc2(data = pseq, tax_level = "Family",
                          fix_formula = "age * bmi + region", rand_formula = NULL,
                          p_adj_method = "holm", pseudo_sens = TRUE,
                          prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                          group = "bmi", struc_zero = TRUE, neg_lb = TRUE,
                          alpha = 0.05, n_cl = 2, verbose = TRUE)

head(output_interact$res) %>%
    dplyr::select(starts_with("lfc_age")) %>%
    mutate_if(is.numeric, function(x) round(x, 2)) %>%
    datatable(caption = "Example of including interaction terms")

## -----------------------------------------------------------------------------
tab_zero = output$zero_ind
tab_zero %>%
    datatable(caption = "The detection of structural zeros")

## -----------------------------------------------------------------------------
res_prim = output$res

## -----------------------------------------------------------------------------
df_age = res_prim %>%
    dplyr::select(taxon, ends_with("age")) 
df_fig_age = df_age %>%
    dplyr::filter(diff_age == 1) %>% 
    dplyr::arrange(desc(lfc_age)) %>%
    dplyr::mutate(direct = ifelse(lfc_age > 0, "Positive LFC", "Negative LFC"),
                  color = ifelse(diff_robust_age, "aquamarine3", "black"))
df_fig_age$taxon = factor(df_fig_age$taxon, levels = df_fig_age$taxon)
df_fig_age$direct = factor(df_fig_age$direct, 
                           levels = c("Positive LFC", "Negative LFC"))
  
fig_age = df_fig_age %>%
    ggplot(aes(x = taxon, y = lfc_age, fill = direct)) + 
    geom_bar(stat = "identity", width = 0.7, color = "black", 
             position = position_dodge(width = 0.4)) +
    geom_errorbar(aes(ymin = lfc_age - se_age, ymax = lfc_age + se_age), 
                  width = 0.2, position = position_dodge(0.05), color = "black") + 
    labs(x = NULL, y = "Log fold change", 
         title = "Log fold changes as one unit increase of age") + 
    scale_fill_discrete(name = NULL) +
    scale_color_discrete(name = NULL) +
    theme_bw() + 
    theme(plot.title = element_text(hjust = 0.5),
          panel.grid.minor.y = element_blank(),
          axis.text.x = element_text(angle = 60, hjust = 1,
                                     color = df_fig_age$color))
fig_age

## -----------------------------------------------------------------------------
df_bmi = res_prim %>%
    dplyr::select(taxon, contains("bmi")) 
df_fig_bmi1 = df_bmi %>%
    dplyr::filter(diff_bmilean == 1 | 
                    diff_bmioverweight == 1) %>%
    dplyr::mutate(lfc1 = ifelse(diff_bmioverweight == 1, 
                                round(lfc_bmioverweight, 2), 0),
                  lfc2 = ifelse(diff_bmilean == 1, 
                                round(lfc_bmilean, 2), 0)) %>%
    tidyr::pivot_longer(cols = lfc1:lfc2, 
                        names_to = "group", values_to = "value") %>%
    dplyr::arrange(taxon)

df_fig_bmi2 = df_bmi %>%
    dplyr::filter(diff_bmilean == 1 | 
                    diff_bmioverweight == 1) %>%
    dplyr::mutate(lfc1 = ifelse(diff_robust_bmioverweight, 
                                "aquamarine3", "black"),
                  lfc2 = ifelse(diff_robust_bmilean, 
                                "aquamarine3", "black")) %>%
    tidyr::pivot_longer(cols = lfc1:lfc2, 
                        names_to = "group", values_to = "color") %>%
    dplyr::arrange(taxon)

df_fig_bmi = df_fig_bmi1 %>%
    dplyr::left_join(df_fig_bmi2, by = c("taxon", "group"))

df_fig_bmi$group = recode(df_fig_bmi$group, 
                          `lfc1` = "Overweight - Obese",
                          `lfc2` = "Lean - Obese")
df_fig_bmi$group = factor(df_fig_bmi$group, 
                          levels = c("Overweight - Obese",
                                     "Lean - Obese"))
  
lo = floor(min(df_fig_bmi$value))
up = ceiling(max(df_fig_bmi$value))
mid = (lo + up)/2
fig_bmi = df_fig_bmi %>%
  ggplot(aes(x = group, y = taxon, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, taxon, label = value, color = color), size = 4) +
  scale_color_identity(guide = "none") +
  labs(x = NULL, y = NULL, title = "Log fold changes as compared to obese subjects") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))
fig_bmi

## -----------------------------------------------------------------------------
res_global = output$res_global
df_bmi = res_prim %>%
    dplyr::select(taxon, contains("bmi")) 
df_fig_global = df_bmi %>%
    dplyr::left_join(res_global %>%
                       dplyr::transmute(taxon, 
                                        diff_bmi = diff_abn, 
                                        diff_robust_bmi = diff_robust_abn)) %>%
    dplyr::filter(diff_bmi == 1) %>%
    dplyr::mutate(lfc_overweight = lfc_bmioverweight,
                  lfc_lean = lfc_bmilean,
                  color = ifelse(diff_robust_bmi, "aquamarine3", "black")) %>%
    dplyr::transmute(taxon,
                     `Overweight - Obese` = round(lfc_overweight, 2),
                     `Lean - Obese` = round(lfc_lean, 2), 
                     color = color) %>%
    tidyr::pivot_longer(cols = `Overweight - Obese`:`Lean - Obese`, 
                        names_to = "group", values_to = "value") %>%
    dplyr::arrange(taxon)

df_fig_global$group = factor(df_fig_global$group, 
                             levels = c("Overweight - Obese",
                                        "Lean - Obese"))
  
lo = floor(min(df_fig_global$value))
up = ceiling(max(df_fig_global$value))
mid = (lo + up)/2
fig_global = df_fig_global %>%
  ggplot(aes(x = group, y = taxon, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, taxon, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Log fold changes for globally significant taxa") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5),
        axis.text.y = element_text(color = df_fig_global %>%
                                       dplyr::distinct(taxon, color) %>%
                                       .$color))
fig_global

## -----------------------------------------------------------------------------
res_pair = output$res_pair

df_fig_pair1 = res_pair %>%
    dplyr::filter(diff_bmioverweight == 1 |
                      diff_bmilean == 1 | 
                      diff_bmilean_bmioverweight == 1) %>%
    dplyr::mutate(lfc1 = ifelse(diff_bmioverweight == 1, 
                                round(lfc_bmioverweight, 2), 0),
                  lfc2 = ifelse(diff_bmilean == 1, 
                                round(lfc_bmilean, 2), 0),
                  lfc3 = ifelse(diff_bmilean_bmioverweight == 1, 
                                round(lfc_bmilean_bmioverweight, 2), 0)) %>%
    tidyr::pivot_longer(cols = lfc1:lfc3, 
                        names_to = "group", values_to = "value") %>%
    dplyr::arrange(taxon)

df_fig_pair2 = res_pair %>%
    dplyr::filter(diff_bmioverweight == 1 |
                      diff_bmilean == 1 | 
                      diff_bmilean_bmioverweight == 1) %>%
    dplyr::mutate(lfc1 = ifelse(diff_robust_bmioverweight, 
                                "aquamarine3", "black"),
                  lfc2 = ifelse(diff_robust_bmilean, 
                                "aquamarine3", "black"),
                  lfc3 = ifelse(diff_robust_bmilean_bmioverweight, 
                                "aquamarine3", "black")) %>%
    tidyr::pivot_longer(cols = lfc1:lfc3, 
                        names_to = "group", values_to = "color") %>%
    dplyr::arrange(taxon)

df_fig_pair = df_fig_pair1 %>%
    dplyr::left_join(df_fig_pair2, by = c("taxon", "group"))

df_fig_pair$group = recode(df_fig_pair$group, 
                          `lfc1` = "Overweight - Obese",
                          `lfc2` = "Lean - Obese",
                          `lfc3` = "Lean - Overweight")
df_fig_pair$group = factor(df_fig_pair$group, 
                          levels = c("Overweight - Obese",
                                     "Lean - Obese", 
                                     "Lean - Overweight"))

lo = floor(min(df_fig_pair$value))
up = ceiling(max(df_fig_pair$value))
mid = (lo + up)/2
fig_pair = df_fig_pair %>%
    ggplot(aes(x = group, y = taxon, fill = value)) + 
    geom_tile(color = "black") +
    scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                         na.value = "white", midpoint = mid, limit = c(lo, up),
                         name = NULL) +
    geom_text(aes(group, taxon, label = value, color = color), size = 4) +
    scale_color_identity(guide = FALSE) +
    labs(x = NULL, y = NULL, title = "Log fold changes as compared to obese subjects") +
    theme_minimal() +
    theme(plot.title = element_text(hjust = 0.5))
fig_pair

## -----------------------------------------------------------------------------
res_dunn = output$res_dunn

df_fig_dunn1 = res_dunn %>%
    dplyr::filter(diff_bmioverweight == 1 | 
                      diff_bmilean == 1) %>%
    dplyr::mutate(lfc1 = ifelse(diff_bmioverweight == 1, 
                                round(lfc_bmioverweight, 2), 0),
                  lfc2 = ifelse(diff_bmilean == 1, 
                                round(lfc_bmilean, 2), 0)) %>%
    tidyr::pivot_longer(cols = lfc1:lfc2, 
                        names_to = "group", values_to = "value") %>%
    dplyr::arrange(taxon)

df_fig_dunn2 = res_dunn %>%
    dplyr::filter(diff_bmioverweight == 1 | 
                      diff_bmilean == 1) %>%
    dplyr::mutate(lfc1 = ifelse(diff_robust_bmioverweight, 
                                "aquamarine3", "black"),
                  lfc2 = ifelse(diff_robust_bmilean == 1, 
                                "aquamarine3", "black")) %>%
    tidyr::pivot_longer(cols = lfc1:lfc2, 
                        names_to = "group", values_to = "color") %>%
    dplyr::arrange(taxon)

df_fig_dunn = df_fig_dunn1 %>%
    dplyr::left_join(df_fig_dunn2, by = c("taxon", "group"))

df_fig_dunn$group = recode(df_fig_dunn$group, 
                          `lfc1` = "Overweight - Obese",
                          `lfc2` = "Lean - Obese")
df_fig_dunn$group = factor(df_fig_dunn$group, 
                          levels = c("Overweight - Obese",
                                     "Lean - Obese"))

lo = floor(min(df_fig_dunn$value))
up = ceiling(max(df_fig_dunn$value))
mid = (lo + up)/2
fig_dunn = df_fig_dunn %>%
  ggplot(aes(x = group, y = taxon, fill = value)) + 
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "white", midpoint = mid, limit = c(lo, up),
                       name = NULL) +
  geom_text(aes(group, taxon, label = value, color = color), size = 4) +
    scale_color_identity(guide = FALSE) +
  labs(x = NULL, y = NULL, title = "Log fold changes as compared to obese subjects") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5))
fig_dunn

## -----------------------------------------------------------------------------
res_trend = output$res_trend

df_fig_trend = res_trend %>%
    dplyr::filter(diff_abn == 1) %>%
    dplyr::mutate(lfc1 = round(lfc_bmioverweight, 2),
                  lfc2 = round(lfc_bmilean, 2),
                  color = ifelse(diff_robust_abn, "aquamarine3", "black")) %>%
    tidyr::pivot_longer(cols = lfc1:lfc2, 
                        names_to = "group", values_to = "value") %>%
    dplyr::arrange(taxon)

df_fig_trend$group = recode(df_fig_trend$group, 
                          `lfc1` = "Overweight - Obese",
                          `lfc2` = "Lean - Obese")
df_fig_trend$group = factor(df_fig_trend$group, 
                          levels = c("Overweight - Obese", 
                                     "Lean - Obese"))

lo = floor(min(df_fig_trend$value))
up = ceiling(max(df_fig_trend$value))
mid = (lo + up)/2
fig_trend = df_fig_trend %>%
    ggplot(aes(x = group, y = taxon, fill = value)) + 
    geom_tile(color = "black") +
    scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                         na.value = "white", midpoint = mid, limit = c(lo, up),
                         name = NULL) +
    geom_text(aes(group, taxon, label = value), color = "black", size = 4) +
    labs(x = NULL, y = NULL, title = "Log fold changes as compared to obese subjects") +
    theme_minimal() +
    theme(plot.title = element_text(hjust = 0.5),
          axis.text.y = element_text(color = df_fig_trend %>%
                                         dplyr::distinct(taxon, color) %>%
                                         .$color))
fig_trend

## ----eval=FALSE---------------------------------------------------------------
# tse = mia::makeTreeSummarizedExperimentFromPhyloseq(atlas1006)
# tse = tse[, tse$time == 0]
# tse$bmi = recode(tse$bmi_group,
#                  obese = "obese",
#                  severeobese = "obese",
#                  morbidobese = "obese")
# tse = tse[, tse$bmi %in% c("lean", "overweight", "obese")]
# tse$bmi = factor(tse$bmi, levels = c("obese", "overweight", "lean"))
# tse$region = recode(as.character(tse$nationality),
#                     Scandinavia = "NE", UKIE = "NE", SouthEurope = "SE",
#                     CentralEurope = "CE", EasternEurope = "EE",
#                     .missing = "unknown")
# tse = tse[, ! tse$region %in% c("EE", "unknown")]
# 
# set.seed(123)
# output = ancombc2(data = tse, assay_name = "counts", tax_level = "Family",
#                   fix_formula = "age + region + bmi", rand_formula = NULL,
#                   p_adj_method = "holm", pseudo_sens = TRUE,
#                   prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
#                   group = "bmi", struc_zero = TRUE, neg_lb = TRUE,
#                   alpha = 0.05, n_cl = 2, verbose = TRUE,
#                   global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
#                   iter_control = list(tol = 1e-2, max_iter = 20,
#                                       verbose = TRUE),
#                   em_control = list(tol = 1e-5, max_iter = 100),
#                   lme_control = lme4::lmerControl(),
#                   mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
#                   trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
#                                                               nrow = 2,
#                                                               byrow = TRUE),
#                                                        matrix(c(-1, 0, 1, -1),
#                                                               nrow = 2,
#                                                               byrow = TRUE),
#                                                        matrix(c(1, 0, 1, -1),
#                                                               nrow = 2,
#                                                               byrow = TRUE)),
#                                        node = list(2, 2, 1),
#                                        solver = "ECOS",
#                                        B = 10))

## ----eval=FALSE---------------------------------------------------------------
# abundance_data = microbiome::abundances(pseq)
# aggregate_data = microbiome::abundances(microbiome::aggregate_taxa(pseq, "Family"))
# meta_data = microbiome::meta(pseq)
# 
# set.seed(123)
# output = ancombc2(data = abundance_data,
#                   aggregate_data = aggregate_data,
#                   meta_data = meta_data,
#                   fix_formula = "age + region + bmi", rand_formula = NULL,
#                   p_adj_method = "holm", pseudo_sens = TRUE,
#                   prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
#                   group = "bmi", struc_zero = TRUE, neg_lb = TRUE,
#                   alpha = 0.05, n_cl = 2, verbose = TRUE,
#                   global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
#                   iter_control = list(tol = 1e-2, max_iter = 20,
#                                       verbose = TRUE),
#                   em_control = list(tol = 1e-5, max_iter = 100),
#                   lme_control = lme4::lmerControl(),
#                   mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
#                   trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
#                                                               nrow = 2,
#                                                               byrow = TRUE),
#                                                        matrix(c(-1, 0, 1, -1),
#                                                               nrow = 2,
#                                                               byrow = TRUE),
#                                                        matrix(c(1, 0, 1, -1),
#                                                               nrow = 2,
#                                                               byrow = TRUE)),
#                                        node = list(2, 2, 1),
#                                        solver = "ECOS",
#                                        B = 10))

## -----------------------------------------------------------------------------
data(dietswap, package = "microbiome")

## -----------------------------------------------------------------------------
set.seed(123)
# It should be noted that we have set the number of bootstrap samples (B) equal 
# to 10 in the 'trend_control' function for computational expediency. 
# However, it is recommended that users utilize the default value of B, 
# which is 100, or larger values for optimal performance.
output = ancombc2(data = dietswap, tax_level = "Family",
                  fix_formula = "nationality + timepoint + group",
                  rand_formula = "(timepoint | subject)",
                  p_adj_method = "holm", pseudo_sens = TRUE,
                  prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
                  group = "group", struc_zero = TRUE, neg_lb = TRUE,
                  alpha = 0.05, n_cl = 2, verbose = TRUE,
                  global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
                  iter_control = list(tol = 1e-2, max_iter = 20, 
                                      verbose = TRUE),
                  em_control = list(tol = 1e-5, max_iter = 100),
                  lme_control = lme4::lmerControl(),
                  mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
                  trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
                                                              nrow = 2, 
                                                              byrow = TRUE)),
                                       node = list(2),
                                       solver = "ECOS",
                                       B = 10))

## ----eval=FALSE---------------------------------------------------------------
# set.seed(123)
# output = ancombc2(data = dietswap, tax_level = "Family",
#                   fix_formula = "nationality + timepoint * group",
#                   rand_formula = "(timepoint | subject)",
#                   p_adj_method = "holm", pseudo_sens = TRUE,
#                   prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
#                   group = "group", struc_zero = TRUE, neg_lb = TRUE,
#                   alpha = 0.05, n_cl = 2, verbose = TRUE,
#                   global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
#                   iter_control = list(tol = 1e-2, max_iter = 20,
#                                       verbose = TRUE),
#                   em_control = list(tol = 1e-5, max_iter = 100),
#                   lme_control = lme4::lmerControl(),
#                   mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
#                   trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
#                                                               nrow = 2,
#                                                               byrow = TRUE)),
#                                        node = list(2),
#                                        solver = "ECOS",
#                                        B = 10))

## -----------------------------------------------------------------------------
res_prim = output$res %>%
    mutate_if(is.numeric, function(x) round(x, 2))
res_prim[1:6, ] %>%
    datatable(caption = "ANCOM-BC2 Primary Analysis")

## -----------------------------------------------------------------------------
res_global = output$res_global %>%
    mutate_if(is.numeric, function(x) round(x, 2))
res_global[1:6, ] %>%
    datatable(caption = "ANCOM-BC2 Global Test")

## -----------------------------------------------------------------------------
res_pair = output$res_pair %>%
    mutate_if(is.numeric, function(x) round(x, 2))
res_pair[1:6, ] %>%
    datatable(caption = "ANCOM-BC2 Multiple Pairwise Comparisons")

## -----------------------------------------------------------------------------
res_dunn = output$res_dunn %>%
    mutate_if(is.numeric, function(x) round(x, 2))
res_dunn[1:6, ] %>%
    datatable(caption = "ANCOM-BC2 Dunnett's Type of Test")

## -----------------------------------------------------------------------------
res_trend = output$res_trend %>%
    mutate_if(is.numeric, function(x) round(x, 2))
res_trend[1:6, ] %>%
    datatable(caption = "ANCOM-BC2 Pattern Analysis")

## ----eval=FALSE---------------------------------------------------------------
# tse = mia::makeTreeSummarizedExperimentFromPhyloseq(dietswap)
# 
# set.seed(123)
# output = ancombc2(data = tse, assay_name = "counts", tax_level = "Family",
#                   fix_formula = "nationality + timepoint + group",
#                   rand_formula = "(timepoint | subject)",
#                   p_adj_method = "holm", pseudo_sens = TRUE,
#                   prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
#                   group = "group", struc_zero = TRUE, neg_lb = TRUE,
#                   alpha = 0.05, n_cl = 2, verbose = TRUE,
#                   global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
#                   iter_control = list(tol = 1e-2, max_iter = 20,
#                                       verbose = TRUE),
#                   em_control = list(tol = 1e-5, max_iter = 100),
#                   lme_control = lme4::lmerControl(),
#                   mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
#                   trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
#                                                               nrow = 2,
#                                                               byrow = TRUE)),
#                                        node = list(2),
#                                        solver = "ECOS",
#                                        B = 10))

## ----eval=FALSE---------------------------------------------------------------
# abundance_data = microbiome::abundances(dietswap)
# aggregate_data = microbiome::abundances(microbiome::aggregate_taxa(dietswap, "Family"))
# meta_data = microbiome::meta(dietswap)
# 
# set.seed(123)
# output = ancombc2(data = abundance_data,
#                   aggregate_data = aggregate_data,
#                   meta_data = meta_data,
#                   fix_formula = "nationality + timepoint + group",
#                   rand_formula = "(timepoint | subject)",
#                   p_adj_method = "holm", pseudo_sens = TRUE,
#                   prv_cut = 0.10, lib_cut = 1000, s0_perc = 0.05,
#                   group = "group", struc_zero = TRUE, neg_lb = TRUE,
#                   alpha = 0.05, n_cl = 2, verbose = TRUE,
#                   global = TRUE, pairwise = TRUE, dunnet = TRUE, trend = TRUE,
#                   iter_control = list(tol = 1e-2, max_iter = 20,
#                                       verbose = TRUE),
#                   em_control = list(tol = 1e-5, max_iter = 100),
#                   lme_control = lme4::lmerControl(),
#                   mdfdr_control = list(fwer_ctrl_method = "holm", B = 100),
#                   trend_control = list(contrast = list(matrix(c(1, 0, -1, 1),
#                                                               nrow = 2,
#                                                               byrow = TRUE)),
#                                        node = list(2),
#                                        solver = "ECOS",
#                                        B = 10))

## -----------------------------------------------------------------------------
bias_correct_log_table = output$bias_correct_log_table
# By default, ANCOM-BC2 does not add pseudo-counts to zero counts, which can 
# result in NAs in the bias-corrected log abundances. Users have the option to 
# either leave the NAs as they are or replace them with zeros. 
# This replacement is equivalent to adding pseudo-counts of ones to the zero counts. 
bias_correct_log_table[is.na(bias_correct_log_table)] = 0
# Show the first 6 samples
round(bias_correct_log_table[, 1:6], 2) %>% 
  datatable(caption = "Bias-corrected log abundances")

## ----sessionInfo, message = FALSE, warning = FALSE, comment = NA--------------
sessionInfo()