Skip to content

Commit

Permalink
refactor: refactor scripts, run black etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexb1200 committed Jul 13, 2023
1 parent 59dddcb commit d13dbc9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
10 changes: 5 additions & 5 deletions examples/scripts/register_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# except ImportError:
# from client import DioptraClient
from dioptra.client import DioptraClient
from dioptra.client import DioptraClient, get_dioptra_client


try:
Expand Down Expand Up @@ -67,18 +67,18 @@ def register_queues(queue, api_url):
"""Register the queues used in Dioptra's examples and demos."""

console = RichConsole(Console())
client = DioptraClient(address=api_url)
client = get_dioptra_client(address = api_url)

console.print_title("Dioptra Examples - Register Queues")
console.print_parameter("queue", value=f"[default not bold]{', '.join(queue)}[/]")
console.print_parameter("api_url", value=f"[default not bold]{api_url}[/]")

for name in queue:
response = client.get_queue_by_name(name=name)
response = client.queue.get_queue_by_name(name=name)

if response is None or "Not Found" in response.get("message", []):
response = client.register_queue(name=name)
response_after = client.get_queue_by_name(name=name)
response = client.queue.register_queue(name=name)
response_after = client.queue.get_queue_by_name(name=name)

if response_after is None or "Not Found" in response_after.get("message", []):
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/register_task_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def register_task_plugins(plugins_dir, api_url):
response = client.get_custom_task_plugin(name=custom_plugin["name"])

if response is None or "Not Found" in response.get("message", []):
response = client.upload_custom_plugin_package(
response = client.custom_task_plugins.upload_custom_plugin_package(
custom_plugin_name=custom_plugin["name"],
custom_plugin_file=custom_plugin["path"],
)
response_after = client.get_custom_task_plugin(
response_after = client.custom_task_plugins.get_custom_task_plugin(
name=custom_plugin["name"]
)

Expand Down
7 changes: 4 additions & 3 deletions src/dioptra/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def get_dioptra_client(address: str | None = None) -> DioptraClient:
)
scheme, netloc, path, _, _, _ = urlparse(address)

# maybe we should make this into a wrapper class to deal with unexpectedly closed connections, retries, error handling etc.?
session = request.Session()
# maybe we should make this into a wrapper class to deal with unexpectedly
# closed connections, retries, error handling etc.?
session = requests.Session()
# More needed with the address most likely

# experiment_client = ExperimentClient(...)
Expand Down Expand Up @@ -605,7 +606,7 @@ def upload_custom_plugin_package(
class QueueClient(object):
def __init__(
self,
session: Session,
session: requests.Session,
address: str | None = None,
) -> None:
address = (
Expand Down

0 comments on commit d13dbc9

Please sign in to comment.