Speeding computations using weights in geex

Bradley Saul

2022-07-24

Motivation

A user had a case of estimating parameters based on a dataset that contained only categorical predictors. The data can be represented either as one row per individual or one row per group defined by unique combinations of categories. In this example, I show how computations in geex can be massively sped up using the latter data representation and the weights option in estimate_equation.

Data

The following code generates two datasets: data1 has one row per unit and data2 has one row per unique combination of the categorical varibles.

library(geex)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
set.seed(42)
n <- 1000

data1 <- data_frame(
  ID = 1:n,
  Y_tau = rbinom(n,1,0.2),
  S_star = rbinom(n,1,0.6),
  Y = rbinom(n,1,0.4),
  Z = rbinom(n,1,0.5))
## Warning: `data_frame()` was deprecated in tibble 1.1.0.
## Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
data2 <- data1 %>% group_by(Y_tau, S_star, Y, Z) %>% count()

Estimating equations

This is the estimating equation that the user provided as an example. I have no idea what the target parameters represent, but it nicely illustrates the point.

example <- function(data) {
  function(theta) {
    with(data, 
         c( (1 - Y_tau)*(1 -Z )*(Y - theta[1]),
            (1-Y_tau)*Z*(Y-theta[2]),
             theta[3] - theta[2]*theta[1]))
  }
}

Computation time

The timing to find point and variance estimates is compared:

system.time({
results1 <- m_estimate(
  estFUN = example,
  data  = data1,
  root_control = setup_root_control(start = c(.5, .5, .5))
)})
##    user  system elapsed 
##   0.526   0.002   0.528
system.time({
  results2 <- m_estimate(
  estFUN = example,
  data  = data2,
  weights = data2$n,
  root_control = setup_root_control(start = c(.5, .5, .5))
)})
##    user  system elapsed 
##   0.036   0.003   0.040

The latter option is clearly preferred.

Results comparison

And the results are basically identical:

roots(results1)
## [1] 0.4123711 0.4014423 0.1655432
roots(results2)
## [1] 0.4123711 0.4014423 0.1655432
vcov(results1)
##              [,1]         [,2]         [,3]
## [1,] 0.0006245391 0.0000000000 0.0002507164
## [2,] 0.0000000000 0.0005776115 0.0002381903
## [3,] 0.0002507164 0.0002381903 0.0001988710
vcov(results2)
##              [,1]         [,2]         [,3]
## [1,] 6.245391e-04 6.873914e-47 0.0002507164
## [2,] 6.873914e-47 5.776115e-04 0.0002381903
## [3,] 2.507164e-04 2.381903e-04 0.0001988710