Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon committed Oct 15, 2024
1 parent 3142978 commit d15d3f1
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 222 deletions.
191 changes: 117 additions & 74 deletions gnovm/pkg/doctest/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,15 @@ const (
IGNORE = "ignore" // Do not run the code block
SHOULD_PANIC = "should_panic" // Expect a panic
ASSERT = "assert" // Assert the result and expected output are equal
gnoLang = "gno"
)

const (
goLang = "go"
gnoLang = "gno"
)
var (
cache = newCache(maxCacheSize)
regexCache = make(map[string]*regexp.Regexp)

// GetStdlibsDir returns the path to the standard libraries directory.
func GetStdlibsDir() string {
_, filename, _, ok := runtime.Caller(0)
if !ok {
panic("cannot get current file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "stdlibs")
}

// cache stores the results of code execution.
var cache = newCache(maxCacheSize)

// hashCodeBlock generates a SHA256 hash for the given code block.
func hashCodeBlock(c codeBlock) string {
h := sha256.New()
h.Write([]byte(c.content))
return hex.EncodeToString(h.Sum(nil))
}
addrRegex = regexp.MustCompile(`gno\.land/r/g[a-z0-9]+/[a-z.]+`)
)

// ExecuteCodeBlock executes a parsed code block and executes it in a gno VM.
func ExecuteCodeBlock(c codeBlock, stdlibDir string) (string, error) {
Expand All @@ -62,24 +46,19 @@ func ExecuteCodeBlock(c codeBlock, stdlibDir string) (string, error) {
}

// Extract the actual language from the lang field
lang := extractLanguage(c.lang)

if lang != goLang && lang != gnoLang {
lang := strings.Split(c.lang, ",")[0]
if lang != gnoLang {
return fmt.Sprintf("SKIPPED (Unsupported language: %s)", lang), nil
}

if lang == goLang {
lang = gnoLang
}

hashKey := hashCodeBlock(c)

// get the result from the cache if it exists
if result, found := cache.get(hashKey); found {
return handleCachedResult(result, c)
}

ctx, acck, _, vmk, stdlibCtx := setupEnvironment()
ctx, acck, _, vmk, stdlibCtx := setupEnv()

files := []*std.MemFile{
{Name: fmt.Sprintf("%d.%s", c.index, lang), Body: c.content},
Expand All @@ -96,6 +75,9 @@ func ExecuteCodeBlock(c codeBlock, stdlibDir string) (string, error) {
return handlePanicMessage(err, c.options.PanicMessage)
}

// remove package path from the result and replace with `main`.
res = replacePackagePath(res)

if err != nil {
return "", err
}
Expand All @@ -112,26 +94,46 @@ func ExecuteCodeBlock(c codeBlock, stdlibDir string) (string, error) {
return compareResults(res, c.expectedOutput, c.expectedError)
}

func extractLanguage(lang string) string {
return strings.Split(lang, ",")[0]
}
// ExecuteMatchingCodeBlock executes all code blocks in the given content that match the given pattern.
// It returns a slice of execution results as strings and any error encountered during the execution.
func ExecuteMatchingCodeBlock(
ctx context.Context,
content string,
pattern string,
) ([]string, error) {
codeBlocks, err := GetCodeBlocks(content)
if err != nil {
return nil, err
}

func handleCachedResult(result string, c codeBlock) (string, error) {
res := strings.TrimSpace(result)
results := make([]string, 0, len(codeBlocks))
for _, block := range codeBlocks {
if err := ctx.Err(); err != nil {
return nil, err
}

if c.expectedOutput == "" && c.expectedError == "" {
return fmt.Sprintf("%s (cached)", res), nil
}
if !matchPattern(block.name, pattern) {
continue
}

res, err := compareResults(res, c.expectedOutput, c.expectedError)
if err != nil {
return "", err
result, err := ExecuteCodeBlock(block, GetStdlibsDir())
if err != nil {
return nil, fmt.Errorf("failed to execute code block %s: %w", block.name, err)
}
results = append(results, fmt.Sprintf("\n=== %s ===\n\n%s\n", block.name, result))
}

return fmt.Sprintf("%s (cached)", res), nil
return results, nil
}

func setupEnvironment() (sdk.Context, authm.AccountKeeper, bankm.BankKeeper, *vm.VMKeeper, sdk.Context) {
// ref: gno.land/pkg/sdk/vm/common_test.go
func setupEnv() (
sdk.Context,
authm.AccountKeeper,
bankm.BankKeeper,
*vm.VMKeeper,
sdk.Context,
) {
baseKey := store.NewStoreKey("baseKey")
iavlKey := store.NewStoreKey("iavlKey")

Expand All @@ -142,7 +144,12 @@ func setupEnvironment() (sdk.Context, authm.AccountKeeper, bankm.BankKeeper, *vm
ms.MountStoreWithDB(iavlKey, iavl.StoreConstructor, db)
ms.LoadLatestVersion()

ctx := sdk.NewContext(sdk.RunTxModeDeliver, ms, &bft.Header{ChainID: "test-chain-id"}, log.NewNoopLogger())
ctx := sdk.NewContext(
sdk.RunTxModeDeliver,
ms,
&bft.Header{ChainID: "test-chain-id"},
log.NewNoopLogger(),
)
acck := authm.NewAccountKeeper(iavlKey, std.ProtoBaseAccount)
bank := bankm.NewBankKeeper(acck)
stdlibsDir := GetStdlibsDir()
Expand All @@ -161,14 +168,37 @@ func setupEnvironment() (sdk.Context, authm.AccountKeeper, bankm.BankKeeper, *vm
return ctx, acck, bank, vmk, stdlibCtx
}

func handleCachedResult(result string, c codeBlock) (string, error) {
res := strings.TrimSpace(result)

if c.expectedOutput == "" && c.expectedError == "" {
return fmt.Sprintf("%s (cached)", res), nil
}

res, err := compareResults(res, c.expectedOutput, c.expectedError)
if err != nil {
return "", err
}

return fmt.Sprintf("%s (cached)", res), nil
}

func handlePanicMessage(err error, panicMessage string) (string, error) {
if err == nil {
return "", fmt.Errorf("expected panic with message: %s, but executed successfully", panicMessage)
return "", fmt.Errorf(
"expected panic with message: %s, but executed successfully",
panicMessage,
)
}
if !strings.Contains(err.Error(), panicMessage) {
return "", fmt.Errorf("expected panic with message: %s, but got: %s", panicMessage, err.Error())

if strings.Contains(err.Error(), panicMessage) {
return fmt.Sprintf("panicked as expected: %v", err), nil
}
return fmt.Sprintf("panicked as expected: %v", err), nil

return "", fmt.Errorf(
"expected panic with message: %s, but got: %s",
panicMessage, err.Error(),
)
}

// compareResults compares the actual output of code execution with the expected output or error.
Expand Down Expand Up @@ -206,37 +236,15 @@ func compareRegex(actual, pattern string) (string, error) {
}

if !re.MatchString(actual) {
return "", fmt.Errorf("output did not match regex pattern:\npattern: %s\nactual: %s", pattern, actual)
return "", fmt.Errorf(
"output did not match regex pattern:\npattern: %s\nactual: %s",
pattern, actual,
)
}

return actual, nil
}

// ExecuteMatchingCodeBlock executes all code blocks in the given content that match the given pattern.
// It returns a slice of execution results as strings and any error encountered during the execution.
func ExecuteMatchingCodeBlock(ctx context.Context, content string, pattern string) ([]string, error) {
codeBlocks := GetCodeBlocks(content)
var results []string

for _, block := range codeBlocks {
if err := ctx.Err(); err != nil {
return nil, err
}

if matchPattern(block.name, pattern) {
result, err := ExecuteCodeBlock(block, GetStdlibsDir())
if err != nil {
return nil, fmt.Errorf("failed to execute code block %s: %w", block.name, err)
}
results = append(results, fmt.Sprintf("\n=== %s ===\n\n%s\n", block.name, result))
}
}

return results, nil
}

var regexCache = make(map[string]*regexp.Regexp)

// getCompiledRegex retrieves or compiles a regex pattern.
// it uses a cache to store compiled regex patterns for reuse.
func getCompiledRegex(pattern string) (*regexp.Regexp, error) {
Expand Down Expand Up @@ -274,3 +282,38 @@ func matchPattern(name, pattern string) bool {

return re.MatchString(name)
}

// for display purpose, replace address string with `main.xxx` when printing type.
// ref: https://github.com/gnolang/gno/pull/2357#discussion_r1704398563
func replacePackagePath(input string) string {
result := addrRegex.ReplaceAllStringFunc(input, func(match string) string {
parts := strings.Split(match, "/")
if len(parts) < 4 {
return match
}
lastPart := parts[len(parts)-1]
subParts := strings.Split(lastPart, ".")
if len(subParts) < 2 {
return "main." + lastPart
}
return "main." + subParts[len(subParts)-1]
})

return result
}

// GetStdlibsDir returns the path to the standard libraries directory.
func GetStdlibsDir() string {
_, filename, _, ok := runtime.Caller(0)
if !ok {
panic("cannot get current file path")
}
return filepath.Join(filepath.Dir(filename), "..", "..", "stdlibs")
}

// hashCodeBlock generates a SHA256 hash for the given code block.
func hashCodeBlock(c codeBlock) string {
h := sha256.New()
h.Write([]byte(c.content))
return hex.EncodeToString(h.Sum(nil))
}
Loading

0 comments on commit d15d3f1

Please sign in to comment.