Skip to content

Commit

Permalink
refactor: fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Oct 25, 2023
1 parent 90e3e81 commit 3b58897
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 157 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ copyright-regexp = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+%(author)s"

[tool.mypy]
pretty = true
mypy_path = "./src:./lib/:./tests"
mypy_path = "./src:./lib/:./tests/integration"
# Exclude non-hydra libraries
exclude = 'lib/charms/((?!openfga_k8s).)'
follow_imports = "silent"
Expand Down
172 changes: 87 additions & 85 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
import secrets
from typing import Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import urlparse

import requests
Expand All @@ -37,12 +37,21 @@
IngressPerAppRevokedEvent,
)
from lightkube.models.core_v1 import ServicePort
from ops import (
ActionEvent,
ConfigChangedEvent,
HookEvent,
LeaderElectedEvent,
PebbleReadyEvent,
StartEvent,
StopEvent,
UpdateStatusEvent,
)
from ops.charm import CharmBase, RelationChangedEvent, RelationDepartedEvent
from ops.jujuversion import JujuVersion
from ops.main import main
from ops.model import ActiveStatus, BlockedStatus, ModelError, Relation, WaitingStatus
from ops.pebble import Error, ExecError
from requests.models import Response
from ops.pebble import Error, ExecError, Layer

from constants import (
DATABASE_NAME,
Expand All @@ -62,13 +71,16 @@
from openfga import OpenFGA
from state import State, requires_state, requires_state_setter

if TYPE_CHECKING:
from ops.pebble import LayerDict

logger = logging.getLogger(__name__)


class OpenFGAOperatorCharm(CharmBase):
"""OpenFGA Operator Charm."""

def __init__(self, *args):
def __init__(self, *args: Any) -> None:
super().__init__(*args)

self._state = State(self.app, lambda: self.model.get_relation("peer"))
Expand Down Expand Up @@ -139,19 +151,19 @@ def __init__(self, *args):
)
self.service_patcher = KubernetesServicePatch(self, [port_http, port_grpc])

def _on_openfga_pebble_ready(self, event):
def _on_openfga_pebble_ready(self, event: PebbleReadyEvent) -> None:
"""Workload pebble ready."""
self._update_workload(event)

def _on_config_changed(self, event):
def _on_config_changed(self, event: ConfigChangedEvent) -> None:
"""Configuration changed."""
self._update_workload(event)

def _on_start(self, event):
def _on_start(self, event: StartEvent) -> None:
"""Start OpenFGA."""
self._update_workload(event)

def _on_stop(self, _):
def _on_stop(self, _: StopEvent) -> None:
"""Stop OpenFGA."""
if self._container.can_connect():
try:
Expand All @@ -163,16 +175,16 @@ def _on_stop(self, _):
self._container.stop(SERVICE_NAME)
self.unit.status = WaitingStatus("service stopped")

def _on_update_status(self, _):
def _on_update_status(self, _: UpdateStatusEvent) -> None:
"""Update the status of the charm."""
self._ready()

@property
def _domain_name(self):
def _domain_name(self) -> str:
if url := self.ingress.url:
# Remove scheme part from url
url = urlparse(url)
dns_name = url.netloc + url.path
parsed = urlparse(url)
dns_name = parsed.netloc + parsed.path
else:
dns_name = "{}.{}-endpoints.{}.svc.cluster.local".format(
self.unit.name.replace("/", "-"), self.app.name, self.model.name
Expand Down Expand Up @@ -213,7 +225,7 @@ def _migration_peer_data_key(self) -> Optional[str]:
return f"{PEER_KEY_DB_MIGRATE_VERSION}_{self.database.relations[0].id}"

@property
def _pebble_layer(self):
def _pebble_layer(self) -> Layer:
env_vars = map_config_to_env_vars(self)
env_vars["OPENFGA_PLAYGROUND_ENABLED"] = "false"
env_vars["OPENFGA_DATASTORE_ENGINE"] = "postgres"
Expand All @@ -230,9 +242,9 @@ def _pebble_layer(self):
self.unit.status = BlockedStatus(
"{} configuration value not set".format(setting),
)
return {}
return Layer()

return {
pebble_layer: LayerDict = {
"summary": "openfga layer",
"description": "pebble layer for openfga",
"services": {
Expand All @@ -259,9 +271,11 @@ def _pebble_layer(self):
},
},
}
return Layer(pebble_layer)

@requires_state_setter
def _create_token(self, event):
def _create_token(self) -> None:
if not self.unit.is_leader():
return
token = secrets.token_urlsafe(32)
if JujuVersion.from_environ().has_secrets:
if not self._state.token_secret_id:
Expand All @@ -273,7 +287,7 @@ def _create_token(self, event):
if not self._state.token:
self._state.token = token

def _get_token(self):
def _get_token(self) -> Optional[str]:
if JujuVersion.from_environ().has_secrets:
if self._state.token_secret_id:
secret = self.model.get_secret(id=self._state.token_secret_id)
Expand All @@ -285,20 +299,20 @@ def _get_token(self):
return self._state.token

@requires_state_setter
def _on_leader_elected(self, event):
def _on_leader_elected(self, event: LeaderElectedEvent) -> None:
"""Leader elected."""
self._update_workload(event)

@requires_state
def _update_workload(self, event):
def _update_workload(self, event: HookEvent) -> None:
"""' Update workload with all available configuration data."""
# make sure we can connect to the container
if not self._container.can_connect():
logger.info("cannot connect to the openfga container")
event.defer()
return

self._create_token(event)
self._create_token()
if not self.model.relations[DATABASE_RELATION_NAME]:
self.unit.status = BlockedStatus("Missing required relation with postgresql")
return
Expand Down Expand Up @@ -334,7 +348,7 @@ def _update_workload(self, event):
self._container.restart(SERVICE_NAME)
self.unit.status = ActiveStatus()

def _on_peer_relation_changed(self, event):
def _on_peer_relation_changed(self, event: RelationChangedEvent) -> None:
self._update_workload(event)

@requires_state_setter
Expand All @@ -355,7 +369,11 @@ def _on_database_created(self, event: DatabaseCreatedEvent) -> None:
logger.error("Automigration job failed, please use the schema-upgrade action")
return

setattr(self._state, self._migration_peer_data_key, self.openfga.get_version())
if not (peer_key := self._migration_peer_data_key):
logger.error("Missing database relation")
return

setattr(self._state, peer_key, self.openfga.get_version())
self._update_workload(event)

@requires_state_setter
Expand All @@ -372,17 +390,22 @@ def _migration_is_needed(self) -> Optional[bool]:
if not self._state.is_ready():
return None

return (
getattr(self._state, self._migration_peer_data_key, None) != self.openfga.get_version()
)
if not (key := self._migration_peer_data_key):
return None

return getattr(self._state, key, None) != self.openfga.get_version()

def _run_sql_migration(self) -> bool:
"""Runs database migration.
Returns True if migration was run successfully, else returns false.
"""
if not (dsn := self._dsn):
logger.info("No database integration")
return False

try:
self.openfga.run_migration(self._dsn)
self.openfga.run_migration(dsn)
logger.info("Successfully executed the database migration")
except Error as e:
err_msg = e.stderr if isinstance(e, ExecError) else e
Expand All @@ -404,11 +427,12 @@ def _ready(self) -> bool:
return False

plan = self._container.get_plan()
if not plan.services.get(SERVICE_NAME):
service = plan.services.get(SERVICE_NAME)
if not service:
self.unit.status = WaitingStatus("waiting for service")
return False

env_vars = plan.services.get(SERVICE_NAME).environment
env_vars = service.environment
for setting in REQUIRED_SETTINGS:
if not env_vars.get(setting, ""):
self.unit.status = BlockedStatus(
Expand Down Expand Up @@ -436,10 +460,12 @@ def _is_openfga_server_running(self) -> bool:
return True

@requires_state_setter
def _on_openfga_relation_changed(self, event: RelationChangedEvent):
def _on_openfga_relation_changed(self, event: RelationChangedEvent) -> None:
"""Open FGA relation changed."""
# the requires side will put the store_name in its
# application bucket
if not event.app:
return
store_name = event.relation.data[event.app].get("store_name", "")
if not store_name:
return
Expand Down Expand Up @@ -479,20 +505,17 @@ def _on_openfga_relation_changed(self, event: RelationChangedEvent):

event.relation.data[self.app].update(data)

def _get_address(self, relation: Relation):
def _get_address(self, relation: Relation) -> str:
"""Returns the ip address to be used with the specified relation."""
return self.model.get_binding(relation).network.ingress_address.exploded

def _create_openfga_store(self, token: str, store_name: str):
def _create_openfga_store(self, token: str, store_name: str) -> Optional[str]:
logger.info("creating store: {}".format(store_name))

address = f"http://localhost:{OPENFGA_SERVER_HTTP_PORT}"
headers = {"Authorization": "Bearer {}".format(token)}

# we need to check if the store with the specified name already
# exists, otherwise OpenFGA will happily create a new store with
# the same name, but different id.
stores = self._list_stores(address, headers)
stores = self._list_stores(token)
for store in stores:
if store["name"] == store_name:
logger.info(
Expand All @@ -502,62 +525,38 @@ def _create_openfga_store(self, token: str, store_name: str):
)
return store["id"]

# to create a new store we issue a POST request to /stores
# endpoint
response = requests.post(
"{}/stores".format(address),
json={"name": store_name},
headers=headers,
)
if response.status_code == 200 or response.status_code == 201:
# if we successfully created the store, we return its id.
data = response.json()
return data["id"]

logger.error(
"failed to create the openfga store: {} {}".format(
response.status_code,
response.json(),
)
)
return ""
try:
store = self.openfga.create_store(token, store_name)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to request OpenFGA API: {e}")
return None

return store["id"]

def _list_stores(self, openfga_host: str, headers, continuation_token="") -> list:
def _list_stores(self, token: str, continuation_token: Optional[str] = None) -> list:
# to list stores we need to issue a GET request to the /stores
# endpoint
response: Response = requests.get(
"{}/stores".format(openfga_host),
headers=headers,
)
if response.status_code != 200:
logger.error("to list existing openfga store: {}".format(response.json()))
return None

data = response.json()
logger.info("received list stores response {}".format(data))
stores = []
for store in data["stores"]:
stores.append({"id": store["id"], "name": store["name"]})
data = self.openfga.list_stores(token, continuation_token=continuation_token)
logger.debug("received list stores response {}".format(data))
stores = [{"id": store["id"], "name": store["name"]} for store in data["stores"]]

# if the response contains a continuation_token, we
# need an additional request to fetch all the stores
ctoken = data["continuation_token"]
if not ctoken:
return stores
else:
return stores.append(
self._list_stores(
openfga_host,
headers,
continuation_token=ctoken,
)
if ctoken := data["continuation_token"]:
# TODO(nsklikas): Python does not support tail recursion. We should
# gather all the stores iteratively. We need to first decide if we
# want to keep this logic. (are stores with the same name really a problem?)
return stores + self._list_stores(
token,
continuation_token=ctoken,
)
return stores

@requires_state_setter
def _on_schema_upgrade_action(self, event):
def _on_schema_upgrade_action(self, event: ActionEvent) -> None:
"""Performs a schema upgrade on the configurable database."""
if not self._container.can_connect():
event.set_results({"error": "cannot connect to the workload container"})
event.fail("Cannot connect to the workload container")
return

if self._run_sql_migration():
Expand All @@ -568,17 +567,20 @@ def _on_schema_upgrade_action(self, event):
return

logger.info("schema upgraded")
setattr(self._state, self._migration_peer_data_key, self.openfga.get_version())
if not (peer_key := self._migration_peer_data_key):
logger.error("Missing database relation")
return
setattr(self._state, peer_key, self.openfga.get_version())
self._update_workload(event)

def _on_ingress_ready(self, event: IngressPerAppReadyEvent):
def _on_ingress_ready(self, event: IngressPerAppReadyEvent) -> None:
self._update_workload(event)

def _on_ingress_revoked(self, event: IngressPerAppRevokedEvent):
def _on_ingress_revoked(self, event: IngressPerAppRevokedEvent) -> None:
self._update_workload(event)


def map_config_to_env_vars(charm, **additional_env):
def map_config_to_env_vars(charm: CharmBase, **additional_env: str) -> Dict:
"""Map config values to environment variables.
Maps the config values provided in config.yaml into environment
Expand Down
Loading

0 comments on commit 3b58897

Please sign in to comment.