From 479427bbf9f5d6a19d2c9808e4ca65b648f3c0ab Mon Sep 17 00:00:00 2001 From: "Dilyara Zharikova (Baymurzina)" Date: Wed, 18 Jan 2023 20:02:06 +0800 Subject: [PATCH] fix: image captioning (#289) * fix: image captioning * fix: download model * fix: reqs --- assistant_dists/dream_multimodal/pipeline_conf.json | 2 +- services/image_captioning/Dockerfile | 2 +- services/image_captioning/requirements.txt | 3 ++- services/image_captioning/server.py | 10 ++++++---- services/image_captioning/test.py | 8 +++----- state_formatters/dp_formatters.py | 4 ++++ 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/assistant_dists/dream_multimodal/pipeline_conf.json b/assistant_dists/dream_multimodal/pipeline_conf.json index 10ba62b2db..0dea234e27 100644 --- a/assistant_dists/dream_multimodal/pipeline_conf.json +++ b/assistant_dists/dream_multimodal/pipeline_conf.json @@ -124,7 +124,7 @@ "timeout": 3, "url": "http://image-captioning:8123/respond" }, - "dialog_formatter": "state_formatters.dp_formatters:image_formatter_service", + "dialog_formatter": "state_formatters.dp_formatters:image_captioning_formatter", "response_formatter": "state_formatters.dp_formatters:simple_formatter_service", "state_manager_method": "add_annotation" } diff --git a/services/image_captioning/Dockerfile b/services/image_captioning/Dockerfile index 666fb1f952..01d783b291 100644 --- a/services/image_captioning/Dockerfile +++ b/services/image_captioning/Dockerfile @@ -32,7 +32,7 @@ RUN apt-get install wget -y RUN mkdir -p /opt/conda/lib/python3.7/site-packages/data/models -RUN gdown 1WBQl0WlzvdctslJyLNgedYpRrWAZC69X -O /opt/conda/lib/python3.7/site-packages/data/models/caption.pt +RUN wget http://files.deeppavlov.ai/dream_data/image_captioning/caption.pt -O /opt/conda/lib/python3.7/site-packages/data/models/caption.pt COPY . /ofa diff --git a/services/image_captioning/requirements.txt b/services/image_captioning/requirements.txt index 2716c1dcf8..96171a047a 100644 --- a/services/image_captioning/requirements.txt +++ b/services/image_captioning/requirements.txt @@ -7,4 +7,5 @@ sentry-sdk[flask]==0.14.1 healthcheck==1.3.3 jinja2<=3.0.3 Werkzeug<=2.0.3 -gdown==4.5.1 \ No newline at end of file +gdown==4.5.1 +protobuf==3.20.1 \ No newline at end of file diff --git a/services/image_captioning/server.py b/services/image_captioning/server.py index ed891b1928..ec6572bbad 100644 --- a/services/image_captioning/server.py +++ b/services/image_captioning/server.py @@ -125,7 +125,7 @@ def apply_half(t): def respond(): st_time = time.time() - img_paths = request.json.get("text", []) + img_paths = request.json.get("image_paths", []) captions = [] try: for img_path in img_paths: @@ -140,12 +140,14 @@ def respond(): with torch.no_grad(): caption, scores = eval_step(task, generator, models, sample) - captions.append(caption) + captions.append(caption[0]) except Exception as exc: logger.exception(exc) sentry_sdk.capture_exception(exc) + captions = [{}] * len(img_paths) total_time = time.time() - st_time - logger.info(f"captioning exec time: {total_time:.3f}s") - return jsonify({"caption": captions}) + logger.info(f"image-captioning exec time: {total_time:.3f}s") + logger.info(f"image-captioning result: {captions}") + return jsonify(captions) diff --git a/services/image_captioning/test.py b/services/image_captioning/test.py index 3e74ebe7e6..308bb52ec7 100644 --- a/services/image_captioning/test.py +++ b/services/image_captioning/test.py @@ -4,15 +4,13 @@ def test_respond(): url = "http://0.0.0.0:8123/respond" - img_path = ["example.jpg"] + image_paths = ["example.jpg"] - request_data = {"text": img_path} + request_data = {"image_paths": image_paths} result = requests.post(url, json=request_data).json() - caption = result["caption"][0][0]["caption"] - print(caption) obligatory_word = "bird" - assert obligatory_word in caption, f"Expected the word '{obligatory_word}' to present in caption" + assert obligatory_word in result[0]["caption"], f"Expected the word '{obligatory_word}' to present in caption" print("\n", "Success!!!") diff --git a/state_formatters/dp_formatters.py b/state_formatters/dp_formatters.py index 6062009595..60f5811afe 100755 --- a/state_formatters/dp_formatters.py +++ b/state_formatters/dp_formatters.py @@ -996,3 +996,7 @@ def context_formatter_dialog(dialog: Dict) -> List[Dict]: dialog = utils.replace_with_annotated_utterances(dialog, mode="punct_sent") contexts = [[uttr["text"] for uttr in dialog["utterances"][-num_last_utterances:]]] return [{"contexts": contexts}] + +def image_captioning_formatter(dialog: Dict) -> List[Dict]: + # Used by: image_captioning + return [{"image_paths": [dialog["human_utterances"][-1].get("attributes", {}).get("image")]}]