Skip to content

Commit

Permalink
Add experimental ImportResolver (#2298)
Browse files Browse the repository at this point in the history
Closes #2294

Signed-off-by: Bjørn Erik Pedersen <[email protected]>
  • Loading branch information
bep authored Aug 5, 2024
1 parent 5ad3f06 commit a186448
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 26 deletions.
19 changes: 19 additions & 0 deletions experimental/importresolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package experimental

import (
"context"

"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/internal/expctxkeys"
)

// ImportResolver is an experimental func type that, if set,
// will be used as the first step in resolving imports.
// See issue 2294.
// If the import name is not found, it should return nil.
type ImportResolver func(name string) api.Module

// WithImportResolver returns a new context with the given ImportResolver.
func WithImportResolver(ctx context.Context, resolver ImportResolver) context.Context {
return context.WithValue(ctx, expctxkeys.ImportResolverKey{}, resolver)
}
101 changes: 101 additions & 0 deletions experimental/importresolver_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package experimental_test

import (
"bytes"
"context"
_ "embed"
"fmt"
"log"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
)

var (
// These wasm files were generated by the following:
// cd testdata
// wat2wasm --debug-names inoutdispatcher.wat
// wat2wasm --debug-names inoutdispatcherclient.wat

//go:embed testdata/inoutdispatcher.wasm
inoutdispatcherWasm []byte
//go:embed testdata/inoutdispatcherclient.wasm
inoutdispatcherclientWasm []byte
)

func Example_importResolver() {
ctx := context.Background()

r := wazero.NewRuntime(ctx)
defer r.Close(ctx)

// The client imports the inoutdispatcher module that reads from stdin and writes to stdout.
// This means that we need multiple instances of the inoutdispatcher module to have different stdin/stdout.
// This example demonstrates a way to do that.
type mod struct {
in bytes.Buffer
out bytes.Buffer

client api.Module
}

wasi_snapshot_preview1.MustInstantiate(ctx, r)

const numInstances = 3
mods := make([]*mod, numInstances)
for i := range mods {
mods[i] = &mod{}
m := mods[i]
idm, err := r.CompileModule(ctx, inoutdispatcherWasm)
if err != nil {
log.Panicln(err)
}
idcm, err := r.CompileModule(ctx, inoutdispatcherclientWasm)
if err != nil {
log.Panicln(err)
}

const inoutDispatcherModuleName = "inoutdispatcher"

dispatcherInstance, err := r.InstantiateModule(ctx, idm,
wazero.NewModuleConfig().
WithStdin(&m.in).
WithStdout(&m.out).
WithName("")) // Makes it an anonymous module.
if err != nil {
log.Panicln(err)
}

ctx = experimental.WithImportResolver(ctx, func(name string) api.Module {
if name == inoutDispatcherModuleName {
return dispatcherInstance
}
return nil
})

m.client, err = r.InstantiateModule(ctx, idcm, wazero.NewModuleConfig().WithName(fmt.Sprintf("m%d", i)))
if err != nil {
log.Panicln(err)
}

}

for i, m := range mods {
m.in.WriteString(fmt.Sprintf("Module instance #%d", i))
_, err := m.client.ExportedFunction("dispatch").Call(ctx)
if err != nil {
log.Panicln(err)
}
}

for i, m := range mods {
fmt.Printf("out%d: %s\n", i, m.out.String())
}

// Output:
// out0: Module instance #0
// out1: Module instance #1
// out2: Module instance #2
}
63 changes: 63 additions & 0 deletions experimental/importresolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package experimental_test

import (
"context"
"fmt"
"testing"

"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/internal/testing/binaryencoding"
"github.com/tetratelabs/wazero/internal/testing/require"
"github.com/tetratelabs/wazero/internal/wasm"
)

func TestImportResolver(t *testing.T) {
ctx := context.Background()

r := wazero.NewRuntime(ctx)
defer r.Close(ctx)

for i := 0; i < 5; i++ {
var callCount int
start := func(ctx context.Context) {
callCount++
}
modImport, err := r.NewHostModuleBuilder(fmt.Sprintf("env%d", i)).
NewFunctionBuilder().WithFunc(start).Export("start").
Compile(ctx)
require.NoError(t, err)
// Anonymous module, it will be resolved by the import resolver.
instanceImport, err := r.InstantiateModule(ctx, modImport, wazero.NewModuleConfig().WithName(""))
require.NoError(t, err)

resolveImport := func(name string) api.Module {
if name == "env" {
return instanceImport
}
return nil
}

// Set the import resolver in the context.
ctx = experimental.WithImportResolver(context.Background(), resolveImport)

one := uint32(1)
binary := binaryencoding.EncodeModule(&wasm.Module{
TypeSection: []wasm.FunctionType{{}},
ImportSection: []wasm.Import{{Module: "env", Name: "start", Type: wasm.ExternTypeFunc, DescFunc: 0}},
FunctionSection: []wasm.Index{0},
CodeSection: []wasm.Code{
{Body: []byte{wasm.OpcodeCall, 0, wasm.OpcodeEnd}}, // Call the imported env.start.
},
StartSection: &one,
})

modMain, err := r.CompileModule(ctx, binary)
require.NoError(t, err)

_, err = r.InstantiateModule(ctx, modMain, wazero.NewModuleConfig())
require.NoError(t, err)
require.Equal(t, 1, callCount)
}
}
Binary file added experimental/testdata/inoutdispatcher.wasm
Binary file not shown.
38 changes: 38 additions & 0 deletions experimental/testdata/inoutdispatcher.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
(module
(import "wasi_snapshot_preview1" "fd_read" (func $fd_read (param i32 i32 i32 i32) (result i32)))
(import "wasi_snapshot_preview1" "fd_write" (func $fd_write (param i32 i32 i32 i32) (result i32)))
(memory 1 1 )
(func (export "dispatch")
;; Buffer of 100 chars to read into.
(i32.store (i32.const 4) (i32.const 12))
(i32.store (i32.const 8) (i32.const 100))

(block $done
(loop $read
;; Read from stdin.
(call $fd_read
(i32.const 0) ;; fd; 0 is stdin.
(i32.const 4) ;; iovs
(i32.const 1) ;; iovs_len
(i32.const 8) ;; nread
)

;; If nread is 0, we're done.
(if (i32.eq (i32.load (i32.const 8)) (i32.const 0))
(then br $done)
)

;; Write to stdout.
(drop (call $fd_write
(i32.const 1) ;; fd; 1 is stdout.
(i32.const 4) ;; iovs
(i32.const 1) ;; iovs_len
(i32.const 0) ;; nwritten
))
(br $read)

)
)
)

)
Binary file added experimental/testdata/inoutdispatcherclient.wasm
Binary file not shown.
7 changes: 7 additions & 0 deletions experimental/testdata/inoutdispatcherclient.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

(module
(import "inoutdispatcher" "dispatch" (func $dispatch))
(func (export "dispatch")
(call $dispatch)
)
)
6 changes: 6 additions & 0 deletions internal/expctxkeys/importresolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package expctxkeys

// ImportResolverKey is a context.Context Value key.
// Its associated value should be an ImportResolver.
// See issue 2294.
type ImportResolverKey struct{}
20 changes: 15 additions & 5 deletions internal/wasm/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (s *Store) instantiate(
return nil, err
}

if err = m.resolveImports(module); err != nil {
if err = m.resolveImports(ctx, module); err != nil {
return nil, err
}

Expand Down Expand Up @@ -410,12 +410,22 @@ func (s *Store) instantiate(
return
}

func (m *ModuleInstance) resolveImports(module *Module) (err error) {
func (m *ModuleInstance) resolveImports(ctx context.Context, module *Module) (err error) {
// Check if ctx contains an ImportResolver.
resolveImport, _ := ctx.Value(expctxkeys.ImportResolverKey{}).(experimental.ImportResolver)

for moduleName, imports := range module.ImportPerModule {
var importedModule *ModuleInstance
importedModule, err = m.s.module(moduleName)
if err != nil {
return err
if resolveImport != nil {
if v := resolveImport(moduleName); v != nil {
importedModule = v.(*ModuleInstance)
}
}
if importedModule == nil {
importedModule, err = m.s.module(moduleName)
if err != nil {
return err
}
}

for _, i := range imports {
Expand Down
39 changes: 22 additions & 17 deletions internal/wasm/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,13 +701,13 @@ func Test_resolveImports(t *testing.T) {

t.Run("module not instantiated", func(t *testing.T) {
m := &ModuleInstance{s: newStore()}
err := m.resolveImports(&Module{ImportPerModule: map[string][]*Import{"unknown": {{}}}})
err := m.resolveImports(context.Background(), &Module{ImportPerModule: map[string][]*Import{"unknown": {{}}}})
require.EqualError(t, err, "module[unknown] not instantiated")
})
t.Run("export instance not found", func(t *testing.T) {
m := &ModuleInstance{s: newStore()}
m.s.nameToModule[moduleName] = &ModuleInstance{Exports: map[string]*Export{}, ModuleName: moduleName}
err := m.resolveImports(&Module{ImportPerModule: map[string][]*Import{moduleName: {{Name: "unknown"}}}})
err := m.resolveImports(context.Background(), &Module{ImportPerModule: map[string][]*Import{moduleName: {{Name: "unknown"}}}})
require.EqualError(t, err, "\"unknown\" is not exported in module \"test\"")
})
t.Run("func", func(t *testing.T) {
Expand Down Expand Up @@ -743,7 +743,7 @@ func Test_resolveImports(t *testing.T) {
}

m := &ModuleInstance{Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}, s: s, Source: module}
err := m.resolveImports(module)
err := m.resolveImports(context.Background(), module)
require.NoError(t, err)

me := m.Engine.(*mockModuleEngine)
Expand Down Expand Up @@ -773,7 +773,7 @@ func Test_resolveImports(t *testing.T) {
}

m := &ModuleInstance{Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}, s: s, Source: module}
err := m.resolveImports(module)
err := m.resolveImports(context.Background(), module)
require.EqualError(t, err, "import func[test.target]: signature mismatch: v_f32 != v_v")
})
})
Expand All @@ -787,6 +787,7 @@ func Test_resolveImports(t *testing.T) {
Exports: map[string]*Export{name: {Type: ExternTypeGlobal, Index: 0}}, ModuleName: moduleName,
}
err := m.resolveImports(
context.Background(),
&Module{
ImportPerModule: map[string][]*Import{moduleName: {{Name: name, Type: ExternTypeGlobal, DescGlobal: g.Type}}},
},
Expand All @@ -805,11 +806,13 @@ func Test_resolveImports(t *testing.T) {
ModuleName: moduleName,
}
m := &ModuleInstance{Globals: make([]*GlobalInstance, 1), s: s}
err := m.resolveImports(&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{Mutable: true}},
}},
})
err := m.resolveImports(
context.Background(),
&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{Mutable: true}},
}},
})
require.EqualError(t, err, "import global[test.target]: mutability mismatch: true != false")
})
t.Run("type mismatch", func(t *testing.T) {
Expand All @@ -823,11 +826,13 @@ func Test_resolveImports(t *testing.T) {
ModuleName: moduleName,
}
m := &ModuleInstance{Globals: make([]*GlobalInstance, 1), s: s}
err := m.resolveImports(&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{ValType: ValueTypeF64}},
}},
})
err := m.resolveImports(
context.Background(),
&Module{
ImportPerModule: map[string][]*Import{moduleName: {
{Module: moduleName, Name: name, Type: ExternTypeGlobal, DescGlobal: GlobalType{ValType: ValueTypeF64}},
}},
})
require.EqualError(t, err, "import global[test.target]: value type mismatch: f64 != i32")
})
})
Expand All @@ -846,7 +851,7 @@ func Test_resolveImports(t *testing.T) {
Engine: importedME,
}
m := &ModuleInstance{s: s, Engine: &mockModuleEngine{resolveImportsCalled: map[Index]Index{}}}
err := m.resolveImports(&Module{
err := m.resolveImports(context.Background(), &Module{
ImportPerModule: map[string][]*Import{
moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: &Memory{Max: max}}},
},
Expand All @@ -866,7 +871,7 @@ func Test_resolveImports(t *testing.T) {
ModuleName: moduleName,
}
m := &ModuleInstance{s: s}
err := m.resolveImports(&Module{
err := m.resolveImports(context.Background(), &Module{
ImportPerModule: map[string][]*Import{
moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}},
},
Expand All @@ -886,7 +891,7 @@ func Test_resolveImports(t *testing.T) {
max := uint32(10)
importMemoryType := &Memory{Max: max}
m := &ModuleInstance{s: s}
err := m.resolveImports(&Module{
err := m.resolveImports(context.Background(), &Module{
ImportPerModule: map[string][]*Import{moduleName: {{Module: moduleName, Name: name, Type: ExternTypeMemory, DescMem: importMemoryType}}},
})
require.EqualError(t, err, "import memory[test.target]: maximum size mismatch: 10 < 65536")
Expand Down
Loading

0 comments on commit a186448

Please sign in to comment.