Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataframe constructor script for scoring #403

Merged
merged 7 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ epi_datas = map(gi_params["gi_means"]) do μ
Gamma(shape, scale)
end .|> gen_dist -> EpiData(gen_distribution = gen_dist)

## Calculate the prediction dataframe
prediction_df = mapreduce(vcat, files) do filename
## Calculate the prediction and scoring dataframes
double_vcat = (dfs1, dfs2) -> (
vcat(dfs1[1], dfs2[1]), vcat(dfs1[2], dfs2[2])
)

dfs = mapreduce(double_vcat, xs) do filename
output = load(joinpath(datadir("epiaware_observables"), filename))
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines)
(
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines),
make_scoring_dataframe_from_output(filename, output, epi_datas, pipelines)
)
end

## Save the prediction dataframe
CSV.write(plotsdir("analysis_df.csv"), prediction_df)
## Save the prediction and scoring dataframes
CSV.write(plotsdir("analysis_df.csv"), dfs[1])
CSV.write(plotsdir("scoring_df.csv"), dfs[2])
3 changes: 2 additions & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ export define_forecast_epiprob, generate_forecasts
export score_parameters, simple_crps, summarise_crps

# Exported functions: Analysis functions for constructing dataframes
export make_prediction_dataframe_from_output, make_truthdata_dataframe
export make_prediction_dataframe_from_output, make_truthdata_dataframe,
make_scoring_dataframe_from_output

# Exported functions: Make main plots
export figureone, figuretwo
Expand Down
1 change: 1 addition & 0 deletions pipeline/src/analysis/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("make_truthdata_dataframe.jl")
include("make_prediction_dataframe_from_output.jl")
include("make_scoring_dataframe_from_output.jl")
57 changes: 57 additions & 0 deletions pipeline/src/analysis/make_scoring_dataframe_from_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Create a dataframe containing scoring results based on the given output and input data.

NB: For non-Renewal infection generating processes (IGP), the function loops over
different GI mean scenarios to generate the CRPS scores. The reason for this is that
for these IGPs the choice of GI is not used in forward simulation, and so we calculate the
effects on inference in post-inference.

# Arguments
- `filename`: The name of the file.
- `output`: The output data containing inference configuration, IGP model, and other information.
- `epi_datas`: The input data for the epidemiological model.

# Returns
A dataframe containing the CRPS scoring results.

"""
function make_scoring_dataframe_from_output(filename, output, epi_datas, pipelines)
#Get the scenario, IGP model, latent model and true mean GI
inference_config = output["inference_config"]
igp_model = output["inference_config"].igp |> string
scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines)
latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename)
true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename)

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas

try
summaries = map(used_epi_datas) do epi_data
summarise_crps(config, inference_results, forecast_results, epi_data)
end
used_gi_means = igp_model == "Renewal" ?
[EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]

#Create the dataframe columnwise
df = mapreduce(vcat, summaries, used_gi_means) do summary, used_gi_mean
_df = DataFrame()
_df[!, "Scenario"] .= scenario
_df[!, "IGP_Model"] .= igp_model
_df[!, "Latent_Model"] .= latent_model
_df[!, "True_GI_Mean"] .= true_mean_gi
_df[!, "Used_GI_Mean"] .= used_gi_mean
_df[!, "Reference_Time"] .= inference_config.tspan[2]
for name in keys(summary)
_df[!, name] = summary[name]
end
end
return df
catch
@warn "Error in generating crps summaries for targets in file $filename"
return nothing
end
end
3 changes: 2 additions & 1 deletion pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ function infer(config::InferenceConfig)
forecast_results = generate_forecasts(
inference_results.samples, inference_results.data, epiprob, config.lookahead)

score_results = summarise_crps(config, inference_results, forecast_results, epiprob)
epidata = epiprob.epi_model.data
score_results = summarise_crps(config, inference_results, forecast_results, epidata)

return Dict("inference_results" => inference_results,
"epiprob" => epiprob, "inference_config" => config,
Expand Down
38 changes: 24 additions & 14 deletions pipeline/src/scoring/summarise_crps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ Summarizes the Continuous Ranked Probability Score (CRPS) for different processe
A dictionary containing the summarized CRPS scores for different processes.

"""
function summarise_crps(config, inference_results, forecast_results, epiprob)
function summarise_crps(config, inference_results, forecast_results, epidata)
ts = config.tspan[1]:min(config.tspan[2] + config.lookahead, length(config.truth_I_t))
epidata = epiprob.epi_model.data

procs_names = (:log_I_t, :rt, :Rt, :I_t, :log_Rt)
scores_log_I_t, scores_rt, scores_Rt, scores_I_t, scores_log_Rt = _process_crps_scores(
Expand All @@ -25,6 +24,26 @@ function summarise_crps(config, inference_results, forecast_results, epiprob)
"scores_y_t" => scores_y_t, "scores_log_y_t" => scores_log_y_t)
end

function _get_predicted_proc(inference_results, forecast_results, epidata, process)
gens = forecast_results.generated
log_I0s = inference_results.samples[:init_incidence]
predicted_proc = mapreduce(hcat, gens, log_I0s) do gen, logI0
I0 = exp(logI0)
It = gen.I_t
procs = calculate_processes(It, I0, epidata)
getfield(procs, process)
end
return predicted_proc
end

function _get_predicted_y_t(forecast_results)
gens = forecast_results.generated
predicted_y_t = mapreduce(hcat, gens) do gen
gen.generated_y_t
end
return predicted_y_t
end

"""
Internal method for calculating the CRPS scores for different processes.
"""
Expand All @@ -37,14 +56,8 @@ function _process_crps_scores(
config.truth_I_t[ts], true_Itminusone, epidata) |>
procs -> getfield(procs, process)
# predictions
gens = forecast_results.generated
log_I0s = inference_results.samples[:init_incidence]
predicted_proc = mapreduce(hcat, gens, log_I0s) do gen, logI0
I0 = exp(logI0)
It = gen.I_t
procs = calculate_processes(It, I0, epidata)
getfield(procs, process)
end
predicted_proc = _get_predicted_proc(
inference_results, forecast_results, epidata, process)
scores = [simple_crps(preds, true_proc[t])
for (t, preds) in enumerate(eachrow(predicted_proc))]
return scores
Expand All @@ -57,10 +70,7 @@ Internal method for calculating the CRPS scores for observed cases and log(cases
"""
function _cases_crps_scores(forecast_results, config, ts; jitter = 1e-6)
true_y_t = config.case_data[ts]
gens = forecast_results.generated
predicted_y_t = mapreduce(hcat, gens) do gen
gen.generated_y_t
end
predicted_y_t = _get_predicted_y_t(forecast_results)
scores_y_t = [simple_crps(preds, true_y_t[t])
for (t, preds) in enumerate(eachrow(predicted_y_t))]
scores_log_y_t = [simple_crps(log.(preds .+ jitter), log(true_y_t[t] + jitter))
Expand Down
2 changes: 1 addition & 1 deletion pipeline/test/utils/test_calculate_processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

data = EpiData(pmf, exp)

result = calculate_processes(I_t, I0, data, pipeline)
result = calculate_processes(I_t, I0, data)

# Check if the log of infections is calculated correctly
@testset "Log of infections" begin
Expand Down
Loading