From 35b07c455ff52471feecc2f1908144b72acba899 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Thu, 6 Jun 2024 20:01:43 +1200 Subject: [PATCH] Add preliminary NN support (#9) --- Project.toml | 3 + ext/OmeletteLuxExt.jl | 31 ++++++ src/models/LinearRegression.jl | 26 +++-- src/models/LogisticRegression.jl | 2 - src/models/Pipeline.jl | 58 +++++++++++ src/models/ReLU.jl | 173 +++++++++++++++++++++++++++++++ test/Project.toml | 14 +++ test/test_Lux.jl | 89 ++++++++++++++++ test/test_ReLU.jl | 73 +++++++++++++ 9 files changed, 458 insertions(+), 11 deletions(-) create mode 100644 ext/OmeletteLuxExt.jl create mode 100644 src/models/Pipeline.jl create mode 100644 src/models/ReLU.jl create mode 100644 test/test_Lux.jl create mode 100644 test/test_ReLU.jl diff --git a/Project.toml b/Project.toml index 8d85cd8..0ed3371 100644 --- a/Project.toml +++ b/Project.toml @@ -10,13 +10,16 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" [weakdeps] GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" [extensions] OmeletteGLMExt = "GLM" +OmeletteLuxExt = "Lux" [compat] Distributions = "0.25" GLM = "1.9" +Lux = "0.5" JuMP = "1" MathOptInterface = "1" julia = "1.9" diff --git a/ext/OmeletteLuxExt.jl b/ext/OmeletteLuxExt.jl new file mode 100644 index 0000000..6f544f7 --- /dev/null +++ b/ext/OmeletteLuxExt.jl @@ -0,0 +1,31 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module OmeletteLuxExt + +import Omelette +import Lux + +function _add_predictor(predictor::Omelette.Pipeline, layer::Lux.Dense, p) + push!(predictor.layers, Omelette.LinearRegression(p.weight, vec(p.bias))) + if layer.activation === identity + # Do nothing + elseif layer.activation === Lux.NNlib.relu + push!(predictor.layers, Omelette.ReLUBigM(1e4)) + else + error("Unsupported activation function: $x") + end + return +end + +function Omelette.Pipeline(x::Lux.Experimental.TrainState) + predictor = Omelette.Pipeline(Omelette.AbstractPredictor[]) + for (layer, parameter) in zip(x.model.layers, x.parameters) + _add_predictor(predictor, layer, parameter) + end + return predictor +end + +end #module diff --git a/src/models/LinearRegression.jl b/src/models/LinearRegression.jl index a96ff06..8a86179 100644 --- a/src/models/LinearRegression.jl +++ b/src/models/LinearRegression.jl @@ -4,13 +4,16 @@ # in the LICENSE.md file or at https://opensource.org/licenses/MIT. """ - LinearRegression(parameters::Matrix) + LinearRegression( + A::Matrix{Float64}, + b::Vector{Float64} = zeros(size(A, 1)), + ) Represents the linear relationship: ```math -f(x) = A x +f(x) = A x + b ``` -where \$A\$ is the \$m \\times n\$ matrix `parameters`. +where \$A\$ is the \$m \\times n\$ matrix `A`. ## Example @@ -22,7 +25,7 @@ julia> model = Model(); julia> @variable(model, x[1:2]); julia> f = Omelette.LinearRegression([2.0, 3.0]) -Omelette.LinearRegression([2.0 3.0]) +Omelette.LinearRegression([2.0 3.0], [0.0]) julia> y = Omelette.add_predictor(model, f, x) 1-element Vector{VariableRef}: @@ -35,11 +38,16 @@ julia> print(model) ``` """ struct LinearRegression <: AbstractPredictor - parameters::Matrix{Float64} + A::Matrix{Float64} + b::Vector{Float64} end -function LinearRegression(parameters::Vector{Float64}) - return LinearRegression(reshape(parameters, 1, length(parameters))) +function LinearRegression(A::Matrix{Float64}) + return LinearRegression(A, zeros(size(A, 1))) +end + +function LinearRegression(A::Vector{Float64}) + return LinearRegression(reshape(A, 1, length(A)), [0.0]) end function add_predictor( @@ -47,8 +55,8 @@ function add_predictor( predictor::LinearRegression, x::Vector{JuMP.VariableRef}, ) - m = size(predictor.parameters, 1) + m = size(predictor.A, 1) y = JuMP.@variable(model, [1:m], base_name = "omelette_y") - JuMP.@constraint(model, predictor.parameters * x .== y) + JuMP.@constraint(model, predictor.A * x .+ predictor.b .== y) return y end diff --git a/src/models/LogisticRegression.jl b/src/models/LogisticRegression.jl index 9e8d86f..b8a53d3 100644 --- a/src/models/LogisticRegression.jl +++ b/src/models/LogisticRegression.jl @@ -42,8 +42,6 @@ function LogisticRegression(parameters::Vector{Float64}) return LogisticRegression(reshape(parameters, 1, length(parameters))) end -Base.size(f::LogisticRegression) = size(f.parameters) - function add_predictor( model::JuMP.Model, predictor::LogisticRegression, diff --git a/src/models/Pipeline.jl b/src/models/Pipeline.jl new file mode 100644 index 0000000..a9015dd --- /dev/null +++ b/src/models/Pipeline.jl @@ -0,0 +1,58 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +""" + Pipeline(layers::Vector{AbstractPredictor}) + +A pipeline of nested layers +```math +f(x) = l_N(\\ldots(l_2(l_1(x)) +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.Pipeline( + Omelette.LinearRegression([1.0 2.0], [0.0]), + Omelette.ReLUQuadratic(), + ) +Omelette.Pipeline(Omelette.AbstractPredictor[Omelette.LinearRegression([1.0 2.0], [0.0]), Omelette.ReLUQuadratic()]) + +julia> y = Omelette.add_predictor(model, f, x) +1-element Vector{VariableRef}: + omelette_y[1] + +julia> print(model) +Feasibility +Subject to + x[1] + 2 x[2] - omelette_y[1] = 0 + omelette_y[1] - omelette_y[1] + _z[1] = 0 + omelette_y[1]*_z[1] = 0 + omelette_y[1] ≥ 0 + _z[1] ≥ 0 +``` +""" +struct Pipeline <: AbstractPredictor + layers::Vector{AbstractPredictor} +end + +Pipeline(args::AbstractPredictor...) = Pipeline(collect(args)) + +function add_predictor( + model::JuMP.Model, + predictor::Pipeline, + x::Vector{JuMP.VariableRef}, +) + for layer in predictor.layers + x = add_predictor(model, layer, x) + end + return x +end diff --git a/src/models/ReLU.jl b/src/models/ReLU.jl new file mode 100644 index 0000000..6c642fb --- /dev/null +++ b/src/models/ReLU.jl @@ -0,0 +1,173 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +""" + ReLUBigM(M::Float64) + +Represents the rectified linear unit relationship: +```math +f(x) = max.(0, x) +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.ReLUBigM(100.0) +Omelette.ReLUBigM(100.0) + +julia> y = Omelette.add_predictor(model, f, x) + +julia> print(model) +Feasibility +Subject to + -x[1] + omelette_y[1] ≥ 0 + -x[2] + omelette_y[2] ≥ 0 + omelette_y[1] - 100 _[5] ≤ 0 + omelette_y[2] - 100 _[6] ≤ 0 + -x[1] + omelette_y[1] + 100 _[5] ≤ 100 + -x[2] + omelette_y[2] + 100 _[6] ≤ 100 + omelette_y[1] ≥ 0 + omelette_y[2] ≥ 0 + _[5] binary + _[6] binary +``` +""" +struct ReLUBigM <: AbstractPredictor + M::Float64 +end + +function add_predictor( + model::JuMP.Model, + predictor::ReLUBigM, + x::Vector{JuMP.VariableRef}, +) + m = length(x) + y = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "omelette_y") + z = JuMP.@variable(model, [1:m], Bin) + JuMP.@constraint(model, y .>= x) + JuMP.@constraint(model, y .<= predictor.M * z) + JuMP.@constraint(model, y .<= x .+ predictor.M * (1 .- z)) + return y +end + +""" + ReLUSOS1() + +Implements the ReLU constraint \$y = max(0, x)\$ by the reformulation: +```math +\\begin{aligned} +x = y - z \\\\ +[y, z] \\in SOS1 \\\\ +y, z \\ge 0 +\\end{aligned} +``` + +## Example + +```jldoctest +julia> using JuMP, Omelette + +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.ReLUSOS1() +Omelette.ReLUSOS1() + +julia> y = Omelette.add_predictor(model, f, x) +2-element Vector{VariableRef}: + omelette_y[1] + omelette_y[2] + +julia> print(model) +Feasibility +Subject to + x[1] - omelette_y[1] + _z[1] = 0 + x[2] - omelette_y[2] + _z[2] = 0 + [omelette_y[1], _z[1]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0]) + [omelette_y[2], _z[2]] ∈ MathOptInterface.SOS1{Float64}([1.0, 2.0]) + omelette_y[1] ≥ 0 + omelette_y[2] ≥ 0 + _z[1] ≥ 0 + _z[2] ≥ 0 +``` +""" +struct ReLUSOS1 <: AbstractPredictor end + +function add_predictor( + model::JuMP.Model, + predictor::ReLUSOS1, + x::Vector{JuMP.VariableRef}, +) + m = length(x) + y = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "omelette_y") + z = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "_z") + JuMP.@constraint(model, x .== y - z) + for i in 1:m + JuMP.@constraint(model, [y[i], z[i]] in MOI.SOS1([1.0, 2.0])) + end + return y +end + +""" + ReLUQuadratic() + +Implements the ReLU constraint \$y = max(0, x)\$ by the reformulation: +```math +\\begin{aligned} +x = y - z \\\\ +y \\times z = 0 \\\\ +y, z \\ge 0 +\\end{aligned} +``` + +## Example + +```jldoctest +julia> model = Model(); + +julia> @variable(model, x[1:2]); + +julia> f = Omelette.ReLUQuadratic() +Omelette.ReLUQuadratic() + +julia> y = Omelette.add_predictor(model, f, x) +2-element Vector{VariableRef}: + omelette_y[1] + omelette_y[2] + +julia> print(model) +Feasibility +Subject to + x[1] - omelette_y[1] + _z[1] = 0 + x[2] - omelette_y[2] + _z[2] = 0 + omelette_y[1]*_z[1] = 0 + omelette_y[2]*_z[2] = 0 + omelette_y[1] ≥ 0 + omelette_y[2] ≥ 0 + _z[1] ≥ 0 + _z[2] ≥ 0 +``` +""" +struct ReLUQuadratic <: AbstractPredictor end + +function add_predictor( + model::JuMP.Model, + predictor::ReLUQuadratic, + x::Vector{JuMP.VariableRef}, +) + m = length(x) + y = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "omelette_y") + z = JuMP.@variable(model, [1:m], lower_bound = 0, base_name = "_z") + JuMP.@constraint(model, x .== y - z) + JuMP.@constraint(model, y .* z .== 0) + return y +end diff --git a/test/Project.toml b/test/Project.toml index 137ee72..a48b52d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,15 +1,29 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Omelette = "e52c2cb8-508e-4e12-9dd2-9c4755b60e73" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "1" GLM = "1" HiGHS = "1" Ipopt = "1" JuMP = "1" +Lux = "0.5" +Optimisers = "0.3" +Printf = "<0.0.1, 1.10" +Random = "<0.0.1, 1.10" +Statistics = "<0.0.1, 1.10" Test = "<0.0.1, 1.6" +Zygote = "0.6" julia = "1.9" diff --git a/test/test_Lux.jl b/test/test_Lux.jl new file mode 100644 index 0000000..cdcc138 --- /dev/null +++ b/test/test_Lux.jl @@ -0,0 +1,89 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module LuxTests + +using JuMP +using Test + +import ADTypes +import HiGHS +import Lux +import Omelette +import Optimisers +import Random +import Statistics +import Zygote + +is_test(x) = startswith(string(x), "test_") + +function runtests() + @testset "$name" for name in filter(is_test, names(@__MODULE__; all = true)) + getfield(@__MODULE__, name)() + end + return +end + +function generate_data(rng::Random.AbstractRNG, n = 128) + x = range(-2.0, 2.0, n) + y = -2 .* x .+ x .^ 2 .+ 0.1 .* randn(rng, n) + return reshape(collect(x), (1, n)), reshape(y, (1, n)) +end + +function loss_function_mse(model, ps, state, (input, output)) + y_pred, updated_state = Lux.apply(model, input, ps, state) + loss = Statistics.mean(abs2, y_pred .- output) + return loss, updated_state, () +end + +function train_cpu( + model, + input, + output; + loss_function::Function = loss_function_mse, + vjp = ADTypes.AutoZygote(), + rng, + optimizer, + epochs::Int, +) + state = Lux.Experimental.TrainState(rng, model, optimizer) + data = (input, output) .|> Lux.cpu_device() + for epoch in 1:epochs + grads, loss, stats, state = + Lux.Experimental.compute_gradients(vjp, loss_function, data, state) + state = Lux.Experimental.apply_gradients!(state, grads) + end + return state +end + +function test_end_to_end() + rng = Random.MersenneTwister() + Random.seed!(rng, 12345) + x, y = generate_data(rng) + model = Lux.Chain(Lux.Dense(1 => 16, Lux.relu), Lux.Dense(16 => 1)) + state = train_cpu( + model, + x, + y; + rng = rng, + optimizer = Optimisers.Adam(0.03f0), + epochs = 250, + ) + f = Omelette.Pipeline(state) + model = Model(HiGHS.Optimizer) + set_silent(model) + @variable(model, x) + y = Omelette.add_predictor(model, f, [x]) + @constraint(model, only(y) <= 4) + @objective(model, Min, x) + optimize!(model) + @assert is_solved_and_feasible(model) + @test isapprox(value(x), -1.24; atol = 1e-2) + return +end + +end # module + +LuxTests.runtests() diff --git a/test/test_ReLU.jl b/test/test_ReLU.jl new file mode 100644 index 0000000..0b667d7 --- /dev/null +++ b/test/test_ReLU.jl @@ -0,0 +1,73 @@ +# Copyright (c) 2024: Oscar Dowson and contributors +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module ReLUTests + +using JuMP +using Test + +import HiGHS +import Ipopt +import Omelette + +is_test(x) = startswith(string(x), "test_") + +function runtests() + @testset "$name" for name in filter(is_test, names(@__MODULE__; all = true)) + getfield(@__MODULE__, name)() + end + return +end + +function test_ReLU_BigM() + model = Model(HiGHS.Optimizer) + set_silent(model) + @variable(model, x[1:2]) + f = Omelette.ReLUBigM(100.0) + y = Omelette.add_predictor(model, f, x) + @test length(y) == 2 + @test num_variables(model) == 6 + @test num_constraints(model, AffExpr, MOI.LessThan{Float64}) == 4 + @test num_constraints(model, AffExpr, MOI.GreaterThan{Float64}) == 2 + @objective(model, Min, sum(y)) + fix.(x, [-1, 2]) + optimize!(model) + @assert is_solved_and_feasible(model) + @test value.(y) ≈ [0.0, 2.0] + return +end + +function test_ReLU_SOS1() + model = Model() + @variable(model, x[1:2]) + f = Omelette.ReLUSOS1() + y = Omelette.add_predictor(model, f, x) + @test length(y) == 2 + @test num_variables(model) == 6 + @test num_constraints(model, Vector{VariableRef}, MOI.SOS1{Float64}) == 2 + # TODO(odow): add a test for solution with solver that supports SOS1 + return +end + +function test_ReLU_Quadratic() + model = Model(Ipopt.Optimizer) + set_silent(model) + @variable(model, x[1:2]) + f = Omelette.ReLUQuadratic() + y = Omelette.add_predictor(model, f, x) + @test length(y) == 2 + @test num_variables(model) == 6 + @test num_constraints(model, QuadExpr, MOI.EqualTo{Float64}) == 2 + @objective(model, Min, sum(y)) + fix.(x, [-1, 2]) + optimize!(model) + @assert is_solved_and_feasible(model) + @test value.(y) ≈ [0.0, 2.0] + return +end + +end + +ReLUTests.runtests()