Skip to content

Commit

Permalink
fix: Change default windows working directory to the "C:\ProgramData\…
Browse files Browse the repository at this point in the history
…Amazon\OpenJD" (#63)

Signed-off-by: Hongli Chen <[email protected]>
  • Loading branch information
Honglichenn authored Jan 31, 2024
1 parent 92f6f43 commit 36263d3
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/openjd/sessions/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from os import name as os_name
from os import stat as os_stat
from pathlib import Path
from tempfile import gettempdir, mkstemp
from tempfile import mkstemp
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, Union

Expand All @@ -38,7 +38,7 @@
from ._runner_step_script import StepScriptRunner
from ._session_user import SessionUser
from ._subprocess import LoggingSubprocess
from ._tempdir import TempDir
from ._tempdir import TempDir, custom_gettempdir
from ._types import (
ActionState,
EnvironmentIdentifier,
Expand Down Expand Up @@ -796,7 +796,7 @@ def _openjd_session_root_dir(self) -> Path:
if self._session_root_directory is not None:
return self._session_root_directory

tempdir = Path(gettempdir()) / "openjd"
tempdir = Path(custom_gettempdir(self._logger))

# Note: If this doesn't have group permissions, then we will be unable to access files
# under this directory if the default group of the current user is the group that
Expand Down Expand Up @@ -833,13 +833,18 @@ def _create_working_directory(self) -> TempDir:
)

# Raises: RuntimeError
return TempDir(dir=root_dir, prefix=self._session_id, user=self._user)
return TempDir(dir=root_dir, prefix=self._session_id, user=self._user, logger=self._logger)

def _create_files_directory(self) -> TempDir:
"""Creates the subdirectory of the working directory in which we'll materialize
any embedded files from the Job Template."""
# Raises: RuntimeError
return TempDir(dir=self.working_directory, prefix="embedded_files", user=self._user)
return TempDir(
dir=self.working_directory,
prefix="embedded_files",
user=self._user,
logger=self._logger,
)

def _materialize_path_mapping(
self, version: SchemaVersion, os_env: dict[str, Optional[str]], symtab: SymbolTable
Expand Down
37 changes: 36 additions & 1 deletion src/openjd/sessions/_tempdir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import stat
from logging import LoggerAdapter
from pathlib import Path
from shutil import chown, rmtree
from tempfile import gettempdir, mkdtemp
Expand All @@ -15,6 +16,37 @@
import ntsecuritycon


def custom_gettempdir(logger: Optional[LoggerAdapter] = None) -> str:
"""
Get a platform-specific temporary directory.
For Windows systems, this function returns a specific directory path,
'%PROGRAMDATA%\\Amazon\\'. If this directory does not exist, it will be created.
For non-Windows systems, it returns the system's default temporary directory.
Args:
logger (Optional[LoggerAdapter]): The logger to which all messages should be sent from this and the
subprocess.
Returns:
str: The path to the temporary directory specific to the operating system.
"""
if is_windows():
program_data_path = os.getenv("PROGRAMDATA")
if program_data_path is None:
program_data_path = r"C:\ProgramData"
if logger:
logger.warning(
f'"PROGRAMDATA" is not set. Set the root directory to the {program_data_path}'
)

temp_dir = os.path.join(program_data_path, "Amazon")
os.makedirs(temp_dir, exist_ok=True)
else:
temp_dir = gettempdir()
return os.path.join(temp_dir, "OpenJD")


class TempDir:
"""This class securely creates a temporary directory using the same rules as mkdtemp(),
but with the option of having the directory owned by a user other than this process' user.
Expand All @@ -37,6 +69,7 @@ def __init__(
dir: Optional[Path] = None,
prefix: Optional[str] = None,
user: Optional[SessionUser] = None,
logger: Optional[LoggerAdapter] = None,
):
"""
Arguments:
Expand All @@ -47,6 +80,8 @@ def __init__(
user (Optional[SessionUser]): A group that will own the created directory.
The group-write bit will be set on the directory if this option is supplied.
Defaults to this process' effective user/group.
logger (Optional[LoggerAdapter]): The logger to which all messages should be sent from this and the
subprocess.
Raises:
RuntimeError - If this process cannot create the temporary directory, or change the
Expand All @@ -59,7 +94,7 @@ def __init__(
raise ValueError("user must be a windows-user. Got %s", type(user))

if not dir:
dir = Path(gettempdir())
dir = Path(custom_gettempdir(logger))

dir = dir.resolve()
try:
Expand Down
2 changes: 1 addition & 1 deletion test/openjd/sessions/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_failed_directory_create(self, method: str, caplog: pytest.LogCaptureFix
@pytest.mark.usefixtures("caplog") # built-in fixture
def test_posix_permissions_warning(self, caplog: pytest.LogCaptureFixture) -> None:
# On POSIX systems, we check the sticky bit of the system /tmp dir
# If its not set, then we emit a security warning into the logs.
# If it is not set, then we emit a security warning into the logs.
# This tests that we do in fact emit that message when the sticky bit isn't set.
with patch("openjd.sessions._session.TempDir", MagicMock()):
with patch("openjd.sessions._session.os_name", "posix"):
Expand Down
24 changes: 21 additions & 3 deletions test/openjd/sessions/test_tempdir.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

import os
import shutil
import stat
import tempfile
from pathlib import Path
Expand All @@ -19,7 +20,7 @@
from unittest.mock import patch

from openjd.sessions import PosixSessionUser, WindowsSessionUser
from openjd.sessions._tempdir import TempDir
from openjd.sessions._tempdir import TempDir, custom_gettempdir

from .conftest import has_posix_disjoint_user, has_posix_target_user

Expand All @@ -28,7 +29,7 @@
class TestTempDirPosix:
def test_defaults(self) -> None:
# GIVEN
tmpdir = Path(tempfile.gettempdir()).resolve()
tmpdir = Path(os.path.join(tempfile.gettempdir(), "OpenJD")).resolve()

# WHEN
result = TempDir()
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_given_dir(self, tmp_path: Path) -> None:

def test_given_prefix(self) -> None:
# GIVEN
tmpdir = Path(tempfile.gettempdir())
tmpdir = Path(custom_gettempdir())
prefix = "testprefix"

# WHEN
Expand Down Expand Up @@ -220,6 +221,23 @@ def principal_has_no_permissions_on_object(self, object_path, principal_name):

return len(access_allowed_masks) == 0

@pytest.fixture
def clean_up_directory(self):
created_dirs = []
yield created_dirs
for dir_path in created_dirs:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)

def test_windows_temp_dir(self, monkeypatch, clean_up_directory):
monkeypatch.setenv("PROGRAMDATA", r"C:\ProgramDataForOpenJDTest")
expected_dir = r"C:\ProgramDataForOpenJDTest\Amazon\OpenJD"
clean_up_directory.append(expected_dir)
assert custom_gettempdir() == expected_dir
assert os.path.exists(
Path(expected_dir).parent
), r"Directory C:\ProgramDataForOpenJDTest\Amazon should be created."


@pytest.mark.xfail(
not has_posix_target_user() or not has_posix_disjoint_user(),
Expand Down

0 comments on commit 36263d3

Please sign in to comment.