Skip to content

Commit

Permalink
Tests and fix improved insertion of quadratics.
Browse files Browse the repository at this point in the history
  • Loading branch information
olof3 committed Sep 6, 2017
1 parent 61a7116 commit f172a18
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 35 deletions.
253 changes: 224 additions & 29 deletions src/dev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import Base.start
import Base.next
import Base.done
import Base.length
import Base.getindex

import Base.==

import IterTools

using StaticArrays
using PyPlot
using Polynomials
using QuadGK

Expand All @@ -23,18 +25,18 @@ include("types/QuadraticPolynomial.jl")
include("types/PiecewiseQuadratic.jl")
include("types/QuadraticForm2.jl")


global const DEBUG = false


function roots{T}(p::QuadraticPolynomial{T})
quad_expr = p.b^2 - 4*p.a*p.c
if quad_expr < 0
b2_minus_4ac = p.b^2 - 4*p.a*p.c
if b2_minus_4ac < 0
return NaN, NaN
end

if p.a != 0
term1 = (p.b / 2 / p.a)
term2 = sqrt(quad_expr) / 2 / abs(p.a)
term2 = sqrt(b2_minus_4ac) / 2 / abs(p.a)
return -term1-term2, -term1+term2
else
term = -p.c / p.b
Expand All @@ -53,7 +55,7 @@ Inserts a quadratic polynomial ρ into the linked list Λ which represents a pie
"""
function add_quadratic{T}::PiecewiseQuadratic{T}, ρ::QuadraticPolynomial{T})

if isnull.next)
if isnull.next) # I.e. the piecewise quadratic object is empty, perhaps better to add dummy polynomial
insert(Λ, ρ, -1e9)
return
end
Expand Down Expand Up @@ -183,32 +185,230 @@ function minimize_wrt_x2(qf::QuadraticForm2)
end
end



##

function add_quadratic2{T}::PiecewiseQuadratic{T}, ρ::QuadraticPolynomial{T})

DEBUG && println("Inserting: ", ρ)
if isnull.next) # I.e. the piecewise quadratic object is empty, perhaps better to add dummy polynomial
insert(Λ, ρ, -1e9)
return
end

λ_prev = Λ
λ_curr = Λ.next

#left_endpoint = NaN

while ~isnull(λ_curr)
DEBUG && println(Λ)

left_endpoint = unsafe_get(λ_curr).left_endpoint

# TODO: This is probably not needed now..
if left_endpoint == -1e9
#println("minus-inf")
left_endpoint = -10000.0
end

right_endpoint = get_right_endpoint(unsafe_get(λ_curr))

if right_endpoint == 1e9
#println("inf")
right_endpoint = left_endpoint + 20000.0
end

Δa = ρ.a - unsafe_get(λ_curr).p.a
Δb = ρ.b - unsafe_get(λ_curr).p.b
Δc = ρ.c - unsafe_get(λ_curr).p.c

b2_minus_4ac = Δb^2 - 4*Δa*Δc

if Δa > 0 # ρ has greater curvature, i.e., ρ is smallest in the middle
if b2_minus_4ac <= 0
# No intersections, old quadratic is smallest, just step forward
λ_prev = unsafe_get(λ_curr)
λ_curr = λ_prev.next
else

# Compute the intersections
term1 = -(Δb / 2 / Δa)
term2 = sqrt(b2_minus_4ac) / 2 / abs(Δa)
root1, root2 = term1-term2, term1+term2

DEBUG && println("Δa > 0 root1:", root1, " root2:", root2)

# Check where the intersections are and act accordingly
if root1 >= right_endpoint || root2 <= left_endpoint
# No intersections, old quadratic is smallest, step forward
DEBUG && println("Two intersections to the side")
λ_prev = unsafe_get(λ_curr)
λ_curr = λ_prev.next
elseif root1 <= left_endpoint && root2 >= right_endpoint
# No intersections, new quadratic is smallest
DEBUG && println("One intersections on either side")
λ_prev, λ_curr = update_segment_new(λ_prev, unsafe_get(λ_curr), ρ)
elseif root1 > left_endpoint && root2 < right_endpoint
DEBUG && println("Two intersections within the interval")
λ_prev, λ_curr = update_segment_old_new_old(unsafe_get(λ_curr), ρ, root1, root2)
elseif root1 > left_endpoint
DEBUG && println("Root 1 within the interval")
λ_prev, λ_curr = update_segment_old_new(unsafe_get(λ_curr), ρ, root1)
elseif root2 < right_endpoint
DEBUG && println("Root 2 within the interval")
λ_prev, λ_curr = update_segment_new_old(λ_prev, unsafe_get(λ_curr), ρ, root2)
else
error("Shouldn't end up here")
end
end

elseif Δa < 0 # ρ has lower curvature, i.e., ρ is smallest on the sides
if b2_minus_4ac <= 0
λ_prev, λ_curr = update_segment_new(λ_prev, unsafe_get(λ_curr), ρ)
else
# Compute the intersections
term1 = -(Δb / 2 / Δa)
term2 = sqrt(b2_minus_4ac) / 2 / abs(Δa)
root1, root2 = term1-term2, term1+term2
DEBUG && println("Δa < 0 root1:", root1, " root2:", root2)

# Check where the intersections are and act accordingly
if root1 >= right_endpoint || root2 <= left_endpoint
# No intersections, ρ is smallest
λ_prev, λ_curr = update_segment_new(λ_prev, unsafe_get(λ_curr), ρ)
elseif root1 <= left_endpoint && root2 >= right_endpoint
# No intersections, old quadratic is smallest, just step forward
λ_prev = unsafe_get(λ_curr)
λ_curr = λ_prev.next
elseif root1 > left_endpoint && root2 < right_endpoint
# Two intersections within the interval
λ_prev, λ_curr = update_segment_new_old_new(λ_prev, unsafe_get(λ_curr), ρ, root1, root2)
elseif root1 > left_endpoint
λ_prev, λ_curr = update_segment_new_old(λ_prev, unsafe_get(λ_curr), ρ, root1)
elseif root2 < right_endpoint
λ_prev, λ_curr = update_segment_old_new(unsafe_get(λ_curr), ρ, root2)
else
error("Shouldn't end up here")
end
end
else # a == 0.0
DEBUG && pritnln("Δa == 0")

if Δb == 0
if Δc >= 0
λ_prev, λ_curr = update_segment_do_nothing(unsafe_get(λ_curr))
else
λ_prev, λ_curr = update_segment_new(λ_prev, unsafe_get(λ_curr), ρ)
end
continue
end

root = -Δc / Δb
if Δb > 0
if root < left_endpoint
λ_prev, λ_curr = update_segment_do_nothing(unsafe_get(λ_curr))
elseif root > right_endpoint
λ_prev, λ_curr = update_segment_new(λ_prev, unsafe_get(λ_curr), ρ)
else
λ_prev, λ_curr = update_segment_new_old(λ_prev, unsafe_get(λ_curr), ρ, root)
end
else
if root < left_endpoint
λ_prev, λ_curr = update_segment_new(λ_prev, unsafe_get(λ_curr), ρ)
elseif root > right_endpoint
λ_prev, λ_curr = update_segment_do_nothing(unsafe_get(λ_curr))
else
λ_prev, λ_curr = update_segment_old_new(unsafe_get(λ_curr), ρ, root)
end
end
end

end
return
end


@inline function update_segment_new_old(λ_prev, λ_curr, ρ, break1)
if λ_prev.p === ρ
λ_curr.left_endpoint = break1
else
λ_prev.next = PiecewiseQuadratic(ρ, λ_curr.left_endpoint, λ_curr)
λ_curr.left_endpoint = break1
end
return λ_curr, λ_curr.next
end

@inline function update_segment_new_old_new(λ_prev, λ_curr, ρ, break1, break2)
update_segment_new_old(λ_prev, λ_curr, ρ, break1)
return update_segment_old_new(λ_curr, ρ, break2)
end

@inline function update_segment_old_new_old(λ_curr, ρ, break1, break2)
second_old_pwq_segment = PiecewiseQuadratic(λ_curr.p, break2, λ_curr.next)
new_pwq_segment = PiecewiseQuadratic(ρ, break1, second_old_pwq_segment)
λ_curr.next = new_pwq_segment
return second_old_pwq_segment, second_old_pwq_segment.next
end

@inline function update_segment_old_new(λ_curr, ρ, break1)
new_pwq_segment = PiecewiseQuadratic(ρ, break1, λ_curr.next)
λ_curr.next = new_pwq_segment
return new_pwq_segment, new_pwq_segment.next
end

@inline function update_segment_new(λ_prev, λ_curr, ρ)
if λ_prev.p === ρ
λ_prev.next = λ_curr.next
v1, v2 = λ_prev, λ_curr.next
else
λ_curr.p = ρ
v1, v2 = λ_curr, λ_curr.next
end
return v1, v2 #λ_curr, λ_curr.next
end

@inline function update_segment_do_nothing(λ_curr)
return λ_curr, λ_curr.next #λ_curr, λ_curr.next
end
###





# Takes a quadratic form in [x1; x2] and a polynomial in x2
# and returns the minimum of the sum wrt to x2,
# i.e. a polynomial of x1
function minimize_wrt_x2_fast{T}(qf::QuadraticForm2{T},p::QuadraticPolynomial{T})
@inline function minimize_wrt_x2_fast{T}(qf::QuadraticForm2{T},p::QuadraticPolynomial{T})

# Create quadratic form representing the sum of qf and p
P = qf.P
q = qf.q
r = qf.r

if P[2,2] + p.a > 0
QuadraticPolynomial(P[1,1] - P[1,2]^2 / (P[2,2]+p.a),
q[1] - P[1,2]*(q[2]+p.b) / (P[2,2]+p.a),
(r+p.c) - (q[2]+p.b)^2 / (P[2,2]+p.a)/ 4)
elseif (P[2,2]+p.a) == 0 || P[1,2] == 0 || (q[2]+p.b) == 0 #why are the two last conditions needed?
QuadraticPolynomial(P[1,1], q[1], r+p.c)
P22_new = P[2,2] + p.a

local v
if P22_new > 0
v = QuadraticPolynomial(P[1,1] - P[1,2]^2 / P22_new,
q[1] - P[1,2]*(q[2]+p.b) / P22_new,
(r+p.c) - (q[2]+p.b)^2 / P22_new/ 4)
elseif P22_new == 0 #|| P[1,2] == 0 || (q[2]+p.b) == 0 #why are the two last conditions needed?
v = QuadraticPolynomial(P[1,1], q[1], r+p.c)
else
# FIXME: what are these condtions?
# There are some special cases, but disregards these
QuadraticPolynomial(0.,0.,-Inf)
v = QuadraticPolynomial(0.,0.,-Inf)
end
return v
end

"""
Find optimal fit
"""
function find_optimal_fit{T}(Λ_0::Array{PiecewiseQuadratic{T},1}, ℓ::Array{QuadraticForm2{T},2}, M)
function find_optimal_fit{T}(Λ_0::Array{PiecewiseQuadratic{T},1}, ℓ::Array{QuadraticForm2{T},2}, M::Int, upper_bound=Inf)
N = size(ℓ, 2)

Λ = Array{PiecewiseQuadratic{T}}(M, N)
Expand All @@ -227,10 +427,15 @@ function find_optimal_fit{T}(Λ_0::Array{PiecewiseQuadratic{T},1}, ℓ::Array{Qu
# ℓ[i,ip] + dev.QuadraticForm2{T}(@SMatrix([0. 0; 0 1])*p.a, @SVector([0., 1])*p.b, p.c))
# # Avoid ceting two extra QuadraticForm2
ρ = dev.minimize_wrt_x2_fast(ℓ[i,ip], p)

if unsafe_minimum(ρ) > upper_bound
continue
end

ρ.time_index = ip
ρ.ancestor = p

dev.add_quadratic(Λ_new, ρ)
dev.add_quadratic2(Λ_new, ρ)
end
end
Λ[m, i] = Λ_new
Expand All @@ -240,18 +445,7 @@ function find_optimal_fit{T}(Λ_0::Array{PiecewiseQuadratic{T},1}, ℓ::Array{Qu
end


# Finds the minimum of a positive definite quadratic one variable polynomial
# the find_minimum fcn returns (opt_x, opt_val)
function find_minimum(p::QuadraticPolynomial)
if p.a <= 0
println("No unique minimum exists")
return (NaN, NaN)
else
x_opt = -p.b / 2 / p.a
f_opt = -p.b^2/4/p.a + p.c
return (x_opt, f_opt)
end
end



function find_minimum::PiecewiseQuadratic)
Expand Down Expand Up @@ -368,15 +562,16 @@ function compute_discrete_transition_costs(g)
G3[k] = G3[k-1] + g[k-1]^2
end

# The P-matrices only depend on the distance d=ip-i
P_mats = Vector{SMatrix{2,2,Float64,4}}(N-1)
P_mats[1] = @SMatrix [1.0 0; 0 0]
for d=2:N-1
off_diag_elems = sum([k*(d - k) for k=0:d-1])
P_mats[d] = @SMatrix [P_mats[d-1][1,1] + d^2 off_diag_elems;
off_diag_elems P_mats[d-1][1,1]]
off_diag_elems P_mats[d-1][1,1]]
end

P_mats = P_mats ./ (1.0:N-1).^2
P_mats = P_mats ./ (1.0:N-1).^2 # FIXME: Why can't this be done above in the loop?

#P_invs = inv.(P_mats)

Expand Down
14 changes: 12 additions & 2 deletions src/types/QuadraticPolynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function QuadraticPolynomial{T}(a::T, b::T, c::T) where {T}
# @assert a ≥ 0, # Δ may have negative a ...
new(a, b, c, Nullable{QuadraticPolynomial{T}}(),-1)
end
function QuadraticPolynomial{T}(a::T, b::T, c::T,ancestor, time_index) where {T}
function QuadraticPolynomial{T}(a::T, b::T, c::T, ancestor, time_index) where {T}
@assert a 0
new(a, b, c, ancestor, time_index)
end
Expand Down Expand Up @@ -56,10 +56,20 @@ function find_minimum(p::QuadraticPolynomial)
println("No unique minimum exists")
return (NaN, NaN)
else
return (-p.b / 2 / p.a, -p.b^2/4 + p.c)
x_opt = -p.b / 2 / p.a
f_opt = -p.b^2/4/p.a + p.c
return (x_opt, f_opt)
end
end

"""
Minimum of a quadratic function which is assumed to be positive definite
No checks of this is done
"""
@inline function unsafe_minimum{T}(p::QuadraticPolynomial{T})
return (-p.b^2/4/p.a + p.c)::T
end



"""
Expand Down
Loading

0 comments on commit f172a18

Please sign in to comment.