Skip to content

Commit

Permalink
Add optimization callbacks that fire on a marker function
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 4, 2024
1 parent ced39bb commit 9b54632
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ const __llvm_initialized = Ref(false)
end
end

for (name, plugin) in PLUGINS
if plugin.finalize_module !== nothing
plugin.finalize_module(job, compiled, ir)
end
end

@timeit_debug to "IR post-processing" begin
# mark everything internal except for entrypoints and any exported
# global variables. this makes sure that the optimizer can, e.g.,
Expand Down Expand Up @@ -335,7 +341,7 @@ const __llvm_initialized = Ref(false)
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. Instead, process the merged module
# from all the jobs here.
if toplevel
if toplevel # TODO: We should be able to remove this now
entry = finish_ir!(job, ir, entry)

# for (job′, fn′) in jobs
Expand Down
41 changes: 40 additions & 1 deletion src/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1)
tm = llvm_machine(job.config.target)

global current_job
global current_job # ScopedValue?
current_job = job

@dispose pb=NewPMPassBuilder() begin
Expand All @@ -14,6 +14,12 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
register!(pb, LowerKernelStatePass())
register!(pb, CleanupKernelStatePass())

for (name, plugin) in PLUGINS
if plugin.pipeline_callback !== nothing
register!(pb, CallbackPass(name, plugin.pipeline_callback))
end
end

add!(pb, NewPMModulePassManager()) do mpm
buildNewPMPipeline!(mpm, job, opt_level)
end
Expand All @@ -24,6 +30,20 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
return
end

struct Plugin
finalize_module # f(@nospecialize(job), compiled, mod::LLVM,Module)
pipeline_callback # f(@nospecialize(job), intrinsic, mod::LLVM.Module)
end

# TODO: Priority heap to provide order between different plugins
const PLUGINS = Dict{String, Plugin}()
function register_plugin!(name::String, check::Bool=true; finalize_module = nothing, pipeline_callback = nothing)
if check && haskey(PLUGINS, name)
error("GPUCompiler plugin with name $name is already registered")
end
PLUGINS[name] = Plugin(finalize_module, pipeline_callback)
end

function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
buildEarlySimplificationPipeline(mpm, job, opt_level)
add!(mpm, AlwaysInlinerPass())
Expand All @@ -41,6 +61,11 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
add!(fpm, WarnMissedTransformationsPass())
end
end
for (name, plugin) in PLUGINS
if plugin.pipeline_callback !== nothing
add!(mpm, CallbackPass(name, plugin.pipeline_callback))
end
end
buildIntrinsicLoweringPipeline(mpm, job, opt_level)
buildCleanupPipeline(mpm, job, opt_level)
end
Expand Down Expand Up @@ -423,3 +448,17 @@ function lower_ptls!(mod::LLVM.Module)
return changed
end
LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)


function callback_pass!(name, callback::F, mod::LLVM.Module) where F
job = current_job::CompilerJob
changed = false

if haskey(functions(mod), name)
marker = functions(mod)[name]
changed = callback(job, marker, mod)
end
return changed
end

CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod))
31 changes: 31 additions & 0 deletions test/plugin_testsetup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
@testsetup module Plugin

using Test
using ReTestItems
import LLVM
import GPUCompiler

function mark(x)
ccall("extern gpucompiler.mark", llvmcall, Nothing, (Int,), x)
end

function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module)
changed = false

for use in LLVM.uses(intrinsic)
val = LLVM.user(use)
if isempty(LLVM.uses(val))
LLVM.erase!(val)
changed = true
else
# the validator will detect this
end
end

return changed
end

GPUCompiler.register_plugin!("gpucompiler.mark", false,
pipeline_callback=remove_mark!)

end
16 changes: 16 additions & 0 deletions test/ptx_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testitem "PTX" setup=[PTX, Helpers] begin

using LLVM
import InteractiveUtils

############################################################################################

Expand Down Expand Up @@ -406,7 +407,22 @@ precompile_test_harness("Inference caching") do load_path
@test check_presence(identity_mi, token)
end
end
end # testitem

############################################################################################

@testitem "PTX plugin" setup=[PTX, Plugin] begin

import InteractiveUtils

@testset "Pipeline callbacks" begin
function kernel(x)
Plugin.mark(x)
return
end
ir = sprint(io->InteractiveUtils.code_llvm(io, kernel, Tuple{Int}))
@test occursin("gpucompiler.mark", ir)
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int}))
@test !occursin("gpucompiler.mark", ir)
end
end #testitem
1 change: 0 additions & 1 deletion test/ptx_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

using GPUCompiler


# create a PTX-based test compiler, and generate reflection methods for it

include("runtime.jl")
Expand Down

0 comments on commit 9b54632

Please sign in to comment.