Skip to content

Commit

Permalink
Merge pull request #99 from TARGENE/generate_iates
Browse files Browse the repository at this point in the history
Generate iates
  • Loading branch information
olivierlabayle authored Dec 18, 2023
2 parents 13fa9a8 + cb4b1ca commit 624f76b
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 38 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.12.2"
version = "0.13.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand All @@ -21,6 +21,7 @@ Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Expand Down Expand Up @@ -55,6 +56,7 @@ TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Zygote = "0.6"
SplitApplyCombine = "1.2.2"
julia = "1.6, 1.7, 1"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import AbstractDifferentiation as AD
using Graphs
using MetaGraphsNext
using Combinatorics
using SplitApplyCombine

# #############################################################################
# EXPORTS
Expand All @@ -28,7 +29,7 @@ using Combinatorics
export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices
export CM, ATE, IATE
export AVAILABLE_ESTIMANDS
export generateATEs
export generateATEs, generateIATEs
export TMLEE, OSE, NAIVE
export ComposedEstimand
export var, estimate, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test,pvalue, confint, emptyIC
Expand Down
220 changes: 195 additions & 25 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,37 +231,207 @@ unique_non_missing(dataset, colname) = unique(skipmissing(Tables.getcolumn(datas

unique_treatment_values(dataset, colnames) =(;(colname => unique_non_missing(dataset, colname) for colname in colnames)...)

get_treatments_contrasts(treatments_unique_values) = [collect(Combinatorics.combinations(treatments_unique_values[T], 2)) for T in keys(treatments_unique_values)]

function generateComposedEstimandFromContrasts(
constructor,
treatments_levels::NamedTuple{names},
outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
treatments_contrasts = get_treatments_contrasts(treatments_levels)
components = []
for combo Iterators.product(treatments_contrasts...)
treatments_contrast = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]
Ψ = constructor(
outcome=outcome,
treatment_values=NamedTuple{names}(treatments_contrast),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
if satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint)
push!(components, Ψ)
end
end
return ComposedEstimand(joint_estimand, Tuple(components))
end

GENERATE_DOCSTRING = """
The components of this estimand are generated from the treatment variables contrasts.
For example, consider two treatment variables T₁ and T₂ each taking three possible values (0, 1, 2).
For each treatment variable, the marginal contrasts are defined by (0 → 1, 1 → 2, 0 → 2), there are thus
3 x 3 = 9 joint contrasts to be generated:
- (T₁: 0 → 1, T₂: 0 → 1)
- (T₁: 0 → 1, T₂: 1 → 2)
- (T₁: 0 → 1, T₂: 0 → 2)
- (T₁: 1 → 2, T₂: 0 → 1)
- (T₁: 1 → 2, T₂: 1 → 2)
- (T₁: 1 → 2, T₂: 0 → 2)
- (T₁: 0 → 2, T₂: 0 → 1)
- (T₁: 0 → 2, T₂: 1 → 2)
- (T₁: 0 → 2, T₂: 0 → 2)
# Return
A `ComposedEstimand` with causal or statistical components.
# Args
- `treatments_levels`: A NamedTuple providing the unique levels each treatment variable can take.
- `outcome`: The outcome variable.
- `confounders=nothing`: The generated components will inherit these confounding variables.
If `nothing`, causal estimands are generated.
- `outcome_extra_covariates=()`: The generated components will inherit these `outcome_extra_covariates`.
- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`
"""

"""
generateATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
Generate a `ComposedEstimand` of ATEs from the `treatments_levels`. $GENERATE_DOCSTRING
# Example:
To generate a causal composed estimand with 3 components:
```@example
generateATEs((T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
```
To generate a statistical composed estimand with 9 components:
```@example
generateATEs((T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
```
"""
function generateATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
return generateComposedEstimandFromContrasts(
ATE,
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

"""
generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
generateATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
Find all unique values for each treatment variable in the dataset and generate all possible ATEs from these values.
"""
function generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
treatments_unique_values = unique_treatment_values(dataset, treatments)
return generateATEs(treatments_unique_values, outcome; confounders=confounders, outcome_extra_covariates=outcome_extra_covariates)
function generateATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
treatments_levels = unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing
return generateATEs(
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

"""
generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())
generateIATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
Generates a `ComposedEstimand` of Average Interation Effects from `treatments_levels`. $GENERATE_DOCSTRING
# Example:
To generate a causal composed estimand with 3 components:
Generate all possible ATEs from the `treatments_unique_values`.
```@example
generateIATEs((T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
```
To generate a statistical composed estimand with 9 components:
```@example
generateIATEs((T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
```
"""
function generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())
treatments = Tuple(Symbol.(keys(treatments_unique_values)))
treatments_control_case = [collect(Combinatorics.combinations(treatments_unique_values[T], 2)) for T in treatments]

ATEs = []
for combo Iterators.product(treatments_control_case...)
treatments_control_case = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]
push!(
ATEs,
ATE(
outcome=outcome,
treatment_values=NamedTuple{treatments}(treatments_control_case),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
)
end
return ATEs
end
function generateIATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
return generateComposedEstimandFromContrasts(
IATE,
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

"""
generateIATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
Finds treatments levels from the dataset and generates a `ComposedEstimand` of Average Interation Effects from them
(see [`generateIATEs(treatments_levels, outcome; confounders=nothing, outcome_extra_covariates=())`](@ref)).
"""
function generateIATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
treatments_levels = unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing
return generateIATEs(
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end

joint_levels::StatisticalIATE) = Iterators.product(values.treatment_values)...)

joint_levels::StatisticalATE) =
(Tuple.treatment_values[T][c] for T keys.treatment_values)) for c in (:case, :control))

joint_levels::StatisticalCM) = (values.treatment_values),)
22 changes: 21 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,24 @@ default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(
G_default = G
)

is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])
is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])

function satisfies_positivity(Ψ, freq_table; positivity_constraint=0.01)
for jointlevel in joint_levels(Ψ)
if !haskey(freq_table, jointlevel) || freq_table[jointlevel] < positivity_constraint
return false
end
end
return true
end

satisfies_positivity(Ψ, freq_table::Nothing; positivity_constraint=nothing) = true

function frequency_table(dataset, colnames)
iterator = zip((Tables.getcolumn(dataset, colname) for colname in sort(collect(colnames)))...)
counts = groupcount(x -> x, iterator)
for key in keys(counts)
counts[key] /= nrows(dataset)
end
return counts
end
Loading

4 comments on commit 624f76b

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97357

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.0 -m "<description of version>" 624f76b8fa939f555045e22f7bd4e142403bbb73
git push origin v0.13.0

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

New features

  • generateIATEs generates interaction estimands

Breaking changes

  • generateATEs now returns a ComposedEstimand

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/97357

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.0 -m "<description of version>" 624f76b8fa939f555045e22f7bd4e142403bbb73
git push origin v0.13.0

Please sign in to comment.