From 2c4ac44f25743f5b7ae9db6bc570ab71f15fd83b Mon Sep 17 00:00:00 2001 From: mlugg Date: Tue, 5 Mar 2024 07:22:47 +0000 Subject: [PATCH] compiler: treat decl_val/decl_ref of potentially generic decls as captures This fixes an issue with the implementation of #18816. Consider the following code: ```zig pub fn Wrap(comptime T: type) type { return struct { pub const T1 = T; inner: struct { x: T1 }, }; } ``` Previously, the type of `inner` was not considered to be "capturing" any value, as `T1` is a decl. However, since it is declared within a generic function, this decl reference depends on the context, and thus should be treated as a capture. AstGen has been augmented to tunnel references to decls through closure when the decl was declared in a potentially-generic context (i.e. within a function). --- lib/std/zig/AstGen.zig | 116 ++++++++++++++++++++++++++++++++--------- lib/std/zig/Zir.zig | 48 ++++++++++++----- src/Autodoc.zig | 22 ++++++-- src/InternPool.zig | 17 ++++-- src/Sema.zig | 32 ++++++++---- src/print_zir.zig | 24 +++++---- 6 files changed, 194 insertions(+), 65 deletions(-) diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index dfb93e0590ba..8535b16806b1 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -44,6 +44,9 @@ compile_errors: ArrayListUnmanaged(Zir.Inst.CompileErrors.Item) = .{}, /// The topmost block of the current function. fn_block: ?*GenZir = null, fn_var_args: bool = false, +/// Whether we are somewhere within a function. If `true`, any container decls may be +/// generic and thus must be tunneled through closure. +within_fn: bool = false, /// The return type of the current function. This may be a trivial `Ref`, or /// otherwise it refers to a `ret_type` instruction. fn_ret_ty: Zir.Inst.Ref = .none, @@ -4050,6 +4053,11 @@ fn fnDecl( }; defer fn_gz.unstack(); + // Set this now, since parameter types, return type, etc may be generic. + const prev_within_fn = astgen.within_fn; + defer astgen.within_fn = prev_within_fn; + astgen.within_fn = true; + const is_pub = fn_proto.visib_token != null; const is_export = blk: { const maybe_export_token = fn_proto.extern_export_inline_token orelse break :blk false; @@ -4311,6 +4319,10 @@ fn fnDecl( const prev_fn_block = astgen.fn_block; const prev_fn_ret_ty = astgen.fn_ret_ty; + defer { + astgen.fn_block = prev_fn_block; + astgen.fn_ret_ty = prev_fn_ret_ty; + } astgen.fn_block = &fn_gz; astgen.fn_ret_ty = if (is_inferred_error or ret_ref.toIndex() != null) r: { // We're essentially guaranteed to need the return type at some point, @@ -4319,10 +4331,6 @@ fn fnDecl( // return type now so the rest of the function can use it. break :r try fn_gz.addNode(.ret_type, decl_node); } else ret_ref; - defer { - astgen.fn_block = prev_fn_block; - astgen.fn_ret_ty = prev_fn_ret_ty; - } const prev_var_args = astgen.fn_var_args; astgen.fn_var_args = is_var_args; @@ -4768,11 +4776,14 @@ fn testDecl( }; defer fn_block.unstack(); + const prev_within_fn = astgen.within_fn; const prev_fn_block = astgen.fn_block; const prev_fn_ret_ty = astgen.fn_ret_ty; + astgen.within_fn = true; astgen.fn_block = &fn_block; astgen.fn_ret_ty = .anyerror_void_error_union_type; defer { + astgen.within_fn = prev_within_fn; astgen.fn_block = prev_fn_block; astgen.fn_ret_ty = prev_fn_ret_ty; } @@ -4871,6 +4882,7 @@ fn structDeclInner( .node = node, .inst = decl_inst, .declaring_gz = gz, + .maybe_generic = astgen.within_fn, }; defer namespace.deinit(gpa); @@ -5195,6 +5207,7 @@ fn unionDeclInner( .node = node, .inst = decl_inst, .declaring_gz = gz, + .maybe_generic = astgen.within_fn, }; defer namespace.deinit(gpa); @@ -5543,6 +5556,7 @@ fn containerDecl( .node = node, .inst = decl_inst, .declaring_gz = gz, + .maybe_generic = astgen.within_fn, }; defer namespace.deinit(gpa); @@ -5709,6 +5723,7 @@ fn containerDecl( .node = node, .inst = decl_inst, .declaring_gz = gz, + .maybe_generic = astgen.within_fn, }; defer namespace.deinit(gpa); @@ -8247,9 +8262,14 @@ fn localVarRef( const name_str_index = try astgen.identAsString(ident_token); var s = scope; var found_already: ?Ast.Node.Index = null; // we have found a decl with the same name already + var found_needs_tunnel: bool = undefined; // defined when `found_already != null` + var found_namespaces_out: u32 = undefined; // defined when `found_already != null` + + // The number of namespaces above `gz` we currently are var num_namespaces_out: u32 = 0; - // defined when `num_namespaces_out != 0` + // defined by `num_namespaces_out != 0` var capturing_namespace: *Scope.Namespace = undefined; + while (true) switch (s.tag) { .local_val => { const local_val = s.cast(Scope.LocalVal).?; @@ -8267,9 +8287,8 @@ fn localVarRef( gz, ident, num_namespaces_out, - capturing_namespace, - local_val.inst, - local_val.token_src, + .{ .ref = local_val.inst }, + .{ .token = local_val.token_src }, ) else local_val.inst; return rvalueNoCoercePreRef(gz, ri, value_inst, ident); @@ -8298,9 +8317,8 @@ fn localVarRef( gz, ident, num_namespaces_out, - capturing_namespace, - local_ptr.ptr, - local_ptr.token_src, + .{ .ref = local_ptr.ptr }, + .{ .token = local_ptr.token_src }, ) else local_ptr.ptr; switch (ri.rl) { @@ -8329,6 +8347,8 @@ fn localVarRef( } // We found a match but must continue looking for ambiguous references to decls. found_already = i; + found_needs_tunnel = ns.maybe_generic; + found_namespaces_out = num_namespaces_out; } num_namespaces_out += 1; capturing_namespace = ns; @@ -8343,6 +8363,29 @@ fn localVarRef( // Decl references happen by name rather than ZIR index so that when unrelated // decls are modified, ZIR code containing references to them can be unmodified. + + if (found_namespaces_out > 0 and found_needs_tunnel) { + switch (ri.rl) { + .ref, .ref_coerced_ty => return tunnelThroughClosure( + gz, + ident, + found_namespaces_out, + .{ .decl_ref = name_str_index }, + .{ .node = found_already.? }, + ), + else => { + const result = try tunnelThroughClosure( + gz, + ident, + found_namespaces_out, + .{ .decl_val = name_str_index }, + .{ .node = found_already.? }, + ); + return rvalueNoCoercePreRef(gz, ri, result, ident); + }, + } + } + switch (ri.rl) { .ref, .ref_coerced_ty => return gz.addStrTok(.decl_ref, name_str_index, ident_token), else => { @@ -8361,17 +8404,22 @@ fn tunnelThroughClosure( inner_ref_node: Ast.Node.Index, /// The number of namespaces being tunnelled through. At least 1. num_tunnels: u32, - /// The namespace being captured from. - ns: *Scope.Namespace, /// The value being captured. - value: Zir.Inst.Ref, - /// The token of the value's declaration. - token: Ast.TokenIndex, + value: union(enum) { + ref: Zir.Inst.Ref, + decl_val: Zir.NullTerminatedString, + decl_ref: Zir.NullTerminatedString, + }, + /// The location of the value's declaration. + decl_src: union(enum) { + token: Ast.TokenIndex, + node: Ast.Node.Index, + }, ) !Zir.Inst.Ref { - const value_inst = value.toIndex() orelse { - // For trivial values, we don't need a tunnel; just return the ref. - return value; - }; + switch (value) { + .ref => |v| if (v.toIndex() == null) return v, // trivia value; do not need tunnel + .decl_val, .decl_ref => {}, + } const astgen = gz.astgen; const gpa = astgen.gpa; @@ -8382,7 +8430,7 @@ fn tunnelThroughClosure( var sfba = std.heap.stackFallback(@sizeOf(usize) * 2, astgen.arena); var intermediate_tunnels = try sfba.get().alloc(*Scope.Namespace, num_tunnels - 1); - { + const root_ns = ns: { var i: usize = num_tunnels - 1; var scope: *Scope = gz.parent; while (i > 0) { @@ -8392,15 +8440,27 @@ fn tunnelThroughClosure( } scope = scope.parent().?; } - } + while (true) { + if (scope.cast(Scope.Namespace)) |ns| break :ns ns; + scope = scope.parent().?; + } + }; // Now that we know the scopes we're tunneling through, begin adding // captures as required, starting with the outermost namespace. + const root_capture = Zir.Inst.Capture.wrap(switch (value) { + .ref => |v| .{ .instruction = v.toIndex().? }, + .decl_val => |str| .{ .decl_val = str }, + .decl_ref => |str| .{ .decl_ref = str }, + }); var cur_capture_index = std.math.cast( u16, - (try ns.captures.getOrPut(gpa, Zir.Inst.Capture.wrap(.{ .inst = value_inst }))).index, - ) orelse return astgen.failNodeNotes(ns.node, "this compiler implementation only supports up to 65536 captures per namespace", .{}, &.{ - try astgen.errNoteTok(token, "captured value here", .{}), + (try root_ns.captures.getOrPut(gpa, root_capture)).index, + ) orelse return astgen.failNodeNotes(root_ns.node, "this compiler implementation only supports up to 65536 captures per namespace", .{}, &.{ + switch (decl_src) { + .token => |t| try astgen.errNoteTok(t, "captured value here", .{}), + .node => |n| try astgen.errNoteNode(n, "captured value here", .{}), + }, try astgen.errNoteNode(inner_ref_node, "value used here", .{}), }); @@ -8409,7 +8469,10 @@ fn tunnelThroughClosure( u16, (try tunnel_ns.captures.getOrPut(gpa, Zir.Inst.Capture.wrap(.{ .nested = cur_capture_index }))).index, ) orelse return astgen.failNodeNotes(tunnel_ns.node, "this compiler implementation only supports up to 65536 captures per namespace", .{}, &.{ - try astgen.errNoteTok(token, "captured value here", .{}), + switch (decl_src) { + .token => |t| try astgen.errNoteTok(t, "captured value here", .{}), + .node => |n| try astgen.errNoteNode(n, "captured value here", .{}), + }, try astgen.errNoteNode(inner_ref_node, "value used here", .{}), }); } @@ -11752,6 +11815,7 @@ const Scope = struct { decls: std.AutoHashMapUnmanaged(Zir.NullTerminatedString, Ast.Node.Index) = .{}, node: Ast.Node.Index, inst: Zir.Inst.Index, + maybe_generic: bool, /// The astgen scope containing this namespace. /// Only valid during astgen. diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig index d46f22fec9bf..f196155a1f11 100644 --- a/lib/std/zig/Zir.zig +++ b/lib/std/zig/Zir.zig @@ -3057,26 +3057,50 @@ pub const Inst = struct { }; /// Represents a single value being captured in a type declaration's closure. - /// If high bit is 0, this represents a `Zir.Inst,Index`. - /// If high bit is 1, this represents an index into the last closure. - pub const Capture = enum(u32) { - _, + pub const Capture = packed struct(u32) { + tag: enum(u2) { + /// `data` is a `u16` index into the parent closure. + nested, + /// `data` is a `Zir.Inst.Index` to an instruction whose value is being captured. + instruction, + /// `data` is a `NullTerminatedString` to a decl name. + decl_val, + /// `data` is a `NullTerminatedString` to a decl name. + decl_ref, + }, + data: u30, pub const Unwrapped = union(enum) { - inst: Zir.Inst.Index, nested: u16, + instruction: Zir.Inst.Index, + decl_val: NullTerminatedString, + decl_ref: NullTerminatedString, }; pub fn wrap(cap: Unwrapped) Capture { return switch (cap) { - .inst => |inst| @enumFromInt(@intFromEnum(inst)), - .nested => |idx| @enumFromInt((1 << 31) | @as(u32, idx)), + .nested => |idx| .{ + .tag = .nested, + .data = idx, + }, + .instruction => |inst| .{ + .tag = .instruction, + .data = @intCast(@intFromEnum(inst)), + }, + .decl_val => |str| .{ + .tag = .decl_val, + .data = @intCast(@intFromEnum(str)), + }, + .decl_ref => |str| .{ + .tag = .decl_ref, + .data = @intCast(@intFromEnum(str)), + }, }; } pub fn unwrap(cap: Capture) Unwrapped { - const raw = @intFromEnum(cap); - const tag: u1 = @intCast(raw >> 31); - return switch (tag) { - 0 => .{ .inst = @enumFromInt(raw) }, - 1 => .{ .nested = @truncate(raw) }, + return switch (cap.tag) { + .nested => .{ .nested = @intCast(cap.data) }, + .instruction => .{ .instruction = @enumFromInt(cap.data) }, + .decl_val => .{ .decl_val = @enumFromInt(cap.data) }, + .decl_ref => .{ .decl_ref = @enumFromInt(cap.data) }, }; } }; diff --git a/src/Autodoc.zig b/src/Autodoc.zig index 57bdf1a9797e..e93884eb2c56 100644 --- a/src/Autodoc.zig +++ b/src/Autodoc.zig @@ -459,11 +459,21 @@ const Scope = struct { NotRequested: u32, // instr_index }; - fn getCapture(scope: Scope, idx: u16) struct { Zir.Inst.Index, *Scope } { + fn getCapture(scope: Scope, idx: u16) struct { + union(enum) { inst: Zir.Inst.Index, decl: Zir.NullTerminatedString }, + *Scope, + } { const parent = scope.parent.?; return switch (scope.captures[idx].unwrap()) { - .inst => |inst| .{ inst, parent }, .nested => |parent_idx| parent.getCapture(parent_idx), + .instruction => |inst| .{ + .{ .inst = inst }, + parent, + }, + .decl_val, .decl_ref => |str| .{ + .{ .decl = str }, + parent, + }, }; } @@ -4048,7 +4058,13 @@ fn walkInstruction( }, .closure_get => { const captured, const scope = parent_scope.getCapture(extended.small); - return self.walkInstruction(file, scope, parent_src, captured, need_type, call_ctx); + switch (captured) { + .inst => |cap_inst| return self.walkInstruction(file, scope, parent_src, cap_inst, need_type, call_ctx), + .decl => |str| { + const decl_status = parent_scope.resolveDeclName(str, file, inst.toOptional()); + return .{ .expr = .{ .declRef = decl_status } }; + }, + } }, } }, diff --git a/src/InternPool.zig b/src/InternPool.zig index 36311100da73..6639603cb5a2 100644 --- a/src/InternPool.zig +++ b/src/InternPool.zig @@ -503,22 +503,29 @@ pub const OptionalNullTerminatedString = enum(u32) { }; /// A single value captured in the closure of a namespace type. This is not a plain -/// `Index` because we must differentiate between runtime-known values (where we -/// store the type) and comptime-known values (where we store the value). +/// `Index` because we must differentiate between the following cases: +/// * runtime-known value (where we store the type) +/// * comptime-known value (where we store the value) +/// * decl val (so that we can analyze the value lazily) +/// * decl ref (so that we can analyze the reference lazily) pub const CaptureValue = packed struct(u32) { - tag: enum { @"comptime", runtime }, - idx: u31, + tag: enum { @"comptime", runtime, decl_val, decl_ref }, + idx: u30, pub fn wrap(val: Unwrapped) CaptureValue { return switch (val) { .@"comptime" => |i| .{ .tag = .@"comptime", .idx = @intCast(@intFromEnum(i)) }, .runtime => |i| .{ .tag = .runtime, .idx = @intCast(@intFromEnum(i)) }, + .decl_val => |i| .{ .tag = .decl_val, .idx = @intCast(@intFromEnum(i)) }, + .decl_ref => |i| .{ .tag = .decl_ref, .idx = @intCast(@intFromEnum(i)) }, }; } pub fn unwrap(val: CaptureValue) Unwrapped { return switch (val.tag) { .@"comptime" => .{ .@"comptime" = @enumFromInt(val.idx) }, .runtime => .{ .runtime = @enumFromInt(val.idx) }, + .decl_val => .{ .decl_val = @enumFromInt(val.idx) }, + .decl_ref => .{ .decl_ref = @enumFromInt(val.idx) }, }; } @@ -527,6 +534,8 @@ pub const CaptureValue = packed struct(u32) { @"comptime": Index, /// Index refers to the type. runtime: Index, + decl_val: DeclIndex, + decl_ref: DeclIndex, }; pub const Slice = struct { diff --git a/src/Sema.zig b/src/Sema.zig index d4a95027c39a..dba0983739cc 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2671,26 +2671,34 @@ fn analyzeAsInt( /// Given a ZIR extra index which points to a list of `Zir.Inst.Capture`, /// resolves this into a list of `InternPool.CaptureValue` allocated by `arena`. -fn getCaptures(sema: *Sema, parent_namespace: ?InternPool.NamespaceIndex, extra_index: usize, captures_len: u32) ![]InternPool.CaptureValue { +fn getCaptures(sema: *Sema, block: *Block, extra_index: usize, captures_len: u32) ![]InternPool.CaptureValue { const zcu = sema.mod; const ip = &zcu.intern_pool; - const parent_captures: InternPool.CaptureValue.Slice = if (parent_namespace) |p| parent: { - break :parent zcu.namespacePtr(p).ty.getCaptures(zcu); - } else undefined; // never used so `undefined` is safe + const parent_captures: InternPool.CaptureValue.Slice = zcu.namespacePtr(block.namespace).ty.getCaptures(zcu); const captures = try sema.arena.alloc(InternPool.CaptureValue, captures_len); for (sema.code.extra[extra_index..][0..captures_len], captures) |raw, *capture| { - const zir_capture: Zir.Inst.Capture = @enumFromInt(raw); + const zir_capture: Zir.Inst.Capture = @bitCast(raw); capture.* = switch (zir_capture.unwrap()) { - .inst => |inst| InternPool.CaptureValue.wrap(capture: { + .nested => |parent_idx| parent_captures.get(ip)[parent_idx], + .instruction => |inst| InternPool.CaptureValue.wrap(capture: { const air_ref = try sema.resolveInst(inst.toRef()); if (try sema.resolveValueResolveLazy(air_ref)) |val| { break :capture .{ .@"comptime" = val.toIntern() }; } break :capture .{ .runtime = sema.typeOf(air_ref).toIntern() }; }), - .nested => |parent_idx| parent_captures.get(ip)[parent_idx], + .decl_val => |str| capture: { + const decl_name = try ip.getOrPutString(sema.gpa, sema.code.nullTerminatedString(str)); + const decl = try sema.lookupIdentifier(block, .unneeded, decl_name); // TODO: could we need this src loc? + break :capture InternPool.CaptureValue.wrap(.{ .decl_val = decl }); + }, + .decl_ref => |str| capture: { + const decl_name = try ip.getOrPutString(sema.gpa, sema.code.nullTerminatedString(str)); + const decl = try sema.lookupIdentifier(block, .unneeded, decl_name); // TODO: could we need this src loc? + break :capture InternPool.CaptureValue.wrap(.{ .decl_ref = decl }); + }, }; } @@ -2727,7 +2735,7 @@ fn zirStructDecl( break :blk decls_len; } else 0; - const captures = try sema.getCaptures(block.namespace, extra_index, captures_len); + const captures = try sema.getCaptures(block, extra_index, captures_len); extra_index += captures_len; if (small.has_backing_int) { @@ -2944,7 +2952,7 @@ fn zirEnumDecl( break :blk decls_len; } else 0; - const captures = try sema.getCaptures(block.namespace, extra_index, captures_len); + const captures = try sema.getCaptures(block, extra_index, captures_len); extra_index += captures_len; const decls = sema.code.bodySlice(extra_index, decls_len); @@ -3209,7 +3217,7 @@ fn zirUnionDecl( break :blk decls_len; } else 0; - const captures = try sema.getCaptures(block.namespace, extra_index, captures_len); + const captures = try sema.getCaptures(block, extra_index, captures_len); extra_index += captures_len; const wip_ty = switch (try ip.getUnionType(gpa, .{ @@ -3315,7 +3323,7 @@ fn zirOpaqueDecl( break :blk decls_len; } else 0; - const captures = try sema.getCaptures(block.namespace, extra_index, captures_len); + const captures = try sema.getCaptures(block, extra_index, captures_len); extra_index += captures_len; const wip_ty = switch (try ip.getOpaqueType(gpa, .{ @@ -17268,6 +17276,8 @@ fn zirClosureGet(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat const capture_ty = switch (captures.get(ip)[extended.small].unwrap()) { .@"comptime" => |index| return Air.internedToRef(index), .runtime => |index| index, + .decl_val => |decl_index| return sema.analyzeDeclVal(block, src, decl_index), + .decl_ref => |decl_index| return sema.analyzeDeclRef(decl_index), }; // The comptime case is handled already above. Runtime case below. diff --git a/src/print_zir.zig b/src/print_zir.zig index 2303810ed89f..d96fe4f6c962 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -1427,11 +1427,11 @@ const Writer = struct { try stream.writeAll("{}, "); } else { try stream.writeAll("{ "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; for (1..captures_len) |_| { try stream.writeAll(", "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; } try stream.writeAll(" }, "); @@ -1652,11 +1652,11 @@ const Writer = struct { try stream.writeAll("{}, "); } else { try stream.writeAll("{ "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; for (1..captures_len) |_| { try stream.writeAll(", "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; } try stream.writeAll(" }, "); @@ -1817,11 +1817,11 @@ const Writer = struct { try stream.writeAll("{}, "); } else { try stream.writeAll("{ "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; for (1..captures_len) |_| { try stream.writeAll(", "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; } try stream.writeAll(" }, "); @@ -1930,11 +1930,11 @@ const Writer = struct { try stream.writeAll("{}, "); } else { try stream.writeAll("{ "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; for (1..captures_len) |_| { try stream.writeAll(", "); - try self.writeCapture(stream, @enumFromInt(self.code.extra[extra_index])); + try self.writeCapture(stream, @bitCast(self.code.extra[extra_index])); extra_index += 1; } try stream.writeAll(" }, "); @@ -2808,8 +2808,14 @@ const Writer = struct { fn writeCapture(self: *Writer, stream: anytype, capture: Zir.Inst.Capture) !void { switch (capture.unwrap()) { - .inst => |inst| return self.writeInstIndex(stream, inst), .nested => |i| return stream.print("[{d}]", .{i}), + .instruction => |inst| return self.writeInstIndex(stream, inst), + .decl_val => |str| try stream.print("decl_val \"{}\"", .{ + std.zig.fmtEscapes(self.code.nullTerminatedString(str)), + }), + .decl_ref => |str| try stream.print("decl_ref \"{}\"", .{ + std.zig.fmtEscapes(self.code.nullTerminatedString(str)), + }), } }