policy_eval

library(polle)

This vignette is a guide to policy_eval() and some of the associated S3 methods. The purpose of policy_eval is to estimate (evaluate) the value of a user-defined policy or a policy learning algorithm. For details on the methodology, see the associated paper (Nordland and Holst 2023).

We consider a fixed two-stage problem as a general setup and simulate data using sim_two_stage() and create a policy_data object using policy_data():

d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
                  action = c("A_1", "A_2"),
                  baseline = c("B", "BB"),
                  covariates = list(L = c("L_1", "L_2"),
                                    C = c("C_1", "C_2")),
                  utility = c("U_1", "U_2", "U_3"))
pd
## Policy data with n = 2000 observations and maximal K = 2 stages.
## 
##      action
## stage    0    1    n
##     1 1017  983 2000
##     2  819 1181 2000
## 
## Baseline covariates: B, BB
## State covariates: L, C
## Average utility: 0.84

Evaluating a user-defined policy

User-defined policies are created using policy_def(). In this case we define a simple static policy always selecting action '1':

p1 <- policy_def(policy_functions = '1', reuse = TRUE, name = "(A=1)")

As we want to apply the same policy function at both stages we set reuse = TRUE.

policy_eval() implements three types of policy evaluations: Inverse probability weighting estimation, outcome regression estimation, and doubly robust (DR) estimation. As doubly robust estimation is a combination of the two other types, we focus on this approach. For details on the implementation see Algorithm 1 in (Nordland and Holst 2023).

(pe1 <- policy_eval(policy_data = pd,
                    policy = p1,
                    type = "dr"))
##                  Estimate Std.Err   2.5% 97.5%   P-value
## E[Z(d)]: d=(A=1)   0.8213  0.1115 0.6027  1.04 1.796e-13

policy_eval() returns an object of type policy_eval which prints like a lava::estimate object. The policy value estimate and variance are available via coef() and vcov():

coef(pe1)
## [1] 0.8213233
vcov(pe1)
##            [,1]
## [1,] 0.01244225

Working with policy_eval objects

The policy_eval object behaves like an lava::estimate object, which can also be directly accessed using estimate().

estimate objects makes it easy to work with estimates with an iid decomposition given by the influence curve/function, see the estimate vignette.

The influence curve is available via IC():

IC(pe1) |> head()
##            [,1]
## [1,]  2.5515875
## [2,] -5.6787782
## [3,]  4.9506000
## [4,]  2.0661524
## [5,]  0.7939672
## [6,] -2.2932160

Merging estimate objects allow the user to get inference for transformations of the estimates via the Delta method. Here we get inference for the average treatment effect, both as a difference and as a ratio:

p0 <- policy_def(policy_functions = 0, reuse = TRUE, name = "(A=0)")
pe0 <- policy_eval(policy_data = pd,
                   policy = p0,
                   type = "dr")

(est <- merge(pe0, pe1))
##                  Estimate Std.Err    2.5%  97.5%   P-value
## E[Z(d)]: d=(A=0) -0.06123  0.0881 -0.2339 0.1114 4.871e-01
## ────────────────                                          
## E[Z(d)]: d=(A=1)  0.82132  0.1115  0.6027 1.0399 1.796e-13
estimate(est, function(x) x[2]-x[1], labels="ATE-difference")
##                Estimate Std.Err   2.5% 97.5%  P-value
## ATE-difference   0.8825  0.1338 0.6203 1.145 4.25e-11
estimate(est, function(x) x[2]/x[1], labels="ATE-ratio")
##           Estimate Std.Err   2.5% 97.5% P-value
## ATE-ratio   -13.41    19.6 -51.83    25  0.4937

Nuisance models

So far we have relied on the default generalized linear models for the nuisance g-models and Q-models. As default, a single g-model trained across all stages using the state/Markov type history, see the policy_data vignette. Use get_g_functions() to get access to the fitted model:

(gf <- get_g_functions(pe1))
## $all_stages
## $model
## 
## Call:  NULL
## 
## Coefficients:
## (Intercept)            L            C            B     BBgroup2     BBgroup3  
##     0.08285      0.03094      0.97993     -0.05753     -0.13970     -0.06122  
## 
## Degrees of Freedom: 3999 Total (i.e. Null);  3994 Residual
## Null Deviance:       5518 
## Residual Deviance: 4356  AIC: 4368
## 
## 
## attr(,"full_history")
## [1] FALSE

The g-functions can be used as input to a new policy evaluation:

policy_eval(policy_data = pd,
            g_functions = gf,
            policy = p0,
            type = "dr")
##                  Estimate Std.Err    2.5%  97.5% P-value
## E[Z(d)]: d=(A=0) -0.06123  0.0881 -0.2339 0.1114  0.4871

or we can get the associated predicted values:

predict(gf, new_policy_data = pd) |> head(6)
## Key: <id, stage>
##       id stage        g_0        g_1
##    <int> <int>      <num>      <num>
## 1:     1     1 0.15628741 0.84371259
## 2:     1     2 0.08850558 0.91149442
## 3:     2     1 0.92994454 0.07005546
## 4:     2     2 0.92580890 0.07419110
## 5:     3     1 0.11184451 0.88815549
## 6:     3     2 0.08082666 0.91917334

Similarly, we can inspect the Q-functions using get_q_functions():

get_q_functions(pe1)
## $stage_1
## $model
## 
## Call:  NULL
## 
## Coefficients:
## (Intercept)           A1            L            C            B     BBgroup2  
##    0.232506     0.682422     0.454642     0.039021    -0.070152    -0.184704  
##    BBgroup3         A1:L         A1:C         A1:B  A1:BBgroup2  A1:BBgroup3  
##   -0.171734    -0.010746     0.938791     0.003772     0.157200     0.270711  
## 
## Degrees of Freedom: 1999 Total (i.e. Null);  1988 Residual
## Null Deviance:       7689 
## Residual Deviance: 3599  AIC: 6877
## 
## 
## $stage_2
## $model
## 
## Call:  NULL
## 
## Coefficients:
## (Intercept)           A1            L            C            B     BBgroup2  
##   -0.043324     0.147356     0.002376    -0.042036     0.005331    -0.001128  
##    BBgroup3         A1:L         A1:C         A1:B  A1:BBgroup2  A1:BBgroup3  
##   -0.108404     0.024424     0.962591    -0.059177    -0.102084     0.094688  
## 
## Degrees of Freedom: 1999 Total (i.e. Null);  1988 Residual
## Null Deviance:       3580 
## Residual Deviance: 1890  AIC: 5588
## 
## 
## attr(,"full_history")
## [1] FALSE

Note that a model is trained for each stage. Again, we can predict from the Q-models using predict().

Usually, we want to specify the nuisance models ourselves using the g_models and q_models arguments:

pe1 <- policy_eval(pd,
            policy = p1,
            g_models = list(
              g_sl(formula = ~ BB + L_1, SL.library = c("SL.glm", "SL.ranger")),
              g_sl(formula = ~ BB + L_1 + C_2, SL.library = c("SL.glm", "SL.ranger"))
            ),
            g_full_history = TRUE,
            q_models = list(
              q_glm(formula = ~ A * (B + C_1)), # including action interactions
              q_glm(formula = ~ A * (B + C_1 + C_2)) # including action interactions
            ),
            q_full_history = TRUE)
## Loading required namespace: ranger

Here we train a super learner g-model for each stage using the full available history and a generalized linear model for the Q-models. The formula argument is used to construct the model frame passed to the model for training (and prediction). The valid formula terms depending on g_full_history and q_full_history are available via get_history_names():

get_history_names(pd) # state/Markov history
## [1] "L"  "C"  "B"  "BB"
get_history_names(pd, stage = 1) # full history
## [1] "L_1" "C_1" "B"   "BB"
get_history_names(pd, stage = 2) # full history
## [1] "A_1" "L_1" "L_2" "C_1" "C_2" "B"   "BB"

Remember that the action variable at the current stage is always named A. Some models like glm require interactions to be specified via the model frame. Thus, for some models, it is important to include action interaction terms for the Q-models.

Evaluating a policy learning algorithm

The value of a learned policy is an important performance measure, and policy_eval() allow for direct evaluation of a given policy learning algorithm. For details, see Algorithm 4 in (Nordland and Holst 2023).

In polle, policy learning algorithms are specified using policy_learn(), see the associated vignette. These functions can be directly evaluated in policy_eval():

policy_eval(pd,
            policy_learn = policy_learn(type = "ql"))
##               Estimate Std.Err  2.5% 97.5%   P-value
## E[Z(d)]: d=ql    1.306 0.06641 1.176 1.437 3.783e-86

In the above example we evaluate the policy estimated via Q-learning. Alternatively, we can first learn the policy and then pass it to policy_eval():

p_ql <- policy_learn(type = "ql")(pd, q_models = q_glm())
policy_eval(pd,
            policy = get_policy(p_ql))
##               Estimate Std.Err  2.5% 97.5%   P-value
## E[Z(d)]: d=ql    1.306 0.06641 1.176 1.437 3.783e-86

Cross-fitting

A key feature of policy_eval() is that it allows for easy cross-fitting of the nuisance models as well the learned policy. Here we specify two-fold cross-fitting via the M argument:

pe_cf <- policy_eval(pd,
                     policy_learn = policy_learn(type = "ql"),
                     M = 2)

Specifically, both the nuisance models and the optimal policy are fitted on each training fold. Subsequently, the doubly robust value score is calculated on the validation folds.

The policy_eval object now consists of a list of policy_eval objects associated with each fold:

pe_cf$folds$fold_1 |> head()
## [1]  3  4  5  7  8 10
pe_cf$cross_fits$fold_1
##               Estimate Std.Err  2.5% 97.5%   P-value
## E[Z(d)]: d=ql    1.261 0.09456 1.075 1.446 1.538e-40

In order to save memory, particularly when cross-fitting, it is possible not to save the nuisance models via the save_g_functions and save_q_functions arguments.

Parallel processing via future.apply

It is easy to parallelize the cross-fitting procedure via the future.apply package:

library(future.apply)
plan("multisession") # local parallel procession
library("progressr") # progress bar
handlers(global = TRUE)

policy_eval(pd,
            policy_learn = policy_learn(type = "ql"),
            q_models = q_rf(),
            M = 20)

plan("sequential") # resetting to sequential processing

SessionInfo

sessionInfo()
## R version 4.4.1 (2024-06-14)
## Platform: aarch64-apple-darwin23.5.0
## Running under: macOS Sonoma 14.6.1
## 
## Matrix products: default
## BLAS:   /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib 
## LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib;  LAPACK version 3.12.0
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## time zone: Europe/Copenhagen
## tzcode source: internal
## 
## attached base packages:
## [1] splines   stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
## [1] ggplot2_3.5.1       data.table_1.15.4   polle_1.5          
## [4] SuperLearner_2.0-29 gam_1.22-4          foreach_1.5.2      
## [7] nnls_1.5           
## 
## loaded via a namespace (and not attached):
##  [1] sass_0.4.9          utf8_1.2.4          future_1.33.2      
##  [4] lattice_0.22-6      listenv_0.9.1       digest_0.6.36      
##  [7] magrittr_2.0.3      evaluate_0.24.0     grid_4.4.1         
## [10] iterators_1.0.14    mvtnorm_1.2-5       policytree_1.2.3   
## [13] fastmap_1.2.0       jsonlite_1.8.8      Matrix_1.7-0       
## [16] survival_3.6-4      fansi_1.0.6         scales_1.3.0       
## [19] numDeriv_2016.8-1.1 codetools_0.2-20    jquerylib_0.1.4    
## [22] lava_1.8.0          cli_3.6.3           rlang_1.1.4        
## [25] mets_1.3.4          parallelly_1.37.1   future.apply_1.11.2
## [28] munsell_0.5.1       withr_3.0.0         cachem_1.1.0       
## [31] yaml_2.3.8          tools_4.4.1         parallel_4.4.1     
## [34] colorspace_2.1-0    ranger_0.16.0       globals_0.16.3     
## [37] vctrs_0.6.5         R6_2.5.1            lifecycle_1.0.4    
## [40] pkgconfig_2.0.3     timereg_2.0.5       progressr_0.14.0   
## [43] bslib_0.7.0         pillar_1.9.0        gtable_0.3.5       
## [46] Rcpp_1.0.13         glue_1.7.0          xfun_0.45          
## [49] tibble_3.2.1        highr_0.11          knitr_1.47         
## [52] farver_2.1.2        htmltools_0.5.8.1   rmarkdown_2.27     
## [55] labeling_0.4.3      compiler_4.4.1

References

Nordland, Andreas, and Klaus K. Holst. 2023. “Policy Learning with the Polle Package.” https://doi.org/10.48550/arXiv.2212.02335.