Skip to content

Commit

Permalink
Merge pull request #107 from TARGENE/ordered_factor
Browse files Browse the repository at this point in the history
Ordered factor
  • Loading branch information
olivierlabayle authored Apr 1, 2024
2 parents 7f7b0a2 + ddc1d34 commit 3d27968
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TMLE"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
authors = ["Olivier Labayle"]
version = "0.15.0"
version = "0.16.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand Down
3 changes: 1 addition & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as
3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE).

```@example quick-start
models = (Y=with_encoder(LinearRegressor()), T = LogisticClassifier())
tmle = TMLEE(models=models)
tmle = TMLEE()
result, _ = tmle(Ψ, dataset, verbosity=0);
result
```
Expand Down
13 changes: 4 additions & 9 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,7 @@ Drawing from the example dataset and `SCM` from the Walk Through section, we can
treatment_confounders=(T₁=[:W₁₁, :W₁₂],),
outcome_extra_covariates=[:C]
)
models = (
Y=with_encoder(LinearRegressor()),
T₁=LogisticClassifier(),
T₂=LogisticClassifier(),
)
tmle = TMLEE(models=models)
tmle = TMLEE()
result₁, cache = tmle(Ψ₁, dataset);
result₁
nothing # hide
Expand Down Expand Up @@ -106,7 +101,7 @@ We could now get an interest in the Average Treatment Effect of `T₂` that we w
treatment_confounders=(T₂=[:W₂₁, :W₂₂],),
outcome_extra_covariates=[:C]
)
ose = OSE(models=models)
ose = OSE()
result₂, cache = ose(Ψ₂, dataset;cache=cache);
result₂
nothing # hide
Expand All @@ -121,14 +116,14 @@ Both TMLE and OSE can be used with sample-splitting, which, for an additional co
To leverage sample-splitting, simply specify a `resampling` strategy when building an estimator:

```@example estimation
cvtmle = TMLEE(models=models, resampling=CV())
cvtmle = TMLEE(resampling=CV())
cvresult₁, _ = cvtmle(Ψ₁, dataset);
```

Similarly, one could build CV-OSE:

```julia
cvose = OSE(models=models, resampling=CV(nfolds=3))
cvose = OSE(resampling=CV(nfolds=3))
```

## Caching model fits
Expand Down
19 changes: 0 additions & 19 deletions docs/src/user_guide/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,6 @@ The adjustment set consists of all the treatment variable's parents. Additional
BackdoorAdjustment(;outcome_extra_covariates=[:C])
```

## Treatment Transformer

To account for the fact that treatment variables are categorical variables we provide a MLJ compliant transformer that will either:

- Retrieve the floating point representation of a treatment if it has a natural ordering
- One hot encode it otherwise

Such transformer can be created with:

```julia
TreatmentTransformer(;encoder=encoder())
```

where `encoder` is a [OneHotEncoder](https://alan-turing-institute.github.io/MLJ.jl/dev/models/OneHotEncoder_MLJModels/#OneHotEncoder_MLJModels).

The `with_encoder(model; encoder=TreatmentTransformer())` provides a shorthand to combine a `TreatmentTransformer` with another MLJ model in a pipeline.

Of course you are also free to define your own strategy!

## Serialization

Many objects from TMLE.jl can be serialized to various file formats. This is achieved by transforming these structures to dictionaries that can then be serialized to classic JSON or YAML format. For that purpose you can use the `TMLE.read_json`, `TMLE.write_json`, `TMLE.read_yaml` and `TMLE.write_yaml` functions.
9 changes: 2 additions & 7 deletions docs/src/walk_through.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,7 @@ Alternatively, you can also directly define the statistical parameters (see [Est
Then each parameter can be estimated by building an estimator (which is simply a function) and evaluating it on data. For illustration, we will keep the models simple. We define a Targeted Maximum Likelihood Estimator:

```@example walk-through
models = (
Y = with_encoder(LinearRegressor()),
T₁ = LogisticClassifier(),
T₂ = LogisticClassifier()
)
tmle = TMLEE(models=models)
tmle = TMLEE()
```

Because we haven't identified the `cm` causal estimand yet, we need to provide the `scm` as well to the estimator:
Expand All @@ -153,7 +148,7 @@ result
Statistical Estimands can be estimated without a ``SCM``, let's use the One-Step estimator:

```@example walk-through
ose = OSE(models=models)
ose = OSE()
result, cache = ose(statistical_iate, dataset)
result
```
Expand Down
5 changes: 4 additions & 1 deletion examples/double_robustness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ function tmle_inference(data)
treatment_values=(Tcat=(case=1.0, control=0.0),),
treatment_confounders=(Tcat=[:W],)
)
models = (Y=with_encoder(LinearRegressor()), Tcat=LinearBinaryClassifier())
models = (
Y = with_encoder(LinearRegressor()),
Tcat = with_encoder(LinearBinaryClassifier())
)
tmle = TMLEE(models=models)
result, _ = tmle(Ψ, data; verbosity=0)
lb, ub = confint(OneSampleTTest(result))
Expand Down
2 changes: 1 addition & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export ComposedEstimand
export var, estimate, pvalue, confint, emptyIC
export significance_test, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test
export compose
export TreatmentTransformer, with_encoder, encoder
export default_models, TreatmentTransformer, with_encoder, encoder
export BackdoorAdjustment, identify
export last_fluctuation_epsilon
export Configuration
Expand Down
2 changes: 1 addition & 1 deletion src/estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function likelihood(estimate::ConditionalDistributionEstimate, dataset)
return pdf.(ŷ, y)
end

function compute_offset::UnivariateFiniteVector{Multiclass{2}})
function compute_offset::UnivariateFiniteVector{<:Union{OrderedFactor{2}, Multiclass{2}}})
μy = expected_value(ŷ)
logit!(μy)
return μy
Expand Down
7 changes: 5 additions & 2 deletions src/treatment_transformer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ Treatments in TMLE are represented by `CategoricalArrays`. If a treatment column
has type `OrderedFactor`, then its integer representation is used, make sure that
the levels correspond to your expectations. All other columns are one-hot encoded.
"""
TreatmentTransformer(;encoder=encoder()) = TreatmentTransformer(encoder)
function TreatmentTransformer(;encoder=encoder())
Base.depwarn("The TreatmentTransformer is deprecated and will be removed in future version, it has been replaced by MLJModels.ContinuousEncoder.", :TreatmentTransformer, force=true)
TreatmentTransformer(encoder)
end

MLJBase.fit(model::TreatmentTransformer, verbosity::Int, X) =
MLJBase.fit(model.encoder, verbosity, X)
Expand Down Expand Up @@ -43,4 +46,4 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew)
return merge(Tables.columntable(Xt), ordered_factors)
end

with_encoder(model; encoder=encoder()) = Pipeline(TreatmentTransformer(;encoder=encoder), model)
with_encoder(model; encoder=ContinuousEncoder(drop_last=true, one_hot_ordered_factors = false)) = Pipeline(encoder, model)
34 changes: 29 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function indicator_values(indicators, T)
return indic
end

expected_value::UnivariateFiniteVector{Multiclass{2}}) = pdf.(ŷ, levels(first(ŷ))[2])
expected_value::UnivariateFiniteVector{<:Union{OrderedFactor{2}, Multiclass{2}}}) = pdf.(ŷ, levels(first(ŷ))[2])
expected_value::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ)
expected_value::AbstractVector{<:Real}) = ŷ

Expand All @@ -68,10 +68,34 @@ function last_fluctuation_epsilon(cache)
return fp.coef
end

default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier(), encoder=encoder()) = (
Q_binary_default = with_encoder(Q_binary, encoder=encoder),
Q_continuous_default = with_encoder(Q_continuous, encoder=encoder),
G_default = G
"""
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier()) = (
Create a NamedTuple containing default models to be used by downstream estimators.
Each provided model is prepended (in a `MLJ.Pipeline`) with an `MLJ.ContinuousEncoder`.
By default:
- Q_binary is a LinearBinaryClassifier
- Q_continuous is a LinearRegressor
- G is a LinearBinaryClassifier
# Example
The following changes the default `Q_binary` to a `LogisticClassifier` and provides a `RidgeRegressor` for `special_y`.
```julia
using MLJLinearModels
models = (
special_y = RidgeRegressor(),
default_models(Q_binary=LogisticClassifier())...
)
```
"""
default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier()) = (
Q_binary_default = with_encoder(Q_binary),
Q_continuous_default = with_encoder(Q_continuous),
G_default = with_encoder(G)
)

is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])
Expand Down
29 changes: 15 additions & 14 deletions test/counterfactual_mean_based/double_robustness_ate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ function binary_outcome_binary_treatment_pb(;n=100)
# Convert to dataframe to respect the Tables.jl
# and convert types
W = convert(Array{Float64}, w)
T = categorical(t)
Y = categorical(y)
T = t
Y = y
dataset = (T=T, W=W, Y=Y)
dataset = coerce(dataset, autotype(dataset))
# Compute the theoretical ATE
ATE₁ = py_given_aw(1, 1)*p_w() + (1-p_w())*py_given_aw(1, 0)
ATE₀ = py_given_aw(0, 1)*p_w() + (1-p_w())*py_given_aw(0, 0)
Expand Down Expand Up @@ -114,7 +115,7 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T = LogisticClassifier(lambda=0)
T = with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -133,7 +134,7 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(TreatmentTransformer() |> LinearRegressor()),
T = ConstantClassifier()
T = with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -150,7 +151,7 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(ConstantClassifier()),
T = LogisticClassifier(lambda=0)
T = with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -163,7 +164,7 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LogisticClassifier(lambda=0)),
T = ConstantClassifier()
T = with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -181,7 +182,7 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T = LogisticClassifier(lambda=0)
T = with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -196,7 +197,7 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LinearRegressor()),
T = ConstantClassifier()
T = with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -220,8 +221,8 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T₁ = LogisticClassifier(lambda=0),
T₂ = LogisticClassifier(lambda=0)
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0)
Expand All @@ -230,8 +231,8 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LinearRegressor()),
T₁ = ConstantClassifier(),
T₂ = ConstantClassifier()
T₁ = with_encoder(ConstantClassifier()),
T₂ = with_encoder(ConstantClassifier())
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0)
Expand All @@ -257,8 +258,8 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T₁ = LogisticClassifier(lambda=0),
T₂ = LogisticClassifier(lambda=0),
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0)),
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₀, dataset; verbosity=0)
Expand Down
30 changes: 14 additions & 16 deletions test/counterfactual_mean_based/double_robustness_iate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ function binary_outcome_binary_treatment_pb(;n=100)

# Respect the Tables.jl interface and convert types
W = float(W)
T₁ = categorical(T₁)
T₂ = categorical(T₂)
Y = categorical(y)
dataset = (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], Y=Y)
dataset = (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], Y=y)
dataset = coerce(dataset, autotype(dataset))
# Compute the theoretical IATE
Wcomb = [1 1 1;
1 1 0;
Expand Down Expand Up @@ -181,8 +179,8 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(ConstantClassifier()),
T₁ = LogisticClassifier(lambda=0),
T₂ = LogisticClassifier(lambda=0),
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0)),
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -196,8 +194,8 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(LogisticClassifier(lambda=0)),
T₁ = ConstantClassifier(),
T₂ = ConstantClassifier(),
T₁ = with_encoder(ConstantClassifier()),
T₂ = with_encoder(ConstantClassifier()),
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand Down Expand Up @@ -225,8 +223,8 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(MLJModels.DeterministicConstantRegressor()),
T₁ = LogisticClassifier(lambda=0),
T₂ = LogisticClassifier(lambda=0),
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0)),
)

dr_estimators = double_robust_estimators(models)
Expand All @@ -241,8 +239,8 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(cont_interacter),
T₁ = ConstantClassifier(),
T₂ = ConstantClassifier(),
T₁ = with_encoder(ConstantClassifier()),
T₂ = with_encoder(ConstantClassifier()),
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -266,8 +264,8 @@ end
# When Q is misspecified but G is well specified
models = (
Y = with_encoder(ConstantClassifier()),
T₁ = LogisticClassifier(lambda=0),
T₂ = LogisticClassifier(lambda=0)
T₁ = with_encoder(LogisticClassifier(lambda=0)),
T₂ = with_encoder(LogisticClassifier(lambda=0))
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand All @@ -282,8 +280,8 @@ end
# When Q is well specified but G is misspecified
models = (
Y = with_encoder(cat_interacter),
T₁ = ConstantClassifier(),
T₂ = ConstantClassifier(),
T₁ = with_encoder(ConstantClassifier()),
T₂ = with_encoder(ConstantClassifier()),
)
dr_estimators = double_robust_estimators(models)
results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0)
Expand Down
Loading

2 comments on commit 3d27968

@olivierlabayle
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Support for OrderedFactor{2}, Multiclass should be deprecated in the future
  • Replace TreatmentTransformer (deprecated) with MLJModels.ContinuousEncoder

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103989

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.16.0 -m "<description of version>" 3d279684717f5eba0226c341903a02e17b5c311f
git push origin v0.16.0

Please sign in to comment.