## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  warning = FALSE,
  message = FALSE
)

## ----setup--------------------------------------------------------------------
library(spFFBS)
library(spBPS)

## ----echo=FALSE, warning=FALSE, message=FALSE---------------------------------
library(mniw)
library(MBA)
library(abind)
library(ggplot2)
library(patchwork)
library(reshape2)

## ----results=F----------------------------------------------------------------
# Dimensions
tmax <- 25
tnew <- 5
n    <- 150
q    <- 2
p    <- 2
u    <- 50

# Parameters
Sigma  <- matrix(c(1, -0.3, -0.3, 1), q, q)
phi <- 8
alfa <- 0.8
a <- ((1/alfa)-1)
V <- a*diag((n+u))

set.seed(1)
# Generate constant sinthetic data structure
coords <- matrix(runif((n+u) * 2), ncol = 2)
D <- spBPS:::arma_dist(coords)
K <- exp(-phi * D)
W <- rbind( cbind(diag(p), matrix(0, p, (n+u))), cbind(matrix(0, (n+u), p), K) )

# Prior information and initial state
m0     <- matrix(0, (n+u)+p, q)
C0 <- rbind( cbind(diag(0.005, p), matrix(0, p, (n+u))), cbind(matrix(0, (n+u), p), K) )
theta0 <- mniw::rMNorm(n = 1, Lambda = m0, SigmaR = C0, SigmaC = Sigma)

# Generate dynamic sinthetic data structure
G <- array(0, c((n+u)+p, (n+u)+p, tmax+tnew))
theta <- array(0, c((n+u)+p, q, tmax+tnew))
X <- array(0, c((n+u), p, tmax+tnew))
P <- array(0, c((n+u), (n+u)+p, tmax+tnew))
Y <- array(0, c((n+u), q, tmax+tnew))

set.seed(1)
for (t in 1:(tmax+tnew)) {
  if (t >= 2) {  
    
  G[,,t]     <- diag(p+n+u)
  theta[,,t] <- G[,,t] %*% theta[,,t-1] + mniw::rMNorm(n = 1, Lambda = m0, SigmaR = W, SigmaC = Sigma)
  X[,,t]     <- matrix(runif((n+u)*p), (n+u), p)
  P[,,t]     <- cbind(X[,,t], diag((n+u)))
  Y[,,t]     <- P[,,t] %*% theta[,,t] + mniw::rMNorm(n = 1, Lambda = matrix(0, (n+u), q), SigmaR = V, SigmaC = Sigma)
  } 
  else {
    
  G[,,t]     <- diag(p+n+u)
  theta[,,t] <- G[,,t] %*% theta0 + mniw::rMNorm(n = 1, Lambda = m0, SigmaR = W, SigmaC = Sigma)
  X[,,t]     <- matrix(runif((n+u)*p), (n+u), p)
  P[,,t]     <- cbind(X[,,t], diag((n+u)))
  Y[,,t]     <- P[,,t] %*% theta[,,t] + mniw::rMNorm(n = 1, Lambda = matrix(0, (n+u), q), SigmaR = V, SigmaC = Sigma)
  }
}

# Unobserved data
Yfuture <- Y[(1:n),,(tmax+1):(tmax+tnew)]
Ytilde <- Y[-(1:n),,]
thetatilde <- theta[-(1:(n+p)),,]
Xtilde <- X[-(1:n),,]
crdtilde <- coords[-(1:n),]
Dtilde   <- as.matrix(dist(crdtilde))
Ktilde   <- exp(-phi*Dtilde)

# Observed data
Y <- Y[(1:n),,1:tmax]
X <- X[(1:n),,]
P     <- P[(1:n),1:(n+p),]
G     <- G[1:(n+p), 1:(n+p),]
crd <- coords[1:n,]
D   <- as.matrix(dist(crd))
K   <- exp(-D)
W <- rbind( cbind(diag(p), matrix(0, p, n)), cbind(matrix(0, n, p), K) )
V <- a*diag((n))

## ----results=F----------------------------------------------------------------
# Priors
m0     <- matrix(0, n+p, q)
C0 <- rbind( cbind(diag(0.005, p), matrix(0, p, n)), cbind(matrix(0, n, p), K) )
nu0 <- 3
Psi0 <- diag(q)
prior <- list("m" = m0, "C" = C0, "nu" = nu0, "Psi" = Psi0)

# hyperparameters values
alfa_seq <- c(0.7, 0.8, 0.9)
phi_seq <- c(6, 8, 9)
par_grid <- list(tau = alfa_seq, phi = phi_seq)

## ----results=F----------------------------------------------------------------
out <- spFFBS::spFFBS(Y = Y, G = G, P = P, D = D,
                      grid = par_grid, 
                      prior = prior,
                      L = 200,
                      do_BS = T, 
                      do_forecast = T, 
                      tnew = tnew,
                      do_spatial = T,
                      spatial = list(crd = crd,
                                     crdtilde = crdtilde,
                                     Xtilde = Xtilde,
                                     t = tmax+tnew),
                      num_threads = 1)

## -----------------------------------------------------------------------------
theta_post <- sapply(1:tmax, function(t){ out$BS[[t]] }, simplify = "array")
beta_post <- theta_post[1:p, 1:q,,]

## ----echo=F, fig.dim = c(7.25, 5)---------------------------------------------
# 1. Configurazione margini
oldpar <- par(no.readonly = TRUE)
# oma = 7 lascia lo spazio necessario a destra per la legenda
par(mfrow = c(p, q), mar = c(3, 3, 2, 1), oma = c(0, 0, 0, 7))

for (i in 1:p) {
  for (j in 1:q) {
    quants <- t(apply(theta_post[i,j,,], 2, quantile, c(0.025, 0.5, 0.975)))
    
    plot(1:nrow(quants), quants[,2], type = "n", 
         ylim = range(quants, theta[i,j,]),
         xlab = "", ylab = "", 
         main = bquote(hat(Beta)[list(.(i),.(j))]),
         cex.main = 0.9)
    
    polygon(c(1:nrow(quants), rev(1:nrow(quants))), c(quants[,1], rev(quants[,3])), 
            col = rgb(0, 0, 1, 0.15), border = NA)
    lines(1:nrow(quants), quants[,1], col = 4, lty = 2, lwd = 0.7)
    lines(1:nrow(quants), quants[,3], col = 4, lty = 2, lwd = 0.7)
    lines(1:nrow(quants), quants[,2], col = 4, lwd = 1.5)
    lines(1:tmax, theta[i,j,1:tmax], col = 2, lwd = 1.2)
  }
}

# 2. Posizionamento Legenda "Ancorato al Dispositivo"
par(xpd = NA)

# Usiamo NDC per l'intera pagina: 
# x = 0.92 (circa centro dello spazio OMA a destra)
# y = 0.5  (centro esatto verticale della figura)
legend(x = grconvertX(0.92, from = "ndc"), 
       y = grconvertY(0.5, from = "ndc"),  
       legend = c("True", "MAP", "95% CI"),
       col = c(2, 4, rgb(0, 0, 1, 0.15)), 
       lty = c(1, 1, 1), 
       lwd = c(1.2, 1.5, NA),
       pch = c(NA, NA, 15),
       pt.cex = c(0, 0, 1.8),
       bty = "n", 
       xjust = 0.5, # Centra la legenda rispetto al punto X
       yjust = 0.5, # Centra la legenda rispetto al punto Y
       cex = 0.8)

par(oldpar)

## ----echo=FALSE---------------------------------------------------------------
# Global weights
Wglobal <- out$Wglobal
models_labels <- attr(out$Wi, "model_labels")
# rownames(Wglobal) <- models_labels
# Wglobal

## -----------------------------------------------------------------------------
# Global weights
Wglobal <- out$Wglobal
J <- nrow(Wglobal)
L <- 200

# Posterior sampling
set.seed(1)
indL <- sample(1:J, L, Wglobal, rep = T)
Sigma_post <- sapply(1:L, function(l) {
  mniw::riwish(1, nu = out$FF[[tmax]]$filtered_results[[indL[l]]]$nu,
               Psi = out$FF[[tmax]]$filtered_results[[indL[l]]]$Psi) },
  simplify = "array")
Sigma_map <- apply(Sigma_post, c(1,2), median)

## ----echo=F, fig.dim = c(7.25, 3)---------------------------------------------
# Plotting
# 2. Calcoliamo il range comune per la legenda
all_values <- c(as.vector(Sigma), as.vector(Sigma_map))
common_limits <- range(all_values)

# 3. Funzione per creare i singoli heatmap
plot_matrix <- function(mat, titolo, limits) {
  df <- melt(mat)
  ggplot(df, aes(Var2, Var1, fill = value)) +
    geom_tile(color = "white", size = 0.5) +
    geom_text(aes(label = sprintf("%.2f", value)), color = "black", size = 5) +
    # Usiamo limits prefissati per uniformare la scala
    scale_fill_gradient2(low = "#3B9AB2", mid = "#EEEEEE", high = "#F21A00", 
                         midpoint = 0, limits = limits) +
    labs(title = titolo, x = NULL, y = NULL, fill = "Valore") +
    theme_minimal() +
    scale_y_reverse() +
    coord_fixed() +
    theme(plot.title = element_text(hjust = 0.5, size = 16))
}

# 4. Generazione dei grafici con bquote (resa perfetta di Sigma e hat)
p1 <- plot_matrix(Sigma, bquote(Sigma), common_limits)
p2 <- plot_matrix(Sigma_map, bquote(hat(Sigma)), common_limits)

# 5. Unione con patchwork: 'guides = "collect"' unifica la legenda
(p1 | p2) + 
  plot_layout(guides = "collect") +
  plot_annotation(title = "",
                  theme = theme(plot.title = element_text(size = 18, face = "bold", hjust = 0.5)))

## -----------------------------------------------------------------------------
Y_forc <- out$forecast$Y_pred

## ----echo=F, fig.dim = c(7.25, 3)---------------------------------------------
oldpar <- par(no.readonly = TRUE)
# Ridotto oma da 8 a 5 per avvicinare la legenda ai plot
par(mfrow = c(1, 2), mar = c(3, 3, 2, 1), oma = c(0, 0, 0, 5))

site_idx <- 1
t_obs <- tmax
t_total <- tmax + tnew 

for (v in 1:2) {
  pred_samps <- Y_forc[site_idx, v, , ] 
  quants <- t(apply(pred_samps, 2, quantile, c(0.025, 0.5, 0.975)))
  
  y_true_path <- c(Y[site_idx, v, ], Yfuture[site_idx, v, ])
  x_axis <- 1:t_total
  
  plot(x_axis, quants[,2], type = "n", 
       ylim = range(quants, y_true_path),
       xlab = "", ylab = "", 
       main = bquote(hat(Y)[.(v)] ~ " Forecast"), cex.main = 0.9)
  
  polygon(c(x_axis, rev(x_axis)), c(quants[,1], rev(quants[,3])), 
          col = rgb(0, 0, 1, 0.15), border = NA)
  
  lines(x_axis, quants[,1], col = 4, lty = 2, lwd = 0.7)
  lines(x_axis, quants[,3], col = 4, lty = 2, lwd = 0.7)
  lines(x_axis, quants[,2], col = 4, lwd = 1.5) 
  lines(x_axis, y_true_path, col = 2, lwd = 1.2)
  
  abline(v = t_obs, lty = 3, col = "grey50") 
}

par(xpd = NA)

# x = 0.88 invece di 0.92 per "attaccare" la legenda ai grafici
legend(x = grconvertX(0.88, from = "ndc"), 
       y = grconvertY(0.5, from = "ndc"),  
       legend = c("True", "MAP", "95% CI"),
       col = c(2, 4, rgb(0, 0, 1, 0.15)), 
       lty = c(1, 1, 1), 
       lwd = c(1.2, 1.5, NA),
       pch = c(NA, NA, 15),
       pt.cex = c(0, 0, 1.5),
       bty = "n", 
       xjust = 0,    # Allineamento a sinistra del punto X per guadagnare spazio
       yjust = 0.5, 
       cex = 0.7)    # Testo leggermente più piccolo per non rubare spazio
par(oldpar)

## -----------------------------------------------------------------------------
Y_pred <- sapply(1:L, function(l){out$spatial[[1]][[l]][1:u,]}, simplify = "array")

## ----echo=F, fig.dim = c(7.25, 6)---------------------------------------------
plot_rows <- list()
grid_res <- 100

for (v in 1:2) {
  z_true <- Ytilde[, v, tmax + tnew]
  z_pred <- apply(Y_pred[, v, ], 1, mean)
  
  surf_t <- mba.surf(cbind(crdtilde, z_true), grid_res, grid_res, extend=TRUE)$xyz.est
  surf_p <- mba.surf(cbind(crdtilde, z_pred), grid_res, grid_res, extend=TRUE)$xyz.est
  
  df_t <- data.frame(expand.grid(x=surf_t$x, y=surf_t$y), z=as.vector(surf_t$z))
  df_p <- data.frame(expand.grid(x=surf_p$x, y=surf_p$y), z=as.vector(surf_p$z))
  lims <- range(c(df_t$z, df_p$z), na.rm=TRUE)
  
  make_spat_plot <- function(df, title) {
    ggplot(df, aes(x, y, fill = z)) + 
      geom_raster(interpolate = TRUE) +
      scale_fill_distiller(palette = "RdBu", limits = lims, name = NULL) +
      # Rimuove lo spazio tra dati e assi
      scale_x_continuous(expand = c(0, 0)) + 
      scale_y_continuous(expand = c(0, 0)) +
      labs(title = title, x = "Easting", y = "Northing") +
      theme_minimal() +
      theme(panel.grid = element_blank(),
            panel.border = element_rect(color = "black", fill = NA, size = 1)) 
  }
  
  p1 <- make_spat_plot(df_t, bquote(Y[list(t+.(tnew),.(v))] ~ "(True)"))
  p2 <- make_spat_plot(df_p, bquote(hat(Y)[list(t+.(tnew),.(v))] ~ "(MAP)"))
  
  plot_rows[[v]] <- (p1 + p2) + plot_layout(guides = "collect")
}

wrap_plots(plot_rows, ncol = 1)

