From 38a75b0916f6beda372b5d061d7d4c3e8cc6b5eb Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 7 Oct 2024 21:46:40 +0200 Subject: [PATCH] chore: Changelog and lint --- changelog_entry.yaml | 10 ++ policyengine_api/ai_prompts/__init__.py | 5 +- policyengine_api/ai_prompts/simulation.py | 8 +- policyengine_api/ai_prompts/tracer.py | 2 +- policyengine_api/api.py | 4 +- policyengine_api/country.py | 8 +- policyengine_api/endpoints/household.py | 5 +- .../endpoints/simulation_analysis.py | 33 +++--- policyengine_api/endpoints/tracer_analysis.py | 29 +++-- policyengine_api/utils/ai_analysis.py | 10 +- policyengine_api/utils/tracer_analysis.py | 9 +- tests/python/test_ai_analysis.py | 40 ++++--- tests/python/test_tracer.py | 102 +++++++++++------- 13 files changed, 165 insertions(+), 100 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..5234fdfc 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,10 @@ +- bump: major + changes: + added: + - /tracer_analysis endpoint for household tracer outputs + - Database for simulation tracer outputs + changed: + - /analysis endpoint renamed to /simulation_analysis endpoint + - Simulation runs now write tracer output to database + - Simulation analysis runs now return ReadableStreams + - Refactored Claude interaction code to be decoupled from endpoints \ No newline at end of file diff --git a/policyengine_api/ai_prompts/__init__.py b/policyengine_api/ai_prompts/__init__.py index 9a14fae1..a4219d0a 100644 --- a/policyengine_api/ai_prompts/__init__.py +++ b/policyengine_api/ai_prompts/__init__.py @@ -1,2 +1,5 @@ from policyengine_api.ai_prompts.tracer import tracer_analysis_prompt -from policyengine_api.ai_prompts.simulation import generate_simulation_analysis_prompt, audience_descriptions \ No newline at end of file +from policyengine_api.ai_prompts.simulation import ( + generate_simulation_analysis_prompt, + audience_descriptions, +) diff --git a/policyengine_api/ai_prompts/simulation.py b/policyengine_api/ai_prompts/simulation.py index e2a855f8..d381be51 100644 --- a/policyengine_api/ai_prompts/simulation.py +++ b/policyengine_api/ai_prompts/simulation.py @@ -2,11 +2,11 @@ audience_descriptions = { "ELI5": "Write this for a five-year-old who doesn't know anything about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.", - "Normal": - "Write this for a policy analyst who knows a bit about economics and policy.", + "Normal": "Write this for a policy analyst who knows a bit about economics and policy.", "Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.", } + def generate_simulation_analysis_prompt( time_period, region, @@ -18,7 +18,7 @@ def generate_simulation_analysis_prompt( is_enhanced_cps, selected_version, country_id, - policy_label + policy_label, ): return f""" I'm using PolicyEngine, a free, open source tool to compute the impact of @@ -125,4 +125,4 @@ def generate_simulation_analysis_prompt( and the share held by the top 1% (describe the relative changes): {json.dumps( impact["inequality"], )} - """ \ No newline at end of file + """ diff --git a/policyengine_api/ai_prompts/tracer.py b/policyengine_api/ai_prompts/tracer.py index fbee7035..7f87f3a2 100644 --- a/policyengine_api/ai_prompts/tracer.py +++ b/policyengine_api/ai_prompts/tracer.py @@ -13,4 +13,4 @@ Keep your explanation concise but informative, suitable for a general audience. Do not start with phrases like "Certainly!" or "Here's an explanation. It will be rendered as markdown, so preface $ with \. -{anthropic.AI_PROMPT}""" \ No newline at end of file +{anthropic.AI_PROMPT}""" diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 31bb74de..272beb49 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -96,7 +96,9 @@ methods=["GET"], )(get_economic_impact) -app.route("//simulation_analysis", methods=["POST"])(execute_simulation_analysis) +app.route("//simulation_analysis", methods=["POST"])( + execute_simulation_analysis +) app.route("//user_policy", methods=["POST"])(set_user_policy) diff --git a/policyengine_api/country.py b/policyengine_api/country.py index d25a118f..53848ea9 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -310,7 +310,13 @@ def build_entities(self) -> dict: # 4. Delete the code at the end of the function that writes to a file (Done) # 5. Add code at the end of the function to write to a database (Done) - def calculate(self, household: dict, reform: Union[dict, None], household_id: Optional[int], policy_id: Optional[int] = None): + def calculate( + self, + household: dict, + reform: Union[dict, None], + household_id: Optional[int], + policy_id: Optional[int] = None, + ): if reform is not None and len(reform.keys()) > 0: system = self.tax_benefit_system.clone() for parameter_name in reform: diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index 4398830b..72832a5c 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -349,7 +349,10 @@ def get_household_under_policy( try: result = country.calculate( - household["household_json"], policy["policy_json"], household_id, policy_id + household["household_json"], + policy["policy_json"], + household_id, + policy_id, ) except Exception as e: logging.exception(e) diff --git a/policyengine_api/endpoints/simulation_analysis.py b/policyengine_api/endpoints/simulation_analysis.py index 93d0f537..e5386a41 100644 --- a/policyengine_api/endpoints/simulation_analysis.py +++ b/policyengine_api/endpoints/simulation_analysis.py @@ -4,13 +4,20 @@ from rq import Queue from redis import Redis from typing import Optional -from policyengine_api.utils.ai_analysis import trigger_ai_analysis, get_existing_analysis -from policyengine_api.ai_prompts import generate_simulation_analysis_prompt, audience_descriptions +from policyengine_api.utils.ai_analysis import ( + trigger_ai_analysis, + get_existing_analysis, +) +from policyengine_api.ai_prompts import ( + generate_simulation_analysis_prompt, + audience_descriptions, +) queue = Queue(connection=Redis()) + def execute_simulation_analysis(country_id: str) -> Response: - + # Pop the various parameters from the request payload = request.json @@ -22,10 +29,12 @@ def execute_simulation_analysis(country_id: str) -> Response: policy = payload.get("policy") region = payload.get("region") relevant_parameters = payload.get("relevant_parameters") - relevant_parameter_baseline_values = payload.get("relevant_parameter_baseline_values") + relevant_parameter_baseline_values = payload.get( + "relevant_parameter_baseline_values" + ) audience = payload.get("audience") - # Check if the region is enhanced_cps + # Check if the region is enhanced_cps is_enhanced_cps = "enhanced_cps" in region # Create prompt based on data @@ -40,7 +49,7 @@ def execute_simulation_analysis(country_id: str) -> Response: is_enhanced_cps, selected_version, country_id, - policy_label + policy_label, ) # Add audience description to end @@ -50,23 +59,17 @@ def execute_simulation_analysis(country_id: str) -> Response: # streaming response existing_analysis = get_existing_analysis(prompt) if existing_analysis is not None: - return Response( - status=200, - response=existing_analysis - ) + return Response(status=200, response=existing_analysis) # Otherwise, pass prompt to Claude, then return streaming function try: analysis = trigger_ai_analysis(prompt) - return Response( - status=200, - response=analysis - ) + return Response(status=200, response=analysis) except Exception as e: return Response( status=500, response={ "message": "Error computing analysis", "error": str(e), - } + }, ) diff --git a/policyengine_api/endpoints/tracer_analysis.py b/policyengine_api/endpoints/tracer_analysis.py index d7af5800..bd6aa7f4 100644 --- a/policyengine_api/endpoints/tracer_analysis.py +++ b/policyengine_api/endpoints/tracer_analysis.py @@ -3,7 +3,10 @@ from flask import Response, request from policyengine_api.country import validate_country from policyengine_api.ai_prompts import tracer_analysis_prompt -from policyengine_api.utils.ai_analysis import trigger_ai_analysis, get_existing_analysis +from policyengine_api.utils.ai_analysis import ( + trigger_ai_analysis, + get_existing_analysis, +) from policyengine_api.utils.tracer_analysis import parse_tracer_output from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS @@ -13,7 +16,8 @@ # Access the prompt and add the parsed tracer output # Pass the complete prompt to the get_analysis function and return its response -#TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit +# TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit + def execute_tracer_analysis( country_id: str, @@ -27,7 +31,7 @@ def execute_tracer_analysis( country_not_found = validate_country(country_id) if country_not_found: return country_not_found - + payload = request.json household_id = payload.get("household_id") @@ -51,7 +55,7 @@ def execute_tracer_analysis( status=404, response={ "message": "Unable to analyze household: no household simulation tracer found", - } + }, ) # Parse the tracer output @@ -64,36 +68,29 @@ def execute_tracer_analysis( response={ "message": "Error parsing tracer output", "error": str(e), - } + }, ) # Add the parsed tracer output to the prompt prompt = tracer_analysis_prompt.format( - variable=variable, - tracer_segment=tracer_segment + variable=variable, tracer_segment=tracer_segment ) # If a calculated record exists for this prompt, return it as a # streaming response existing_analysis = get_existing_analysis(prompt) if existing_analysis is not None: - return Response( - status=200, - response=existing_analysis - ) + return Response(status=200, response=existing_analysis) # Otherwise, pass prompt to Claude, then return streaming function try: analysis = trigger_ai_analysis(prompt) - return Response( - status=200, - response=analysis - ) + return Response(status=200, response=analysis) except Exception as e: return Response( status=500, response={ "message": "Error computing analysis", "error": str(e), - } + }, ) diff --git a/policyengine_api/utils/ai_analysis.py b/policyengine_api/utils/ai_analysis.py index 70183589..2c329838 100644 --- a/policyengine_api/utils/ai_analysis.py +++ b/policyengine_api/utils/ai_analysis.py @@ -7,7 +7,7 @@ def trigger_ai_analysis(prompt: str) -> Generator[str, None, None]: - + # Configure a Claude client claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) @@ -38,7 +38,7 @@ def generate(): chunk = buffer[:chunk_size] buffer = buffer[chunk_size:] yield json.dumps({"stream": chunk}) + "\n" - + if buffer: yield json.dumps({"stream": buffer}) + "\n" @@ -63,7 +63,7 @@ def get_existing_analysis(prompt: str) -> Generator[str, None, None] | None: if analysis is None: return None - + def generate(): # First, yield prompt so it's accessible on front end @@ -75,8 +75,8 @@ def generate(): chunk_size = 5 for i in range(0, len(analysis["analysis"]), chunk_size): - chunk = analysis["analysis"][i:i + chunk_size] + chunk = analysis["analysis"][i : i + chunk_size] yield json.dumps({"stream": chunk}) + "\n" time.sleep(0.05) - return generate() \ No newline at end of file + return generate() diff --git a/policyengine_api/utils/tracer_analysis.py b/policyengine_api/utils/tracer_analysis.py index 60af0751..a9e595d4 100644 --- a/policyengine_api/utils/tracer_analysis.py +++ b/policyengine_api/utils/tracer_analysis.py @@ -1,19 +1,20 @@ import re + def parse_tracer_output(tracer_output, target_variable): result = [] target_indent = None capturing = False # Create a regex pattern to match the exact variable name - # This will match the variable name followed by optional whitespace, + # This will match the variable name followed by optional whitespace, # then optional angle brackets with any content, then optional whitespace - pattern = rf'^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*' + pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*" for line in tracer_output: # Count leading spaces to determine indentation level indent = len(line) - len(line.strip()) - + # Check if this line matches our target variable match = re.match(pattern, line) if match and not capturing: @@ -27,4 +28,4 @@ def parse_tracer_output(tracer_output, target_variable): # Capture dependencies (lines with greater indentation) result.append(line) - return result \ No newline at end of file + return result diff --git a/tests/python/test_ai_analysis.py b/tests/python/test_ai_analysis.py index 8900d853..9f7bec68 100644 --- a/tests/python/test_ai_analysis.py +++ b/tests/python/test_ai_analysis.py @@ -2,16 +2,22 @@ from unittest.mock import patch, MagicMock import json import os -from policyengine_api.utils.ai_analysis import trigger_ai_analysis, get_existing_analysis +from policyengine_api.utils.ai_analysis import ( + trigger_ai_analysis, + get_existing_analysis, +) -@patch('policyengine_api.utils.ai_analysis.anthropic.Anthropic') -@patch('policyengine_api.utils.ai_analysis.local_database') + +@patch("policyengine_api.utils.ai_analysis.anthropic.Anthropic") +@patch("policyengine_api.utils.ai_analysis.local_database") def test_trigger_ai_analysis(mock_db, mock_anthropic): mock_client = MagicMock() mock_anthropic.return_value = mock_client mock_stream = MagicMock() mock_stream.text_stream = ["Test ", "response ", "from ", "AI"] - mock_client.messages.stream.return_value.__enter__.return_value = mock_stream + mock_client.messages.stream.return_value.__enter__.return_value = ( + mock_stream + ) prompt = "Test prompt" generator = trigger_ai_analysis(prompt) @@ -40,11 +46,14 @@ def test_trigger_ai_analysis(mock_db, mock_anthropic): messages=[{"role": "user", "content": prompt}], ) -@patch('policyengine_api.utils.ai_analysis.local_database') -@patch('policyengine_api.utils.ai_analysis.time.sleep') + +@patch("policyengine_api.utils.ai_analysis.local_database") +@patch("policyengine_api.utils.ai_analysis.time.sleep") def test_get_existing_analysis_found(mock_sleep, mock_db): - mock_db.query.return_value.fetchone.return_value = {"analysis": "Existing analysis"} - + mock_db.query.return_value.fetchone.return_value = { + "analysis": "Existing analysis" + } + prompt = "Test prompt" generator = get_existing_analysis(prompt) @@ -66,10 +75,11 @@ def test_get_existing_analysis_found(mock_sleep, mock_db): # Check that sleep was called for each chunk assert mock_sleep.call_count == 4 -@patch('policyengine_api.utils.ai_analysis.local_database') + +@patch("policyengine_api.utils.ai_analysis.local_database") def test_get_existing_analysis_not_found(mock_db): mock_db.query.return_value.fetchone.return_value = None - + prompt = "Test prompt" result = get_existing_analysis(prompt) @@ -79,13 +89,15 @@ def test_get_existing_analysis_not_found(mock_db): (prompt,), ) + # Additional test to check environment variable def test_anthropic_api_key(): - with patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'test_key'}): - assert os.getenv("ANTHROPIC_API_KEY") == 'test_key' + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_key"}): + assert os.getenv("ANTHROPIC_API_KEY") == "test_key" + # Test error handling in trigger_ai_analysis -@patch('policyengine_api.utils.ai_analysis.anthropic.Anthropic') +@patch("policyengine_api.utils.ai_analysis.anthropic.Anthropic") def test_trigger_ai_analysis_error(mock_anthropic): mock_client = MagicMock() mock_anthropic.return_value = mock_client @@ -100,4 +112,4 @@ def test_trigger_ai_analysis_error(mock_anthropic): # The generator should stop after the initial yield due to the error with pytest.raises(Exception, match="API Error"): - list(generator) \ No newline at end of file + list(generator) diff --git a/tests/python/test_tracer.py b/tests/python/test_tracer.py index 1cc9679c..aa66b304 100644 --- a/tests/python/test_tracer.py +++ b/tests/python/test_tracer.py @@ -5,12 +5,14 @@ from policyengine_api.utils.tracer_analysis import parse_tracer_output from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS + @pytest.fixture def app(): app = Flask(__name__) - app.config['TESTING'] = True + app.config["TESTING"] = True return app + # Test cases for parse_tracer_output function def test_parse_tracer_output(): @@ -22,22 +24,27 @@ def test_parse_tracer_output(): " non_market_income <500>", " pension_income <500>", ] - + result = parse_tracer_output(tracer_output, "only_government_benefit") assert result == tracer_output - + result = parse_tracer_output(tracer_output, "market_income") assert result == tracer_output[1:4] - + result = parse_tracer_output(tracer_output, "non_market_income") assert result == tracer_output[4:] + # Test cases for execute_tracer_analysis function -@patch('policyengine_api.endpoints.tracer_analysis.local_database') -@patch('policyengine_api.endpoints.tracer_analysis.trigger_ai_analysis') -def test_execute_tracer_analysis_success(mock_trigger_ai_analysis, mock_db, app, rest_client): +@patch("policyengine_api.endpoints.tracer_analysis.local_database") +@patch("policyengine_api.endpoints.tracer_analysis.trigger_ai_analysis") +def test_execute_tracer_analysis_success( + mock_trigger_ai_analysis, mock_db, app, rest_client +): mock_db.query.return_value.fetchone.return_value = { - "tracer_output": json.dumps(["disposable_income <1000>", " market_income <1000>"]) + "tracer_output": json.dumps( + ["disposable_income <1000>", " market_income <1000>"] + ) } mock_trigger_ai_analysis.return_value = "AI analysis result" test_household_id = 1500 @@ -45,35 +52,49 @@ def test_execute_tracer_analysis_success(mock_trigger_ai_analysis, mock_db, app, # Set this to US current law test_policy_id = 2 - with app.test_request_context('/us/tracer_analysis', json={ - "household_id": test_household_id, - "policy_id": test_policy_id, - "variable": "disposable_income" - }): + with app.test_request_context( + "/us/tracer_analysis", + json={ + "household_id": test_household_id, + "policy_id": test_policy_id, + "variable": "disposable_income", + }, + ): response = execute_tracer_analysis("us") assert response.status_code == 200 assert b"AI analysis result" in response.data -@patch('policyengine_api.endpoints.tracer_analysis.local_database') + +@patch("policyengine_api.endpoints.tracer_analysis.local_database") def test_execute_tracer_analysis_no_tracer(mock_db, app, rest_client): mock_db.query.return_value.fetchone.return_value = None - with app.test_request_context('/us/tracer_analysis', json={ - "household_id": "test_household", - "policy_id": "test_policy", - "variable": "disposable_income" - }): + with app.test_request_context( + "/us/tracer_analysis", + json={ + "household_id": "test_household", + "policy_id": "test_policy", + "variable": "disposable_income", + }, + ): response = execute_tracer_analysis("us") - + assert response.status_code == 404 - assert "no household simulation tracer found" in response.response["message"] + assert ( + "no household simulation tracer found" in response.response["message"] + ) + -@patch('policyengine_api.endpoints.tracer_analysis.local_database') -@patch('policyengine_api.endpoints.tracer_analysis.trigger_ai_analysis') -def test_execute_tracer_analysis_ai_error(mock_trigger_ai_analysis, mock_db, app, rest_client): +@patch("policyengine_api.endpoints.tracer_analysis.local_database") +@patch("policyengine_api.endpoints.tracer_analysis.trigger_ai_analysis") +def test_execute_tracer_analysis_ai_error( + mock_trigger_ai_analysis, mock_db, app, rest_client +): mock_db.query.return_value.fetchone.return_value = { - "tracer_output": json.dumps(["disposable_income <1000>", " market_income <1000>"]) + "tracer_output": json.dumps( + ["disposable_income <1000>", " market_income <1000>"] + ) } mock_trigger_ai_analysis.side_effect = Exception(KeyError) @@ -82,22 +103,29 @@ def test_execute_tracer_analysis_ai_error(mock_trigger_ai_analysis, mock_db, app # Set this to US current law test_policy_id = 2 - with app.test_request_context('/us/tracer_analysis', json={ - "household_id": test_household_id, - "policy_id": test_policy_id, - "variable": "disposable_income" - }): + with app.test_request_context( + "/us/tracer_analysis", + json={ + "household_id": test_household_id, + "policy_id": test_policy_id, + "variable": "disposable_income", + }, + ): response = execute_tracer_analysis("us") - + assert response.status_code == 500 assert "Error computing analysis" in response.response["message"] + # Test invalid country def test_invalid_country(rest_client): - response = rest_client.post('/invalid_country/tracer_analysis', json={ - "household_id": "test_household", - "policy_id": "test_policy", - "variable": "disposable_income" - }) + response = rest_client.post( + "/invalid_country/tracer_analysis", + json={ + "household_id": "test_household", + "policy_id": "test_policy", + "variable": "disposable_income", + }, + ) assert response.status_code == 404 - assert b"Country invalid_country not found" in response.data \ No newline at end of file + assert b"Country invalid_country not found" in response.data