Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sqlite decimal warning #129

Merged
merged 1 commit into from
May 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions OpenOversight/app/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from builtins import input
from datetime import date, datetime
from decimal import Decimal
from getpass import getpass
from typing import Dict, List

Expand Down Expand Up @@ -396,21 +397,18 @@ def process_salary(row, officer, compare=False):
# Get existing salaries for officer and compare to row data
salaries = Salary.query.filter_by(officer_id=officer.id).all()
for salary in salaries:
from decimal import Decimal

print(vars(salary))
print(row)
if (
Decimal("%.2f" % salary.salary)
== Decimal("%.2f" % float(row["salary"]))
round(salary.salary, 2) == round(Decimal(row["salary"]), 2)
and salary.year == int(row["salary_year"])
and salary.is_fiscal_year == is_fiscal_year
and (
(
salary.overtime_pay
and "overtime_pay" in row
and Decimal("%.2f" % salary.overtime_pay)
== Decimal("%.2f" % float(row["overtime_pay"]))
and round(salary.overtime_pay, 2)
== round(Decimal(row["overtime_pay"]), 2)
)
or (
not salary.overtime_pay
Expand All @@ -425,12 +423,12 @@ def process_salary(row, officer, compare=False):
# create new salary
salary = Salary(
officer_id=officer.id,
salary=float(row["salary"]),
salary=round(Decimal(row["salary"]), 2),
year=int(row["salary_year"]),
is_fiscal_year=is_fiscal_year,
)
if "overtime_pay" in row and row["overtime_pay"]:
salary.overtime_pay = float(row["overtime_pay"])
salary.overtime_pay = round(Decimal(row["overtime_pay"]), 2)
db.session.add(salary)
db.session.flush()

Expand Down
11 changes: 6 additions & 5 deletions OpenOversight/app/model_imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import dateutil.parser
Expand Down Expand Up @@ -50,9 +51,9 @@ def parse_int(value: Optional[Union[str, int]]) -> Optional[int]:
return None


def parse_float(value: Optional[Union[str, float]]) -> Optional[float]:
if value == 0.0 or value:
return float(value)
def parse_decimal(value: Optional[Union[str, Decimal]]) -> Optional[Decimal]:
if value == Decimal(0) or value:
return Decimal(value)
return None
Comment on lines +54 to 57
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the goal of this function just to filter None? Other than None, all the potential inputs I can think of seem to be equivalent to just using the Decimal constructor:

input parse_decimal(input) Decimal(input)
0 0 0
1 1 1
1.1 1.1 1.1
0.0 0.0 0.0
0.01 0.01000000000000000020816681711721685132943093776702880859375 0.01000000000000000020816681711721685132943093776702880859375
0.0 0 0
0.0 0 0
-1 -1 -1
0 0 0
1000.01 1000.009999999999990905052982270717620849609375 1000.009999999999990905052982270717620849609375

It seems like the function might be able to be simplified (from a reader's perspective) like this:

Suggested change
def parse_decimal(value: Optional[Union[str, Decimal]]) -> Optional[Decimal]:
if value == Decimal(0) or value:
return Decimal(value)
return None
def parse_decimal(value: Optional[Union[str, Decimal]]) -> Optional[Decimal]:
return Decimal(value) if value is not None else None

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with this part of the code at all but the type hints seem to suggest that this function takes both string and decimal, so I think the original function might have been trying to filter out empty string as well as None. Does that make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, you're right, I totally missed that in my test. Decimal("") throws an error 👍



Expand Down Expand Up @@ -162,8 +163,8 @@ def update_assignment_from_dict(
def create_salary_from_dict(data: Dict[str, Any], force_id: bool = False) -> Salary:
salary = Salary(
officer_id=int(data["officer_id"]),
salary=float(data["salary"]),
overtime_pay=parse_float(data.get("overtime_pay")),
salary=Decimal(data["salary"]),
overtime_pay=parse_decimal(data.get("overtime_pay")),
year=int(data["year"]),
is_fiscal_year=parse_bool(data.get("is_fiscal_year")),
)
Expand Down
30 changes: 28 additions & 2 deletions OpenOversight/app/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import time
from datetime import date
from decimal import Decimal

from authlib.jose import JoseError, JsonWebToken
from flask import current_app
Expand Down Expand Up @@ -225,14 +226,39 @@ def __repr__(self):
)


class Currency(db.TypeDecorator):
"""
Store currency as an integer in sqlite to avoid float conversion
https://stackoverflow.com/questions/10355767/
"""

impl = db.Numeric

def load_dialect_impl(self, dialect):
typ = db.Numeric()
if dialect.name == "sqlite":
typ = db.Integer()
return dialect.type_descriptor(typ)

def process_bind_param(self, value, dialect):
if dialect.name == "sqlite" and value is not None:
value = int(Decimal(value) * 100)
return value

def process_result_value(self, value, dialect):
if dialect.name == "sqlite" and value is not None:
value = Decimal(value) / 100
return value


class Salary(BaseModel):
__tablename__ = "salaries"

id = db.Column(db.Integer, primary_key=True)
officer_id = db.Column(db.Integer, db.ForeignKey("officers.id", ondelete="CASCADE"))
officer = db.relationship("Officer", back_populates="salaries")
salary = db.Column(db.Numeric, index=True, unique=False, nullable=False)
overtime_pay = db.Column(db.Numeric, index=True, unique=False, nullable=True)
salary = db.Column(Currency(), index=True, unique=False, nullable=False)
overtime_pay = db.Column(Currency(), index=True, unique=False, nullable=True)
year = db.Column(db.Integer, index=True, unique=False, nullable=False)
is_fiscal_year = db.Column(db.Boolean, index=False, unique=False, nullable=False)

Expand Down
3 changes: 2 additions & 1 deletion OpenOversight/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import time
import uuid
from decimal import Decimal
from io import BytesIO
from pathlib import Path
from typing import List
Expand Down Expand Up @@ -111,7 +112,7 @@ def pick_uid():


def pick_salary():
return random.randint(100, 100000000) / 100
return Decimal(random.randint(100, 100000000)) / 100


def generate_officer():
Expand Down
33 changes: 19 additions & 14 deletions OpenOversight/tests/routes/test_officer_and_department.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import random
from datetime import date, datetime
from decimal import Decimal
from io import BytesIO

import pytest
Expand Down Expand Up @@ -1894,7 +1895,7 @@ def test_admin_can_add_salary(mockdata, client, session):
login_admin(client)

form = SalaryForm(
salary=123456.78, overtime_pay=666.66, year=2019, is_fiscal_year=False
salary="123456.78", overtime_pay="666.66", year=2019, is_fiscal_year=False
)

rv = client.post(
Expand All @@ -1906,7 +1907,9 @@ def test_admin_can_add_salary(mockdata, client, session):
assert "Added new salary" in rv.data.decode("utf-8")
assert "<td>$123,456.78</td>" in rv.data.decode("utf-8")

officer = Officer.query.filter(Officer.salaries.any(salary=123456.78)).first()
officer = Officer.query.filter(
Officer.salaries.any(salary=Decimal("123456.78"))
).first()
assert officer is not None


Expand All @@ -1915,7 +1918,7 @@ def test_ac_can_add_salary_in_their_dept(mockdata, client, session):
login_ac(client)

form = SalaryForm(
salary=123456.78, overtime_pay=666.66, year=2019, is_fiscal_year=False
salary="123456.78", overtime_pay="666.66", year=2019, is_fiscal_year=False
)
officer = Officer.query.filter_by(department_id=AC_DEPT).first()

Expand All @@ -1928,7 +1931,9 @@ def test_ac_can_add_salary_in_their_dept(mockdata, client, session):
assert "Added new salary" in rv.data.decode("utf-8")
assert "<td>$123,456.78</td>" in rv.data.decode("utf-8")

officer = Officer.query.filter(Officer.salaries.any(salary=123456.78)).first()
officer = Officer.query.filter(
Officer.salaries.any(salary=Decimal("123456.78"))
).first()
assert officer is not None


Expand All @@ -1937,7 +1942,7 @@ def test_ac_cannot_add_non_dept_salary(mockdata, client, session):
login_ac(client)

form = SalaryForm(
salary=123456.78, overtime_pay=666.66, year=2019, is_fiscal_year=False
salary="123456.78", overtime_pay="666.66", year=2019, is_fiscal_year=False
)
officer = Officer.query.except_(
Officer.query.filter_by(department_id=AC_DEPT)
Expand All @@ -1960,7 +1965,7 @@ def test_admin_can_edit_salary(mockdata, client, session):
Salary.query.filter_by(officer_id=1).delete()

form = SalaryForm(
salary=123456.78, overtime_pay=666.66, year=2019, is_fiscal_year=False
salary="123456.78", overtime_pay="666.66", year=2019, is_fiscal_year=False
)

rv = client.post(
Expand All @@ -1972,7 +1977,7 @@ def test_admin_can_edit_salary(mockdata, client, session):
assert "Added new salary" in rv.data.decode("utf-8")
assert "<td>$123,456.78</td>" in rv.data.decode("utf-8")

form = SalaryForm(salary=150000)
form = SalaryForm(salary="150000")
officer = Officer.query.filter_by(id=1).one()

rv = client.post(
Expand All @@ -1990,7 +1995,7 @@ def test_admin_can_edit_salary(mockdata, client, session):
assert "<td>$150,000.00</td>" in rv.data.decode("utf-8")

officer = Officer.query.filter_by(id=1).one()
assert officer.salaries[0].salary == 150000
assert officer.salaries[0].salary == Decimal(150000)


def test_ac_can_edit_salary_in_their_dept(mockdata, client, session):
Expand All @@ -2003,7 +2008,7 @@ def test_ac_can_edit_salary_in_their_dept(mockdata, client, session):
Salary.query.filter_by(officer_id=officer_id).delete()

form = SalaryForm(
salary=123456.78, overtime_pay=666.66, year=2019, is_fiscal_year=False
salary="123456.78", overtime_pay="666.66", year=2019, is_fiscal_year=False
)

rv = client.post(
Expand All @@ -2015,7 +2020,7 @@ def test_ac_can_edit_salary_in_their_dept(mockdata, client, session):
assert "Added new salary" in rv.data.decode("utf-8")
assert "<td>$123,456.78</td>" in rv.data.decode("utf-8")

form = SalaryForm(salary=150000)
form = SalaryForm(salary="150000")
officer = Officer.query.filter_by(id=officer_id).one()

rv = client.post(
Expand All @@ -2033,7 +2038,7 @@ def test_ac_can_edit_salary_in_their_dept(mockdata, client, session):
assert "<td>$150,000.00</td>" in rv.data.decode("utf-8")

officer = Officer.query.filter_by(id=officer_id).one()
assert officer.salaries[0].salary == 150000
assert officer.salaries[0].salary == Decimal(150000)


def test_ac_cannot_edit_non_dept_salary(mockdata, client, session):
Expand All @@ -2047,7 +2052,7 @@ def test_ac_cannot_edit_non_dept_salary(mockdata, client, session):
Salary.query.filter_by(officer_id=officer_id).delete()

form = SalaryForm(
salary=123456.78, overtime_pay=666.66, year=2019, is_fiscal_year=False
salary="123456.78", overtime_pay="666.66", year=2019, is_fiscal_year=False
)

login_admin(client)
Expand All @@ -2061,7 +2066,7 @@ def test_ac_cannot_edit_non_dept_salary(mockdata, client, session):
assert "<td>$123,456.78</td>" in rv.data.decode("utf-8")

login_ac(client)
form = SalaryForm(salary=150000)
form = SalaryForm(salary="150000")
officer = Officer.query.filter_by(id=officer_id).one()

rv = client.post(
Expand All @@ -2078,7 +2083,7 @@ def test_ac_cannot_edit_non_dept_salary(mockdata, client, session):
assert rv.status_code == 403

officer = Officer.query.filter_by(id=officer_id).one()
assert float(officer.salaries[0].salary) == 123456.78
assert officer.salaries[0].salary == Decimal("123456.78")


def test_get_department_ranks_with_specific_department_id(mockdata, client, session):
Expand Down
5 changes: 4 additions & 1 deletion OpenOversight/tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import traceback
import uuid
from decimal import Decimal

import pandas as pd
import pytest
Expand Down Expand Up @@ -423,7 +424,9 @@ def test_csv_new_salary(csvfile):
officer = Officer.query.filter_by(id=officer_id).one()
assert len(list(officer.salaries)) == 2
for salary in officer.salaries:
assert float(salary.salary) == 123456.78 or float(salary.salary) == 150000.00
assert salary.salary == Decimal("123456.78") or salary.salary == Decimal(
"150000.00"
)


def test_bulk_add_officers__success(session, department_with_ranks, csv_path):
Expand Down
36 changes: 30 additions & 6 deletions OpenOversight/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime
import time
from decimal import Decimal

from pytest import raises
import pytest
from mock import MagicMock

from OpenOversight.app.models import (
Assignment,
Currency,
Department,
Face,
Image,
Expand Down Expand Up @@ -86,7 +89,7 @@ def test_salary_repr(mockdata):

def test_password_not_printed(mockdata):
user = User(password="bacon")
with raises(AttributeError):
with pytest.raises(AttributeError):
user.password


Expand Down Expand Up @@ -231,7 +234,7 @@ def test_area_coordinator_with_dept_is_valid(mockdata):


def test_locations_must_have_valid_zip_codes(mockdata):
with raises(ValueError):
with pytest.raises(ValueError):
Location(
street_name="Brookford St",
cross_street1="Mass Ave",
Expand Down Expand Up @@ -260,7 +263,7 @@ def test_locations_can_be_saved_with_valid_zip_codes(mockdata):


def test_locations_must_have_valid_states(mockdata):
with raises(ValueError):
with pytest.raises(ValueError):
Location(
street_name="Brookford St",
cross_street1="Mass Ave",
Expand Down Expand Up @@ -290,7 +293,7 @@ def test_locations_can_be_saved_with_valid_states(mockdata):


def test_license_plates_must_have_valid_states(mockdata):
with raises(ValueError):
with pytest.raises(ValueError):
LicensePlate(number="603EEE", state="JK")


Expand All @@ -310,7 +313,7 @@ def test_license_plates_can_be_saved_with_valid_states(mockdata):

def test_links_must_have_valid_urls(mockdata):
bad_url = "www.rachel.com"
with raises(ValueError):
with pytest.raises(ValueError):
Link(link_type="video", url=bad_url)


Expand Down Expand Up @@ -381,3 +384,24 @@ def test_images_added_with_user_id(mockdata):
db.session.commit()
saved = Image.query.filter_by(user_id=user_id).first()
assert saved is not None


@pytest.mark.parametrize(
"dialect_name,original_value,intermediate_value",
[
("sqlite", None, None),
("sqlite", Decimal("123.45"), 12345),
("postgresql", None, None),
("postgresql", Decimal("123.45"), Decimal("123.45")),
],
)
def test_currency_type_decorator(dialect_name, original_value, intermediate_value):
currency = Currency()
dialect = MagicMock()
dialect.name = dialect_name

value = currency.process_bind_param(original_value, dialect)
assert intermediate_value == value

value = currency.process_result_value(value, dialect)
assert original_value == value
4 changes: 4 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ import +args:
lint:
pre-commit run --all-files

# Run Flask-Migrate tasks in the web container
db +migrateargs:
just run --no-deps web flask db {{ migrateargs }}

# Run unit tests in the web container
test *pytestargs:
just run --no-deps web pytest -n auto -m "not acceptance" {{ pytestargs }}
Expand Down