-
-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: remove Requires and move everything to extensions
- Loading branch information
Showing
8 changed files
with
498 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
using NNPACK_jll | ||
|
||
include("libnnpack_types.jl") | ||
include("error.jl") | ||
include("libnnpack.jl") | ||
include("performance.jl") | ||
include("interface.jl") | ||
|
||
|
||
const shared_threadpool_dict = Dict{UInt64, Base.RefValue}() | ||
|
||
""" | ||
is_nnpack_available() | ||
Checks if the current hardware is supported by NNPACK. | ||
""" | ||
function is_nnpack_available() | ||
status = nnp_initialize() | ||
if status == nnp_status_unsupported_hardware | ||
return false | ||
else | ||
return true | ||
end | ||
end | ||
|
||
""" | ||
allocate_threadpool() | ||
Allocates several threadpool based on the upper limit on the number of threads for the machine. | ||
Allows NNPACK to intelligently choose which threadpool to use for getting the best | ||
performance. | ||
""" | ||
function allocate_threadpool() | ||
global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : UInt64(exp2(floor(log2(NNPACK_CPU_THREADS)))) | ||
for i in 0:Int(log2(NNPACK_CPU_THREADS)) | ||
threads = UInt64(2^i) | ||
push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads))) | ||
end | ||
end | ||
|
||
@init begin | ||
status = nnp_initialize() | ||
if status == nnp_status_unsupported_hardware | ||
@warn "Hardware is unsupported by NNPACK so falling back to default NNlib" | ||
end | ||
try | ||
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"]) | ||
catch | ||
# Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on | ||
# a particular machine. However, we fix the runtime threadpool here to have a max of | ||
# 4 threads so anything above will be ignored anyways | ||
global NNPACK_CPU_THREADS = UInt64(4) | ||
end | ||
allocate_threadpool() | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
module NNlibNNPACK_jllExt | ||
|
||
using NNlib: NNlib | ||
using NNPACK_jll, Pkg | ||
|
||
if isdefined(NNPACK_jll, :libnnpack) | ||
include("NNPACK.jl") | ||
else | ||
@warn "NNPACK not available for your platform: " * | ||
"$( Pkg.BinaryPlatforms.platform_name(Pkg.BinaryPlatforms.platform_key_abi()))" * | ||
"($( Pkg.BinaryPlatforms.triplet(Pkg.BinaryPlatforms.platform_key_abi()))) | ||
You will be able to use only the default Julia NNlib backend" | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
struct NNPACKError <: Exception | ||
code::nnp_status | ||
msg::AbstractString | ||
end | ||
|
||
Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))") | ||
|
||
function NNPACKError(status::nnp_status) | ||
msg = "NNPACK STATUS SUCCESS" | ||
if status == nnp_status_invalid_batch_size | ||
msg = "NNPACK STATUS INVALID BATCH SIZE" | ||
elseif status == nnp_status_invalid_channels | ||
msg = "NNPACK STATUS INVALID CHANNELS" | ||
elseif status == nnp_status_invalid_input_channels | ||
msg = "NNPACK STATUS INVALID INPUT CHANNELS" | ||
elseif status == nnp_status_invalid_output_channels | ||
msg = "NNPACK STATUS INVALID OUTPUT CHANNELS" | ||
elseif status == nnp_status_invalid_input_size | ||
msg = "NNPACK STATUS INVALID INPUT SIZE" | ||
elseif status == nnp_status_invalid_input_stride | ||
msg = "NNPACK STATUS INVALID INPUT STRIDE" | ||
elseif status == nnp_status_invalid_input_padding | ||
msg = "NNPACK STATUS INVALID INPUT PADDING" | ||
elseif status == nnp_status_invalid_kernel_size | ||
msg = "NNPACK STATUS INVALID KERNEL SIZE" | ||
elseif status == nnp_status_invalid_pooling_size | ||
msg = "NNPACK STATUS INVALID POOLING SIZE" | ||
elseif status == nnp_status_invalid_pooling_stride | ||
msg = "NNPACK STATUS INVALID POOLING STRIDE" | ||
elseif status == nnp_status_invalid_algorithm | ||
msg = "NNPACK STATUS INVALID ALGORITHM" | ||
elseif status == nnp_status_invalid_transform_strategy | ||
msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY" | ||
elseif status == nnp_status_invalid_output_subsampling | ||
msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING" | ||
elseif status == nnp_status_invalid_activation | ||
msg = "NNPACK STATUS INVALID ACTIVATION" | ||
elseif status == nnp_status_invalid_activation_parameters | ||
msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS" | ||
elseif status == nnp_status_unsupported_input_size | ||
msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE" | ||
elseif status == nnp_status_unsupported_input_stride | ||
msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE" | ||
elseif status == nnp_status_unsupported_input_padding | ||
msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING" | ||
elseif status == nnp_status_unsupported_kernel_size | ||
msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE" | ||
elseif status == nnp_status_unsupported_pooling_size | ||
msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE" | ||
elseif status == nnp_status_unsupported_pooling_stride | ||
msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE" | ||
elseif status == nnp_status_unsupported_algorithm | ||
msg = "NNPACK STATUS UNSUPPORTED ALGORITHM" | ||
elseif status == nnp_status_unsupported_transform_strategy | ||
msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY" | ||
elseif status == nnp_status_unsupported_activation | ||
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION" | ||
elseif status == nnp_status_unsupported_activation_parameters | ||
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS" | ||
elseif status == nnp_status_uninitialized | ||
msg = "NNPACK STATUS UNINITIALIZED" | ||
elseif status == nnp_status_unsupported_hardware | ||
msg = "NNPACK STATUS UNSUPPORTED HARDWARE" | ||
elseif status == nnp_status_out_of_memory | ||
msg = "NNPACK STATUS OUT OF MEMORY" | ||
elseif status == nnp_status_insufficient_buffer | ||
msg = "NNPACK STATUS INSUFFICIENT BUFFER" | ||
elseif status == nnp_status_misaligned_buffer | ||
msg = "NNPACK STATUS MISALIGNED BUFFER" | ||
end | ||
NNPACKError(status, msg) | ||
end | ||
|
||
macro nnpack_check(nnp_func) | ||
quote | ||
local err::nnp_status | ||
err = $(esc(nnp_func)) | ||
if err != nnp_status_success | ||
throw(NNPACKError(err)) | ||
end | ||
err | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4}} | ||
check_dims(size(x), size(y), pdims) | ||
threadpool = select_threadpool(pdims, size(y, 4)) | ||
nnp_max_pooling_output(y, x, kernel_size(pdims), padding = padding(pdims), | ||
stride = stride(pdims), threadpool = threadpool) | ||
end | ||
|
||
function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims; | ||
b::A2 = zeros(Float32, size(x, 3)), | ||
algo = UInt32(0)) where {A1<:Array{Float32, 4}, | ||
A2<:Array{Float32, 1}} | ||
check_dims(size(x), size(w), size(y), cdims) | ||
threadpool = select_threadpool(cdims, size(y, 4)) | ||
|
||
if flipkernel(cdims) == 0 | ||
w = flipweight(w) | ||
end | ||
|
||
nnp_convolution_output(y, x, w, b, algo = algo, padding = padding(cdims), | ||
stride = stride(cdims), threadpool = threadpool) | ||
end | ||
|
||
function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims; | ||
algo = UInt32(0)) where{A<:Array{Float32, 4}} | ||
check_dims(size(dx), size(w), size(dy), cdims) | ||
threadpool = select_threadpool(cdims, size(dy, 4)) | ||
|
||
if flipkernel(cdims) == 0 | ||
w = flipweight(w) | ||
end | ||
|
||
nnp_convolution_input_gradient(dx, dy, w, algo = algo, padding = padding(cdims), | ||
stride = stride(cdims), threadpool = threadpool) | ||
end | ||
|
||
function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims; | ||
algo = UInt32(0)) where{A<:Array{Float32, 4}} | ||
check_dims(size(x), size(dw), size(dy), cdims) | ||
threadpool = select_threadpool(cdims, size(dy, 4)) | ||
|
||
nnp_convolution_kernel_gradient(dw, x, dy, algo = algo, padding = padding(cdims), | ||
stride = stride(cdims), threadpool = threadpool) | ||
|
||
if flipkernel(cdims) == 0 | ||
dw .= flipweight(dw) | ||
end | ||
|
||
dw | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
include("impl.jl") | ||
|
||
## NNPACK supports only Float32 | ||
for (front_name, backend) in ( | ||
:conv => :_nnpack, | ||
:∇conv_data => :_nnpack, | ||
:∇conv_filter => :_nnpack, | ||
) | ||
@eval begin | ||
function NNlib.$(Symbol("$(front_name)$(backend)!"))( | ||
out::Array{T1,4}, in1::Array{T2,4}, in2::Array{T3,4}, | ||
cdims::ConvDims; kwargs...) where {T1, T2, T3} | ||
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1 | ||
# Output must of the same type as in the function signature | ||
T1.($(Symbol("$(front_name)$(backend)!"))(Float32.(out), Float32.(in1), | ||
Float32.(in2), cdims; kwargs...)) | ||
end | ||
end | ||
end | ||
|
||
function maxpool_nnpack!(y::Array{T1, 4}, x::Array{T2, 4}, pdims::PoolDims; | ||
kwargs...) where {T1, T2} | ||
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1 | ||
# We want the output to be of the same type as desired | ||
T1.(maxpool_nnpack!(Float32.(y), Float32.(x), pdims; kwargs...)) | ||
end | ||
|
||
""" | ||
nnpack_supported_operation(cdims::ConvDims) | ||
nnpack_supported_operation(pdims::PoolDims) | ||
Returns `true` if nnpack supports the convolution/pooling operation for the given parameters. | ||
""" | ||
function nnpack_supported_operation(pdims::PoolDims{2, K, S, P, (1, 1)}) where {K, S, P} | ||
val = input_size(pdims)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K | ||
return val .% S == (0, 0) ? true : false | ||
end | ||
|
||
function nnpack_supported_operation(cdims::ConvDims{2, K, (1, 1), P, (1, 1)}) where {K, S, P} | ||
return true | ||
end | ||
|
||
# Return false for everything else | ||
nnpack_supported_operation(dims) = false |
Oops, something went wrong.