Skip to content

Distributed computation of differentiation pipelines to use multiple workers, devices, GPU, etc. since Julia wasn't fast enough already

License

Notifications You must be signed in to change notification settings

FluxML/DaggerFlux.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

88 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DaggerFlux.jl

This is currently an early stage integration between Dagger.jl and Flux.jl to allow for distributed computation of differentiation pipelines to use multiple workers, devices, GPUs etc. This package enables model parallelism for Flux models.

Basic Usage

To see the package in action, we would have to start julia with multiple workers.

Also make sure that the workers have access to the environment and code that is going to be run. This is typically done with the help of the exeflags keyword in addprocs. Something like addprocs(2, exeflags = "--project") is usually enough. Please ensure that the environment has access to DaggerFlux.

julia> using DaggerFlux, Dagger, Flux, Zygote

julia> @everywhere function layer(x)
         @show myid()
         x
       end

julia> ip = rand(3,3);

julia> c = Chain(layer, layer, layer, layer)
Chain(layer, layer, layer, layer)

julia> dc = DaggerChain(c)
DaggerChain(Chain(layer, layer, layer, layer))

julia> dc(ip) # notice the output is a Dagger Thunk rather than an eager evaluation
Thunk[4](layer, (Thunk[3](layer, ...),))

julia> collect(dc(ip))
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
3×3 Matrix{Float64}:
 0.813575   0.828228  0.0630336
 0.0755053  0.215495  0.64503
 0.462957   0.345485  0.83312

Notice that the model was now evaluated across multiple workers.

Flux models

This is basically the same as before, but we will demo how to differentiate through Flux models.

julia> y, back = Zygote.pullback((m,x) -> m(x), dc, ip)
(Thunk[135](layer, (Thunk[131](layer, ...),)), Zygote.var"#46#47"{typeof((#11))}(∂(#11)))

julia> collect(y)
      From worker 3:    myid() = 3
      From worker 3:    myid() = 3
      From worker 2:    myid() = 2
      From worker 2:    myid() = 2
3×3 Matrix{Float64}:
 0.813575   0.828228  0.0630336
 0.0755053  0.215495  0.64503
 0.462957   0.345485  0.83312

julia> back(one.(y))
      From worker 2:    myid() = 2
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
      [...]
      From worker 2:    myid() = 2
      From worker 3:    myid() = 3
      From worker 2:    myid() = 2
((chain = (layers = (nothing, nothing, nothing, nothing),),), [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0])

And now one can optimise over entire models!

Of course one can substitute our dummy model here with more routine models such as ResNet from Metalhead.jl. Here's a slightly simpler model for an example.

julia> m = Chain(Dense(2,2), Dense(2,2))
Chain(
  Dense(2, 2),                          # 6 parameters
  Dense(2, 2),                          # 6 parameters
)                   # Total: 4 arrays, 12 parameters, 304 bytes.

julia> dm = DaggerChain(m)
DaggerChain(Chain(Dense(2, 2), Dense(2, 2)))

julia> y, b = Zygote.pullback((m,x) -> m(x), dm, rand(Float32, 2
,2))
(Thunk[150](Dense(2, 2), (Thunk[149](Dense(2, 2), ...),)), Zygote.var"#46#47"{typeof((#13))}(∂(#13)))

julia> b(one.(y))
((chain = (layers = ((weight = Float32[1.0398567 0.45392603; 0.4867683 0.21248773], bias = Float32[1.6065784, 0.75205684], σ = nothing), (weight = Float32[-1.247205 1.2783735; -1.247205 1.278
735], bias = Float32[2.0, 2.0], σ = nothing)),),), Float32[-0.14533046 -0.14533046; -0.58934844 -0.58934844])

Contributions welcome to the GitHub repository!

About

Distributed computation of differentiation pipelines to use multiple workers, devices, GPU, etc. since Julia wasn't fast enough already

Topics

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Sponsor this project

 

Contributors 4

  •  
  •  
  •  
  •  

Languages