Skip to content

Commit

Permalink
refactor: remove Requires and move everything to extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 29, 2024
1 parent 5b35d4d commit 7a51894
Show file tree
Hide file tree
Showing 8 changed files with 498 additions and 0 deletions.
55 changes: 55 additions & 0 deletions ext/NNlibNNPACK_jllExt/NNPACK.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using NNPACK_jll

include("libnnpack_types.jl")
include("error.jl")
include("libnnpack.jl")
include("performance.jl")
include("interface.jl")


const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()

"""
is_nnpack_available()
Checks if the current hardware is supported by NNPACK.
"""
function is_nnpack_available()
status = nnp_initialize()
if status == nnp_status_unsupported_hardware
return false
else
return true
end
end

"""
allocate_threadpool()
Allocates several threadpool based on the upper limit on the number of threads for the machine.
Allows NNPACK to intelligently choose which threadpool to use for getting the best
performance.
"""
function allocate_threadpool()
global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : UInt64(exp2(floor(log2(NNPACK_CPU_THREADS))))
for i in 0:Int(log2(NNPACK_CPU_THREADS))
threads = UInt64(2^i)
push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads)))
end
end

@init begin
status = nnp_initialize()
if status == nnp_status_unsupported_hardware
@warn "Hardware is unsupported by NNPACK so falling back to default NNlib"
end
try
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"])
catch
# Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on
# a particular machine. However, we fix the runtime threadpool here to have a max of
# 4 threads so anything above will be ignored anyways
global NNPACK_CPU_THREADS = UInt64(4)
end
allocate_threadpool()
end
15 changes: 15 additions & 0 deletions ext/NNlibNNPACK_jllExt/NNlibNNPACK_jllExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module NNlibNNPACK_jllExt

using NNlib: NNlib
using NNPACK_jll, Pkg

if isdefined(NNPACK_jll, :libnnpack)
include("NNPACK.jl")
else
@warn "NNPACK not available for your platform: " *
"$( Pkg.BinaryPlatforms.platform_name(Pkg.BinaryPlatforms.platform_key_abi()))" *
"($( Pkg.BinaryPlatforms.triplet(Pkg.BinaryPlatforms.platform_key_abi())))
You will be able to use only the default Julia NNlib backend"
end

end
83 changes: 83 additions & 0 deletions ext/NNlibNNPACK_jllExt/error.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
struct NNPACKError <: Exception
code::nnp_status
msg::AbstractString
end

Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))")

function NNPACKError(status::nnp_status)
msg = "NNPACK STATUS SUCCESS"
if status == nnp_status_invalid_batch_size
msg = "NNPACK STATUS INVALID BATCH SIZE"
elseif status == nnp_status_invalid_channels
msg = "NNPACK STATUS INVALID CHANNELS"
elseif status == nnp_status_invalid_input_channels
msg = "NNPACK STATUS INVALID INPUT CHANNELS"
elseif status == nnp_status_invalid_output_channels
msg = "NNPACK STATUS INVALID OUTPUT CHANNELS"
elseif status == nnp_status_invalid_input_size
msg = "NNPACK STATUS INVALID INPUT SIZE"
elseif status == nnp_status_invalid_input_stride
msg = "NNPACK STATUS INVALID INPUT STRIDE"
elseif status == nnp_status_invalid_input_padding
msg = "NNPACK STATUS INVALID INPUT PADDING"
elseif status == nnp_status_invalid_kernel_size
msg = "NNPACK STATUS INVALID KERNEL SIZE"
elseif status == nnp_status_invalid_pooling_size
msg = "NNPACK STATUS INVALID POOLING SIZE"
elseif status == nnp_status_invalid_pooling_stride
msg = "NNPACK STATUS INVALID POOLING STRIDE"
elseif status == nnp_status_invalid_algorithm
msg = "NNPACK STATUS INVALID ALGORITHM"
elseif status == nnp_status_invalid_transform_strategy
msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY"
elseif status == nnp_status_invalid_output_subsampling
msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING"
elseif status == nnp_status_invalid_activation
msg = "NNPACK STATUS INVALID ACTIVATION"
elseif status == nnp_status_invalid_activation_parameters
msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS"
elseif status == nnp_status_unsupported_input_size
msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE"
elseif status == nnp_status_unsupported_input_stride
msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE"
elseif status == nnp_status_unsupported_input_padding
msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING"
elseif status == nnp_status_unsupported_kernel_size
msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE"
elseif status == nnp_status_unsupported_pooling_size
msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE"
elseif status == nnp_status_unsupported_pooling_stride
msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE"
elseif status == nnp_status_unsupported_algorithm
msg = "NNPACK STATUS UNSUPPORTED ALGORITHM"
elseif status == nnp_status_unsupported_transform_strategy
msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY"
elseif status == nnp_status_unsupported_activation
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION"
elseif status == nnp_status_unsupported_activation_parameters
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS"
elseif status == nnp_status_uninitialized
msg = "NNPACK STATUS UNINITIALIZED"
elseif status == nnp_status_unsupported_hardware
msg = "NNPACK STATUS UNSUPPORTED HARDWARE"
elseif status == nnp_status_out_of_memory
msg = "NNPACK STATUS OUT OF MEMORY"
elseif status == nnp_status_insufficient_buffer
msg = "NNPACK STATUS INSUFFICIENT BUFFER"
elseif status == nnp_status_misaligned_buffer
msg = "NNPACK STATUS MISALIGNED BUFFER"
end
NNPACKError(status, msg)
end

macro nnpack_check(nnp_func)
quote
local err::nnp_status
err = $(esc(nnp_func))
if err != nnp_status_success
throw(NNPACKError(err))
end
err
end
end
50 changes: 50 additions & 0 deletions ext/NNlibNNPACK_jllExt/impl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4}}
check_dims(size(x), size(y), pdims)
threadpool = select_threadpool(pdims, size(y, 4))
nnp_max_pooling_output(y, x, kernel_size(pdims), padding = padding(pdims),
stride = stride(pdims), threadpool = threadpool)
end

function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims;
b::A2 = zeros(Float32, size(x, 3)),
algo = UInt32(0)) where {A1<:Array{Float32, 4},
A2<:Array{Float32, 1}}
check_dims(size(x), size(w), size(y), cdims)
threadpool = select_threadpool(cdims, size(y, 4))

if flipkernel(cdims) == 0
w = flipweight(w)
end

nnp_convolution_output(y, x, w, b, algo = algo, padding = padding(cdims),
stride = stride(cdims), threadpool = threadpool)
end

function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims;
algo = UInt32(0)) where{A<:Array{Float32, 4}}
check_dims(size(dx), size(w), size(dy), cdims)
threadpool = select_threadpool(cdims, size(dy, 4))

if flipkernel(cdims) == 0
w = flipweight(w)
end

nnp_convolution_input_gradient(dx, dy, w, algo = algo, padding = padding(cdims),
stride = stride(cdims), threadpool = threadpool)
end

function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims;
algo = UInt32(0)) where{A<:Array{Float32, 4}}
check_dims(size(x), size(dw), size(dy), cdims)
threadpool = select_threadpool(cdims, size(dy, 4))

nnp_convolution_kernel_gradient(dw, x, dy, algo = algo, padding = padding(cdims),
stride = stride(cdims), threadpool = threadpool)

if flipkernel(cdims) == 0
dw .= flipweight(dw)
end

dw
end

44 changes: 44 additions & 0 deletions ext/NNlibNNPACK_jllExt/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
include("impl.jl")

## NNPACK supports only Float32
for (front_name, backend) in (
:conv => :_nnpack,
:∇conv_data => :_nnpack,
:∇conv_filter => :_nnpack,
)
@eval begin
function NNlib.$(Symbol("$(front_name)$(backend)!"))(
out::Array{T1,4}, in1::Array{T2,4}, in2::Array{T3,4},
cdims::ConvDims; kwargs...) where {T1, T2, T3}
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1
# Output must of the same type as in the function signature
T1.($(Symbol("$(front_name)$(backend)!"))(Float32.(out), Float32.(in1),
Float32.(in2), cdims; kwargs...))
end
end
end

function maxpool_nnpack!(y::Array{T1, 4}, x::Array{T2, 4}, pdims::PoolDims;
kwargs...) where {T1, T2}
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1
# We want the output to be of the same type as desired
T1.(maxpool_nnpack!(Float32.(y), Float32.(x), pdims; kwargs...))
end

"""
nnpack_supported_operation(cdims::ConvDims)
nnpack_supported_operation(pdims::PoolDims)
Returns `true` if nnpack supports the convolution/pooling operation for the given parameters.
"""
function nnpack_supported_operation(pdims::PoolDims{2, K, S, P, (1, 1)}) where {K, S, P}
val = input_size(pdims)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K
return val .% S == (0, 0) ? true : false
end

function nnpack_supported_operation(cdims::ConvDims{2, K, (1, 1), P, (1, 1)}) where {K, S, P}
return true
end

# Return false for everything else
nnpack_supported_operation(dims) = false
Loading

0 comments on commit 7a51894

Please sign in to comment.