This vignette is a guide to policy_eval()
and some of
the associated S3 methods. The purpose of policy_eval
is to
estimate (evaluate) the value of a user-defined policy or a policy
learning algorithm. For details on the methodology, see the associated
paper (Nordland and Holst 2023).
We consider a fixed two-stage problem as a general setup and simulate
data using sim_two_stage()
and create a
policy_data
object using policy_data()
:
d <- sim_two_stage(n = 2e3, seed = 1)
pd <- policy_data(d,
action = c("A_1", "A_2"),
baseline = c("B", "BB"),
covariates = list(L = c("L_1", "L_2"),
C = c("C_1", "C_2")),
utility = c("U_1", "U_2", "U_3"))
pd
## Policy data with n = 2000 observations and maximal K = 2 stages.
##
## action
## stage 0 1 n
## 1 1017 983 2000
## 2 819 1181 2000
##
## Baseline covariates: B, BB
## State covariates: L, C
## Average utility: 0.84
User-defined policies are created using policy_def()
. In
this case we define a simple static policy always selecting action
'1'
:
As we want to apply the same policy function at both stages we set
reuse = TRUE
.
policy_eval()
implements three types of policy
evaluations: Inverse probability weighting estimation, outcome
regression estimation, and doubly robust (DR) estimation. As doubly
robust estimation is a combination of the two other types, we focus on
this approach. For details on the implementation see Algorithm 1 in
(Nordland and Holst 2023).
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=(A=1) 0.8213 0.1115 0.6027 1.04 1.796e-13
policy_eval()
returns an object of type
policy_eval
which prints like a lava::estimate
object. The policy value estimate and variance are available via
coef()
and vcov()
:
## [1] 0.8213233
## [,1]
## [1,] 0.01244225
policy_eval
objectsThe policy_eval
object behaves like an
lava::estimate
object, which can also be directly accessed
using estimate()
.
estimate
objects makes it easy to work with estimates
with an iid decomposition given by the influence curve/function, see the
estimate
vignette.
The influence curve is available via IC()
:
## [,1]
## [1,] 2.5515875
## [2,] -5.6787782
## [3,] 4.9506000
## [4,] 2.0661524
## [5,] 0.7939672
## [6,] -2.2932160
Merging estimate
objects allow the user to get inference
for transformations of the estimates via the Delta method. Here we get
inference for the average treatment effect, both as a difference and as
a ratio:
p0 <- policy_def(policy_functions = 0, reuse = TRUE, name = "(A=0)")
pe0 <- policy_eval(policy_data = pd,
policy = p0,
type = "dr")
(est <- merge(pe0, pe1))
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=(A=0) -0.06123 0.0881 -0.2339 0.1114 4.871e-01
## ────────────────
## E[Z(d)]: d=(A=1) 0.82132 0.1115 0.6027 1.0399 1.796e-13
## Estimate Std.Err 2.5% 97.5% P-value
## ATE-difference 0.8825 0.1338 0.6203 1.145 4.25e-11
## Estimate Std.Err 2.5% 97.5% P-value
## ATE-ratio -13.41 19.6 -51.83 25 0.4937
So far we have relied on the default generalized linear models for
the nuisance g-models and Q-models. As default, a single g-model trained
across all stages using the state/Markov type history, see the
policy_data
vignette. Use get_g_functions()
to
get access to the fitted model:
## $all_stages
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) L C B BBgroup2 BBgroup3
## 0.08285 0.03094 0.97993 -0.05753 -0.13970 -0.06122
##
## Degrees of Freedom: 3999 Total (i.e. Null); 3994 Residual
## Null Deviance: 5518
## Residual Deviance: 4356 AIC: 4368
##
##
## attr(,"full_history")
## [1] FALSE
The g-functions can be used as input to a new policy evaluation:
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=(A=0) -0.06123 0.0881 -0.2339 0.1114 0.4871
or we can get the associated predicted values:
## Key: <id, stage>
## id stage g_0 g_1
## <int> <int> <num> <num>
## 1: 1 1 0.15628741 0.84371259
## 2: 1 2 0.08850558 0.91149442
## 3: 2 1 0.92994454 0.07005546
## 4: 2 2 0.92580890 0.07419110
## 5: 3 1 0.11184451 0.88815549
## 6: 3 2 0.08082666 0.91917334
Similarly, we can inspect the Q-functions using
get_q_functions()
:
## $stage_1
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) A1 L C B BBgroup2
## 0.232506 0.682422 0.454642 0.039021 -0.070152 -0.184704
## BBgroup3 A1:L A1:C A1:B A1:BBgroup2 A1:BBgroup3
## -0.171734 -0.010746 0.938791 0.003772 0.157200 0.270711
##
## Degrees of Freedom: 1999 Total (i.e. Null); 1988 Residual
## Null Deviance: 7689
## Residual Deviance: 3599 AIC: 6877
##
##
## $stage_2
## $model
##
## Call: NULL
##
## Coefficients:
## (Intercept) A1 L C B BBgroup2
## -0.043324 0.147356 0.002376 -0.042036 0.005331 -0.001128
## BBgroup3 A1:L A1:C A1:B A1:BBgroup2 A1:BBgroup3
## -0.108404 0.024424 0.962591 -0.059177 -0.102084 0.094688
##
## Degrees of Freedom: 1999 Total (i.e. Null); 1988 Residual
## Null Deviance: 3580
## Residual Deviance: 1890 AIC: 5588
##
##
## attr(,"full_history")
## [1] FALSE
Note that a model is trained for each stage. Again, we can predict
from the Q-models using predict()
.
Usually, we want to specify the nuisance models ourselves using the
g_models
and q_models
arguments:
pe1 <- policy_eval(pd,
policy = p1,
g_models = list(
g_sl(formula = ~ BB + L_1, SL.library = c("SL.glm", "SL.ranger")),
g_sl(formula = ~ BB + L_1 + C_2, SL.library = c("SL.glm", "SL.ranger"))
),
g_full_history = TRUE,
q_models = list(
q_glm(formula = ~ A * (B + C_1)), # including action interactions
q_glm(formula = ~ A * (B + C_1 + C_2)) # including action interactions
),
q_full_history = TRUE)
## Loading required namespace: ranger
Here we train a super learner g-model for each stage using the full
available history and a generalized linear model for the Q-models. The
formula
argument is used to construct the model frame
passed to the model for training (and prediction). The valid formula
terms depending on g_full_history
and
q_full_history
are available via
get_history_names()
:
## [1] "L" "C" "B" "BB"
## [1] "L_1" "C_1" "B" "BB"
## [1] "A_1" "L_1" "L_2" "C_1" "C_2" "B" "BB"
Remember that the action variable at the current stage is always
named A
. Some models like glm
require
interactions to be specified via the model frame. Thus, for some models,
it is important to include action interaction terms for the
Q-models.
The value of a learned policy is an important performance measure,
and policy_eval()
allow for direct evaluation of a given
policy learning algorithm. For details, see Algorithm 4 in (Nordland and Holst 2023).
In polle
, policy learning algorithms are specified using
policy_learn()
, see the associated vignette. These
functions can be directly evaluated in policy_eval()
:
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=ql 1.306 0.06641 1.176 1.437 3.783e-86
In the above example we evaluate the policy estimated via Q-learning.
Alternatively, we can first learn the policy and then pass it to
policy_eval()
:
p_ql <- policy_learn(type = "ql")(pd, q_models = q_glm())
policy_eval(pd,
policy = get_policy(p_ql))
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=ql 1.306 0.06641 1.176 1.437 3.783e-86
A key feature of policy_eval()
is that it allows for
easy cross-fitting of the nuisance models as well the learned policy.
Here we specify two-fold cross-fitting via the M
argument:
Specifically, both the nuisance models and the optimal policy are fitted on each training fold. Subsequently, the doubly robust value score is calculated on the validation folds.
The policy_eval
object now consists of a list of
policy_eval
objects associated with each fold:
## [1] 3 4 5 7 8 10
## Estimate Std.Err 2.5% 97.5% P-value
## E[Z(d)]: d=ql 1.261 0.09456 1.075 1.446 1.538e-40
In order to save memory, particularly when cross-fitting, it is
possible not to save the nuisance models via the
save_g_functions
and save_q_functions
arguments.
future.apply
It is easy to parallelize the cross-fitting procedure via the
future.apply
package:
## R version 4.4.1 (2024-06-14)
## Platform: aarch64-apple-darwin23.5.0
## Running under: macOS Sonoma 14.6.1
##
## Matrix products: default
## BLAS: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRblas.dylib
## LAPACK: /Users/oano/.asdf/installs/R/4.4.1/lib/R/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: Europe/Copenhagen
## tzcode source: internal
##
## attached base packages:
## [1] splines stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] ggplot2_3.5.1 data.table_1.15.4 polle_1.5
## [4] SuperLearner_2.0-29 gam_1.22-4 foreach_1.5.2
## [7] nnls_1.5
##
## loaded via a namespace (and not attached):
## [1] sass_0.4.9 utf8_1.2.4 future_1.33.2
## [4] lattice_0.22-6 listenv_0.9.1 digest_0.6.36
## [7] magrittr_2.0.3 evaluate_0.24.0 grid_4.4.1
## [10] iterators_1.0.14 mvtnorm_1.2-5 policytree_1.2.3
## [13] fastmap_1.2.0 jsonlite_1.8.8 Matrix_1.7-0
## [16] survival_3.6-4 fansi_1.0.6 scales_1.3.0
## [19] numDeriv_2016.8-1.1 codetools_0.2-20 jquerylib_0.1.4
## [22] lava_1.8.0 cli_3.6.3 rlang_1.1.4
## [25] mets_1.3.4 parallelly_1.37.1 future.apply_1.11.2
## [28] munsell_0.5.1 withr_3.0.0 cachem_1.1.0
## [31] yaml_2.3.8 tools_4.4.1 parallel_4.4.1
## [34] colorspace_2.1-0 ranger_0.16.0 globals_0.16.3
## [37] vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
## [40] pkgconfig_2.0.3 timereg_2.0.5 progressr_0.14.0
## [43] bslib_0.7.0 pillar_1.9.0 gtable_0.3.5
## [46] Rcpp_1.0.13 glue_1.7.0 xfun_0.45
## [49] tibble_3.2.1 highr_0.11 knitr_1.47
## [52] farver_2.1.2 htmltools_0.5.8.1 rmarkdown_2.27
## [55] labeling_0.4.3 compiler_4.4.1