Skip to content

Commit

Permalink
compiler: change representation of closures
Browse files Browse the repository at this point in the history
This changes the representation of closures in Zir and Sema. Rather than
a pair of instructions `closure_capture` and `closure_get`, the system
now works as follows:

* Each ZIR type declaration (`struct_decl` etc) contains a list of
  captures in the form of ZIR indices (or, for efficiency, direct
  references to parent captures). This is an ordered list; indexes into
  it are used to refer to captured values.
* The `extended(closure_get)` ZIR instruction refers to a value in this
  list via a 16-bit index (limiting this index to 16 bits allows us to
  store this in `extended`).
* `Module.Namespace` has a new field `captures` which contains the list
  of values captured in a given namespace. This is initialized based on
  the ZIR capture list whenever a type declaration is analyzed.

This change eliminates `CaptureScope` from semantic analysis, which is a
nice simplification; but the main motivation here is that this change is
a prerequisite for ziglang#18816.
  • Loading branch information
mlugg committed Mar 6, 2024
1 parent 90ab8ea commit a6ca20b
Show file tree
Hide file tree
Showing 7 changed files with 563 additions and 344 deletions.
245 changes: 149 additions & 96 deletions lib/std/zig/AstGen.zig

Large diffs are not rendered by default.

147 changes: 99 additions & 48 deletions lib/std/zig/Zir.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1004,17 +1004,6 @@ pub const Inst = struct {
@"resume",
@"await",

/// When a type or function refers to a comptime value from an outer
/// scope, that forms a closure over comptime value. The outer scope
/// will record a capture of that value, which encodes its current state
/// and marks it to persist. Uses `un_tok` field. Operand is the
/// instruction value to capture.
closure_capture,
/// The inner scope of a closure uses closure_get to retrieve the value
/// stored by the outer scope. Uses `inst_node` field. Operand is the
/// closure_capture instruction ref.
closure_get,

/// A defer statement.
/// Uses the `defer` union field.
@"defer",
Expand Down Expand Up @@ -1251,8 +1240,6 @@ pub const Inst = struct {
.@"await",
.ret_err_value_code,
.extended,
.closure_get,
.closure_capture,
.ret_ptr,
.ret_type,
.@"try",
Expand Down Expand Up @@ -1542,8 +1529,6 @@ pub const Inst = struct {
.@"resume",
.@"await",
.ret_err_value_code,
.closure_get,
.closure_capture,
.@"break",
.break_inline,
.condbr,
Expand Down Expand Up @@ -1829,9 +1814,6 @@ pub const Inst = struct {
.@"resume" = .un_node,
.@"await" = .un_node,

.closure_capture = .un_tok,
.closure_get = .inst_node,

.@"defer" = .@"defer",
.defer_err_code = .defer_err_code,

Expand Down Expand Up @@ -2074,6 +2056,10 @@ pub const Inst = struct {
/// `operand` is payload index to `RestoreErrRetIndex`.
/// `small` is undefined.
restore_err_ret_index,
/// Retrieves a value from the current type declaration scope's closure.
/// `operand` is `src_node: i32`.
/// `small` is closure index.
closure_get,
/// Used as a placeholder instruction which is just a dummy index for Sema to replace
/// with a specific value. For instance, this is used for the capture of an `errdefer`.
/// This should never appear in a body.
Expand Down Expand Up @@ -2949,7 +2935,7 @@ pub const Inst = struct {
/// These are stored in trailing data in `extra` for each prong.
pub const ProngInfo = packed struct(u32) {
body_len: u28,
capture: Capture,
capture: ProngInfo.Capture,
is_inline: bool,
has_tag_capture: bool,

Expand Down Expand Up @@ -3013,27 +2999,29 @@ pub const Inst = struct {
};

/// Trailing:
/// 0. fields_len: u32, // if has_fields_len
/// 1. decls_len: u32, // if has_decls_len
/// 2. backing_int_body_len: u32, // if has_backing_int
/// 3. backing_int_ref: Ref, // if has_backing_int and backing_int_body_len is 0
/// 4. backing_int_body_inst: Inst, // if has_backing_int and backing_int_body_len is > 0
/// 5. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 6. flags: u32 // for every 8 fields
/// 0. captures_len: u32 // if has_captures_len
/// 1. fields_len: u32, // if has_fields_len
/// 2. decls_len: u32, // if has_decls_len
/// 3. capture: Capture // for every captures_len
/// 4. backing_int_body_len: u32, // if has_backing_int
/// 5. backing_int_ref: Ref, // if has_backing_int and backing_int_body_len is 0
/// 6. backing_int_body_inst: Inst, // if has_backing_int and backing_int_body_len is > 0
/// 7. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 8. flags: u32 // for every 8 fields
/// - sets of 4 bits:
/// 0b000X: whether corresponding field has an align expression
/// 0b00X0: whether corresponding field has a default expression
/// 0b0X00: whether corresponding field is comptime
/// 0bX000: whether corresponding field has a type expression
/// 7. fields: { // for every fields_len
/// 9. fields: { // for every fields_len
/// field_name: u32, // if !is_tuple
/// doc_comment: NullTerminatedString, // .empty if no doc comment
/// field_type: Ref, // if corresponding bit is not set. none means anytype.
/// field_type_body_len: u32, // if corresponding bit is set
/// align_body_len: u32, // if corresponding bit is set
/// init_body_len: u32, // if corresponding bit is set
/// }
/// 8. bodies: { // for every fields_len
/// 10. bodies: { // for every fields_len
/// field_type_body_inst: Inst, // for each field_type_body_len
/// align_body_inst: Inst, // for each align_body_len
/// init_body_inst: Inst, // for each init_body_len
Expand All @@ -3052,6 +3040,7 @@ pub const Inst = struct {
}

pub const Small = packed struct {
has_captures_len: bool,
has_fields_len: bool,
has_decls_len: bool,
has_backing_int: bool,
Expand All @@ -3063,10 +3052,35 @@ pub const Inst = struct {
any_default_inits: bool,
any_comptime_fields: bool,
any_aligned_fields: bool,
_: u3 = undefined,
_: u2 = undefined,
};
};

/// Represents a single value being captured in a type declaration's closure.
/// If high bit is 0, this represents a `Zir.Inst,Index`.
/// If high bit is 1, this represents an index into the last closure.
pub const Capture = enum(u32) {
_,
pub const Unwrapped = union(enum) {
inst: Zir.Inst.Index,
nested: u16,
};
pub fn wrap(cap: Unwrapped) Capture {
return switch (cap) {
.inst => |inst| @enumFromInt(@intFromEnum(inst)),
.nested => |idx| @enumFromInt((1 << 31) | @as(u32, idx)),
};
}
pub fn unwrap(cap: Capture) Unwrapped {
const raw = @intFromEnum(cap);
const tag: u1 = @intCast(raw >> 31);
return switch (tag) {
0 => .{ .inst = @enumFromInt(raw) },
1 => .{ .nested = @truncate(raw) },
};
}
};

pub const NameStrategy = enum(u2) {
/// Use the same name as the parent declaration name.
/// e.g. `const Foo = struct {...};`.
Expand Down Expand Up @@ -3098,14 +3112,16 @@ pub const Inst = struct {

/// Trailing:
/// 0. tag_type: Ref, // if has_tag_type
/// 1. body_len: u32, // if has_body_len
/// 2. fields_len: u32, // if has_fields_len
/// 3. decls_len: u32, // if has_decls_len
/// 4. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 5. inst: Index // for every body_len
/// 6. has_bits: u32 // for every 32 fields
/// 1. captures_len: u32, // if has_captures_len
/// 2. body_len: u32, // if has_body_len
/// 3. fields_len: u32, // if has_fields_len
/// 4. decls_len: u32, // if has_decls_len
/// 5. capture: Capture // for every captures_len
/// 6. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 7. inst: Index // for every body_len
/// 8. has_bits: u32 // for every 32 fields
/// - the bit is whether corresponding field has an value expression
/// 7. fields: { // for every fields_len
/// 9. fields: { // for every fields_len
/// field_name: u32,
/// doc_comment: u32, // .empty if no doc_comment
/// value: Ref, // if corresponding bit is set
Expand All @@ -3125,29 +3141,32 @@ pub const Inst = struct {

pub const Small = packed struct {
has_tag_type: bool,
has_captures_len: bool,
has_body_len: bool,
has_fields_len: bool,
has_decls_len: bool,
name_strategy: NameStrategy,
nonexhaustive: bool,
_: u9 = undefined,
_: u8 = undefined,
};
};

/// Trailing:
/// 0. tag_type: Ref, // if has_tag_type
/// 1. body_len: u32, // if has_body_len
/// 2. fields_len: u32, // if has_fields_len
/// 3. decls_len: u32, // if has_decls_len
/// 4. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 5. inst: Index // for every body_len
/// 6. has_bits: u32 // for every 8 fields
/// 1. captures_len: u32 // if has_captures_len
/// 2. body_len: u32, // if has_body_len
/// 3. fields_len: u32, // if has_fields_len
/// 4. decls_len: u37, // if has_decls_len
/// 5. capture: Capture // for every captures_len
/// 6. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 7. inst: Index // for every body_len
/// 8. has_bits: u32 // for every 8 fields
/// - sets of 4 bits:
/// 0b000X: whether corresponding field has a type expression
/// 0b00X0: whether corresponding field has a align expression
/// 0b0X00: whether corresponding field has a tag value expression
/// 0bX000: unused
/// 7. fields: { // for every fields_len
/// 9. fields: { // for every fields_len
/// field_name: NullTerminatedString, // null terminated string index
/// doc_comment: NullTerminatedString, // .empty if no doc comment
/// field_type: Ref, // if corresponding bit is set
Expand All @@ -3170,6 +3189,7 @@ pub const Inst = struct {

pub const Small = packed struct {
has_tag_type: bool,
has_captures_len: bool,
has_body_len: bool,
has_fields_len: bool,
has_decls_len: bool,
Expand All @@ -3183,13 +3203,15 @@ pub const Inst = struct {
/// true | false | union(T) { }
auto_enum_tag: bool,
any_aligned_fields: bool,
_: u6 = undefined,
_: u5 = undefined,
};
};

/// Trailing:
/// 0. decls_len: u32, // if has_decls_len
/// 1. decl: Index, // for every decls_len; points to a `declaration` instruction
/// 0. captures_len: u32, // if has_captures_len
/// 1. decls_len: u32, // if has_decls_len
/// 2. capture: Capture, // for every captures_len
/// 3. decl: Index, // for every decls_len; points to a `declaration` instruction
pub const OpaqueDecl = struct {
src_node: i32,

Expand All @@ -3198,9 +3220,10 @@ pub const Inst = struct {
}

pub const Small = packed struct {
has_captures_len: bool,
has_decls_len: bool,
name_strategy: NameStrategy,
_: u13 = undefined,
_: u12 = undefined,
};
};

Expand Down Expand Up @@ -3502,13 +3525,20 @@ pub fn declIterator(zir: Zir, decl_inst: Zir.Inst.Index) DeclIterator {
.struct_decl => {
const small: Inst.StructDecl.Small = @bitCast(extended.small);
var extra_index: u32 = @intCast(extended.operand + @typeInfo(Inst.StructDecl).Struct.fields.len);
const captures_len = if (small.has_captures_len) captures_len: {
const captures_len = zir.extra[extra_index];
extra_index += 1;
break :captures_len captures_len;
} else 0;
extra_index += @intFromBool(small.has_fields_len);
const decls_len = if (small.has_decls_len) decls_len: {
const decls_len = zir.extra[extra_index];
extra_index += 1;
break :decls_len decls_len;
} else 0;

extra_index += captures_len;

if (small.has_backing_int) {
const backing_int_body_len = zir.extra[extra_index];
extra_index += 1; // backing_int_body_len
Expand All @@ -3529,6 +3559,11 @@ pub fn declIterator(zir: Zir, decl_inst: Zir.Inst.Index) DeclIterator {
const small: Inst.EnumDecl.Small = @bitCast(extended.small);
var extra_index: u32 = @intCast(extended.operand + @typeInfo(Inst.EnumDecl).Struct.fields.len);
extra_index += @intFromBool(small.has_tag_type);
const captures_len = if (small.has_captures_len) captures_len: {
const captures_len = zir.extra[extra_index];
extra_index += 1;
break :captures_len captures_len;
} else 0;
extra_index += @intFromBool(small.has_body_len);
extra_index += @intFromBool(small.has_fields_len);
const decls_len = if (small.has_decls_len) decls_len: {
Expand All @@ -3537,6 +3572,8 @@ pub fn declIterator(zir: Zir, decl_inst: Zir.Inst.Index) DeclIterator {
break :decls_len decls_len;
} else 0;

extra_index += captures_len;

return .{
.extra_index = extra_index,
.decls_remaining = decls_len,
Expand All @@ -3547,6 +3584,11 @@ pub fn declIterator(zir: Zir, decl_inst: Zir.Inst.Index) DeclIterator {
const small: Inst.UnionDecl.Small = @bitCast(extended.small);
var extra_index: u32 = @intCast(extended.operand + @typeInfo(Inst.UnionDecl).Struct.fields.len);
extra_index += @intFromBool(small.has_tag_type);
const captures_len = if (small.has_captures_len) captures_len: {
const captures_len = zir.extra[extra_index];
extra_index += 1;
break :captures_len captures_len;
} else 0;
extra_index += @intFromBool(small.has_body_len);
extra_index += @intFromBool(small.has_fields_len);
const decls_len = if (small.has_decls_len) decls_len: {
Expand All @@ -3555,6 +3597,8 @@ pub fn declIterator(zir: Zir, decl_inst: Zir.Inst.Index) DeclIterator {
break :decls_len decls_len;
} else 0;

extra_index += captures_len;

return .{
.extra_index = extra_index,
.decls_remaining = decls_len,
Expand All @@ -3569,6 +3613,13 @@ pub fn declIterator(zir: Zir, decl_inst: Zir.Inst.Index) DeclIterator {
extra_index += 1;
break :decls_len decls_len;
} else 0;
const captures_len = if (small.has_captures_len) captures_len: {
const captures_len = zir.extra[extra_index];
extra_index += 1;
break :captures_len captures_len;
} else 0;

extra_index += captures_len;

return .{
.extra_index = extra_index,
Expand Down
Loading

0 comments on commit a6ca20b

Please sign in to comment.