Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove @compact(name=...) and replace with NoShow #19

Merged
merged 9 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ include("chain.jl")

include("compact.jl")

include("noshow.jl")
export NoShow

include("new_recur.jl")

end # module Fluxperimental
41 changes: 3 additions & 38 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,7 @@ for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
```
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

You may also specify a `name` for the model, which will
be used instead of the default printout, which gives a verbatim
representation of the code used to construct the model:

```
model = @compact(w=rand(3), name="Linear(3 => 1)") do x
sum(w .* x)
end
println(model) # "Linear(3 => 1)"
```

This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.
To specify a custom printout for the model, you may find [`NoShow`](@ref) useful.
"""
macro compact(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
Expand All @@ -108,16 +95,6 @@ macro compact(_exs...)
kwexs2 = map(ex -> Expr(:kw, ex.args...), _kwexs) # handle keyword arguments provided before semicolon
kwexs = (kwexs1..., kwexs2...)

# check if user has named layer:
name = findfirst(ex -> ex.args[1] == :name, kwexs)
if name !== nothing && kwexs[name].args[2] !== nothing
length(kwexs) == 1 && error("expects keyword arguments")
name_str = kwexs[name].args[2]
# remove name from kwexs (a tuple)
kwexs = (kwexs[1:name-1]..., kwexs[name+1:end]...)
name = name_str
end

# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
Expand All @@ -136,7 +113,7 @@ macro compact(_exs...)
fex = supportself(fex, vars)

# assemble
return esc(:($CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
end

function supportself(fex::Expr, vars)
Expand All @@ -155,12 +132,11 @@ end

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
fun::F
name::Union{String,Nothing}
strings::NTuple{3,String}
setup_strings::NT1
variables::NT2
end
CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw))
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
Flux.@functor CompactLayer
Expand All @@ -179,16 +155,6 @@ end

function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
local_name = obj.name
has_explicit_name = local_name !== nothing
if has_explicit_name
if indent != 0 || length(Flux.params(obj)) <= 2
_just_show_params(io, local_name, obj, indent)
else # indent == 0
print(io, local_name)
Flux._big_finale(io, obj)
end
else # no name, so print normally
layer, input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
Expand Down Expand Up @@ -220,7 +186,6 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
else
println(io, ",")
end
end
end

# Modified from src/layers/show.jl
Expand Down
62 changes: 62 additions & 0 deletions src/noshow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

"""
NoShow(layer)
NoShow(string, layer)

This alters printing (for instance at the REPL prompt) to let you hide the complexity
of some part of a Flux model. It has no effect on the actual running of the model.

By default it prints `NoShow(...)` instead of the given layer.
If you provide a string, it prints that instead -- it can be anything,
but it may make sense to print the name of a function which will
Copy link
Contributor

@gaurav-arya gaurav-arya Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a callable function to reconstruct the layer is what's desired, I thought a bit about allowing the user to specify the function (and its args) and incorporating something like https://github.com/JuliaLang/julia/blob/197180d8589ad14fc4bc4c23782b76739c4ec5a4/base/show.jl#L522 to make this more robust. I don't think it is worth the implementation complexity, and could also easily be added later if we really wanted it, so just noting for posterity.

re-create the same structure.

# Examples

```jldoctest
julia> Chain(Dense(2 => 3), NoShow(Parallel(vcat, Dense(3 => 4), Dense(3 => 5))), Dense(9 => 10))
Chain(
Dense(2 => 3), # 9 parameters
NoShow(...), # 36 parameters
Dense(9 => 10), # 100 parameters
) # Total: 8 arrays, 145 parameters, 1.191 KiB.

julia> pseudolayer((i,o)::Pair) = NoShow(
"pseudolayer(\$i => \$o)",
Parallel(+, Dense(i => o, relu), Dense(i => o, tanh)),
)
pseudolayer (generic function with 1 method)

julia> Chain(Dense(2 => 3), pseudolayer(3 => 10), Dense(9 => 10))
Chain(
Dense(2 => 3), # 9 parameters
pseudolayer(3 => 10), # 80 parameters
Dense(9 => 10), # 100 parameters
) # Total: 8 arrays, 189 parameters, 1.379 KiB.
```
"""
struct NoShow{T}
str::String
layer::T
end

NoShow(layer) = NoShow("NoShow(...)", layer)

Flux.@functor NoShow

(no::NoShow)(x...) = no.layer(x...)

Base.show(io::IO, no::NoShow) = print(io, no.str)

Flux._show_leaflike(::NoShow) = true # I think this is right
Flux._show_children(::NoShow) = (;) # Seems to be needed?

function Base.show(io::IO, ::MIME"text/plain", m::NoShow)
if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL
Flux._big_show(io, m)
elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix
Flux._layer_show(io, m)
else
show(io, m)
end
end
58 changes: 13 additions & 45 deletions test/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
(1, 128),
(1,),
]
@test size(model(randn(n_in, 32))) == (1, 32)
@test size(model(randn(Float32, n_in, 32))) == (1, 32)
end

@testset "String representations" begin
Expand All @@ -118,15 +118,6 @@ end
@test similar_strings(get_model_string(model), expected_string)
end

mcabbott marked this conversation as resolved.
Show resolved Hide resolved
@testset "Custom naming" begin
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
tmp = sum(w(x))
return tmp + y
end
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Hierarchical models" begin
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
Expand Down Expand Up @@ -161,41 +152,6 @@ end
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Hierarchy with inner model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
) do x
w2 * w1(x)
end
expected_string = """@compact(
Model(32), # 1_024 parameters
w2 = randn(32, 32), # 1_024 parameters
w3 = randn(32), # 32 parameters
) do x
w2 * w1(x)
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Hierarchy with outer model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32)) do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
name="Model(32)"
) do x
w2 * w1(x)
end
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Dependent initializations" begin
# Test that initialization lines cannot depend on each other
@test_throws UndefVarError @compact(y = 3, z = y^2) do x
Expand Down Expand Up @@ -234,3 +190,15 @@ end
end
end


@testset "Custom naming of @compact with NoShow" begin
_model = @compact(w=Dense(32, 32)) do x, y
tmp = sum(w(x))
return tmp + y
end
model = NoShow(_model)
expected_string = "NoShow(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
model2 = NoShow("test", _model)
@test contains(get_model_string(model2), "test")
end
28 changes: 28 additions & 0 deletions test/noshow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

@testset "NoShow" begin
d23 = Dense(2 => 3)
d34 = Dense(3 => 4, tanh)
d35 = Dense(3 => 5, relu)
d910 = Dense(9 => 10)

model = Chain(d23, Parallel(vcat, d34, d35), d910)
m_no = Chain(d23, NoShow(Parallel(vcat, d34, NoShow("zzz", d35))), d910)

@test sum(length, Flux.params(model)) == sum(length, Flux.params(m_no))

xin = randn(Float32, 2, 7)
@test model(xin) ≈ m_no(xin)

# gradients
grad = gradient(m -> m(xin)[1], model)[1]
g_no = gradient(m -> m(xin)[1], m_no)[1]

@test grad.layers[2].layers[1].bias ≈ g_no.layers[2].layer.layers[1].bias
@test grad.layers[2].layers[2].bias ≈ g_no.layers[2].layer.layers[2].layer.bias

# printing -- see also compact.jl for another test
@test !contains(string(model), "NoShow(...)")
@test contains(string(m_no), "NoShow(...)")
@test !contains(string(m_no), "3 => 4")
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Flux, Fluxperimental
include("chain.jl")

include("compact.jl")
include("noshow.jl")

include("new_recur.jl")

Expand Down
Loading