diff --git a/cmd/quadlet/main.go b/cmd/quadlet/main.go index 7037d4be73..bd19875599 100644 --- a/cmd/quadlet/main.go +++ b/cmd/quadlet/main.go @@ -56,12 +56,6 @@ var ( } ) -var ( - unitDirAdminUser string - resolvedUnitDirAdminUser string - systemUserDirLevel int -) - // We log directly to /dev/kmsg, because that is the only way to get information out // of the generator into the system logs. func logToKmsg(s string) bool { @@ -115,57 +109,86 @@ func Debugf(format string, a ...interface{}) { // For user generators these can live in $XDG_RUNTIME_DIR/containers/systemd, /etc/containers/systemd/users, /etc/containers/systemd/users/$UID, and $XDG_CONFIG_HOME/containers/systemd func getUnitDirs(rootless bool) []string { // Allow overriding source dir, this is mainly for the CI tests + if varExist, dirs := getDirsFromEnv(); varExist { + return dirs + } + + resolvedUnitDirAdminUser := resolveUnitDirAdminUser() + userLevelFilter := getUserLevelFilter(resolvedUnitDirAdminUser) + + if rootless { + systemUserDirLevel := len(strings.Split(resolvedUnitDirAdminUser, string(os.PathSeparator))) + nonNumericFilter := getNonNumericFilter(resolvedUnitDirAdminUser, systemUserDirLevel) + return getRootlessDirs(nonNumericFilter, userLevelFilter) + } + + return getRootDirs(userLevelFilter) +} + +func getDirsFromEnv() (bool, []string) { unitDirsEnv := os.Getenv("QUADLET_UNIT_DIRS") - dirs := make([]string, 0) + if len(unitDirsEnv) == 0 { + return false, nil + } - unitDirAdminUser = filepath.Join(quadlet.UnitDirAdmin, "users") - var err error - if resolvedUnitDirAdminUser, err = filepath.EvalSymlinks(unitDirAdminUser); err != nil { - if !errors.Is(err, fs.ErrNotExist) { - Debugf("Error occurred resolving path %q: %s", unitDirAdminUser, err) + dirs := make([]string, 0) + for _, eachUnitDir := range strings.Split(unitDirsEnv, ":") { + if !filepath.IsAbs(eachUnitDir) { + Logf("%s not a valid file path", eachUnitDir) + return true, nil } - resolvedUnitDirAdminUser = unitDirAdminUser + dirs = appendSubPaths(dirs, eachUnitDir, false, nil) } - systemUserDirLevel = len(strings.Split(resolvedUnitDirAdminUser, string(os.PathSeparator))) + return true, dirs +} - if len(unitDirsEnv) > 0 { - for _, eachUnitDir := range strings.Split(unitDirsEnv, ":") { - if !filepath.IsAbs(eachUnitDir) { - Logf("%s not a valid file path", eachUnitDir) - return nil - } - dirs = appendSubPaths(dirs, eachUnitDir, false, nil) - } - return dirs +func getRootlessDirs(nonNumericFilter, userLevelFilter func(string, bool) bool) []string { + dirs := make([]string, 0) + + runtimeDir, found := os.LookupEnv("XDG_RUNTIME_DIR") + if found { + dirs = appendSubPaths(dirs, path.Join(runtimeDir, "containers/systemd"), false, nil) } - if rootless { - runtimeDir, found := os.LookupEnv("XDG_RUNTIME_DIR") - if found { - dirs = appendSubPaths(dirs, path.Join(runtimeDir, "containers/systemd"), false, nil) - } + configDir, err := os.UserConfigDir() + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: %v", err) + return nil + } + dirs = appendSubPaths(dirs, path.Join(configDir, "containers/systemd"), false, nil) - configDir, err := os.UserConfigDir() - if err != nil { - fmt.Fprintf(os.Stderr, "Warning: %v", err) - return nil - } - dirs = appendSubPaths(dirs, path.Join(configDir, "containers/systemd"), false, nil) - u, err := user.Current() - if err == nil { - dirs = appendSubPaths(dirs, filepath.Join(quadlet.UnitDirAdmin, "users"), true, nonNumericFilter) - dirs = appendSubPaths(dirs, filepath.Join(quadlet.UnitDirAdmin, "users", u.Uid), true, userLevelFilter) - } else { - fmt.Fprintf(os.Stderr, "Warning: %v", err) - } - return append(dirs, filepath.Join(quadlet.UnitDirAdmin, "users")) + u, err := user.Current() + if err == nil { + dirs = appendSubPaths(dirs, filepath.Join(quadlet.UnitDirAdmin, "users"), true, nonNumericFilter) + dirs = appendSubPaths(dirs, filepath.Join(quadlet.UnitDirAdmin, "users", u.Uid), true, userLevelFilter) + } else { + fmt.Fprintf(os.Stderr, "Warning: %v", err) } + return append(dirs, filepath.Join(quadlet.UnitDirAdmin, "users")) +} + +func getRootDirs(userLevelFilter func(string, bool) bool) []string { + dirs := make([]string, 0) + dirs = appendSubPaths(dirs, quadlet.UnitDirTemp, false, userLevelFilter) dirs = appendSubPaths(dirs, quadlet.UnitDirAdmin, false, userLevelFilter) return appendSubPaths(dirs, quadlet.UnitDirDistro, false, nil) } +func resolveUnitDirAdminUser() string { + unitDirAdminUser := filepath.Join(quadlet.UnitDirAdmin, "users") + var err error + var resolvedUnitDirAdminUser string + if resolvedUnitDirAdminUser, err = filepath.EvalSymlinks(unitDirAdminUser); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + Debugf("Error occurred resolving path %q: %s", unitDirAdminUser, err) + } + resolvedUnitDirAdminUser = unitDirAdminUser + } + return resolvedUnitDirAdminUser +} + func appendSubPaths(dirs []string, path string, isUserFlag bool, filterPtr func(string, bool) bool) []string { resolvedPath, err := filepath.EvalSymlinks(path) if err != nil { @@ -197,33 +220,37 @@ func appendSubPaths(dirs []string, path string, isUserFlag bool, filterPtr func( return dirs } -func nonNumericFilter(_path string, isUserFlag bool) bool { - // when running in rootless, recursive walk directories that are non numeric - // ignore sub dirs under the `users` directory which correspond to a user id - if strings.HasPrefix(_path, resolvedUnitDirAdminUser) { - listDirUserPathLevels := strings.Split(_path, string(os.PathSeparator)) - if len(listDirUserPathLevels) > systemUserDirLevel { - if !(regexp.MustCompile(`^[0-9]*$`).MatchString(listDirUserPathLevels[systemUserDirLevel])) { - return true +func getNonNumericFilter(resolvedUnitDirAdminUser string, systemUserDirLevel int) func(string, bool) bool { + return func(path string, isUserFlag bool) bool { + // when running in rootless, recursive walk directories that are non numeric + // ignore sub dirs under the `users` directory which correspond to a user id + if strings.HasPrefix(path, resolvedUnitDirAdminUser) { + listDirUserPathLevels := strings.Split(path, string(os.PathSeparator)) + if len(listDirUserPathLevels) > systemUserDirLevel { + if !(regexp.MustCompile(`^[0-9]*$`).MatchString(listDirUserPathLevels[systemUserDirLevel])) { + return true + } } + } else { + return true } - } else { - return true + return false } - return false } -func userLevelFilter(_path string, isUserFlag bool) bool { - // if quadlet generator is run rootless, do not recurse other user sub dirs - // if quadlet generator is run as root, ignore users sub dirs - if strings.HasPrefix(_path, resolvedUnitDirAdminUser) { - if isUserFlag { +func getUserLevelFilter(resolvedUnitDirAdminUser string) func(string, bool) bool { + return func(_path string, isUserFlag bool) bool { + // if quadlet generator is run rootless, do not recurse other user sub dirs + // if quadlet generator is run as root, ignore users sub dirs + if strings.HasPrefix(_path, resolvedUnitDirAdminUser) { + if isUserFlag { + return true + } + } else { return true } - } else { - return true + return false } - return false } func isExtSupported(filename string) bool { diff --git a/cmd/quadlet/main_test.go b/cmd/quadlet/main_test.go index 6ee7bebf4e..74fec0a022 100644 --- a/cmd/quadlet/main_test.go +++ b/cmd/quadlet/main_test.go @@ -10,6 +10,7 @@ import ( "path" "path/filepath" "strconv" + "strings" "syscall" "testing" @@ -60,6 +61,9 @@ func TestUnitDirs(t *testing.T) { if os.Getenv("_UNSHARED") != "true" { unitDirs := getUnitDirs(false) + + resolvedUnitDirAdminUser := resolveUnitDirAdminUser() + userLevelFilter := getUserLevelFilter(resolvedUnitDirAdminUser) rootDirs := []string{} rootDirs = appendSubPaths(rootDirs, quadlet.UnitDirTemp, false, userLevelFilter) rootDirs = appendSubPaths(rootDirs, quadlet.UnitDirAdmin, false, userLevelFilter) @@ -71,6 +75,9 @@ func TestUnitDirs(t *testing.T) { rootlessDirs := []string{} + systemUserDirLevel := len(strings.Split(resolvedUnitDirAdminUser, string(os.PathSeparator))) + nonNumericFilter := getNonNumericFilter(resolvedUnitDirAdminUser, systemUserDirLevel) + runtimeDir, found := os.LookupEnv("XDG_RUNTIME_DIR") if found { rootlessDirs = appendSubPaths(rootlessDirs, path.Join(runtimeDir, "containers/systemd"), false, nil)