Skip to content

Commit

Permalink
feat[tool]: add integrity hash to initcode (#4234)
Browse files Browse the repository at this point in the history
this commit adds the integrity hash of the source code to the
initcode. it extends the existing cbor metadata payload in the
initcode, so that verifiers can compare the integrity hash to the
artifact produced by a source bundle.

the integrity hash is put in the initcode to preserve bytecode space of
the runtime code.

refactor:
- change existing `insert_compiler_metadata=` flag to the more generic
`compiler_metadata=None`, which is more extensible.
  • Loading branch information
charles-cooper authored Oct 4, 2024
1 parent 4f47497 commit 9655119
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 29 deletions.
18 changes: 16 additions & 2 deletions tests/functional/builtins/codegen/test_raw_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ def __default__():
assert env.message_call(caller.address, data=sig) == b""


def _strip_initcode_suffix(bytecode):
bs = bytes.fromhex(bytecode.removeprefix("0x"))
to_strip = int.from_bytes(bs[-2:], "big")
return bs[:-to_strip].hex()


# check max_outsize=0 does same thing as not setting max_outsize.
# compile to bytecode and compare bytecode directly.
def test_max_outsize_0():
Expand All @@ -276,7 +282,11 @@ def test_raw_call(_target: address):
"""
output1 = compile_code(code1, output_formats=["bytecode", "bytecode_runtime"])
output2 = compile_code(code2, output_formats=["bytecode", "bytecode_runtime"])
assert output1 == output2
assert output1["bytecode_runtime"] == output2["bytecode_runtime"]

bytecode1 = output1["bytecode"]
bytecode2 = output2["bytecode"]
assert _strip_initcode_suffix(bytecode1) == _strip_initcode_suffix(bytecode2)


# check max_outsize=0 does same thing as not setting max_outsize,
Expand All @@ -298,7 +308,11 @@ def test_raw_call(_target: address) -> bool:
"""
output1 = compile_code(code1, output_formats=["bytecode", "bytecode_runtime"])
output2 = compile_code(code2, output_formats=["bytecode", "bytecode_runtime"])
assert output1 == output2
assert output1["bytecode_runtime"] == output2["bytecode_runtime"]

bytecode1 = output1["bytecode"]
bytecode2 = output2["bytecode"]
assert _strip_initcode_suffix(bytecode1) == _strip_initcode_suffix(bytecode2)


# test functionality of max_outsize=0
Expand Down
37 changes: 28 additions & 9 deletions tests/unit/compiler/test_bytecode_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def test_bytecode_runtime():


def test_bytecode_signature():
out = vyper.compile_code(simple_contract_code, output_formats=["bytecode_runtime", "bytecode"])
out = vyper.compile_code(
simple_contract_code, output_formats=["bytecode_runtime", "bytecode", "integrity"]
)

runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))

metadata = _parse_cbor_metadata(initcode)
runtime_len, data_section_lengths, immutables_len, compiler = metadata
integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata

assert integrity_hash.hex() == out["integrity"]

assert runtime_len == len(runtime_code)
assert data_section_lengths == []
Expand All @@ -73,14 +77,18 @@ def test_bytecode_signature_dense_jumptable():
settings = Settings(optimize=OptimizationLevel.CODESIZE)

out = vyper.compile_code(
many_functions, output_formats=["bytecode_runtime", "bytecode"], settings=settings
many_functions,
output_formats=["bytecode_runtime", "bytecode", "integrity"],
settings=settings,
)

runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))

metadata = _parse_cbor_metadata(initcode)
runtime_len, data_section_lengths, immutables_len, compiler = metadata
integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata

assert integrity_hash.hex() == out["integrity"]

assert runtime_len == len(runtime_code)
assert data_section_lengths == [5, 35]
Expand All @@ -92,14 +100,18 @@ def test_bytecode_signature_sparse_jumptable():
settings = Settings(optimize=OptimizationLevel.GAS)

out = vyper.compile_code(
many_functions, output_formats=["bytecode_runtime", "bytecode"], settings=settings
many_functions,
output_formats=["bytecode_runtime", "bytecode", "integrity"],
settings=settings,
)

runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))

metadata = _parse_cbor_metadata(initcode)
runtime_len, data_section_lengths, immutables_len, compiler = metadata
integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata

assert integrity_hash.hex() == out["integrity"]

assert runtime_len == len(runtime_code)
assert data_section_lengths == [8]
Expand All @@ -108,13 +120,17 @@ def test_bytecode_signature_sparse_jumptable():


def test_bytecode_signature_immutables():
out = vyper.compile_code(has_immutables, output_formats=["bytecode_runtime", "bytecode"])
out = vyper.compile_code(
has_immutables, output_formats=["bytecode_runtime", "bytecode", "integrity"]
)

runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))

metadata = _parse_cbor_metadata(initcode)
runtime_len, data_section_lengths, immutables_len, compiler = metadata
integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata

assert integrity_hash.hex() == out["integrity"]

assert runtime_len == len(runtime_code)
assert data_section_lengths == []
Expand All @@ -129,7 +145,10 @@ def test_bytecode_signature_deployed(code, get_contract, env):
deployed_code = env.get_code(c.address)

metadata = _parse_cbor_metadata(c.bytecode)
runtime_len, data_section_lengths, immutables_len, compiler = metadata
integrity_hash, runtime_len, data_section_lengths, immutables_len, compiler = metadata

out = vyper.compile_code(code, output_formats=["integrity"])
assert integrity_hash.hex() == out["integrity"]

assert compiler == {"vyper": list(vyper.version.version_tuple)}

Expand Down
6 changes: 2 additions & 4 deletions vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,13 @@ def _build_source_map_output(compiler_data, bytecode, pc_maps):


def build_source_map_output(compiler_data: CompilerData) -> dict:
bytecode, pc_maps = compile_ir.assembly_to_evm(
compiler_data.assembly, insert_compiler_metadata=False
)
bytecode, pc_maps = compile_ir.assembly_to_evm(compiler_data.assembly, compiler_metadata=None)
return _build_source_map_output(compiler_data, bytecode, pc_maps)


def build_source_map_runtime_output(compiler_data: CompilerData) -> dict:
bytecode, pc_maps = compile_ir.assembly_to_evm(
compiler_data.assembly_runtime, insert_compiler_metadata=False
compiler_data.assembly_runtime, compiler_metadata=None
)
return _build_source_map_output(compiler_data, bytecode, pc_maps)

Expand Down
17 changes: 9 additions & 8 deletions vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import cached_property
from pathlib import Path, PurePath
from typing import Optional
from typing import Any, Optional

from vyper import ast as vy_ast
from vyper.ast import natspec
Expand Down Expand Up @@ -249,12 +249,15 @@ def assembly_runtime(self) -> list:

@cached_property
def bytecode(self) -> bytes:
insert_compiler_metadata = not self.no_bytecode_metadata
return generate_bytecode(self.assembly, insert_compiler_metadata=insert_compiler_metadata)
metadata = None
if not self.no_bytecode_metadata:
module_t = self.compilation_target._metadata["type"]
metadata = bytes.fromhex(module_t.integrity_sum)
return generate_bytecode(self.assembly, compiler_metadata=metadata)

@cached_property
def bytecode_runtime(self) -> bytes:
return generate_bytecode(self.assembly_runtime, insert_compiler_metadata=False)
return generate_bytecode(self.assembly_runtime, compiler_metadata=None)

@cached_property
def blueprint_bytecode(self) -> bytes:
Expand Down Expand Up @@ -351,7 +354,7 @@ def _find_nested_opcode(assembly, key):
return any(_find_nested_opcode(x, key) for x in sublists)


def generate_bytecode(assembly: list, insert_compiler_metadata: bool) -> bytes:
def generate_bytecode(assembly: list, compiler_metadata: Optional[Any]) -> bytes:
"""
Generate bytecode from assembly instructions.
Expand All @@ -365,6 +368,4 @@ def generate_bytecode(assembly: list, insert_compiler_metadata: bool) -> bytes:
bytes
Final compiled bytecode.
"""
return compile_ir.assembly_to_evm(assembly, insert_compiler_metadata=insert_compiler_metadata)[
0
]
return compile_ir.assembly_to_evm(assembly, compiler_metadata=compiler_metadata)[0]
15 changes: 9 additions & 6 deletions vyper/ir/compile_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,22 +1155,24 @@ def _relocate_segments(assembly):


# TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps
def assembly_to_evm(assembly, pc_ofst=0, insert_compiler_metadata=False):
def assembly_to_evm(assembly, pc_ofst=0, compiler_metadata=None):
bytecode, source_maps, _ = assembly_to_evm_with_symbol_map(
assembly, pc_ofst=pc_ofst, insert_compiler_metadata=insert_compiler_metadata
assembly, pc_ofst=pc_ofst, compiler_metadata=compiler_metadata
)
return bytecode, source_maps


def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadata=False):
def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, compiler_metadata=None):
"""
Assembles assembly into EVM
assembly: list of asm instructions
pc_ofst: when constructing the source map, the amount to offset all
pcs by (no effect until we add deploy code source map)
insert_compiler_metadata: whether to append vyper metadata to output
(should be true for runtime code)
compiler_metadata: any compiler metadata to add. pass `None` to indicate
no metadata to be added (should always be `None` for
runtime code). the value is opaque, and will be passed
directly to `cbor2.dumps()`.
"""
line_number_map = {
"breakpoints": set(),
Expand Down Expand Up @@ -1278,10 +1280,11 @@ def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadat
pc += 1

bytecode_suffix = b""
if insert_compiler_metadata:
if compiler_metadata is not None:
# this will hold true when we are in initcode
assert immutables_len is not None
metadata = (
compiler_metadata,
len(runtime_code),
data_section_lengths,
immutables_len,
Expand Down

0 comments on commit 9655119

Please sign in to comment.