Skip to content

Commit

Permalink
Dataframe constructor script for scoring (#403)
Browse files Browse the repository at this point in the history
* fix test

* breakup scoring into more re-usable atomic functions

* make scoring dataframe function

* remove unused kwarg

* modify dataframe constructor script

* Update create_analysis_dataframes.jl

* Update create_analysis_dataframes.jl

---------

Co-authored-by: Sam Abbott <azw1@cdc.gov>
  • Loading branch information
SamuelBrand1 and seabbs authored Jul 28, 2024
1 parent 2b816a8 commit 2e30981
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ epi_datas = map(gi_params["gi_means"]) do μ
Gamma(shape, scale)
end .|> gen_dist -> EpiData(gen_distribution = gen_dist)

## Calculate the prediction dataframe
prediction_df = mapreduce(vcat, files) do filename
## Calculate the prediction and scoring dataframes
double_vcat = (dfs1, dfs2) -> (
vcat(dfs1[1], dfs2[1]), vcat(dfs1[2], dfs2[2])
)

dfs = mapreduce(double_vcat, xs) do filename
output = load(joinpath(datadir("epiaware_observables"), filename))
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines)
(
make_prediction_dataframe_from_output(filename, output, epi_datas, pipelines),
make_scoring_dataframe_from_output(filename, output, epi_datas, pipelines)
)
end

## Save the prediction dataframe
CSV.write(plotsdir("analysis_df.csv"), prediction_df)
## Save the prediction and scoring dataframes
CSV.write(plotsdir("analysis_df.csv"), dfs[1])
CSV.write(plotsdir("scoring_df.csv"), dfs[2])
3 changes: 2 additions & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ export define_forecast_epiprob, generate_forecasts
export score_parameters, simple_crps, summarise_crps

# Exported functions: Analysis functions for constructing dataframes
export make_prediction_dataframe_from_output, make_truthdata_dataframe
export make_prediction_dataframe_from_output, make_truthdata_dataframe,
make_scoring_dataframe_from_output

# Exported functions: Make main plots
export figureone, figuretwo
Expand Down
1 change: 1 addition & 0 deletions pipeline/src/analysis/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("make_truthdata_dataframe.jl")
include("make_prediction_dataframe_from_output.jl")
include("make_scoring_dataframe_from_output.jl")
57 changes: 57 additions & 0 deletions pipeline/src/analysis/make_scoring_dataframe_from_output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Create a dataframe containing scoring results based on the given output and input data.
NB: For non-Renewal infection generating processes (IGP), the function loops over
different GI mean scenarios to generate the CRPS scores. The reason for this is that
for these IGPs the choice of GI is not used in forward simulation, and so we calculate the
effects on inference in post-inference.
# Arguments
- `filename`: The name of the file.
- `output`: The output data containing inference configuration, IGP model, and other information.
- `epi_datas`: The input data for the epidemiological model.
# Returns
A dataframe containing the CRPS scoring results.
"""
function make_scoring_dataframe_from_output(filename, output, epi_datas, pipelines)
#Get the scenario, IGP model, latent model and true mean GI
inference_config = output["inference_config"]
igp_model = output["inference_config"].igp |> string
scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines)
latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename)
true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename)

#Get the quantiles for the targets across the gi mean scenarios
#if Renewal model, then we use the underlying epi model
#otherwise we use the epi datas to loop over different gi mean implications
used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas

try
summaries = map(used_epi_datas) do epi_data
summarise_crps(config, inference_results, forecast_results, epi_data)
end
used_gi_means = igp_model == "Renewal" ?
[EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] :
make_gi_params(EpiAwareExamplePipeline())["gi_means"]

#Create the dataframe columnwise
df = mapreduce(vcat, summaries, used_gi_means) do summary, used_gi_mean
_df = DataFrame()
_df[!, "Scenario"] .= scenario
_df[!, "IGP_Model"] .= igp_model
_df[!, "Latent_Model"] .= latent_model
_df[!, "True_GI_Mean"] .= true_mean_gi
_df[!, "Used_GI_Mean"] .= used_gi_mean
_df[!, "Reference_Time"] .= inference_config.tspan[2]
for name in keys(summary)
_df[!, name] = summary[name]
end
end
return df
catch
@warn "Error in generating crps summaries for targets in file $filename"
return nothing
end
end
3 changes: 2 additions & 1 deletion pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ function infer(config::InferenceConfig)
forecast_results = generate_forecasts(
inference_results.samples, inference_results.data, epiprob, config.lookahead)

score_results = summarise_crps(config, inference_results, forecast_results, epiprob)
epidata = epiprob.epi_model.data
score_results = summarise_crps(config, inference_results, forecast_results, epidata)

return Dict("inference_results" => inference_results,
"epiprob" => epiprob, "inference_config" => config,
Expand Down
38 changes: 24 additions & 14 deletions pipeline/src/scoring/summarise_crps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ Summarizes the Continuous Ranked Probability Score (CRPS) for different processe
A dictionary containing the summarized CRPS scores for different processes.
"""
function summarise_crps(config, inference_results, forecast_results, epiprob)
function summarise_crps(config, inference_results, forecast_results, epidata)
ts = config.tspan[1]:min(config.tspan[2] + config.lookahead, length(config.truth_I_t))
epidata = epiprob.epi_model.data

procs_names = (:log_I_t, :rt, :Rt, :I_t, :log_Rt)
scores_log_I_t, scores_rt, scores_Rt, scores_I_t, scores_log_Rt = _process_crps_scores(
Expand All @@ -25,6 +24,26 @@ function summarise_crps(config, inference_results, forecast_results, epiprob)
"scores_y_t" => scores_y_t, "scores_log_y_t" => scores_log_y_t)
end

function _get_predicted_proc(inference_results, forecast_results, epidata, process)
gens = forecast_results.generated
log_I0s = inference_results.samples[:init_incidence]
predicted_proc = mapreduce(hcat, gens, log_I0s) do gen, logI0
I0 = exp(logI0)
It = gen.I_t
procs = calculate_processes(It, I0, epidata)
getfield(procs, process)
end
return predicted_proc
end

function _get_predicted_y_t(forecast_results)
gens = forecast_results.generated
predicted_y_t = mapreduce(hcat, gens) do gen
gen.generated_y_t
end
return predicted_y_t
end

"""
Internal method for calculating the CRPS scores for different processes.
"""
Expand All @@ -37,14 +56,8 @@ function _process_crps_scores(
config.truth_I_t[ts], true_Itminusone, epidata) |>
procs -> getfield(procs, process)
# predictions
gens = forecast_results.generated
log_I0s = inference_results.samples[:init_incidence]
predicted_proc = mapreduce(hcat, gens, log_I0s) do gen, logI0
I0 = exp(logI0)
It = gen.I_t
procs = calculate_processes(It, I0, epidata)
getfield(procs, process)
end
predicted_proc = _get_predicted_proc(
inference_results, forecast_results, epidata, process)
scores = [simple_crps(preds, true_proc[t])
for (t, preds) in enumerate(eachrow(predicted_proc))]
return scores
Expand All @@ -57,10 +70,7 @@ Internal method for calculating the CRPS scores for observed cases and log(cases
"""
function _cases_crps_scores(forecast_results, config, ts; jitter = 1e-6)
true_y_t = config.case_data[ts]
gens = forecast_results.generated
predicted_y_t = mapreduce(hcat, gens) do gen
gen.generated_y_t
end
predicted_y_t = _get_predicted_y_t(forecast_results)
scores_y_t = [simple_crps(preds, true_y_t[t])
for (t, preds) in enumerate(eachrow(predicted_y_t))]
scores_log_y_t = [simple_crps(log.(preds .+ jitter), log(true_y_t[t] + jitter))
Expand Down
2 changes: 1 addition & 1 deletion pipeline/test/utils/test_calculate_processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

data = EpiData(pmf, exp)

result = calculate_processes(I_t, I0, data, pipeline)
result = calculate_processes(I_t, I0, data)

# Check if the log of infections is calculated correctly
@testset "Log of infections" begin
Expand Down

0 comments on commit 2e30981

Please sign in to comment.