From 088d5eea6aa00c2a19acdf559e0205bd9f1eaf59 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Tue, 17 Sep 2024 22:19:11 -0300 Subject: [PATCH] Merge os-log --- lib/mtl/MTL.jl | 1 + lib/mtl/command_queue.jl | 25 +++ lib/mtl/log_state.jl | 40 +++++ src/Metal.jl | 1 + src/device/intrinsics/output.jl | 301 ++++++++++++++++++++++++++++++++ src/state.jl | 22 ++- test/output.jl | 154 ++++++++++++++++ 7 files changed, 543 insertions(+), 1 deletion(-) create mode 100644 lib/mtl/log_state.jl create mode 100644 src/device/intrinsics/output.jl create mode 100644 test/output.jl diff --git a/lib/mtl/MTL.jl b/lib/mtl/MTL.jl index e8d77898c..fb93111ee 100644 --- a/lib/mtl/MTL.jl +++ b/lib/mtl/MTL.jl @@ -34,6 +34,7 @@ include("events.jl") include("fences.jl") include("heap.jl") include("buffer.jl") +include("log_state.jl") include("command_queue.jl") include("command_buf.jl") include("compute_pipeline.jl") diff --git a/lib/mtl/command_queue.jl b/lib/mtl/command_queue.jl index ed41bd85d..141d17c97 100644 --- a/lib/mtl/command_queue.jl +++ b/lib/mtl/command_queue.jl @@ -1,3 +1,21 @@ +export MTLCommandQueueDescriptor + +@objcwrapper immutable=false MTLCommandQueueDescriptor <: NSObject + +@objcproperties MTLCommandQueueDescriptor begin + @autoproperty maxCommandBufferCount::NSUInteger + @autoproperty logState::id{MTLLogState} setter=setLogState +end + +function MTLCommandQueueDescriptor() + handle = @objc [MTLCommandQueueDescriptor alloc]::id{MTLCommandQueueDescriptor} + obj = MTLCommandQueueDescriptor(handle) + finalizer(release, obj) + @objc [obj::id{MTLCommandQueueDescriptor} init]::id{MTLCommandQueueDescriptor} + return obj +end + + export MTLCommandQueue @objcwrapper immutable=false MTLCommandQueue <: NSObject @@ -13,3 +31,10 @@ function MTLCommandQueue(dev::MTLDevice) finalizer(release, obj) return obj end + +function MTLCommandQueue(dev::MTLDevice, descriptor::MTLCommandQueueDescriptor) + handle = @objc [dev::id{MTLDevice} newCommandQueueWithDescriptor:descriptor::id{MTLCommandQueueDescriptor}]::id{MTLCommandQueue} + obj = MTLCommandQueue(handle) + finalizer(release, obj) + return obj +end \ No newline at end of file diff --git a/lib/mtl/log_state.jl b/lib/mtl/log_state.jl new file mode 100644 index 000000000..f2f8430e8 --- /dev/null +++ b/lib/mtl/log_state.jl @@ -0,0 +1,40 @@ +export MTLLogLevel + +@cenum MTLLogLevel::NSInteger begin + MTLLogLevelUndefined = 0 + MTLLogLevelDebug = 1 + MTLLogLevelInfo = 2 + MTLLogLevelNotice = 3 + MTLLogLevelError = 4 + MTLLogLevelFault = 5 +end + +export MTLLogStateDescriptor + +@objcwrapper immutable=false MTLLogStateDescriptor <: NSObject + +@objcproperties MTLLogStateDescriptor begin + @autoproperty level::MTLLogLevel setter=setLevel + @autoproperty bufferSize::NSInteger setter=setBufferSize +end + +function MTLLogStateDescriptor() + handle = @objc [MTLLogStateDescriptor alloc]::id{MTLLogStateDescriptor} + obj = MTLLogStateDescriptor(handle) + finalizer(release, obj) + @objc [obj::id{MTLLogStateDescriptor} init]::id{MTLLogStateDescriptor} + return obj +end + + +export MTLLogState + +@objcwrapper MTLLogState <: NSObject + +function MTLLogState(dev::MTLDevice, descriptor::MTLLogStateDescriptor) + err = Ref{id{NSError}}(nil) + handle = @objc [dev::id{MTLDevice} newLogStateWithDescriptor:descriptor::id{MTLLogStateDescriptor} + error:err::Ptr{id{NSError}}]::id{MTLLogState} + err[] == nil || throw(NSError(err[])) + MTLLogState(handle) +end \ No newline at end of file diff --git a/src/Metal.jl b/src/Metal.jl index 08eba6039..2ad86a1cc 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -36,6 +36,7 @@ include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") include("device/intrinsics/version.jl") include("device/intrinsics/atomics.jl") +include("device/intrinsics/output.jl") include("device/quirks.jl") # array essentials diff --git a/src/device/intrinsics/output.jl b/src/device/intrinsics/output.jl new file mode 100644 index 000000000..aa92506cb --- /dev/null +++ b/src/device/intrinsics/output.jl @@ -0,0 +1,301 @@ +const MTLLOG_SUBSYSTEM = "com.juliagpu.metal.jl" +const MTLLOG_CATEGRORY = "mtlprintf" + +const __METAL_OS_LOG_TYPE_DEBUG__ = Int32(2) +const __METAL_OS_LOG_TYPE_INFO__ = Int32(1) +const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0) +const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16) +const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17) + +const ALLOW_DOUBLE_META = "allowdouble" + +export @mtlprintf + +@generated function promote_c_argument(arg) + # > When a function with a variable-length argument list is called, the variable + # > arguments are passed using C's old ``default argument promotions.'' These say that + # > types char and short int are automatically promoted to int, and type float is + # > automatically promoted to double. Therefore, varargs functions will never receive + # > arguments of type char, short int, or float. + + if arg == Cchar || arg == Cshort + return :(Cint(arg)) + else + return :(arg) + end +end + +@generated function tag_doubles(arg) + @dispose ctx=Context() begin + ret = arg == Cfloat ? Cdouble : arg + T_arg = convert(LLVMType, arg) + T_ret = convert(LLVMType, ret) + + f, ft = create_function(T_ret, [T_arg]) + + @dispose builder=IRBuilder() begin + entry = BasicBlock(f, "entry") + position!(builder, entry) + + p1 = parameters(f)[1] + + if arg == Cfloat + res = fpext!(builder, p1, LLVM.DoubleType()) + metadata(res)["ir_check_ignore"] = MDNode([]) + ret!(builder, res) + else + ret!(builder, p1) + end + end + + call_function(f, ret, Tuple{arg}, :arg) + end +end + + +""" + @mtlprintf("%Fmt", args...) + +Print a formatted string in device context on the host standard output. +""" +macro mtlprintf(fmt::String, args...) + fmt_val = Val(Symbol(fmt)) + + return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...))) +end + +@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt} + @dispose ctx=Context() begin + arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)] + arg_types = [argspec...] + + T_void = LLVM.VoidType() + T_int32 = LLVM.Int32Type() + T_int64 = LLVM.Int64Type() + T_pint8 = LLVM.PointerType(LLVM.Int8Type()) + T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2) + + # create functions + param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types] + llvm_f, llvm_ft = create_function(T_void, LLVMType[]; vararg=true) + mod = LLVM.parent(llvm_f) + + # generate IR + @dispose builder=IRBuilder() begin + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + + str = globalstring_ptr!(builder, String(fmt), addrspace=2) + + # compute argsize + argtypes = LLVM.StructType(param_types) + dl = datalayout(mod) + arg_size = LLVM.ConstantInt(T_int64, sizeof(dl, argtypes)) + + alloc = alloca!(builder, T_pint8) + buffer = bitcast!(builder, alloc, T_pint8) + alloc_size = LLVM.ConstantInt(T_int64, sizeof(dl, T_pint8)) + + lifetime_start_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8]) + lifetime_start = LLVM.Function(mod, "llvm.lifetime.start.p0i8", lifetime_start_fty) + call!(builder, lifetime_start_fty, lifetime_start, [alloc_size, buffer]) + + va_start_fty = LLVM.FunctionType(T_void, [T_pint8]) + va_start = LLVM.Function(mod, "llvm.va_start", va_start_fty) + call!(builder, va_start_fty, va_start, [buffer]) + + # invoke @air.os_log and return + subsystem_str = globalstring_ptr!(builder, MTLLOG_SUBSYSTEM, addrspace=2) + category_str = globalstring_ptr!(builder, MTLLOG_CATEGRORY, addrspace=2) + log_type = LLVM.ConstantInt(T_int32, __METAL_OS_LOG_TYPE_DEBUG__) + os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64]) + os_log = LLVM.Function(mod, "air.os_log", os_log_fty) + + arg_ptr = load!(builder, T_pint8, alloc) + + call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type, str, arg_ptr, arg_size]) + + va_end_fty = LLVM.FunctionType(T_void, [T_pint8]) + va_end = LLVM.Function(mod, "llvm.va_end", va_end_fty) + call!(builder, va_end_fty, va_end, [buffer]) + + lifetime_end_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8]) + lifetime_end = LLVM.Function(mod, "llvm.lifetime.end.p0i8", lifetime_end_fty) + call!(builder, lifetime_end_fty, lifetime_end, [alloc_size, buffer]) + + ret!(builder) + end + + call_function(llvm_f, Nothing, Tuple{arg_types...}, arg_exprs...) + end +end + + +## print-like functionality + +export @mtlprint, @mtlprintln + +# simple conversions, defining an expression and the resulting argument type. nothing fancy, +# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`. +const mtlprint_conversions = [ + Float32 => (x->:(Float64($x)), Float64), + Ptr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}), + LLVMPtr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}), + Bool => (x->:(Int32($x)), Int32), +] + +# format specifiers +const mtlprint_specifiers = Dict( + # integers + Int16 => "%hd", + Int32 => "%d", + Int64 => "%ld", + UInt16 => "%hu", + UInt32 => "%u", + UInt64 => "%lu", + + # floating-point + Float32 => "%f", + + # other + Cchar => "%c", + Ptr{Cvoid} => "%p", + Cstring => "%s", +) + +@inline @generated function _mtlprint(parts...) + fmt = "" + args = Expr[] + + for i in 1:length(parts) + part = :(parts[$i]) + T = parts[i] + + # put literals directly in the format string + if T <: Val + fmt *= string(T.parameters[1]) + continue + end + + # try to convert arguments if they are not supported directly + if !haskey(mtlprint_specifiers, T) + for (Tmatch, rule) in mtlprint_conversions + if T <: Tmatch + part = rule[1](part) + T = rule[2] + break + end + end + end + + # render the argument + if haskey(mtlprint_specifiers, T) + fmt *= mtlprint_specifiers[T] + push!(args, part) + elseif T <: Tuple + fmt *= "(" + for (j, U) in enumerate(T.parameters) + if haskey(mtlprint_specifiers, U) + fmt *= mtlprint_specifiers[U] + push!(args, :($part[$j])) + if j < length(T.parameters) + fmt *= ", " + elseif length(T.parameters) == 1 + fmt *= "," + end + else + @error("@mtlprint does not support values of type $U") + end + end + fmt *= ")" + elseif T <: String + @error("@mtlprint does not support non-literal strings") + elseif T <: Type + fmt *= string(T.parameters[1]) + else + @warn("@mtlprint does not support values of type $T") + fmt *= "$(T)(...)" + end + end + + quote + @mtlprintf($fmt, $(args...)) + end +end + +""" + @mtlprint(xs...) + @mtlprintln(xs...) + +Print a textual representation of values `xs` to standard output from the GPU. The +functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of +that API. However, that also means there's only limited support for argument types, handling +16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and +pointers. For more complex output, use `@mtlprintf` directly. + +Limited string interpolation is also possible: + +```julia + @mtlprint("Hello, World ", 42, "\\n") + @mtlprint "Hello, World \$(42)\\n" +``` +""" +macro mtlprint(parts...) + args = Union{Val,Expr,Symbol}[] + + parts = [parts...] + while true + isempty(parts) && break + + part = popfirst!(parts) + + # handle string interpolation + if isa(part, Expr) && part.head == :string + parts = vcat(part.args, parts) + continue + end + + # expose literals to the generator by using Val types + if isbits(part) # literal numbers, etc + push!(args, Val(part)) + elseif isa(part, QuoteNode) # literal symbols + push!(args, Val(part.value)) + elseif isa(part, String) # literal strings need to be interned + push!(args, Val(Symbol(part))) + else # actual values that will be passed to printf + push!(args, part) + end + end + + quote + _mtlprint($(map(esc, args)...)) + end +end + +@doc (@doc @mtlprint) -> +macro mtlprintln(parts...) + esc(quote + Metal.@mtlprint($(parts...), "\n") + end) +end + +export @mtlshow + +""" + @mtlshow(ex) + +GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref). + +```julia +@mtlshow thread_position_in_grid_1d() +``` +""" +macro mtlshow(exs...) + blk = Expr(:block) + for ex in exs + push!(blk.args, :(Metal.@mtlprintln($(sprint(Base.show_unquoted,ex)*" = "), + begin local value = $(esc(ex)) end))) + end + isempty(exs) || push!(blk.args, :value) + blk +end diff --git a/src/state.jl b/src/state.jl index 2208f739e..0707bcaaa 100644 --- a/src/state.jl +++ b/src/state.jl @@ -47,7 +47,27 @@ function global_queue(dev::MTLDevice) @autoreleasepool begin # NOTE: MTLCommandQueue itself is manually reference-counted, # the release pool is for resources used during its construction. - queue = MTLCommandQueue(dev) + + queue = if macos_version() >= v"15" + log_state_descriptor = MTLLogStateDescriptor() + log_state_descriptor.level = MTL.MTLLogLevelDebug + log_state = MTLLogState(dev, log_state_descriptor) + + function log_handler(subSystem, category, logLevel, message) + print(String(NSString(message))) + return nothing + end + + block = @objcblock(log_handler, Nothing, (id{NSString}, id{NSString}, NSInteger, id{NSString})) + + @objc [log_state::id{MTLLogState} addLogHandler:block::id{NSBlock}]::Nothing + + queue_descriptor = MTLCommandQueueDescriptor() + queue_descriptor.logState = log_state + MTLCommandQueue(dev, queue_descriptor) + else + MTLCommandQueue(dev) + end queue.label = "global_queue($(current_task()))" global_queues[queue] = nothing queue diff --git a/test/output.jl b/test/output.jl new file mode 100644 index 000000000..549dfaa01 --- /dev/null +++ b/test/output.jl @@ -0,0 +1,154 @@ +@testset "output" begin + + @testset "formatted output" begin + _, out = @grab_output @on_device @mtlprintf("") + @test out == "" + + _, out = @grab_output @on_device @mtlprintf("Testing...\n") + @test out == "Testing...\n" + + # narrow integer + _, out = @grab_output @on_device @mtlprintf("Testing %d %d...\n", Int32(1), Int32(2)) + @test out == "Testing 1 2...\n" + + # wide integer + _, out = @grab_output @on_device @mtlprintf("Testing %ld %ld...\n", Int64(1), Int64(2)) + @test out == "Testing 1 2...\n" + + _, out = @grab_output @on_device begin + @mtlprintf("foo") + @mtlprintf("bar\n") + end + @test out == "foobar\n" + + # c argument promotions + # function kernel(A) + # @mtlprintf("%f %f\n", A[1], A[1]) + # return + # end + # x = MtlArray(ones(2, 2)) + # _, out = @grab_output begin + # Metal.@sync @metal kernel(x) + # end + # @test out == "1.000000 1.000000\n" + end + + @testset "@mtlprint" begin + # basic @mtlprint/@mtlprintln + + _, out = @grab_output @on_device @mtlprint("Hello, World\n") + @test out == "Hello, World\n" + + _, out = @grab_output @on_device @mtlprintln("Hello, World") + @test out == "Hello, World\n" + + + # argument interpolation (by the macro, so can use literals) + + _, out = @grab_output @on_device @mtlprint("foobar") + @test out == "foobar" + + _, out = @grab_output @on_device @mtlprint(:foobar) + @test out == "foobar" + + _, out = @grab_output @on_device @mtlprint("foo", "bar") + @test out == "foobar" + + _, out = @grab_output @on_device @mtlprint("foobar ", 42) + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint("foobar $(42)") + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint("foobar $(4)", 2) + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint("foobar ", 4, "$(2)") + @test out == "foobar 42" + + _, out = @grab_output @on_device @mtlprint(42) + @test out == "42" + + _, out = @grab_output @on_device @mtlprint(4, 2) + @test out == "42" + + _, out = @grab_output @on_device @mtlprint(Any) + @test out == "Any" + + _, out = @grab_output @on_device @mtlprintln("foobar $(42)") + @test out == "foobar 42\n" + + + # argument types + + # we're testing the generated functions now, so can't use literals + function test_output(val, str) + canary = rand(Int32) # if we mess up the main arg, this one will print wrong + _, out = @grab_output @on_device @mtlprint(val, " (", canary, ")") + @test out == "$(str) ($(Int(canary)))" + end + + for typ in (Int16, Int32, Int64, UInt16, UInt32, UInt64) + test_output(typ(42), "42") + end + + # for typ in (Float32,) + # test_output(typ(42), "42.000000") + # end + + test_output(Cchar('c'), "c") + + for typ in (Ptr{Cvoid}, Ptr{Int}) + ptr = convert(typ, Int(0x12345)) + test_output(ptr, "0x12345") + end + + test_output(true, "1") + test_output(false, "0") + + test_output((1,), "(1,)") + test_output((1,2), "(1, 2)") + # test_output((1,2,3.), "(1, 2, 3.000000)") + + + # escaping + + kernel1(val) = (@mtlprint(val); nothing) + _, out = @grab_output @on_device kernel1(42) + @test out == "42" + + kernel2(val) = (@mtlprintln(val); nothing) + _, out = @grab_output @on_device kernel2(42) + @test out == "42\n" + end + + # @testset "@mtlshow" begin + # function kernel() + # seven_i32 = Int32(7) + # three_f32 = Float32(3) + # @mtlshow seven_i32 + # @mtlshow three_f32 + # @mtlshow 1f0 + 4f0 + # return + # end + + # _, out = @grab_output @on_device kernel() + # @test out == "seven_i32 = 7\nthree_f64 = 3.000000\n1.0f0 + 4.0f0 = 5.000000\n" + # end + + # @testset "@mtlshow array pointers" begin + # function kernel() + # a = mtlStaticSharedArray(Float32, 1) + # b = mtlStaticSharedArray(Float32, 2) + # @mtlshow pointer(a) pointer(b) + # return + # end + + # _, out = @grab_output @on_device kernel() + # @test ocmtlrsin("pointer(a) = ", out) + # @test ocmtlrsin("pointer(b) = ", out) + # @test ocmtlrsin("= 0", out) + # end + +end + \ No newline at end of file