Skip to content

Commit

Permalink
Merge pull request #16207 from Luukdegram/wasi-threads
Browse files Browse the repository at this point in the history
WASI: Implement experimental threading support
  • Loading branch information
Luukdegram authored Jun 27, 2023
2 parents ff37ccd + 87b8a05 commit 622c5f3
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 8 deletions.
290 changes: 290 additions & 0 deletions lib/std/Thread.zig
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ else if (use_pthreads)
PosixThreadImpl
else if (target.os.tag == .linux)
LinuxThreadImpl
else if (target.os.tag == .wasi)
WasiThreadImpl
else
UnsupportedImpl;

Expand Down Expand Up @@ -266,6 +268,7 @@ pub const Id = switch (target.os.tag) {
.freebsd,
.openbsd,
.haiku,
.wasi,
=> u32,
.macos, .ios, .watchos, .tvos => u64,
.windows => os.windows.DWORD,
Expand Down Expand Up @@ -296,6 +299,8 @@ pub const SpawnConfig = struct {

/// Size in bytes of the Thread's stack
stack_size: usize = 16 * 1024 * 1024,
/// The allocator to be used to allocate memory for the to-be-spawned thread
allocator: ?std.mem.Allocator = null,
};

pub const SpawnError = error{
Expand Down Expand Up @@ -733,6 +738,291 @@ const PosixThreadImpl = struct {
}
};

const WasiThreadImpl = struct {
thread: *WasiThread,

pub const ThreadHandle = i32;
threadlocal var tls_thread_id: Id = 0;

const WasiThread = struct {
/// Thread ID
tid: Atomic(i32) = Atomic(i32).init(0),
/// Contains all memory which was allocated to bootstrap this thread, including:
/// - Guard page
/// - Stack
/// - TLS segment
/// - `Instance`
/// All memory is freed upon call to `join`
memory: []u8,
/// The allocator used to allocate the thread's memory,
/// which is also used during `join` to ensure clean-up.
allocator: std.mem.Allocator,
/// The current state of the thread.
state: State = State.init(.running),
};

/// A meta-data structure used to bootstrap a thread
const Instance = struct {
thread: WasiThread,
/// Contains the offset to the new __tls_base.
/// The offset starting from the memory's base.
tls_offset: usize,
/// Contains the offset to the stack for the newly spawned thread.
/// The offset is calculated starting from the memory's base.
stack_offset: usize,
/// Contains the raw pointer value to the wrapper which holds all arguments
/// for the callback.
raw_ptr: usize,
/// Function pointer to a wrapping function which will call the user's
/// function upon thread spawn. The above mentioned pointer will be passed
/// to this function pointer as its argument.
call_back: *const fn (usize) void,
/// When a thread is in `detached` state, we must free all of its memory
/// upon thread completion. However, as this is done while still within
/// the thread, we must first jump back to the main thread's stack or else
/// we end up freeing the stack that we're currently using.
original_stack_pointer: [*]u8,
};

const State = Atomic(enum(u8) { running, completed, detached });

fn getCurrentId() Id {
return tls_thread_id;
}

fn getHandle(self: Impl) ThreadHandle {
return self.thread.tid.load(.SeqCst);
}

fn detach(self: Impl) void {
switch (self.thread.state.swap(.detached, .SeqCst)) {
.running => {},
.completed => self.join(),
.detached => unreachable,
}
}

fn join(self: Impl) void {
defer {
// Create a copy of the allocator so we do not free the reference to the
// original allocator while freeing the memory.
var allocator = self.thread.allocator;
allocator.free(self.thread.memory);
}

var spin: u8 = 10;
while (true) {
const tid = self.thread.tid.load(.SeqCst);
if (tid == 0) {
break;
}

if (spin > 0) {
spin -= 1;
std.atomic.spinLoopHint();
continue;
}

const result = asm (
\\ local.get %[ptr]
\\ local.get %[expected]
\\ i64.const -1 # infinite
\\ memory.atomic.wait32 0
\\ local.set %[ret]
: [ret] "=r" (-> u32),
: [ptr] "r" (&self.thread.tid.value),
[expected] "r" (tid),
);
switch (result) {
0 => continue, // ok
1 => continue, // expected =! loaded
2 => unreachable, // timeout (infinite)
else => unreachable,
}
}
}

fn spawn(config: std.Thread.SpawnConfig, comptime f: anytype, args: anytype) !WasiThreadImpl {
if (config.allocator == null) return error.OutOfMemory; // an allocator is required to spawn a WASI-thread

// Wrapping struct required to hold the user-provided function arguments.
const Wrapper = struct {
args: @TypeOf(args),
fn entry(ptr: usize) void {
const w: *@This() = @ptrFromInt(ptr);
@call(.auto, f, w.args);
}
};

var stack_offset: usize = undefined;
var tls_offset: usize = undefined;
var wrapper_offset: usize = undefined;
var instance_offset: usize = undefined;

// Calculate the bytes we have to allocate to store all thread information, including:
// - The actual stack for the thread
// - The TLS segment
// - `Instance` - containing information about how to call the user's function.
const map_bytes = blk: {
// start with atleast a single page, which is used as a guard to prevent
// other threads clobbering our new thread.
// Unfortunately, WebAssembly has no notion of read-only segments, so this
// is only a best effort.
var bytes: usize = std.wasm.page_size;

bytes = std.mem.alignForward(usize, bytes, 16); // align stack to 16 bytes
stack_offset = bytes;
bytes += @max(std.wasm.page_size, config.stack_size);

bytes = std.mem.alignForward(usize, bytes, __tls_align());
tls_offset = bytes;
bytes += __tls_size();

bytes = std.mem.alignForward(usize, bytes, @alignOf(Wrapper));
wrapper_offset = bytes;
bytes += @sizeOf(Wrapper);

bytes = std.mem.alignForward(usize, bytes, @alignOf(Instance));
instance_offset = bytes;
bytes += @sizeOf(Instance);

bytes = std.mem.alignForward(usize, bytes, std.wasm.page_size);
break :blk bytes;
};

// Allocate the amount of memory required for all meta data.
const allocated_memory = try config.allocator.?.alloc(u8, map_bytes);

const wrapper: *Wrapper = @ptrCast(@alignCast(&allocated_memory[wrapper_offset]));
wrapper.* = .{ .args = args };

const instance: *Instance = @ptrCast(@alignCast(&allocated_memory[instance_offset]));
instance.* = .{
.thread = .{ .memory = allocated_memory, .allocator = config.allocator.? },
.tls_offset = tls_offset,
.stack_offset = stack_offset,
.raw_ptr = @intFromPtr(wrapper),
.call_back = &Wrapper.entry,
.original_stack_pointer = __get_stack_pointer(),
};

const tid = spawnWasiThread(instance);
// The specification says any value lower than 0 indicates an error.
// The values of such error are unspecified. WASI-Libc treats it as EAGAIN.
if (tid < 0) {
return error.SystemResources;
}
instance.thread.tid.store(tid, .SeqCst);

return .{ .thread = &instance.thread };
}

/// Bootstrap procedure, called by the host environment after thread creation.
export fn wasi_thread_start(tid: i32, arg: *Instance) void {
if (builtin.single_threaded) {
// ensure function is not analyzed in single-threaded mode
return;
}
__set_stack_pointer(arg.thread.memory.ptr + arg.stack_offset);
__wasm_init_tls(arg.thread.memory.ptr + arg.tls_offset);
@atomicStore(u32, &WasiThreadImpl.tls_thread_id, @intCast(tid), .SeqCst);

// Finished bootstrapping, call user's procedure.
arg.call_back(arg.raw_ptr);

switch (arg.thread.state.swap(.completed, .SeqCst)) {
.running => {
// reset the Thread ID
asm volatile (
\\ local.get %[ptr]
\\ i32.const 0
\\ i32.atomic.store 0
:
: [ptr] "r" (&arg.thread.tid.value),
);

// Wake the main thread listening to this thread
asm volatile (
\\ local.get %[ptr]
\\ i32.const 1 # waiters
\\ memory.atomic.notify 0
\\ drop # no need to know the waiters
:
: [ptr] "r" (&arg.thread.tid.value),
);
},
.completed => unreachable,
.detached => {
// restore the original stack pointer so we can free the memory
// without having to worry about freeing the stack
__set_stack_pointer(arg.original_stack_pointer);
// Ensure a copy so we don't free the allocator reference itself
var allocator = arg.thread.allocator;
allocator.free(arg.thread.memory);
},
}
}

/// Asks the host to create a new thread for us.
/// Newly created thread will call `wasi_tread_start` with the thread ID as well
/// as the input `arg` that was provided to `spawnWasiThread`
const spawnWasiThread = @"thread-spawn";
extern "wasi" fn @"thread-spawn"(arg: *Instance) i32;

/// Initializes the TLS data segment starting at `memory`.
/// This is a synthetic function, generated by the linker.
extern fn __wasm_init_tls(memory: [*]u8) void;

/// Returns a pointer to the base of the TLS data segment for the current thread
inline fn __tls_base() [*]u8 {
return asm (
\\ .globaltype __tls_base, i32
\\ global.get __tls_base
\\ local.set %[ret]
: [ret] "=r" (-> [*]u8),
);
}

/// Returns the size of the TLS segment
inline fn __tls_size() u32 {
return asm volatile (
\\ .globaltype __tls_size, i32, immutable
\\ global.get __tls_size
\\ local.set %[ret]
: [ret] "=r" (-> u32),
);
}

/// Returns the alignment of the TLS segment
inline fn __tls_align() u32 {
return asm (
\\ .globaltype __tls_align, i32, immutable
\\ global.get __tls_align
\\ local.set %[ret]
: [ret] "=r" (-> u32),
);
}

/// Allows for setting the stack pointer in the WebAssembly module.
inline fn __set_stack_pointer(addr: [*]u8) void {
asm volatile (
\\ local.get %[ptr]
\\ global.set __stack_pointer
:
: [ptr] "r" (addr),
);
}

/// Returns the current value of the stack pointer
inline fn __get_stack_pointer() [*]u8 {
return asm (
\\ global.get __stack_pointer
\\ local.set %[stack_ptr]
: [stack_ptr] "=r" (-> [*]u8),
);
}
};

const LinuxThreadImpl = struct {
const linux = os.linux;

Expand Down
45 changes: 45 additions & 0 deletions lib/std/Thread/Futex.zig
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ else if (builtin.os.tag == .openbsd)
OpenbsdImpl
else if (builtin.os.tag == .dragonfly)
DragonflyImpl
else if (builtin.target.isWasm())
WasmImpl
else if (std.Thread.use_pthreads)
PosixImpl
else
Expand Down Expand Up @@ -446,6 +448,49 @@ const DragonflyImpl = struct {
}
};

const WasmImpl = struct {
fn wait(ptr: *const Atomic(u32), expect: u32, timeout: ?u64) error{Timeout}!void {
if (!comptime std.Target.wasm.featureSetHas(builtin.target.cpu.features, .atomics)) {
@compileError("WASI target missing cpu feature 'atomics'");
}
const to: i64 = if (timeout) |to| @intCast(to) else -1;
const result = asm (
\\local.get %[ptr]
\\local.get %[expected]
\\local.get %[timeout]
\\memory.atomic.wait32 0
\\local.set %[ret]
: [ret] "=r" (-> u32),
: [ptr] "r" (&ptr.value),
[expected] "r" (@as(i32, @bitCast(expect))),
[timeout] "r" (to),
);
switch (result) {
0 => {}, // ok
1 => {}, // expected =! loaded
2 => return error.Timeout,
else => unreachable,
}
}

fn wake(ptr: *const Atomic(u32), max_waiters: u32) void {
if (!comptime std.Target.wasm.featureSetHas(builtin.target.cpu.features, .atomics)) {
@compileError("WASI target missing cpu feature 'atomics'");
}
assert(max_waiters != 0);
const woken_count = asm (
\\local.get %[ptr]
\\local.get %[waiters]
\\memory.atomic.notify 0
\\local.set %[ret]
: [ret] "=r" (-> u32),
: [ptr] "r" (&ptr.value),
[waiters] "r" (max_waiters),
);
_ = woken_count; // can be 0 when linker flag 'shared-memory' is not enabled
}
};

/// Modified version of linux's futex and Go's sema to implement userspace wait queues with pthread:
/// https://code.woboq.org/linux/linux/kernel/futex.c.html
/// https://go.dev/src/runtime/sema.go
Expand Down
Loading

0 comments on commit 622c5f3

Please sign in to comment.