Skip to content

Commit

Permalink
Merge os-log
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Sep 18, 2024
1 parent 5251b6c commit 088d5ee
Show file tree
Hide file tree
Showing 7 changed files with 543 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/mtl/MTL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include("events.jl")
include("fences.jl")
include("heap.jl")
include("buffer.jl")
include("log_state.jl")
include("command_queue.jl")
include("command_buf.jl")
include("compute_pipeline.jl")
Expand Down
25 changes: 25 additions & 0 deletions lib/mtl/command_queue.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
export MTLCommandQueueDescriptor

@objcwrapper immutable=false MTLCommandQueueDescriptor <: NSObject

@objcproperties MTLCommandQueueDescriptor begin
@autoproperty maxCommandBufferCount::NSUInteger
@autoproperty logState::id{MTLLogState} setter=setLogState
end

function MTLCommandQueueDescriptor()
handle = @objc [MTLCommandQueueDescriptor alloc]::id{MTLCommandQueueDescriptor}
obj = MTLCommandQueueDescriptor(handle)
finalizer(release, obj)
@objc [obj::id{MTLCommandQueueDescriptor} init]::id{MTLCommandQueueDescriptor}
return obj
end


export MTLCommandQueue

@objcwrapper immutable=false MTLCommandQueue <: NSObject
Expand All @@ -13,3 +31,10 @@ function MTLCommandQueue(dev::MTLDevice)
finalizer(release, obj)
return obj
end

function MTLCommandQueue(dev::MTLDevice, descriptor::MTLCommandQueueDescriptor)
handle = @objc [dev::id{MTLDevice} newCommandQueueWithDescriptor:descriptor::id{MTLCommandQueueDescriptor}]::id{MTLCommandQueue}
obj = MTLCommandQueue(handle)
finalizer(release, obj)
return obj
end
40 changes: 40 additions & 0 deletions lib/mtl/log_state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
export MTLLogLevel

@cenum MTLLogLevel::NSInteger begin
MTLLogLevelUndefined = 0
MTLLogLevelDebug = 1
MTLLogLevelInfo = 2
MTLLogLevelNotice = 3
MTLLogLevelError = 4
MTLLogLevelFault = 5
end

export MTLLogStateDescriptor

@objcwrapper immutable=false MTLLogStateDescriptor <: NSObject

@objcproperties MTLLogStateDescriptor begin
@autoproperty level::MTLLogLevel setter=setLevel
@autoproperty bufferSize::NSInteger setter=setBufferSize
end

function MTLLogStateDescriptor()
handle = @objc [MTLLogStateDescriptor alloc]::id{MTLLogStateDescriptor}
obj = MTLLogStateDescriptor(handle)
finalizer(release, obj)
@objc [obj::id{MTLLogStateDescriptor} init]::id{MTLLogStateDescriptor}
return obj
end


export MTLLogState

@objcwrapper MTLLogState <: NSObject

function MTLLogState(dev::MTLDevice, descriptor::MTLLogStateDescriptor)
err = Ref{id{NSError}}(nil)
handle = @objc [dev::id{MTLDevice} newLogStateWithDescriptor:descriptor::id{MTLLogStateDescriptor}
error:err::Ptr{id{NSError}}]::id{MTLLogState}
err[] == nil || throw(NSError(err[]))
MTLLogState(handle)
end
1 change: 1 addition & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ include("device/intrinsics/memory.jl")
include("device/intrinsics/simd.jl")
include("device/intrinsics/version.jl")
include("device/intrinsics/atomics.jl")
include("device/intrinsics/output.jl")
include("device/quirks.jl")

# array essentials
Expand Down
301 changes: 301 additions & 0 deletions src/device/intrinsics/output.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
const MTLLOG_SUBSYSTEM = "com.juliagpu.metal.jl"
const MTLLOG_CATEGRORY = "mtlprintf"

const __METAL_OS_LOG_TYPE_DEBUG__ = Int32(2)
const __METAL_OS_LOG_TYPE_INFO__ = Int32(1)
const __METAL_OS_LOG_TYPE_DEFAULT__ = Int32(0)
const __METAL_OS_LOG_TYPE_ERROR__ = Int32(16)
const __METAL_OS_LOG_TYPE_FAULT__ = Int32(17)

const ALLOW_DOUBLE_META = "allowdouble"

export @mtlprintf

@generated function promote_c_argument(arg)
# > When a function with a variable-length argument list is called, the variable
# > arguments are passed using C's old ``default argument promotions.'' These say that
# > types char and short int are automatically promoted to int, and type float is
# > automatically promoted to double. Therefore, varargs functions will never receive
# > arguments of type char, short int, or float.

if arg == Cchar || arg == Cshort
return :(Cint(arg))
else
return :(arg)
end
end

@generated function tag_doubles(arg)
@dispose ctx=Context() begin
ret = arg == Cfloat ? Cdouble : arg
T_arg = convert(LLVMType, arg)
T_ret = convert(LLVMType, ret)

f, ft = create_function(T_ret, [T_arg])

@dispose builder=IRBuilder() begin
entry = BasicBlock(f, "entry")
position!(builder, entry)

p1 = parameters(f)[1]

if arg == Cfloat
res = fpext!(builder, p1, LLVM.DoubleType())
metadata(res)["ir_check_ignore"] = MDNode([])
ret!(builder, res)
else
ret!(builder, p1)
end
end

call_function(f, ret, Tuple{arg}, :arg)
end
end


"""
@mtlprintf("%Fmt", args...)
Print a formatted string in device context on the host standard output.
"""
macro mtlprintf(fmt::String, args...)
fmt_val = Val(Symbol(fmt))

return :(_mtlprintf($fmt_val, $(map(arg -> :(tag_doubles(promote_c_argument($arg))), esc.(args))...)))
end

@generated function _mtlprintf(::Val{fmt}, argspec...) where {fmt}
@dispose ctx=Context() begin
arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
arg_types = [argspec...]

T_void = LLVM.VoidType()
T_int32 = LLVM.Int32Type()
T_int64 = LLVM.Int64Type()
T_pint8 = LLVM.PointerType(LLVM.Int8Type())
T_pint8a2 = LLVM.PointerType(LLVM.Int8Type(), 2)

# create functions
param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types]
llvm_f, llvm_ft = create_function(T_void, LLVMType[]; vararg=true)
mod = LLVM.parent(llvm_f)

# generate IR
@dispose builder=IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)

str = globalstring_ptr!(builder, String(fmt), addrspace=2)

# compute argsize
argtypes = LLVM.StructType(param_types)
dl = datalayout(mod)
arg_size = LLVM.ConstantInt(T_int64, sizeof(dl, argtypes))

alloc = alloca!(builder, T_pint8)
buffer = bitcast!(builder, alloc, T_pint8)
alloc_size = LLVM.ConstantInt(T_int64, sizeof(dl, T_pint8))

lifetime_start_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8])
lifetime_start = LLVM.Function(mod, "llvm.lifetime.start.p0i8", lifetime_start_fty)
call!(builder, lifetime_start_fty, lifetime_start, [alloc_size, buffer])

va_start_fty = LLVM.FunctionType(T_void, [T_pint8])
va_start = LLVM.Function(mod, "llvm.va_start", va_start_fty)
call!(builder, va_start_fty, va_start, [buffer])

# invoke @air.os_log and return
subsystem_str = globalstring_ptr!(builder, MTLLOG_SUBSYSTEM, addrspace=2)
category_str = globalstring_ptr!(builder, MTLLOG_CATEGRORY, addrspace=2)
log_type = LLVM.ConstantInt(T_int32, __METAL_OS_LOG_TYPE_DEBUG__)
os_log_fty = LLVM.FunctionType(T_void, [T_pint8a2, T_pint8a2, T_int32, T_pint8a2, T_pint8, T_int64])
os_log = LLVM.Function(mod, "air.os_log", os_log_fty)

arg_ptr = load!(builder, T_pint8, alloc)

call!(builder, os_log_fty, os_log, [subsystem_str, category_str, log_type, str, arg_ptr, arg_size])

va_end_fty = LLVM.FunctionType(T_void, [T_pint8])
va_end = LLVM.Function(mod, "llvm.va_end", va_end_fty)
call!(builder, va_end_fty, va_end, [buffer])

lifetime_end_fty = LLVM.FunctionType(T_void, [T_int64, T_pint8])
lifetime_end = LLVM.Function(mod, "llvm.lifetime.end.p0i8", lifetime_end_fty)
call!(builder, lifetime_end_fty, lifetime_end, [alloc_size, buffer])

ret!(builder)
end

call_function(llvm_f, Nothing, Tuple{arg_types...}, arg_exprs...)
end
end


## print-like functionality

export @mtlprint, @mtlprintln

# simple conversions, defining an expression and the resulting argument type. nothing fancy,
# `@mtlprint` pretty directly maps to `@mtlprintf`; we should just support `write(::IO)`.
const mtlprint_conversions = [
Float32 => (x->:(Float64($x)), Float64),
Ptr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}),
LLVMPtr{<:Any} => (x->:(reinterpret(Int, $x)), Ptr{Cvoid}),
Bool => (x->:(Int32($x)), Int32),
]

# format specifiers
const mtlprint_specifiers = Dict(
# integers
Int16 => "%hd",
Int32 => "%d",
Int64 => "%ld",
UInt16 => "%hu",
UInt32 => "%u",
UInt64 => "%lu",

# floating-point
Float32 => "%f",

# other
Cchar => "%c",
Ptr{Cvoid} => "%p",
Cstring => "%s",
)

@inline @generated function _mtlprint(parts...)
fmt = ""
args = Expr[]

for i in 1:length(parts)
part = :(parts[$i])
T = parts[i]

# put literals directly in the format string
if T <: Val
fmt *= string(T.parameters[1])
continue
end

# try to convert arguments if they are not supported directly
if !haskey(mtlprint_specifiers, T)
for (Tmatch, rule) in mtlprint_conversions
if T <: Tmatch
part = rule[1](part)
T = rule[2]
break
end
end
end

# render the argument
if haskey(mtlprint_specifiers, T)
fmt *= mtlprint_specifiers[T]
push!(args, part)
elseif T <: Tuple
fmt *= "("
for (j, U) in enumerate(T.parameters)
if haskey(mtlprint_specifiers, U)
fmt *= mtlprint_specifiers[U]
push!(args, :($part[$j]))
if j < length(T.parameters)
fmt *= ", "
elseif length(T.parameters) == 1
fmt *= ","
end
else
@error("@mtlprint does not support values of type $U")
end
end
fmt *= ")"
elseif T <: String
@error("@mtlprint does not support non-literal strings")
elseif T <: Type
fmt *= string(T.parameters[1])
else
@warn("@mtlprint does not support values of type $T")
fmt *= "$(T)(...)"
end
end

quote
@mtlprintf($fmt, $(args...))
end
end

"""
@mtlprint(xs...)
@mtlprintln(xs...)
Print a textual representation of values `xs` to standard output from the GPU. The
functionality builds on `@mtlprintf`, and is intended as a more use friendly alternative of
that API. However, that also means there's only limited support for argument types, handling
16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and
pointers. For more complex output, use `@mtlprintf` directly.
Limited string interpolation is also possible:
```julia
@mtlprint("Hello, World ", 42, "\\n")
@mtlprint "Hello, World \$(42)\\n"
```
"""
macro mtlprint(parts...)
args = Union{Val,Expr,Symbol}[]

parts = [parts...]
while true
isempty(parts) && break

part = popfirst!(parts)

# handle string interpolation
if isa(part, Expr) && part.head == :string
parts = vcat(part.args, parts)
continue
end

# expose literals to the generator by using Val types
if isbits(part) # literal numbers, etc
push!(args, Val(part))
elseif isa(part, QuoteNode) # literal symbols
push!(args, Val(part.value))
elseif isa(part, String) # literal strings need to be interned
push!(args, Val(Symbol(part)))
else # actual values that will be passed to printf
push!(args, part)
end
end

quote
_mtlprint($(map(esc, args)...))
end
end

@doc (@doc @mtlprint) ->
macro mtlprintln(parts...)
esc(quote
Metal.@mtlprint($(parts...), "\n")
end)
end

export @mtlshow

"""
@mtlshow(ex)
GPU analog of `Base.@show`. It comes with the same type restrictions as [`@mtlprintf`](@ref).
```julia
@mtlshow thread_position_in_grid_1d()
```
"""
macro mtlshow(exs...)
blk = Expr(:block)
for ex in exs
push!(blk.args, :(Metal.@mtlprintln($(sprint(Base.show_unquoted,ex)*" = "),
begin local value = $(esc(ex)) end)))
end
isempty(exs) || push!(blk.args, :value)
blk
end
Loading

0 comments on commit 088d5ee

Please sign in to comment.