tfNeuralODE-Adjoint

In this example, we’re going to train a simple network to learn a new spiral trajectory using the adjoint method to train. We start by loading all of the libraries and setting our initial conditions, along with a few constants.

library(reticulate)
library(tensorflow)
library(tfNeuralODE)
library(keras)
library(deSolve)

# iterations, time span (layers)
niters = 25
t = seq(0, 25, by = 25/100)
# initial conditions
h0 = tf$cast(t(c(1., 0.)), dtype= tf$float32)
W = tf$cast(rbind(c(-0.1, 1.0), c(-0.2, -0.1)), dtype = tf$float32)
h0_var = tf$Variable(h0, name = "")
hN_target = tf$cast(t(c(0., 0.5)), dtype = tf$float32)

We solve for the initial trajectory.

trueODEfunc<- function(du, u, p, t){
  true_A = rbind(c(-0.1, 1.0), c(-0.2, -0.1))
  du <- (u) %*% true_A
  return(list(du))
}

# solved ode output
init_path <- lsode(func = trueODEfunc, y = c(1., 0.), times = t)

Now we instantiate a very simple ODE model following that initial trajectory and an optimizer to train the ODE model with.

# ODE Model

optimizer = tf$keras$optimizers$legacy$SGD(learning_rate=1e-2, momentum=0.95)

OdeModel(keras$Model) %py_class% {
  call <- function(inputs) {
    tf$matmul(inputs, W)
  }
}
model<- OdeModel()

Now we train the model, using 25 iterations. Each 5 iterations, the plot of the ODE will be produced.

for(i in 1:niters){
  print(paste("Iteration", i, "out of", niters, "iterations."))
  with(tf$GradientTape() %as% tape, {
    pred = forward(model, inputs = h0_var, tsteps = t)
    tape$watch(pred)
    loss = tf$reduce_sum((hN_target - pred) ^ 2)
  })
  #print(paste("loss:", as.numeric(loss)))
  dLoss = tape$gradient(loss, pred)
  dfdh0 = backward(model, t, pred, output_gradients = dLoss)[[2]]
  optimizer$apply_gradients(list(c(dfdh0, h0_var)))

  # graphing the Neural ODE
  if(i %% 5 == 0 || i == 1){
    pred_y = forward(model = model, inputs = tf$cast((as.matrix(h0_var)), dtype = tf$float32),
                     tsteps = t, return_states = TRUE)
    pred_y_c<- k_concatenate(pred_y[[2]], 1)
    p_m<- as.matrix(pred_y_c)
    plot(p_m,
         main = paste("Iteration", i), type = "l", col = "red",
         xlim = c(-1,2), ylim = c(-1,2))
    lines(init_path[,2], init_path[,3], col = "blue")
  }
}
#> [1] "Iteration 1 out of 25 iterations."
plot of iteration 1
plot of iteration 1
#> [1] "Iteration 2 out of 25 iterations."
#> [1] "Iteration 3 out of 25 iterations."
#> [1] "Iteration 4 out of 25 iterations."
#> [1] "Iteration 5 out of 25 iterations."
plot of iteration 5
plot of iteration 5
#> [1] "Iteration 6 out of 25 iterations."
#> [1] "Iteration 7 out of 25 iterations."
#> [1] "Iteration 8 out of 25 iterations."
#> [1] "Iteration 9 out of 25 iterations."
#> [1] "Iteration 10 out of 25 iterations."
plot of iteration 10
plot of iteration 10
#> [1] "Iteration 11 out of 25 iterations."
#> [1] "Iteration 12 out of 25 iterations."
#> [1] "Iteration 13 out of 25 iterations."
#> [1] "Iteration 14 out of 25 iterations."
#> [1] "Iteration 15 out of 25 iterations."
plot of iteration 15
plot of iteration 15
#> [1] "Iteration 16 out of 25 iterations."
#> [1] "Iteration 17 out of 25 iterations."
#> [1] "Iteration 18 out of 25 iterations."
#> [1] "Iteration 19 out of 25 iterations."
#> [1] "Iteration 20 out of 25 iterations."
plot of iteration 20
plot of iteration 20
#> [1] "Iteration 21 out of 25 iterations."
#> [1] "Iteration 22 out of 25 iterations."
#> [1] "Iteration 23 out of 25 iterations."
#> [1] "Iteration 24 out of 25 iterations."
#> [1] "Iteration 25 out of 25 iterations."
plot of iteration 25
plot of iteration 25