diff --git a/pkg/elemental/elemental.go b/pkg/elemental/elemental.go index a4cd2dc8..f5f595d0 100644 --- a/pkg/elemental/elemental.go +++ b/pkg/elemental/elemental.go @@ -19,11 +19,12 @@ package elemental import ( "errors" "fmt" - v1 "github.com/kairos-io/kairos-agent/v2/pkg/types/v1" - "github.com/kairos-io/kairos-agent/v2/pkg/utils/fs" "path/filepath" "strings" + v1 "github.com/kairos-io/kairos-agent/v2/pkg/types/v1" + fsutils "github.com/kairos-io/kairos-agent/v2/pkg/utils/fs" + agentConfig "github.com/kairos-io/kairos-agent/v2/pkg/config" cnst "github.com/kairos-io/kairos-agent/v2/pkg/constants" "github.com/kairos-io/kairos-agent/v2/pkg/partitioner" @@ -560,12 +561,12 @@ func (e Elemental) SetDefaultGrubEntry(partMountPoint string, imgMountPoint stri func (e Elemental) FindKernelInitrd(rootDir string) (kernel string, initrd string, err error) { kernelNames := []string{"uImage", "Image", "zImage", "vmlinuz", "image"} initrdNames := []string{"initrd", "initramfs"} - kernel, err = utils.FindFileWithPrefix(e.config.Fs, filepath.Join(rootDir, "boot"), kernelNames...) + kernel, err = utils.FindFileWithPrefixRecursively(e.config.Fs, filepath.Join(rootDir, "boot"), kernelNames...) if err != nil { e.config.Logger.Errorf("No Kernel file found") return "", "", err } - initrd, err = utils.FindFileWithPrefix(e.config.Fs, filepath.Join(rootDir, "boot"), initrdNames...) + initrd, err = utils.FindFileWithPrefixRecursively(e.config.Fs, filepath.Join(rootDir, "boot"), initrdNames...) if err != nil { e.config.Logger.Errorf("No initrd file found") return "", "", err diff --git a/pkg/utils/common.go b/pkg/utils/common.go index f4d77902..8c4c0b8b 100644 --- a/pkg/utils/common.go +++ b/pkg/utils/common.go @@ -21,7 +21,6 @@ import ( "crypto/sha256" "errors" "fmt" - sdkTypes "github.com/kairos-io/kairos-sdk/types" "io" random "math/rand" "net/url" @@ -32,6 +31,8 @@ import ( "strings" "time" + sdkTypes "github.com/kairos-io/kairos-sdk/types" + "github.com/kairos-io/kairos-sdk/state" agentConfig "github.com/kairos-io/kairos-agent/v2/pkg/config" @@ -437,37 +438,15 @@ func ValidTaggedContainerReference(ref string) bool { return true } +func FindFileWithPrefixRecursively(fs v1.FS, path string, prefixes ...string) (string, error) { + return findFileWithPrefix(true, fs, filepath.Join(path, path), prefixes...) +} + // FindFileWithPrefix looks for a file in the given path matching one of the given // prefixes. Returns the found file path including the given path. It does not // check subfolders recusively func FindFileWithPrefix(fs v1.FS, path string, prefixes ...string) (string, error) { - files, err := fs.ReadDir(path) - if err != nil { - return "", err - } - for _, f := range files { - if f.IsDir() { - continue - } - for _, p := range prefixes { - if strings.HasPrefix(f.Name(), p) { - if f.Mode()&os.ModeSymlink == os.ModeSymlink { - found, err := fs.Readlink(filepath.Join(path, f.Name())) - if err == nil { - if !filepath.IsAbs(found) { - found = filepath.Join(path, found) - } - if exists, _ := fsutils.Exists(fs, found); exists { - return found, nil - } - } - } else { - return filepath.Join(path, f.Name()), nil - } - } - } - } - return "", fmt.Errorf("No file found with prefixes: %v", prefixes) + return findFileWithPrefix(false, fs, filepath.Join(path, path), prefixes...) } // CalcFileChecksum opens the given file and returns the sha256 checksum of it. @@ -595,3 +574,39 @@ func SystemdBootConfWriter(fs v1.FS, filePath string, conf map[string]string) er return writer.Flush() } + +func findFileWithPrefix(recursively bool, fs v1.FS, path string, prefixes ...string) (string, error) { + files, err := fs.ReadDir(path) + if err != nil { + return "", err + } + for _, f := range files { + if f.IsDir() { + if recursively { + f, err := findFileWithPrefix(recursively, fs, filepath.Join(path, f.Name()), prefixes...) + if err == nil { + return f, nil + } + } + continue + } + for _, p := range prefixes { + if strings.HasPrefix(f.Name(), p) { + if f.Mode()&os.ModeSymlink == os.ModeSymlink { + found, err := fs.Readlink(filepath.Join(path, f.Name())) + if err == nil { + if !filepath.IsAbs(found) { + found = filepath.Join(path, found) + } + if exists, _ := fsutils.Exists(fs, found); exists { + return found, nil + } + } + } else { + return filepath.Join(path, f.Name()), nil + } + } + } + } + return "", fmt.Errorf("No file found with prefixes: %v", prefixes) +}