In this example, we’re going to train a simple network to learn a new spiral trajectory using the adjoint method to train. We start by loading all of the libraries and setting our initial conditions, along with a few constants.
library(reticulate)
library(tensorflow)
library(tfNeuralODE)
library(keras)
library(deSolve)
# iterations, time span (layers)
niters = 25
t = seq(0, 25, by = 25/100)
# initial conditions
h0 = tf$cast(t(c(1., 0.)), dtype= tf$float32)
W = tf$cast(rbind(c(-0.1, 1.0), c(-0.2, -0.1)), dtype = tf$float32)
h0_var = tf$Variable(h0, name = "")
hN_target = tf$cast(t(c(0., 0.5)), dtype = tf$float32)
We solve for the initial trajectory.
trueODEfunc<- function(du, u, p, t){
true_A = rbind(c(-0.1, 1.0), c(-0.2, -0.1))
du <- (u) %*% true_A
return(list(du))
}
# solved ode output
init_path <- lsode(func = trueODEfunc, y = c(1., 0.), times = t)
Now we instantiate a very simple ODE model following that initial trajectory and an optimizer to train the ODE model with.
# ODE Model
optimizer = tf$keras$optimizers$legacy$SGD(learning_rate=1e-2, momentum=0.95)
OdeModel(keras$Model) %py_class% {
call <- function(inputs) {
tf$matmul(inputs, W)
}
}
model<- OdeModel()
Now we train the model, using 25 iterations. Each 5 iterations, the plot of the ODE will be produced.
for(i in 1:niters){
print(paste("Iteration", i, "out of", niters, "iterations."))
with(tf$GradientTape() %as% tape, {
pred = forward(model, inputs = h0_var, tsteps = t)
tape$watch(pred)
loss = tf$reduce_sum((hN_target - pred) ^ 2)
})
#print(paste("loss:", as.numeric(loss)))
dLoss = tape$gradient(loss, pred)
dfdh0 = backward(model, t, pred, output_gradients = dLoss)[[2]]
optimizer$apply_gradients(list(c(dfdh0, h0_var)))
# graphing the Neural ODE
if(i %% 5 == 0 || i == 1){
pred_y = forward(model = model, inputs = tf$cast((as.matrix(h0_var)), dtype = tf$float32),
tsteps = t, return_states = TRUE)
pred_y_c<- k_concatenate(pred_y[[2]], 1)
p_m<- as.matrix(pred_y_c)
plot(p_m,
main = paste("Iteration", i), type = "l", col = "red",
xlim = c(-1,2), ylim = c(-1,2))
lines(init_path[,2], init_path[,3], col = "blue")
}
}
#> [1] "Iteration 1 out of 25 iterations."
#> [1] "Iteration 2 out of 25 iterations."
#> [1] "Iteration 3 out of 25 iterations."
#> [1] "Iteration 4 out of 25 iterations."
#> [1] "Iteration 5 out of 25 iterations."
#> [1] "Iteration 6 out of 25 iterations."
#> [1] "Iteration 7 out of 25 iterations."
#> [1] "Iteration 8 out of 25 iterations."
#> [1] "Iteration 9 out of 25 iterations."
#> [1] "Iteration 10 out of 25 iterations."
#> [1] "Iteration 11 out of 25 iterations."
#> [1] "Iteration 12 out of 25 iterations."
#> [1] "Iteration 13 out of 25 iterations."
#> [1] "Iteration 14 out of 25 iterations."
#> [1] "Iteration 15 out of 25 iterations."
#> [1] "Iteration 16 out of 25 iterations."
#> [1] "Iteration 17 out of 25 iterations."
#> [1] "Iteration 18 out of 25 iterations."
#> [1] "Iteration 19 out of 25 iterations."
#> [1] "Iteration 20 out of 25 iterations."
#> [1] "Iteration 21 out of 25 iterations."
#> [1] "Iteration 22 out of 25 iterations."
#> [1] "Iteration 23 out of 25 iterations."
#> [1] "Iteration 24 out of 25 iterations."
#> [1] "Iteration 25 out of 25 iterations."