From 7b77a08671aa8114f4501e524787b5cbc7455840 Mon Sep 17 00:00:00 2001 From: TShapinsky Date: Thu, 28 Dec 2023 14:55:31 -0700 Subject: [PATCH] add api testing independent of alfalfa-client and fix some bugs revealed by tests --- alfalfa_web/server/api-v2.js | 15 ++- alfalfa_web/server/api.js | 9 +- setup.cfg | 2 +- tests/api/__init__.py | 1 + tests/api/conftest.py | 37 ++++++ tests/api/test_model.py | 122 ++++++++++++++++++++ tests/api/test_point.py | 206 +++++++++++++++++++++++++++++++++ tests/api/test_run.py | 214 +++++++++++++++++++++++++++++++++++ 8 files changed, 597 insertions(+), 9 deletions(-) create mode 100644 tests/api/__init__.py create mode 100644 tests/api/conftest.py create mode 100644 tests/api/test_model.py create mode 100644 tests/api/test_point.py create mode 100644 tests/api/test_run.py diff --git a/alfalfa_web/server/api-v2.js b/alfalfa_web/server/api-v2.js index aec45fa6..e74b107d 100644 --- a/alfalfa_web/server/api-v2.js +++ b/alfalfa_web/server/api-v2.js @@ -593,7 +593,14 @@ router.get("/runs/:runId/points/:pointId", (req, res, next) => { * description: The point was successfully updated */ router.put("/runs/:runId/points/:pointId", (req, res, next) => { - // TODO Confirm that point isn't an OUTPUT type + const { value } = req.body; + + if (req.point.point_type == "OUTPUT") { + return res + .status(400) + .json({ message: `Point '${req.point.ref_id}' is of type '${req.point.point_type}' and cannot be written to` }); + } + if (value !== null) { const error = validate( { value }, @@ -687,7 +694,7 @@ router.delete("/runs/:runId", (req, res, next) => { * schema: * $ref: '#/components/schemas/Error' */ -router.post("/runs/:runId/start", async (req, res, next) => { +router.post("/runs/:runId/start", (req, res, next) => { const { body } = req; const timeValidator = /^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$/; @@ -777,7 +784,7 @@ router.post("/runs/:runId/advance", (req, res, next) => { */ router.post("/runs/:runId/stop", (req, res, next) => { // If the run is already stopping or stopped there is no need to send message - if (["STOPPING", "STOPPED", "COMPLETE"].includes(req.run.status)) { + if (["STOPPING", "COMPLETE", "ERROR", "READY"].includes(req.run.status)) { res.sendStatus(204); } api @@ -1220,7 +1227,7 @@ router.get("/simulations", async (req, res, next) => { .catch(next); }); -router.get("*", (req, res) => res.status(404).json({ message: "Page not found" })); +router.all("*", (req, res) => res.status(404).json({ message: "Page not found" })); router.use(errorHandler); diff --git a/alfalfa_web/server/api.js b/alfalfa_web/server/api.js index b9768999..06b3b411 100644 --- a/alfalfa_web/server/api.js +++ b/alfalfa_web/server/api.js @@ -104,7 +104,7 @@ class AlfalfaAPI { getPointsById = async (run, pointIds) => { return Promise.all( pointIds.map((pointId) => { - this.getPointById(run, pointId); + return this.getPointById(run, pointId); }) ); }; @@ -173,7 +173,7 @@ class AlfalfaAPI { removeRun = async (run) => { // Delete run - const { deletedCount } = await this.run.deleteOne({ _id: run._id }); + const { deletedCount } = await this.runs.deleteOne({ _id: run._id }); if (deletedCount == 1) { // Delete points @@ -258,7 +258,8 @@ class AlfalfaAPI { }; listModels = async () => { - return this.getModels().map(this.formatModel); + const models = await this.getModels(); + return models.map(this.formatModel); }; getModels = async () => { @@ -285,7 +286,7 @@ class AlfalfaAPI { new GetObjectCommand({ Bucket: process.env.S3_BUCKET, Key: `uploads/${model.ref_id}/${model.model_name}`, - ResponseContentDisposition: `attachment; filename="${run.ref_id}.tar.gz"` + ResponseContentDisposition: `attachment; filename="${model.ref_id}.tar.gz"` }), { expiresIn: 86400 diff --git a/setup.cfg b/setup.cfg index 6d58e6b5..e2836e28 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,7 +71,7 @@ extras = True addopts = --cov alfalfa_worker --cov-report term-missing --verbose - -m "not integration and not fmu and not docker and not scale" + -m "not integration and not fmu and not docker and not scale and not api" markers = integration: marks tests as integration tests (deselect with '-m "not integration"') fmu: mark tests that require fmu support, e.g., pyfmi (deselect with '-m "not fmu"') diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..2ff5e5b8 --- /dev/null +++ b/tests/api/__init__.py @@ -0,0 +1 @@ +"""This module includes tests for the mechanical aspects of the api, but not necessarily all of the functionality.""" diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 00000000..e6e995e9 --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,37 @@ +import os +from pathlib import Path + +import pytest +from alfalfa_client import AlfalfaClient +from requests import HTTPError + + +@pytest.fixture +def base_url(): + return 'http://localhost/api/v2' + + +@pytest.fixture +def alfalfa_client(): + return AlfalfaClient() + + +@pytest.fixture +def model_path(): + return Path(os.path.dirname(__file__)) / "models" / "small_office" + + +@pytest.fixture +def model_id(alfalfa_client: AlfalfaClient, model_path): + return alfalfa_client.upload_model(model_path) + + +@pytest.fixture +def run_id(alfalfa_client: AlfalfaClient, model_path): + run_id = alfalfa_client.submit(model_path) + yield run_id + try: + if alfalfa_client.status(run_id) not in ["COMPLETE", "ERROR", "STOPPING", "READY"]: + alfalfa_client.stop(run_id) + except HTTPError: + pass diff --git a/tests/api/test_model.py b/tests/api/test_model.py new file mode 100644 index 00000000..d81b61df --- /dev/null +++ b/tests/api/test_model.py @@ -0,0 +1,122 @@ +from collections import OrderedDict +from io import StringIO +from uuid import uuid4 + +import pytest +import requests +from requests_toolbelt import MultipartEncoder + + +@pytest.mark.api +def test_model_upload_download(base_url): + + # Send request with empty data expecting a 400 error + response = requests.post(f"{base_url}/models/upload") + + assert response.status_code == 400 + response_body = response.json() + assert "message" in response_body + + # Send request with proper data expecting a 200 + model_name = "test_model.zip" + request_body = {"modelName": model_name} + response = requests.post(f"{base_url}/models/upload", json=request_body) + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert "url" in payload + assert "modelId" in payload + model_id = payload["modelId"] + + # upload a model to the s3 bucket with the pre-signed request + form_data = OrderedDict(payload["fields"]) + file_contents = "This is a test string to be used as a file stand in" + form_data['file'] = ('filename', StringIO(file_contents)) + + encoder = MultipartEncoder(fields=form_data) + response = requests.post(payload['url'], data=encoder, headers={'Content-Type': encoder.content_type}) + assert response.status_code == 204 + + # download model and check it is the same + response = requests.get(f"{base_url}/models/{model_id}/download") + assert response.status_code == 200 + + contents = b"" + for chunk in response.iter_content(): + contents += chunk + assert contents.decode('utf-8') == file_contents, "Downloaded model does not match uploaded model" + + +@pytest.mark.api +def test_model_retrieval(base_url): + # create a model + model_name = "test_model.zip" + request_body = {"modelName": model_name} + response = requests.post(f"{base_url}/models/upload", json=request_body) + response_body = response.json() + + model_id = response_body["payload"]["modelId"] + response_body["payload"]["url"] + + # request all models + response = requests.get(f"{base_url}/models") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) > 0, "No models in model list" + + contains_model = False + for model in payload: + if model["id"] == model_id: + assert model["modelName"] == model_name + contains_model = True + + assert contains_model, "Could not find uploaded model in list" + + # request only uploaded model + response = requests.get(f"{base_url}/models/{model_id}") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert payload["id"] == model_id + assert payload["modelName"] == model_name + assert "created" in payload + assert "modified" in payload + + # request with invalid model id + model_id = "not_a_model_id" + response = requests.get(f"{base_url}/models/{model_id}") + + assert response.status_code == 400 + response_body = response.json() + assert "message" in response_body + + +@pytest.mark.api +def test_model_not_found(base_url): + # request non-existent model + response = requests.get(f"{base_url}/models/{uuid4()}") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # download non-existent model + response = requests.get(f"{base_url}/models/{uuid4()}/download") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # create run from model which does not exist + response = requests.post(f"{base_url}/models/{uuid4()}/createRun") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body diff --git a/tests/api/test_point.py b/tests/api/test_point.py new file mode 100644 index 00000000..45fcd7d1 --- /dev/null +++ b/tests/api/test_point.py @@ -0,0 +1,206 @@ +from datetime import datetime + +import pytest +import requests + + +@pytest.mark.api +def test_point_retrieval(base_url, run_id, alfalfa_client): + alfalfa_client.start(run_id, datetime(2020, 1, 1, 0, 0), datetime(2020, 1, 2, 0, 0), external_clock=True) + # get points + response = requests.get(f"{base_url}/runs/{run_id}/points") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) > 0, "No points in payload" + + outputs = [] + bidirectionals = [] + inputs = [] + + for point in payload: + assert "id" in point + assert "name" in point + assert "type" in point + assert "value" not in point + + if point["type"] == "OUTPUT": + outputs.append(point) + elif point["type"] == "BIDIRECTIONAL": + bidirectionals.append(point) + elif point["type"] == "INPUT": + inputs.append(point) + + # get points of a specific type + request_body = { + 'pointTypes': ["OUTPUT"] + } + response = requests.post(f"{base_url}/runs/{run_id}/points", json=request_body) + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) == len(outputs), "Filter did not return correct number of points" + + for point in payload: + assert point in outputs, "Filter did not return correct points" + + # get points by list + request_body = { + 'points': [point["id"] for point in inputs] + } + response = requests.post(f"{base_url}/runs/{run_id}/points", json=request_body) + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) == len(inputs), "Filter did not return correct number of points" + + for point in payload: + assert point in inputs, "Filter did not return correct points" + + # Wait for run to be running and advance one timestep + alfalfa_client.wait(run_id, "RUNNING") + alfalfa_client.advance(run_id) + + # get the values from each output individually + for point in outputs + bidirectionals: + response = requests.get(f"{base_url}/runs/{run_id}/points/{point['id']}") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert "value" in payload + + # get all non outputs individually + for point in inputs: + response = requests.get(f"{base_url}/runs/{run_id}/points/{point['id']}") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert point == payload, "Point data changed or incorrect" + + # get values for all output points + response = requests.get(f"{base_url}/runs/{run_id}/points/values") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) == len(outputs + bidirectionals), "Mismatch in output points" + + expected_ids = [point["id"] for point in (outputs + bidirectionals)] + + for id in payload: + assert id in expected_ids, "Point not expected in response" + + # get point values for all bidirectional points + request_body = { + 'pointTypes': ["BIDIRECTIONAL"] + } + response = requests.post(f"{base_url}/runs/{run_id}/points/values", json=request_body) + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) == len(bidirectionals), "Filter did not return correct number of points" + + expected_ids = [point["id"] for point in bidirectionals] + + for id in payload: + assert id in expected_ids, "Filter returned incorrect point" + + # get point values for all outputs by id + request_body = { + 'points': [point["id"] for point in outputs] + } + response = requests.post(f"{base_url}/runs/{run_id}/points/values", json=request_body) + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) == len(outputs), "Filter did not return correct number of points" + + expected_ids = [point["id"] for point in outputs] + + for id in payload: + assert id in expected_ids, "Filter returned incorrect point" + + +@pytest.mark.api +def test_point_writes(base_url, run_id): + # get all points + response = requests.get(f"{base_url}/runs/{run_id}/points") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + + outputs = [] + bidirectionals = [] + inputs = [] + + for point in payload: + if point["type"] == "OUTPUT": + outputs.append(point) + elif point["type"] == "BIDIRECTIONAL": + bidirectionals.append(point) + elif point["type"] == "INPUT": + inputs.append(point) + + all_points = payload + + # write to points individually + for point in all_points: + request_body = { + 'value': 5 + } + response = requests.put(f"{base_url}/runs/{run_id}/points/{point['id']}", json=request_body) + + if point in outputs: + assert response.status_code == 400 + response_body = response.json() + assert "message" in response_body + else: + assert response.status_code == 204 + + # write to point with invalid value + request_body = { + 'value': "hello" + } + response = requests.put(f"{base_url}/runs/{run_id}/points/{inputs[0]['id']}", json=request_body) + + assert response.status_code == 400 + response_body = response.json() + assert "message" in response_body + + # write to all valid points + request_body = { + 'points': dict([(point["id"], 5) for point in (inputs + bidirectionals)]) + } + response = requests.put(f"{base_url}/runs/{run_id}/points/values", json=request_body) + + assert response.status_code == 204 + + # write to all points (some invalid) + request_body = { + 'points': dict([(point["id"], 5) for point in all_points]) + } + response = requests.put(f"{base_url}/runs/{run_id}/points/values", json=request_body) + + assert response.status_code == 400 + response_body = response.json() + assert "message" in response_body + assert "payload" in response_body + payload = response_body["payload"] + assert len(payload) == len(outputs), "Error cardinality mismatch" diff --git a/tests/api/test_run.py b/tests/api/test_run.py new file mode 100644 index 00000000..ecddfc67 --- /dev/null +++ b/tests/api/test_run.py @@ -0,0 +1,214 @@ +import tarfile +import time +from datetime import datetime +from uuid import uuid4 + +import pytest +import requests + + +@pytest.mark.api +def test_run_create(base_url, model_id): + # create run from model + response = requests.post(f"{base_url}/models/{model_id}/createRun") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert "runId" in payload + + +@pytest.mark.api +def test_run_retrieval(base_url, run_id): + # request all runs + response = requests.get(f"{base_url}/runs") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + + contains_run = False + + for run in payload: + if run["id"] == run_id: + contains_run = True + + assert contains_run, "Could not find run in list" + + # request run information + response = requests.get(f"{base_url}/runs/{run_id}") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + + assert "id" in payload + assert "name" in payload + assert "status" in payload + assert "datetime" in payload + assert "simType" in payload + assert "errorLog" in payload + + # delete run + response = requests.delete(f"{base_url}/runs/{run_id}") + + assert response.status_code == 204 + + # request run information (expect 404) + response = requests.get(f"{base_url}/runs/{run_id}") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + +@pytest.mark.api +def test_run_not_found(base_url): + # request run which does not exist + response = requests.get(f"{base_url}/runs/{uuid4()}") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # delete run which does not exist + response = requests.delete(f"{base_url}/runs/{uuid4()}") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # start run which does not exist + response = requests.post(f"{base_url}/runs/{uuid4()}/start") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # request time from run which does not exist + response = requests.get(f"{base_url}/runs/{uuid4()}/time") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # advance run which does not exist + response = requests.post(f"{base_url}/runs/{uuid4()}/advance") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # stop run which does not exist + response = requests.post(f"{base_url}/runs/{uuid4()}/stop") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + # download run which does not exist + response = requests.get(f"{base_url}/runs/{uuid4()}/download") + + assert response.status_code == 404 + response_body = response.json() + assert "message" in response_body + + +@pytest.mark.api +def test_run_start_stop(base_url, run_id): + # start run with invalid arguments + response = requests.post(f"{base_url}/runs/{run_id}/start") + + assert response.status_code == 400 + response_body = response.json() + assert "message" in response_body + + start_datetime = datetime(2020, 1, 1, 0, 0) + + # start run + request_body = { + 'startDatetime': str(start_datetime), + 'endDatetime': str(datetime(2020, 1, 1, 23, 59, 59)), + 'externalClock': True + } + + response = requests.post(f"{base_url}/runs/{run_id}/start", json=request_body) + + assert response.status_code == 204 + + # get status and wait for "RUNNING" + timeout = time.time() + 60 + while True: + response = requests.get(f"{base_url}/runs/{run_id}") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert "status" in payload + status = payload["status"] + + if status == "RUNNING": + break + if time.time() > timeout: + pytest.fail("Timed out waiting for run to start") + + # get time of run + response = requests.get(f"{base_url}/runs/{run_id}/time") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert "time" in payload + + assert start_datetime == datetime.strptime(payload["time"], '%Y-%m-%d %H:%M:%S') + + # advance run + response = requests.post(f"{base_url}/runs/{run_id}/advance") + + assert response.status_code == 204 + + # stop run + response = requests.post(f"{base_url}/runs/{run_id}/stop") + + assert response.status_code == 204 + + # get status and wait for "STOPPING" + timeout = time.time() + 60 + while True: + response = requests.get(f"{base_url}/runs/{run_id}") + + assert response.status_code == 200 + response_body = response.json() + assert "payload" in response_body + payload = response_body["payload"] + assert "status" in payload + status = payload["status"] + + if status == "STOPPING": + break + if time.time() > timeout: + pytest.fail("Timed out waiting for run to stop") + + # stop already stopped run + response = requests.post(f"{base_url}/runs/{run_id}/stop") + + assert response.status_code == 204 + + +@pytest.mark.api +def test_run_download(base_url, run_id, tmp_path): + run_file = tmp_path / f"{run_id}.tar.gz" + # download run + response = requests.get(f"{base_url}/runs/{run_id}/download") + + assert response.status_code == 200 + with run_file.open("wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + assert tarfile.is_tarfile(str(run_file)), "Downloaded file is not a valid tar.gz"