Skip to content

Commit

Permalink
fix!: set AWS_ENDPOINT_URL_DEADLINE after installing service model (#96)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: The ServiceModel class no longer has the
"install_command" or "file_path" attributes. Also,
the WORKER_REGION configuration option has been deprecated.
Use REGION instead.

Signed-off-by: Jericho Tolentino <[email protected]>
  • Loading branch information
jericht authored Jun 3, 2024
1 parent 4076a4a commit 6bc4d8f
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 81 deletions.
9 changes: 3 additions & 6 deletions src/deadline_test_fixtures/deadline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
from dataclasses import dataclass, field, InitVar, replace
from typing import Any, ClassVar, Optional, cast

from .client import DeadlineClient
from ..models import (
PipInstall,
PosixSessionUser,
ServiceModel,
)
from ..util import call_api, wait_for

Expand Down Expand Up @@ -54,9 +52,9 @@ def configure_worker_command(*, config: DeadlineWorkerConfiguration) -> str: #
# fmt: on
]

if config.service_model:
if config.service_model_path:
cmds.append(
f"runuser -l {config.user} -s /bin/bash -c '{config.service_model.install_command}'"
f"runuser -l {config.user} -s /bin/bash -c 'aws configure add-model --service-model file://{config.service_model_path}'"
)

return " && ".join(cmds)
Expand Down Expand Up @@ -128,7 +126,7 @@ class DeadlineWorkerConfiguration:
)
start_service: bool = False
no_install_service: bool = False
service_model: ServiceModel | None = None
service_model_path: str | None = None
file_mappings: list[tuple[str, str]] | None = None
"""Mapping of files to copy from host environment to worker environment"""
pre_install_commands: list[str] | None = None
Expand All @@ -146,7 +144,6 @@ class EC2InstanceWorker(DeadlineWorker):
s3_client: botocore.client.BaseClient
ec2_client: botocore.client.BaseClient
ssm_client: botocore.client.BaseClient
deadline_client: DeadlineClient
configuration: DeadlineWorkerConfiguration

instance_id: Optional[str] = field(init=False, default=None)
Expand Down
7 changes: 6 additions & 1 deletion src/deadline_test_fixtures/example_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ export CODEARTIFACT_REGION

# --- OPTIONAL --- #

# The AWS region to use
# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
export REGION

# Extra local path for boto to look for AWS models in
# Does not apply to the worker
export AWS_DATA_PATH
Expand All @@ -38,9 +42,10 @@ export AWS_DATA_PATH
# Default is to pip install the latest "deadline-cloud-worker-agent" package
export WORKER_AGENT_WHL_PATH

# DEPRECATED: Use REGION instead
# The AWS region to configure the worker for
# Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
export WORKER_REGION
# export WORKER_REGION

# The POSIX user to configure the worker for
# Defaults to "deadline-worker"
Expand Down
102 changes: 65 additions & 37 deletions src/deadline_test_fixtures/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import botocore.loaders
import boto3
import glob
import json
import logging
import os
import pathlib
import posixpath
import pytest
import tempfile
Expand Down Expand Up @@ -141,6 +143,11 @@ def codeartifact() -> CodeArtifactRepositoryInfo:
)


@pytest.fixture(scope="session")
def region() -> str:
return os.getenv("REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2"))


@pytest.fixture(scope="session")
def service_model() -> Generator[ServiceModel, None, None]:
service_model_s3_uri = os.getenv("DEADLINE_SERVICE_MODEL_S3_URI")
Expand Down Expand Up @@ -168,15 +175,22 @@ def service_model() -> Generator[ServiceModel, None, None]:
if not local_model_path:
local_model_path = _find_latest_service_model_file("deadline")
LOG.info(f"Using service model at: {local_model_path}")
yield ServiceModel.from_json_file(local_model_path)
if local_model_path.endswith(".json"):
yield ServiceModel.from_json_file(local_model_path)
elif local_model_path.endswith(".json.gz"):
yield ServiceModel.from_json_gz_file(local_model_path)
else:
raise RuntimeError(
f"Unsupported service model file format (must be .json or .json.gz): {local_model_path}"
)


@pytest.fixture(scope="session")
def install_service_model(service_model: ServiceModel) -> Generator[str, None, None]:
def install_service_model(service_model: ServiceModel, region: str) -> Generator[str, None, None]:
LOG.info("Installing service model and configuring boto to use it for API calls")
with service_model.install() as model_path:
LOG.info(f"Installed service model to {model_path}")
yield model_path
with service_model.install(region) as service_model_install:
LOG.info(f"Installed service model to {service_model_install}")
yield service_model_install


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -365,13 +379,12 @@ def worker_config(
deadline_resources: DeadlineResources,
codeartifact: CodeArtifactRepositoryInfo,
service_model: ServiceModel,
) -> DeadlineWorkerConfiguration:
region: str,
) -> Generator[DeadlineWorkerConfiguration, None, None]:
"""
Builds the configuration for a DeadlineWorker.
Environment Variables:
WORKER_REGION: The AWS region to configure the worker for
Falls back to AWS_DEFAULT_REGION, then defaults to us-west-2
WORKER_POSIX_USER: The POSIX user to configure the worker for
Defaults to "deadline-worker"
WORKER_POSIX_SHARED_GROUP: The shared POSIX group to configure the worker user and job user with
Expand All @@ -387,6 +400,12 @@ def worker_config(
"""
file_mappings: list[tuple[str, str]] = []

# Deprecated environment variable
if os.getenv("WORKER_REGION") is not None:
raise Exception(
"The environment variable WORKER_REGION is no longer supported. Please use REGION instead."
)

# Prepare the Worker agent Python package
worker_agent_whl_path = os.getenv("WORKER_AGENT_WHL_PATH")
if worker_agent_whl_path:
Expand All @@ -410,35 +429,36 @@ def worker_config(
LOG.info(f"Using Worker agent package {worker_agent_requirement_specifier}")

# Path map the service model
dst_path = posixpath.join("/tmp", "deadline-cloud-service-model.json")
path_mapped_model = ServiceModel(
file_path=dst_path,
api_version=service_model.api_version,
service_name=service_model.service_name,
)
LOG.info(f"The service model will be copied to {dst_path} on the Worker environment")
file_mappings.append((service_model.file_path, dst_path))

return DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
region=os.getenv("WORKER_REGION", os.getenv("AWS_DEFAULT_REGION", "us-west-2")),
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
allow_shutdown=True,
worker_agent_install=PipInstall(
requirement_specifiers=[worker_agent_requirement_specifier],
codeartifact=codeartifact,
),
service_model=path_mapped_model,
file_mappings=file_mappings or None,
)
with tempfile.TemporaryDirectory() as tmpdir:
src_path = pathlib.Path(tmpdir) / f"{service_model.service_name}-service-2.json"

LOG.info(f"Staging service model to {src_path} for uploading to S3")
with src_path.open(mode="w") as f:
json.dump(service_model.model, f)

dst_path = posixpath.join("/tmp", src_path.name)
LOG.info(f"The service model will be copied to {dst_path} on the Worker environment")
file_mappings.append((str(src_path), dst_path))

yield DeadlineWorkerConfiguration(
farm_id=deadline_resources.farm.id,
fleet_id=deadline_resources.fleet.id,
region=region,
user=os.getenv("WORKER_POSIX_USER", "deadline-worker"),
group=os.getenv("WORKER_POSIX_SHARED_GROUP", "shared-group"),
allow_shutdown=True,
worker_agent_install=PipInstall(
requirement_specifiers=[worker_agent_requirement_specifier],
codeartifact=codeartifact,
),
service_model_path=dst_path,
file_mappings=file_mappings or None,
)


@pytest.fixture(scope="session")
def worker(
request: pytest.FixtureRequest,
deadline_client: DeadlineClient,
worker_config: DeadlineWorkerConfiguration,
) -> Generator[DeadlineWorker, None, None]:
"""
Expand Down Expand Up @@ -484,7 +504,6 @@ def worker(

worker = EC2InstanceWorker(
ec2_client=ec2_client,
deadline_client=deadline_client,
s3_client=s3_client,
bootstrap_bucket_name=bootstrap_resources.bootstrap_bucket_name,
ssm_client=ssm_client,
Expand All @@ -496,6 +515,11 @@ def worker(
)

def stop_worker():
if request.session.testsfailed > 0:
if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() == "true":
LOG.info("KEEP_WORKER_AFTER_FAILURE is set, not stopping worker")
return

try:
worker.stop()
except Exception as e:
Expand All @@ -509,9 +533,8 @@ def stop_worker():
worker.start()
except Exception as e:
LOG.exception(f"Failed to start worker: {e}")
if os.getenv("KEEP_WORKER_AFTER_FAILURE", "false").lower() != "true":
LOG.info("Stopping worker because it failed to start")
stop_worker()
LOG.info("Stopping worker because it failed to start")
stop_worker()
raise

yield worker
Expand Down Expand Up @@ -550,4 +573,9 @@ def _find_latest_service_model_file(service_name: str) -> str:
service_name, loader.determine_latest_version(service_name, "service-2"), "service-2"
)
_, service_model_path = loader.load_data_with_path(full_name)
return f"{service_model_path}.json"
service_model_files = glob.glob(f"{service_model_path}.*")
if len(service_model_files) > 1:
raise RuntimeError(
f"Expected exactly one file to match glob '{service_model_path}.*, but got: {service_model_files}"
)
return service_model_files[0]
61 changes: 37 additions & 24 deletions src/deadline_test_fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from __future__ import annotations

import gzip
import json
import os
import re
Expand All @@ -10,7 +10,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Literal
from typing import Any, Generator, Literal


@dataclass(frozen=True)
Expand Down Expand Up @@ -87,53 +87,66 @@ def path_mappings(self) -> list[tuple[str, str]]:

@dataclass(frozen=True)
class ServiceModel:
file_path: str
api_version: str
service_name: str
model: dict[str, Any]

@staticmethod
def from_json_file(path: str) -> ServiceModel:
with open(path) as f:
model = json.load(f)
return ServiceModel(
file_path=path,
api_version=model["metadata"]["apiVersion"],
service_name=model["metadata"]["serviceId"],
)
return ServiceModel(model=model)

@staticmethod
def from_json_gz_file(path: str) -> ServiceModel:
with gzip.open(path, mode="r") as f:
model = json.load(f)
return ServiceModel(model=model)

@contextmanager
def install(self) -> Generator[str, None, None]:
def install(self, region: str) -> Generator[str, None, None]:
"""
Copies the model to a temporary directory in the structure expected by boto
and sets the AWS_DATA_PATH environment variable to it
"""
try:
old_aws_data_path = os.environ.get("AWS_DATA_PATH")
src_file = Path(self.file_path)
old_endpoint_url = os.environ.get("AWS_ENDPOINT_URL_DEADLINE")

# Set endpoint URL
os.environ["AWS_ENDPOINT_URL_DEADLINE"] = self.endpoint_url_fmt_str.format(region)

# Install service model
with tempfile.TemporaryDirectory() as tmpdir:
json_path = Path(tmpdir) / self.service_name / self.api_version / "service-2.json"
json_path.parent.mkdir(parents=True)
json_path.write_text(src_file.read_text())
json_path.write_text(json.dumps(self.model))
os.environ["AWS_DATA_PATH"] = tmpdir
yield str(tmpdir)
finally:
if old_aws_data_path:
os.environ["AWS_DATA_PATH"] = old_aws_data_path
else:
del os.environ["AWS_DATA_PATH"]
if old_endpoint_url:
os.environ["AWS_ENDPOINT_URL_DEADLINE"] = old_endpoint_url
else:
del os.environ["AWS_ENDPOINT_URL_DEADLINE"]

@property
def install_command(self) -> str:
return " ".join(
[
"aws",
"configure",
"add-model",
"--service-model",
f"file://{self.file_path}",
*(["--service-name", self.service_name] if self.service_name else []),
]
)
def api_version(self) -> str:
return self.model["metadata"]["apiVersion"]

@property
def service_name(self) -> str:
return self.model["metadata"]["serviceId"]

@property
def endpoint_prefix(self) -> str:
return self.model["metadata"]["endpointPrefix"]

@property
def endpoint_url_fmt_str(self) -> str:
"""Format string for the service endpoint URL with one format field for the region"""
return f"https://{self.endpoint_prefix}.{{}}.amazonaws.com"


@dataclass(frozen=True)
Expand Down
Loading

0 comments on commit 6bc4d8f

Please sign in to comment.