diff --git a/windows/mkwinsyscall/mkwinsyscall.go b/windows/mkwinsyscall/mkwinsyscall.go index b080c5399..7fe4efa90 100644 --- a/windows/mkwinsyscall/mkwinsyscall.go +++ b/windows/mkwinsyscall/mkwinsyscall.go @@ -480,15 +480,14 @@ func newFn(s string) (*Fn, error) { return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") } s = trim(s[1:]) - a := strings.Split(s, ".") - switch len(a) { - case 1: - f.dllfuncname = a[0] - case 2: - f.dllname = a[0] - f.dllfuncname = a[1] - default: - return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") + if i := strings.LastIndex(s, "."); i >= 0 { + f.dllname = s[:i] + f.dllfuncname = s[i+1:] + } else { + f.dllfuncname = s + } + if f.dllfuncname == "" { + return nil, fmt.Errorf("function name is not specified in %q", s) } if n := f.dllfuncname; strings.HasSuffix(n, "?") { f.dllfuncname = n[:len(n)-1] @@ -505,7 +504,23 @@ func (f *Fn) DLLName() string { return f.dllname } -// DLLName returns DLL function name for function f. +// DLLVar returns a valid Go identifier that represents DLLName. +func (f *Fn) DLLVar() string { + id := strings.Map(func(r rune) rune { + switch r { + case '.', '-': + return '_' + default: + return r + } + }, f.DLLName()) + if !token.IsIdentifier(id) { + panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName())) + } + return id +} + +// DLLFuncName returns DLL function name for function f. func (f *Fn) DLLFuncName() string { if f.dllfuncname == "" { return f.Name @@ -650,6 +665,13 @@ func (f *Fn) HelperName() string { return "_" + f.Name } +// DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when +// naming a variable that refers to the DLL. +type DLL struct { + Name string + Var string +} + // Source files and functions. type Source struct { Funcs []*Fn @@ -699,17 +721,19 @@ func ParseFiles(fs []string) (*Source, error) { } // DLLs return dll names for a source set src. -func (src *Source) DLLs() []string { +func (src *Source) DLLs() []DLL { uniq := make(map[string]bool) - r := make([]string, 0) + r := make([]DLL, 0) for _, f := range src.Funcs { - name := f.DLLName() - if _, found := uniq[name]; !found { - uniq[name] = true - r = append(r, name) + id := f.DLLVar() + if _, found := uniq[id]; !found { + uniq[id] = true + r = append(r, DLL{f.DLLName(), id}) } } - sort.Strings(r) + sort.Slice(r, func(i, j int) bool { + return r[i].Var < r[j].Var + }) return r } @@ -936,10 +960,10 @@ var ( {{/* help functions */}} -{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{newlazydll .}} +{{define "dlls"}}{{range .DLLs}} mod{{.Var}} = {{newlazydll .Name}} {{end}}{{end}} -{{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") +{{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}") {{end}}{{end}} {{define "helperbody"}} diff --git a/windows/mkwinsyscall/mkwinsyscall_test.go b/windows/mkwinsyscall/mkwinsyscall_test.go new file mode 100644 index 000000000..cabbf4031 --- /dev/null +++ b/windows/mkwinsyscall/mkwinsyscall_test.go @@ -0,0 +1,50 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "go/format" + "os" + "path/filepath" + "testing" +) + +func TestDLLFilenameEscaping(t *testing.T) { + tests := []struct { + name string + filename string + }{ + {"no escaping necessary", "kernel32"}, + {"escape period", "windows.networking"}, + {"escape dash", "api-ms-win-wsl-api-l1-1-0"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Write a made-up syscall into a temp file for testing. + const prefix = "package windows\n//sys Example() = " + const suffix = ".Example" + name := filepath.Join(t.TempDir(), "syscall.go") + if err := os.WriteFile(name, []byte(prefix+tt.filename+suffix), 0666); err != nil { + t.Fatal(err) + } + + // Ensure parsing, generating, and formatting run without errors. + // This is good enough to show that escaping is working. + src, err := ParseFiles([]string{name}) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + if err := src.Generate(&buf); err != nil { + t.Fatal(err) + } + if _, err := format.Source(buf.Bytes()); err != nil { + t.Log(buf.String()) + t.Fatal(err) + } + }) + } +}