-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dataframe constructor script for scoring (#403)
* fix test * breakup scoring into more re-usable atomic functions * make scoring dataframe function * remove unused kwarg * modify dataframe constructor script * Update create_analysis_dataframes.jl * Update create_analysis_dataframes.jl --------- Co-authored-by: Sam Abbott <azw1@cdc.gov>
- Loading branch information
1 parent
2b816a8
commit 2e30981
Showing
7 changed files
with
100 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
include("make_truthdata_dataframe.jl") | ||
include("make_prediction_dataframe_from_output.jl") | ||
include("make_scoring_dataframe_from_output.jl") |
57 changes: 57 additions & 0 deletions
57
pipeline/src/analysis/make_scoring_dataframe_from_output.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
Create a dataframe containing scoring results based on the given output and input data. | ||
NB: For non-Renewal infection generating processes (IGP), the function loops over | ||
different GI mean scenarios to generate the CRPS scores. The reason for this is that | ||
for these IGPs the choice of GI is not used in forward simulation, and so we calculate the | ||
effects on inference in post-inference. | ||
# Arguments | ||
- `filename`: The name of the file. | ||
- `output`: The output data containing inference configuration, IGP model, and other information. | ||
- `epi_datas`: The input data for the epidemiological model. | ||
# Returns | ||
A dataframe containing the CRPS scoring results. | ||
""" | ||
function make_scoring_dataframe_from_output(filename, output, epi_datas, pipelines) | ||
#Get the scenario, IGP model, latent model and true mean GI | ||
inference_config = output["inference_config"] | ||
igp_model = output["inference_config"].igp |> string | ||
scenario = EpiAwarePipeline._get_scenario_from_filename(filename, pipelines) | ||
latent_model = EpiAwarePipeline._get_latent_model_from_filename(filename) | ||
true_mean_gi = EpiAwarePipeline._get_true_gi_mean_from_filename(filename) | ||
|
||
#Get the quantiles for the targets across the gi mean scenarios | ||
#if Renewal model, then we use the underlying epi model | ||
#otherwise we use the epi datas to loop over different gi mean implications | ||
used_epi_datas = igp_model == "Renewal" ? [output["epiprob"].epi_model.data] : epi_datas | ||
|
||
try | ||
summaries = map(used_epi_datas) do epi_data | ||
summarise_crps(config, inference_results, forecast_results, epi_data) | ||
end | ||
used_gi_means = igp_model == "Renewal" ? | ||
[EpiAwarePipeline._get_used_gi_mean_from_filename(filename)] : | ||
make_gi_params(EpiAwareExamplePipeline())["gi_means"] | ||
|
||
#Create the dataframe columnwise | ||
df = mapreduce(vcat, summaries, used_gi_means) do summary, used_gi_mean | ||
_df = DataFrame() | ||
_df[!, "Scenario"] .= scenario | ||
_df[!, "IGP_Model"] .= igp_model | ||
_df[!, "Latent_Model"] .= latent_model | ||
_df[!, "True_GI_Mean"] .= true_mean_gi | ||
_df[!, "Used_GI_Mean"] .= used_gi_mean | ||
_df[!, "Reference_Time"] .= inference_config.tspan[2] | ||
for name in keys(summary) | ||
_df[!, name] = summary[name] | ||
end | ||
end | ||
return df | ||
catch | ||
@warn "Error in generating crps summaries for targets in file $filename" | ||
return nothing | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters