I am not necessarily helpful here - so apologies in advance.
I did have a PhD student that I advise code up a zero-inflated Poisson distribution. Her work is below. Maybe by looking at that example, it might just help with your code? I believe the example is taken from Richard McElreath’s blog (McElreath Oxen Example). Good luck!
library("greta")
library("R6")
library("tensorflow")
#set up
distribution_node <- greta::.internals$nodes$node_classes$distribution_node
as.greta_array <- greta::.internals$greta_arrays$as.greta_array
check_dims <- greta::.internals$utils$checks$check_dims
distrib <- greta::.internals$nodes$constructors$distrib
fl <- greta::.internals$utils$misc$fl
tf_sum <- greta:::tf_sum
tf_prod <- greta:::tf_prod
tf_max <- greta:::tf_max
zero_inflated_distribution <- R6Class(
"zero_inflated_distribution",
inherit = distribution_node,
public = list(
initialize = function(prob, rate, dim) {
prob <- as.greta_array(prob)
rate <- as.greta_array(rate)
# add the nodes as children and parameters
dim <- check_dims(prob, rate, target_dim = dim)
super$initialize("zero_inflated", dim, discrete = TRUE)
self$add_parameter(prob, "prob")
self$add_parameter(rate, "rate")
},
tf_distrib = function(parameters, dag) {
prob <- parameters$prob
rate <- parameters$rate
log_prob <- function(x) {
#using relu
tf$log(prob * tf$cast(tf$nn$relu(fl(1) - x), tf$float64) + (fl(1) - prob) * tf$pow(rate, x) * tf$exp(-rate) / tf$exp(tf$lgamma(x + fl(1))))
#attempts that did not work
#tf$log(prob * (fl(1) - tf$cast(tf$count_nonzero(c(x)), tf$float64)) + (fl(1) - prob) * tf$pow(rate, x) * tf$exp(-rate) / tf$exp(tf$lgamma(x + fl(1))))
#tf$log(tf$to_float(prob) * tf$to_float(tf$to_float(1) - tf$to_float(tf$count_nonzero(tf$to_float(x)))) + tf$to_float((tf$to_float(1) - tf$to_float(prob)) * (tf$to_float(tf$pow(tf$to_float(rate), tf$to_float(x))) * tf$to_float(tf$exp(tf$to_float(-rate)))) / tf$to_float(tf$exp(tf$lgamma(tf$to_float(x) + tf$to_float(1))))))
#tf$log(prob * tf$to_float(tf_max(c(tf$to_float(1) - tf$to_float(x), tf$to_float(0)))) + tf$to_float(tf$to_float(1) - prob) * tf$to_float(tf$pow(rate, x) * tf$exp(-rate)) / (tf$exp(tf$lgamma(tf$to_float(x) + tf$to_float(1)))))
}
# attempt that works but did not converge
#tf$log(prob * tf_max(fl(1) - x, 0) + (fl(1) - prob) * (tf$pow(rate, x) * tf$exp(-rate)) / (tf$exp(tf$lgamma(x + fl(1)))))
list(log_prob = log_prob, cdf = NULL, log_cdf = NULL)
},
tf_cdf_function = NULL,
tf_log_cdf_function = NULL
)
)
zero_inflated <- function (prob, rate, dim = NULL) {
distrib("zero_inflated", prob, rate, dim)
}
#Data
#simulated data
## define parameters
prob_drink <- 0.2 # 20% of days
rate_work <- 1 # average 1 manuscript per day
## sample one year of production
N <- 365
## simulate days monks drink
drink <- rbinom( N , 1 , prob_drink )
## simulate manuscripts completed
y <- (1-drink)*rpois( N , rate_work )
# prior
log_rate <- normal(0, 10)
logit_drink <- normal(0, 1)
rate <- exp(log_rate)
drink <- ilogit(logit_drink)
# likelihood
distribution(y) <- zero_inflated(drink, rate)
#model
m <- model(logit_drink, log_rate)
#plot
plot(m)
#samples
samples <- greta::mcmc(m, warmup = 400, n_samples = 4000)
#summary
summary(samples)