Skip to content

Commit

Permalink
chore: Changelog and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
anth-volk committed Oct 7, 2024
1 parent 3c2cd01 commit 38a75b0
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 100 deletions.
10 changes: 10 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
- bump: major
changes:
added:
- /tracer_analysis endpoint for household tracer outputs
- Database for simulation tracer outputs
changed:
- /analysis endpoint renamed to /simulation_analysis endpoint
- Simulation runs now write tracer output to database
- Simulation analysis runs now return ReadableStreams
- Refactored Claude interaction code to be decoupled from endpoints
5 changes: 4 additions & 1 deletion policyengine_api/ai_prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from policyengine_api.ai_prompts.tracer import tracer_analysis_prompt
from policyengine_api.ai_prompts.simulation import generate_simulation_analysis_prompt, audience_descriptions
from policyengine_api.ai_prompts.simulation import (
generate_simulation_analysis_prompt,
audience_descriptions,
)
8 changes: 4 additions & 4 deletions policyengine_api/ai_prompts/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

audience_descriptions = {
"ELI5": "Write this for a five-year-old who doesn't know anything about economics or policy. Explain fundamental concepts like taxes, poverty rates, and inequality as needed.",
"Normal":
"Write this for a policy analyst who knows a bit about economics and policy.",
"Normal": "Write this for a policy analyst who knows a bit about economics and policy.",
"Wonk": "Write this for a policy analyst who knows a lot about economics and policy. Use acronyms and jargon if it makes the content more concise and informative.",
}


def generate_simulation_analysis_prompt(
time_period,
region,
Expand All @@ -18,7 +18,7 @@ def generate_simulation_analysis_prompt(
is_enhanced_cps,
selected_version,
country_id,
policy_label
policy_label,
):
return f"""
I'm using PolicyEngine, a free, open source tool to compute the impact of
Expand Down Expand Up @@ -125,4 +125,4 @@ def generate_simulation_analysis_prompt(
and the share held by the top 1% (describe the relative changes): {json.dumps(
impact["inequality"],
)}
"""
"""
2 changes: 1 addition & 1 deletion policyengine_api/ai_prompts/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
Keep your explanation concise but informative, suitable for a general audience. Do not start with phrases like "Certainly!" or "Here's an explanation. It will be rendered as markdown, so preface $ with \.
{anthropic.AI_PROMPT}"""
{anthropic.AI_PROMPT}"""
4 changes: 3 additions & 1 deletion policyengine_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@
methods=["GET"],
)(get_economic_impact)

app.route("/<country_id>/simulation_analysis", methods=["POST"])(execute_simulation_analysis)
app.route("/<country_id>/simulation_analysis", methods=["POST"])(
execute_simulation_analysis
)

app.route("/<country_id>/user_policy", methods=["POST"])(set_user_policy)

Expand Down
8 changes: 7 additions & 1 deletion policyengine_api/country.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,13 @@ def build_entities(self) -> dict:
# 4. Delete the code at the end of the function that writes to a file (Done)
# 5. Add code at the end of the function to write to a database (Done)

def calculate(self, household: dict, reform: Union[dict, None], household_id: Optional[int], policy_id: Optional[int] = None):
def calculate(
self,
household: dict,
reform: Union[dict, None],
household_id: Optional[int],
policy_id: Optional[int] = None,
):
if reform is not None and len(reform.keys()) > 0:
system = self.tax_benefit_system.clone()
for parameter_name in reform:
Expand Down
5 changes: 4 additions & 1 deletion policyengine_api/endpoints/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ def get_household_under_policy(

try:
result = country.calculate(
household["household_json"], policy["policy_json"], household_id, policy_id
household["household_json"],
policy["policy_json"],
household_id,
policy_id,
)
except Exception as e:
logging.exception(e)
Expand Down
33 changes: 18 additions & 15 deletions policyengine_api/endpoints/simulation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
from rq import Queue
from redis import Redis
from typing import Optional
from policyengine_api.utils.ai_analysis import trigger_ai_analysis, get_existing_analysis
from policyengine_api.ai_prompts import generate_simulation_analysis_prompt, audience_descriptions
from policyengine_api.utils.ai_analysis import (
trigger_ai_analysis,
get_existing_analysis,
)
from policyengine_api.ai_prompts import (
generate_simulation_analysis_prompt,
audience_descriptions,
)

queue = Queue(connection=Redis())


def execute_simulation_analysis(country_id: str) -> Response:

# Pop the various parameters from the request
payload = request.json

Expand All @@ -22,10 +29,12 @@ def execute_simulation_analysis(country_id: str) -> Response:
policy = payload.get("policy")
region = payload.get("region")
relevant_parameters = payload.get("relevant_parameters")
relevant_parameter_baseline_values = payload.get("relevant_parameter_baseline_values")
relevant_parameter_baseline_values = payload.get(
"relevant_parameter_baseline_values"
)
audience = payload.get("audience")

# Check if the region is enhanced_cps
# Check if the region is enhanced_cps
is_enhanced_cps = "enhanced_cps" in region

# Create prompt based on data
Expand All @@ -40,7 +49,7 @@ def execute_simulation_analysis(country_id: str) -> Response:
is_enhanced_cps,
selected_version,
country_id,
policy_label
policy_label,
)

# Add audience description to end
Expand All @@ -50,23 +59,17 @@ def execute_simulation_analysis(country_id: str) -> Response:
# streaming response
existing_analysis = get_existing_analysis(prompt)
if existing_analysis is not None:
return Response(
status=200,
response=existing_analysis
)
return Response(status=200, response=existing_analysis)

# Otherwise, pass prompt to Claude, then return streaming function
try:
analysis = trigger_ai_analysis(prompt)
return Response(
status=200,
response=analysis
)
return Response(status=200, response=analysis)
except Exception as e:
return Response(
status=500,
response={
"message": "Error computing analysis",
"error": str(e),
}
},
)
29 changes: 13 additions & 16 deletions policyengine_api/endpoints/tracer_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from flask import Response, request
from policyengine_api.country import validate_country
from policyengine_api.ai_prompts import tracer_analysis_prompt
from policyengine_api.utils.ai_analysis import trigger_ai_analysis, get_existing_analysis
from policyengine_api.utils.ai_analysis import (
trigger_ai_analysis,
get_existing_analysis,
)
from policyengine_api.utils.tracer_analysis import parse_tracer_output
from policyengine_api.country import COUNTRY_PACKAGE_VERSIONS

Expand All @@ -13,7 +16,8 @@
# Access the prompt and add the parsed tracer output
# Pass the complete prompt to the get_analysis function and return its response

#TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit
# TODO: Add the prompt in a new variable; this could even be duplicated from the Streamlit


def execute_tracer_analysis(
country_id: str,
Expand All @@ -27,7 +31,7 @@ def execute_tracer_analysis(
country_not_found = validate_country(country_id)
if country_not_found:
return country_not_found

payload = request.json

household_id = payload.get("household_id")
Expand All @@ -51,7 +55,7 @@ def execute_tracer_analysis(
status=404,
response={
"message": "Unable to analyze household: no household simulation tracer found",
}
},
)

# Parse the tracer output
Expand All @@ -64,36 +68,29 @@ def execute_tracer_analysis(
response={
"message": "Error parsing tracer output",
"error": str(e),
}
},
)

# Add the parsed tracer output to the prompt
prompt = tracer_analysis_prompt.format(
variable=variable,
tracer_segment=tracer_segment
variable=variable, tracer_segment=tracer_segment
)

# If a calculated record exists for this prompt, return it as a
# streaming response
existing_analysis = get_existing_analysis(prompt)
if existing_analysis is not None:
return Response(
status=200,
response=existing_analysis
)
return Response(status=200, response=existing_analysis)

# Otherwise, pass prompt to Claude, then return streaming function
try:
analysis = trigger_ai_analysis(prompt)
return Response(
status=200,
response=analysis
)
return Response(status=200, response=analysis)
except Exception as e:
return Response(
status=500,
response={
"message": "Error computing analysis",
"error": str(e),
}
},
)
10 changes: 5 additions & 5 deletions policyengine_api/utils/ai_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def trigger_ai_analysis(prompt: str) -> Generator[str, None, None]:

# Configure a Claude client
claude_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

Expand Down Expand Up @@ -38,7 +38,7 @@ def generate():
chunk = buffer[:chunk_size]
buffer = buffer[chunk_size:]
yield json.dumps({"stream": chunk}) + "\n"

if buffer:
yield json.dumps({"stream": buffer}) + "\n"

Expand All @@ -63,7 +63,7 @@ def get_existing_analysis(prompt: str) -> Generator[str, None, None] | None:

if analysis is None:
return None

def generate():

# First, yield prompt so it's accessible on front end
Expand All @@ -75,8 +75,8 @@ def generate():

chunk_size = 5
for i in range(0, len(analysis["analysis"]), chunk_size):
chunk = analysis["analysis"][i:i + chunk_size]
chunk = analysis["analysis"][i : i + chunk_size]
yield json.dumps({"stream": chunk}) + "\n"
time.sleep(0.05)

return generate()
return generate()
9 changes: 5 additions & 4 deletions policyengine_api/utils/tracer_analysis.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import re


def parse_tracer_output(tracer_output, target_variable):
result = []
target_indent = None
capturing = False

# Create a regex pattern to match the exact variable name
# This will match the variable name followed by optional whitespace,
# This will match the variable name followed by optional whitespace,
# then optional angle brackets with any content, then optional whitespace
pattern = rf'^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*'
pattern = rf"^(\s*)({re.escape(target_variable)})\s*(?:<[^>]*>)?\s*"

for line in tracer_output:
# Count leading spaces to determine indentation level
indent = len(line) - len(line.strip())

# Check if this line matches our target variable
match = re.match(pattern, line)
if match and not capturing:
Expand All @@ -27,4 +28,4 @@ def parse_tracer_output(tracer_output, target_variable):
# Capture dependencies (lines with greater indentation)
result.append(line)

return result
return result
Loading

0 comments on commit 38a75b0

Please sign in to comment.