Skip to content

Commit

Permalink
378 fig2 code (#387)
Browse files Browse the repository at this point in the history
* new figure 1 function

* add reference date horizon as vline

* reformat

* move legend to under plot

* rename truthdata because its getting reused

* Create create_figure2.jl

* Create figuretwo.jl

* refactor for common df checking across plots

* Update figureone.jl

* add figure 2 plots

* reformat
  • Loading branch information
SamuelBrand1 authored Jul 19, 2024
1 parent fb8cbed commit af5291a
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 59 deletions.
50 changes: 50 additions & 0 deletions pipeline/scripts/create_figure2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## Script to make figure 2 and alternate latent models for SI
using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))

using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta,
Statistics, Distributions, CSV

##
pipelines = [
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(),
SmoothEndemicPipeline(), RoughEndemicPipeline()]

## load some data and create a dataframe for the plot
truth_data_files = readdir(datadir("truth_data")) |>
strs -> filter(s -> occursin("jld2", s), strs)
analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame
truth_df = mapreduce(vcat, truth_data_files) do filename
D = JLD2.load(joinpath(datadir("truth_data"), filename))
make_truthdata_dataframe(filename, D, pipelines)
end

# Define scenario titles and reference times for figure 2
scenario_dict = Dict(
"measures_outbreak" => (title = "Outbreak with measures", T = 28),
"smooth_outbreak" => (title = "Outbreak no measures", T = 35),
"smooth_endemic" => (title = "Smooth endemic", T = 35),
"rough_endemic" => (title = "Rough endemic", T = 35)
)

target_dict = Dict(
"log_I_t" => (title = "log(Incidence)", ylims = (3.5, 6), ord = 1),
"rt" => (title = "Exp. growth rate", ylims = (-0.1, 0.1), ord = 2),
"Rt" => (title = "Reproductive number", ylims = (-0.1, 3), ord = 3)
)

latent_model_dict = Dict(
"wkly_rw" => (title = "Random walk",),
"wkly_ar" => (title = "AR(1)",),
"wkly_diff_ar" => (title = "Diff. AR(1)",)
)

##

fig = figuretwo(
truth_df, analysis_df, "Renewal", scenario_dict, target_dict)
_ = map(analysis_df.IGP_Model |> unique) do igp
fig = figureone(
truth_df, analysis_df, latent_model, scenario_dict, target_dict, latent_model_dict)
save(plotsdir("figure2_$(igp).png"), fig)
end
2 changes: 1 addition & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export score_parameters
export make_prediction_dataframe_from_output, make_truthdata_dataframe

# Exported functions: Make main plots
export figureone
export figureone, figuretwo

# Exported functions: plot functions
export plot_truth_data, plot_Rt
Expand Down
52 changes: 52 additions & 0 deletions pipeline/src/mainplots/df_checking.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Internal method to check if the required columns are present in the truth dataframe.
# Arguments
- `truth_df`: The truth dataframe to be checked.
"""
function _truth_dataframe_checks(truth_df)
@assert "True_GI_Mean" names(truth_df) "True_GI_Mean col not in truth data"
@assert "Scenario" names(truth_df) "Scenario col not in truth data"
@assert "target_times" names(truth_df) "target_times col not in truth data"
@assert "target_values" names(truth_df) "target_values col not in truth data"
end

"""
Internal method to perform checks on the analysis dataframe to ensure that it contains the required columns.
# Arguments
- `analysis_df`: The analysis dataframe to be checked.
# Raises
- `AssertionError`: If any of the required columns are missing in the analysis dataframe.
"""
function _analysis_dataframe_checks(analysis_df)
@assert "True_GI_Mean" names(analysis_df) "True_GI_Mean col not in analysis data"
@assert "Used_GI_Mean" names(analysis_df) "Used_GI_Mean col not in analysis data"
@assert "Reference_Time" names(analysis_df) "Reference_Time col not in analysis data"
@assert "Scenario" names(analysis_df) "Scenario col not in analysis data"
@assert "IGP_Model" names(analysis_df) "IGP_Model col not in analysis data"
@assert "Latent_Model" names(analysis_df) "Latent_Model col not in analysis data"
@assert "target_times" names(analysis_df) "target_times col not in analysis data"
end

"""
Internal method to perform checks on the truth and analysis dataframes for Figure One.
# Arguments
- `truth_df::DataFrame`: The truth dataframe.
- `analysis_df::DataFrame`: The analysis dataframe.
- `scenario_dict::Dict{String, Any}`: A dictionary containing scenario information.
# Raises
- `AssertionError`: If the scenarios in the truth and analysis dataframes do not match, or if the scenarios in the truth dataframe do not match the keys in the scenario dictionary.
"""
function _dataframe_checks(truth_df, analysis_df, scenario_dict)
@assert issetequal(unique(truth_df.Scenario), unique(analysis_df.Scenario)) "Truth and analysis data scenarios do not match"
@assert issetequal(unique(truth_df.Scenario), keys(scenario_dict)) "Truth and analysis data True_GI_Mean do not match"
_truth_dataframe_checks(truth_df)
_analysis_dataframe_checks(analysis_df)
end
65 changes: 7 additions & 58 deletions pipeline/src/mainplots/figureone.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,3 @@
"""
Internal method to check if the required columns are present in the truth dataframe.
# Arguments
- `truth_df`: The truth dataframe to be checked.
"""
function _figure_one_truth_dataframe_checks(truth_df)
@assert "True_GI_Mean" names(truth_df) "True_GI_Mean col not in truth data"
@assert "Scenario" names(truth_df) "Scenario col not in truth data"
@assert "target_times" names(truth_df) "target_times col not in truth data"
@assert "target_values" names(truth_df) "target_values col not in truth data"
end

"""
Internal method to perform checks on the analysis dataframe to ensure that it contains the required columns.
# Arguments
- `analysis_df`: The analysis dataframe to be checked.
# Raises
- `AssertionError`: If any of the required columns are missing in the analysis dataframe.
"""
function _figure_one_analysis_dataframe_checks(analysis_df)
@assert "True_GI_Mean" names(analysis_df) "True_GI_Mean col not in analysis data"
@assert "Used_GI_Mean" names(analysis_df) "Used_GI_Mean col not in analysis data"
@assert "Reference_Time" names(analysis_df) "Reference_Time col not in analysis data"
@assert "Scenario" names(analysis_df) "Scenario col not in analysis data"
@assert "IGP_Model" names(analysis_df) "IGP_Model col not in analysis data"
@assert "Latent_Model" names(analysis_df) "Latent_Model col not in analysis data"
@assert "target_times" names(analysis_df) "target_times col not in analysis data"
end

"""
Internal method to perform checks on the truth and analysis dataframes for Figure One.
# Arguments
- `truth_df::DataFrame`: The truth dataframe.
- `analysis_df::DataFrame`: The analysis dataframe.
- `scenario_dict::Dict{String, Any}`: A dictionary containing scenario information.
# Raises
- `AssertionError`: If the scenarios in the truth and analysis dataframes do not match, or if the scenarios in the truth dataframe do not match the keys in the scenario dictionary.
"""
function _figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict)
@assert issetequal(unique(truth_df.Scenario), unique(analysis_df.Scenario)) "Truth and analysis data scenarios do not match"
@assert issetequal(unique(truth_df.Scenario), keys(scenario_dict)) "Truth and analysis data True_GI_Mean do not match"
_figure_one_truth_dataframe_checks(truth_df)
_figure_one_analysis_dataframe_checks(analysis_df)
end

"""
Internal method for creating a figure of model inference for a specific scenario
using the given analysis data.
Expand Down Expand Up @@ -98,13 +45,13 @@ Internal method that generates a plot of the truth data for a specific scenario.
- `plt_truth`: The plot of the truth data.
"""
function _figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice)
function _figure_scenario_truth_data(truth_df, scenario; true_gi_choice)
truth_plotting_data = truth_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :Scenario.==scenario) |> data
plt_truth = truth_plotting_data *
mapping(:target_times => "T", :target_values => "values",
col = :Target, color = :Latent_Model) *
col = :Target, color = :Latent_Model => "Latent Model") *
visual(Lines)
return plt_truth
end
Expand Down Expand Up @@ -132,13 +79,13 @@ function figureone_with_latent_model(
truth_df, analysis_df, scenario_dict; fig_kws = (; size = (1000, 2000)),
true_gi_choice = 10.0, used_gi_choice = 10.0, legend_title = "Process type")
# Perform checks on the dataframes
_figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict)
_dataframe_checks(truth_df, analysis_df, scenario_dict)
# Treat the truth data as a Latent model option
truth_df[!, "Latent_Model"] .= "Truth data"

scenarios = analysis_df.Scenario |> unique
plt_truth_vect = map(scenarios) do scenario
_figure_one_scenario_truth_data(truth_df, scenario; true_gi_choice)
_figure_scenario_truth_data(truth_df, scenario; true_gi_choice)
end
plt_analysis_vect = map(scenarios) do scenario
_figure_one_scenario(
Expand Down Expand Up @@ -296,7 +243,7 @@ function figureone(
scenarios = [
"measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"])
# Perform checks on the dataframes
EpiAwarePipeline._figure_one_dataframe_checks(truth_df, analysis_df, scenario_dict)
_dataframe_checks(truth_df, analysis_df, scenario_dict)
latent_models = analysis_df.Latent_Model |> unique
@assert latent_model in latent_models "The latent model is not in the analysis data"
@assert latent_model in keys(latent_model_dict) "The latent model is not in the latent_model_dict dictionary"
Expand Down Expand Up @@ -340,8 +287,10 @@ function figureone(
"Latent model\n for infection\n generating\n process:\n$(latent_model_dict[latent_model].title)",
fontsize = 18,
font = :bold)

_leg = (leg[1], leg[2], [legend_title])
Legend(fig[5, 2], _leg...)

resize_to_layout!(fig)
return fig
end
91 changes: 91 additions & 0 deletions pipeline/src/mainplots/figuretwo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
function _make_captions!(df, scenario_dict, target_dict)
scenario_titles = [scenario_dict[scenario].title for scenario in df.Scenario]
target_titles = [target_dict[target].title for target in df.Target]
df.Scenario_Target .= scenario_titles .* "\n" .* target_titles
return nothing
end

function _figure_two_truth_data(
truth_df, scenario_dict, target_dict; true_gi_choice, gi_choices = [
2.0, 10.0, 20.0])
_truth_df = mapreduce(vcat, gi_choices) do used_gi
df = deepcopy(truth_df)
df.Used_GI_Mean .= used_gi
df
end
_make_captions!(_truth_df, scenario_dict, target_dict)

truth_plotting_data = _truth_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @transform(df, :Data="Truth data") |> data
plt_truth = truth_plotting_data *
mapping(:target_times => "T", :target_values => "Process values",
row = :Scenario_Target,
col = :Used_GI_Mean => renamer([2.0 => "Underestimate GI",
10.0 => "Good GI", 20.0 => "Overestimate GI"]),
color = :Data => AlgebraOfGraphics.scale(:color2)) *
visual(AlgebraOfGraphics.Scatter)
return plt_truth
end

function _figure_two_scenario(
analysis_df, igp, scenario_dict, target_dict; true_gi_choice,
lower_sym = :q_025, upper_sym = :q_975)
min_ref_time = minimum(analysis_df.Reference_Time)
early_df = analysis_df |>
df -> @subset(df, :Reference_Time.==min_ref_time) |>
df -> @subset(df, :IGP_Model.==igp) |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :target_times.<=min_ref_time - 7)

seqn_df = analysis_df |>
df -> @subset(df, :True_GI_Mean.==true_gi_choice) |>
df -> @subset(df, :IGP_Model.==igp) |>
df -> @subset(df,
:Reference_Time .- :target_times.∈fill(0:6, size(df, 1)))

full_df = vcat(early_df, seqn_df)
_make_captions!(full_df, scenario_dict, target_dict)

model_plotting_data = full_df |> data

plt_model = model_plotting_data *
mapping(:target_times => "T", :q_5 => "Process values",
row = :Scenario_Target,
col = :Used_GI_Mean => renamer([2.0 => "Underestimate GI",
10.0 => "Good GI", 20.0 => "Overestimate GI"]),
color = :Latent_Model => "Latent models") *
mapping(lower = lower_sym, upper = upper_sym) *
visual(LinesFill)

return plt_model
end

function figuretwo(truth_df, analysis_df, igp, scenario_dict,
target_dict; fig_kws = (; size = (1000, 2800)),
true_gi_choice = 10.0, gi_choices = [2.0, 10.0, 20.0])

# Perform checks on the dataframes
_dataframe_checks(truth_df, analysis_df, scenario_dict)

f_td = _figure_two_truth_data(
truth_df, scenario_dict, target_dict; true_gi_choice, gi_choices)
f_mdl = _figure_two_scenario(
analysis_df, igp, scenario_dict, target_dict; true_gi_choice)

fg = draw(f_mdl + f_td; facet = (; linkyaxes = :none),
legend = (; orientation = :horizontal, position = :bottom),
figure = fig_kws,
axis = (; xlabel = "T", ylabel = "Process values"))
for g in fg.grid[1:3:end, :]
g.axis.limits = (nothing, target_dict["rt"].ylims)
end
for g in fg.grid[2:3:end, :]
g.axis.limits = (nothing, target_dict["Rt"].ylims)
end
for g in fg.grid[3:3:end, :]
g.axis.limits = (nothing, target_dict["log_I_t"].ylims)
end

return fg
end
2 changes: 2 additions & 0 deletions pipeline/src/mainplots/mainplots.jl
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
include("df_checking.jl")
include("figureone.jl")
include("figuretwo.jl")

0 comments on commit af5291a

Please sign in to comment.