# -------------------------------------------------------------
# Multipliers from shocks vs. rule changes in R
# -------------------------------------------------------------
suppressPackageStartupMessages(library(ggplot2))

simulate_system <- function(T=2000, burn=200,
                            beta=0.5, rho=0.6, gamma=0.2,
                            su2=1, sv2=1, suv=0) {
  Sigma <- matrix(c(su2, suv, suv, sv2), 2, 2)
  L <- chol(Sigma)
  y <- x <- numeric(T + burn)
  for (t in 2:(T + burn)) {
    e <- drop(L %*% rnorm(2))
    u <- e[1]; v <- e[2]
    x[t] <- rho * x[t-1] + gamma * y[t-1] + v
    y[t] <- beta * x[t] + u
  }
  data.frame(t = 1:T, y = y[-(1:burn)], x = x[-(1:burn)])
}

estimate_pedagogical <- function(df) {
  fit_y <- lm(y ~ x, data=df)
  lag <- transform(df, x_l1 = dplyr::lag(x), y_l1 = dplyr::lag(y))
  lag <- lag[complete.cases(lag), ]
  fit_x <- lm(x ~ x_l1 + y_l1, data=lag)
  list(beta=coef(fit_y)["x"],
       rho=coef(fit_x)["x_l1"],
       gamma=coef(fit_x)["y_l1"])
}

irf_struct <- function(beta, rho, gamma, H=20) {
  A0 <- matrix(c(1, -beta, 0, 1), 2, 2, byrow=TRUE)
  A1 <- matrix(c(0,0,gamma,rho), 2, 2, byrow=TRUE)
  F <- solve(A0) %*% A1; B <- solve(A0)
  e_v <- c(0,1)
  irf <- matrix(0, nrow=H+1, ncol=2)
  Phi <- diag(2)
  for (h in 0:H) {
    irf[h+1,] <- drop(Phi %*% B %*% e_v)
    Phi <- F %*% Phi
  }
  irf <- as.data.frame(irf)
  names(irf) <- c("y","x"); irf$h <- 0:H
  irf$cum_y <- cumsum(irf$y)
  irf
}

set.seed(123)
beta_true <- 0.5; rho_true <- 0.6; gamma_true <- 0.2
df <- simulate_system(beta=beta_true, rho=rho_true, gamma=gamma_true)
est <- estimate_pedagogical(df)
H <- 20
irf_base <- irf_struct(est$beta, est$rho, est$gamma, H)
rho_policy <- 0.85
irf_policy <- irf_struct(est$beta, rho_policy, est$gamma, H)
irf_base$scenario <- "Baseline rule"
irf_policy$scenario <- "Higher persistence rule"
irf_all <- rbind(irf_base, irf_policy)

ggplot(irf_all, aes(h, cum_y, linetype=scenario)) +
  geom_line() + theme_bw() +
  labs(title="Cumulative multiplier M_h", y="Cumulative y-response")

M_inf <- function(beta, rho, gamma) beta / (1 - rho - gamma*beta)
c(M_inf_base = M_inf(est$beta, est$rho, est$gamma),
  M_inf_policy = M_inf(est$beta, rho_policy, est$gamma))

