Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow running services without https #1217

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions gateway/src/dstack/gateway/core/nginx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class SiteConfig(BaseModel):
type: str
domain: str
https: bool = True

def render(self) -> str:
template = importlib.resources.read_text(
Expand Down Expand Up @@ -73,12 +74,15 @@ async def set_acme_settings(
server=server, eab_kid=eab_kid, eab_hmac_key=eab_hmac_key
)

async def register_service(self, project: str, service_id: str, domain: str, auth: bool):
async def register_service(
self, project: str, service_id: str, domain: str, https: bool, auth: bool
):
config_name = self.get_config_name(domain)
conf = ServiceConfig(
project=project,
service_id=service_id,
domain=domain,
https=https,
auth=auth,
)

Expand All @@ -88,16 +92,18 @@ async def register_service(self, project: str, service_id: str, domain: str, aut

logger.debug("Registering service domain %s", domain)

await run_async(self.run_certbot, domain)
if https:
await run_async(self.run_certbot, domain)
await run_async(self.write_conf, conf.render(), config_name)
self.configs[config_name] = conf

logger.info("Service domain %s is registered now", domain)

async def register_entrypoint(self, domain: str, prefix: str):
async def register_entrypoint(self, domain: str, prefix: str, https: bool):
config_name = self.get_config_name(domain)
conf = EntrypointConfig(
domain=domain,
https=https,
proxy_path=prefix,
)

Expand All @@ -107,7 +113,8 @@ async def register_entrypoint(self, domain: str, prefix: str):

logger.debug("Registering entrypoint domain %s", domain)

await run_async(self.run_certbot, domain)
if https:
await run_async(self.run_certbot, domain)
await run_async(self.write_conf, conf.render(), config_name)
self.configs[config_name] = conf

Expand Down
20 changes: 14 additions & 6 deletions gateway/src/dstack/gateway/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Replica(BaseModel):
class Service(BaseModel):
id: str
domain: str
https: bool = True
auth: bool
options: dict
replicas: List[Replica] = []
Expand All @@ -56,6 +57,7 @@ class Store(PersistentModel):
projects: DefaultDict[str, Set[str]] = defaultdict(set)
entrypoints: Dict[str, Tuple[str, str]] = {}
nginx: Nginx = Field(default_factory=Nginx)
gateway_https: Optional[bool] = True
_lock: Lock = Lock()
_subscribers: List["StoreSubscriber"] = []
_ssh_keys_dir = PrivateAttr(
Expand Down Expand Up @@ -103,6 +105,7 @@ async def register_service(self, project: str, service: Service, ssh_private_key
project,
service.id,
service.domain,
service.https,
service.auth,
)
stack.push_async_callback(
Expand Down Expand Up @@ -267,19 +270,24 @@ async def unregister_replica(self, project: str, service_id: str, replica_id: st
service.domain,
)

async def register_entrypoint(self, project: str, domain: str, module: str):
async def register_entrypoint(self, project: str, domain: str, https: bool, module: str):
async with self._lock:
if domain in self.entrypoints:
if self.entrypoints[domain] == (project, module):
if self.entrypoints[domain] == (project, module) and self.gateway_https == https:
return
raise GatewayError(
f"Domain {domain} is already registered as {self.entrypoints[domain]}"
)
# If the gateway's https settings changed, re-register the endpoint.
elif self.entrypoints[domain] == (project, module) and self.gateway_https != https:
await self.nginx.unregister_domain(domain)
else:
raise GatewayError(
f"Domain {domain} is already registered as {self.entrypoints[domain]}"
)

logger.debug("%s: registering entrypoint %s for module %s", project, domain, module)

await self.nginx.register_entrypoint(domain, f"/api/{module}/{project}")
await self.nginx.register_entrypoint(domain, f"/api/{module}/{project}", https)
self.entrypoints[domain] = (project, module)
self.gateway_https = https

logger.info("%s: entrypoint %s is now registered", project, domain)

Expand Down
3 changes: 2 additions & 1 deletion gateway/src/dstack/gateway/registry/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def post_register_service(
Service(
id=body.run_id,
domain=body.domain.lower(),
https=body.https,
auth=body.auth,
options=body.options,
),
Expand Down Expand Up @@ -73,5 +74,5 @@ async def post_register_entrypoint(
body: RegisterEntrypointRequest,
store: Annotated[Store, Depends(get_store)],
) -> OkResponse:
await store.register_entrypoint(project.lower(), body.domain.lower(), body.module)
await store.register_entrypoint(project.lower(), body.domain.lower(), body.https, body.module)
return OkResponse()
2 changes: 2 additions & 0 deletions gateway/src/dstack/gateway/registry/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class RegisterServiceRequest(BaseModel):
run_id: str
domain: str
https: bool = True
auth: bool = True
options: dict = {}
ssh_private_key: str
Expand All @@ -23,3 +24,4 @@ class RegisterReplicaRequest(BaseModel):
class RegisterEntrypointRequest(BaseModel):
module: Literal["openai"]
domain: str
https: bool = True
2 changes: 2 additions & 0 deletions gateway/src/dstack/gateway/resources/nginx/entrypoint.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ server {
proxy_read_timeout 300s;
}
listen 80;
{% if https %}
listen 443 ssl;
ssl_certificate /etc/letsencrypt/live/{{ domain }}/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/{{ domain }}/privkey.pem;
Expand All @@ -22,4 +23,5 @@ server {
if ($force_https) {
return 301 https://$host$request_uri;
}
{% endif %}
}
2 changes: 2 additions & 0 deletions gateway/src/dstack/gateway/resources/nginx/service.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ server {
{% endif %}

listen 80;
{% if https %}
listen 443 ssl;
ssl_certificate /etc/letsencrypt/live/{{ domain }}/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/{{ domain }}/privkey.pem;
Expand All @@ -72,4 +73,5 @@ server {
if ($force_https) {
return 301 https://$host$request_uri;
}
{% endif %}
}
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class ServiceConfigurationParams(CoreModel):
Optional[AnyModel],
Field(description="Mapping of the model for the OpenAI-compatible endpoint"),
] = None
https: Annotated[bool, Field(description="Enable HTTPS")] = True
auth: Annotated[bool, Field(description="Enable the authorization")] = True
replicas: Annotated[
Union[conint(ge=1), constr(regex=r"^[0-9]+..[1-9][0-9]*$"), Range[int]],
Expand Down Expand Up @@ -300,6 +301,7 @@ class ServiceConfiguration(
home_dir (str): The absolute path to the home directory inside the container. Defaults to `/root`.
resources (Optional[ResourcesSpec]): The requirements to run the configuration.
model (Optional[ModelMapping]): Mapping of the model for the OpenAI-compatible endpoint.
https (bool): Enable HTTPS. Defaults to `True`.
auth (bool): Enable the authorization. Defaults to `True`.
replicas Range[int]: The range of the number of replicas. Defaults to `1`.
scaling: Optional[ScalingSpec]: The auto-scaling configuration.
Expand Down
14 changes: 11 additions & 3 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,13 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
async def register_service(session: AsyncSession, run_model: RunModel):
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)

service_https = run_spec.configuration.https
service_protocol = "https" if service_https else "http"

# Currently, gateway endpoint is always https
gateway_https = True
gateway_protocol = "https" if gateway_https else "http"

# TODO(egor-s): allow to configure gateway name
gateway_name: Optional[str] = None
if gateway_name is None:
Expand All @@ -333,12 +340,11 @@ async def register_service(session: AsyncSession, run_model: RunModel):
wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
if wildcard_domain is None:
raise ServerClientError("Domain is required for gateway")
# we force port 443 for now
service_spec = ServiceSpec(url=f"https://{run_model.run_name}.{wildcard_domain}")
service_spec = ServiceSpec(url=f"{service_protocol}://{run_model.run_name}.{wildcard_domain}")
if run_spec.configuration.model is not None:
service_spec.model = ServiceModelSpec(
name=run_spec.configuration.model.name,
base_url=f"https://gateway.{wildcard_domain}",
base_url=f"{gateway_protocol}://gateway.{wildcard_domain}",
type=run_spec.configuration.model.type,
)
service_spec.options = get_service_options(run_spec.configuration)
Expand All @@ -357,6 +363,8 @@ async def register_service(session: AsyncSession, run_model: RunModel):
project=run_model.project.name,
run_id=run_model.id,
domain=urlparse(service_spec.url).hostname,
service_https=service_https,
gateway_https=gateway_https,
auth=run_spec.configuration.auth,
options=service_spec.options,
ssh_private_key=run_model.project.ssh_private_key,
Expand Down
8 changes: 6 additions & 2 deletions src/dstack/_internal/server/services/gateways/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@ async def register_service(
project: str,
run_id: uuid.UUID,
domain: str,
service_https: bool,
gateway_https: bool,
auth: bool,
options: dict,
ssh_private_key: str,
):
if "openai" in options:
entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}"
await self.register_openai_entrypoint(project, entrypoint)
await self.register_openai_entrypoint(project, entrypoint, gateway_https)

payload = {
"run_id": run_id.hex,
"domain": domain,
"https": service_https,
"auth": auth,
"options": options,
"ssh_private_key": ssh_private_key,
Expand Down Expand Up @@ -126,12 +129,13 @@ async def unregister_replica(self, project: str, run_id: uuid.UUID, job_id: uuid
resp.raise_for_status()
self.is_server_ready = True

async def register_openai_entrypoint(self, project: str, domain: str):
async def register_openai_entrypoint(self, project: str, domain: str, https: bool):
resp = await self._client.post(
self._url(f"/api/registry/{project}/entrypoints/register"),
json={
"module": "openai",
"domain": domain,
"https": https,
},
)
if resp.status_code == 400:
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@

UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None
DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None
SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None))
SKIP_GATEWAY_UPDATE = os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None) is not None
Loading