The fairGNN package provides a complete pipeline for training and evaluating a Gated Neural Network (GNN) designed to mitigate demographic bias in predictive modelling. The package implements a fairness-aware GNN that uses a custom loss function to enforce the Equalized Odds fairness criterion by minimising the variance in True Positive and False Positive Rates across subgroups.
This vignette demonstrates the full workflow using the GENDEP dataset to predict antidepressant response, focusing on fairness across gender subgroups.
The package is built around a logical sequence of functions:
prepare_data(): Cleans and prepares the input
data.train_gnn(): Trains the Gated Neural Network.analyse_gnn_results(): Conducts performance and gate
analysis.analyse_experts(): Performs analysis of expert
specialisation.plot_sankey(): Visualises the model’s patient routing
behaviour.First, we load the necessary libraries and the dataset.
library(fairGNN)
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(readxl)
#> Error in library(readxl): there is no package called 'readxl'
# For a reproducible vignette, we create a dummy dataframe.
# A user would load their own data here.
set.seed(123)
raw_data <- data.frame(
subjectid = 1:430,
hdremit.all = sample(0:1, 430, replace = TRUE),
sex = sample(1:2, 430, replace = TRUE),
madrs.total = rnorm(430, 25, 5),
feature1 = rnorm(430),
feature2 = rnorm(430)
)We use prepare_data() to process the raw data for the
Male/Female analysis.
# Raw coding: 2 = Male, 1 = Female
numeric_mappings_gender <- list('2' = 0, '1' = 1)
cols_to_drop <- c(
"subjectid", "Row.names", "bloodsampleid.x", "madrs.total", "hrsd.total",
"bdi.total", "bdi14wk0", "bdi20wk0", "f61score0", "f62score0",
"f64score0", "f65score0", "k30"
)
prepared_data_gender <- prepare_data(
data = raw_data,
outcome_var = "hdremit.all",
group_var = "sex",
group_mappings = numeric_mappings_gender,
cols_to_remove = cols_to_drop
)For a fast and reliable vignette, we load pre-computed results from a
train_gnn run instead of re-training the model each time.
These results were generated using the real GENDEP data.
# In a real analysis, a user would run train_gnn() here.
# For the vignette, we load the results saved by the create_vignette_data.R script.
gnn_results <- readRDS("data/gnn_results.rds")
#> Warning in gzfile(file, "rb"): cannot open compressed file
#> 'data/gnn_results.rds', probable reason 'No such file or directory'
#> Error in gzfile(file, "rb"): cannot open the connection
expert_analyses <- readRDS("data/expert_analyses.rds")
#> Warning in gzfile(file, "rb"): cannot open compressed file
#> 'data/expert_analyses.rds', probable reason 'No such file or directory'
#> Error in gzfile(file, "rb"): cannot open the connectionWith the results loaded, we run analyse_gnn_results() to
generate all the standard performance plots and gate analyses.
label_mappings_gender <- list('0' = "Male", '1' = "Female")
# Run basic analysis
basic_analyses <- analyse_gnn_results(
gnn_results = gnn_results,
prepared_data = prepared_data_gender,
group_mappings = label_mappings_gender
)
#> Error: object 'gnn_results' not found
# --- View all plots from the basic analysis ---
cat("## ROC Curve\n")
#> ## ROC Curve
print(basic_analyses$roc_plot)
#> Error: object 'basic_analyses' not found
cat("\n## Calibration Plot\n")
#>
#> ## Calibration Plot
print(basic_analyses$calibration_plot)
#> Error: object 'basic_analyses' not found
cat("\n## Gate Weight Distribution\n")
#>
#> ## Gate Weight Distribution
print(basic_analyses$gate_density_plot)
#> Error: object 'basic_analyses' not found
cat("\n## Gate Entropy Distribution\n")
#>
#> ## Gate Entropy Distribution
print(basic_analyses$entropy_density_plot)
#> Error: object 'basic_analyses' not foundNow, we use analyse_experts() to investigate how the
different expert networks have specialised their learning.
# --- View all results from the expert analysis ---
cat("\n## Feature Importance: Female vs. Male\n")
#>
#> ## Feature Importance: Female vs. Male
# This table shows the features with the biggest difference in importance between the two experts.
print(head(expert_analyses$pairwise_differences$Female_vs_Male))
#> Error: object 'expert_analyses' not found
cat("\n## Feature Importance Difference Plot\n")
#>
#> ## Feature Importance Difference Plot
# This plot visualises the differences from the table above.
print(expert_analyses$difference_plot)
#> Error: object 'expert_analyses' not foundFinally, we use plot_sankey() to create the key
visualisation from the research paper, showing how patients are routed
through the model.
# Generate and print the Sankey plot
sankey_diagram <- plot_sankey(
raw_data = raw_data,
gnn_results = gnn_results,
expert_results = expert_analyses,
group_mappings = label_mappings_gender,
group_var = "sex"
)
#> Error: object 'expert_analyses' not found
print(sankey_diagram)
#> Error: object 'sankey_diagram' not found