Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify handling missing GatewayConfiguration #1724

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ class GatewayModel(BaseModel):
)
name: Mapped[str] = mapped_column(String(100))
region: Mapped[str] = mapped_column(String(100))
wildcard_domain: Mapped[str] = mapped_column(String(100), nullable=True)
wildcard_domain: Mapped[Optional[str]] = mapped_column(String(100))
# `configuration` is optional for compatibility with pre-0.18.2 gateways.
# Use `get_gateway_configuration` to construct `configuration` for old gateways.
configuration: Mapped[Optional[str]] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
status: Mapped[GatewayStatus] = mapped_column(Enum(GatewayStatus))
Expand Down Expand Up @@ -374,6 +376,8 @@ class GatewayComputeModel(BaseModel):
instance_id: Mapped[str] = mapped_column(String(100))
ip_address: Mapped[str] = mapped_column(String(100))
hostname: Mapped[Optional[str]] = mapped_column(String(100))
# `configuration` is optional for compatibility with pre-0.18.2 gateways.
# Use `get_gateway_compute_configuration` to construct `configuration` for old gateways.
configuration: Mapped[Optional[str]] = mapped_column(Text)
backend_data: Mapped[Optional[str]] = mapped_column(Text)
region: Mapped[str] = mapped_column(String(100))
Expand Down
23 changes: 6 additions & 17 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,23 +372,16 @@ async def register_service(session: AsyncSession, run_model: RunModel):
if gateway.status != GatewayStatus.RUNNING:
raise ServerClientError("Gateway status is not running")

gateway_configuration = None
if gateway.configuration is not None:
gateway_configuration = GatewayConfiguration.__response__.parse_raw(gateway.configuration)

gateway_configuration = get_gateway_configuration(gateway)
service_https = _get_service_https(run_spec, gateway_configuration)
service_protocol = "https" if service_https else "http"

if (
service_https
and gateway_configuration is not None
and gateway_configuration.certificate is None
):
if service_https and gateway_configuration.certificate is None:
raise ServerClientError(
"Cannot run HTTPS service on gateway with no SSL cerfificates configured"
"Cannot run HTTPS service on gateway with no SSL certificates configured"
)

gateway_https = _get_gateway_https(run_spec, gateway_configuration)
gateway_https = _get_gateway_https(gateway_configuration)
gateway_protocol = "https" if gateway_https else "http"

wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
Expand Down Expand Up @@ -729,19 +722,15 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration):
raise ServerClientError("acm certificate type is supported for aws backend only")


def _get_service_https(run_spec: RunSpec, configuration: Optional[GatewayConfiguration]) -> bool:
def _get_service_https(run_spec: RunSpec, configuration: GatewayConfiguration) -> bool:
if not run_spec.configuration.https:
return False
if configuration is None:
return True
if configuration.certificate is not None and configuration.certificate.type == "acm":
return False
return True


def _get_gateway_https(run_spec: RunSpec, configuration: Optional[GatewayConfiguration]) -> bool:
if configuration is None:
return True
def _get_gateway_https(configuration: GatewayConfiguration) -> bool:
if configuration.certificate is not None and configuration.certificate.type == "acm":
return False
if configuration.certificate is not None and configuration.certificate.type == "lets-encrypt":
Expand Down