diff --git a/procfs_linux.go b/procfs_linux.go index ef0f44a..fdb13d0 100644 --- a/procfs_linux.go +++ b/procfs_linux.go @@ -332,10 +332,10 @@ func hasStatxMountId() bool { return hasStatxMountIdBool } -func checkSymlinkOvermount(procRoot *os.File, dir *os.File, path string) error { +func getMountId(dir *os.File, path string) (uint64, error) { // If we don't have statx(STATX_MNT_ID*) support, we can't do anything. if !hasStatxMountId() { - return nil + return 0, nil } var ( @@ -345,31 +345,29 @@ func checkSymlinkOvermount(procRoot *os.File, dir *os.File, path string) error { wantStxMask uint32 = unix.STATX_MNT_ID_UNIQUE | unix.STATX_MNT_ID ) - // Get the mntId of our procfs handle. - err := unix.Statx(int(procRoot.Fd()), "", unix.AT_EMPTY_PATH, int(wantStxMask), &stx) - if err != nil { - return &os.PathError{Op: "statx", Path: dir.Name(), Err: err} - } + err := unix.Statx(int(dir.Fd()), path, unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW, int(wantStxMask), &stx) if stx.Mask&wantStxMask == 0 { // It's not a kernel limitation, for some reason we couldn't get a // mount ID. Assume it's some kind of attack. - return fmt.Errorf("%w: could not get mnt id of dir %s", errUnsafeProcfs, dir.Name()) + err = fmt.Errorf("%w: could not get mount id", errUnsafeProcfs) } - expectedMountId := stx.Mnt_id + if err != nil { + return 0, &os.PathError{Op: "statx(STATX_MNT_ID_...)", Path: dir.Name() + "/" + path, Err: err} + } + return stx.Mnt_id, nil +} - // Get the mntId of the target symlink. - stx = unix.Statx_t{} - err = unix.Statx(int(dir.Fd()), path, unix.AT_SYMLINK_NOFOLLOW|unix.AT_EMPTY_PATH, int(wantStxMask), &stx) +func checkSymlinkOvermount(procRoot *os.File, dir *os.File, path string) error { + // Get the mntId of our procfs handle. + expectedMountId, err := getMountId(procRoot, "") if err != nil { - return &os.PathError{Op: "statx", Path: dir.Name() + "/" + path, Err: err} + return err } - if stx.Mask&wantStxMask == 0 { - // It's not a kernel limitation, for some reason we couldn't get a - // mount ID. Assume it's some kind of attack. - return fmt.Errorf("%w: could not get mnt id of symlink %s", errUnsafeProcfs, path) + // Get the mntId of the target magic-link. + gotMountId, err := getMountId(dir, path) + if err != nil { + return err } - gotMountId := stx.Mnt_id - // As long as the directory mount is alive, even with wrapping mount IDs, // we would expect to see a different mount ID here. (Of course, if we're // using unsafeHostProcRoot() then an attaker could change this after we