Skip to content

Commit

Permalink
[Lake] integrate pdr_subscriptions into GQL Data Factory (#469)
Browse files Browse the repository at this point in the history
* first commit for subscriptions

* hook up pdr_subscriptions to gql_factory

* Tests passing, expanding tests to support multiple tables

* Adding tests and improving handling of empty parquet files

* Subscriptions test

* Updating logic to use predictSubscriptions, take lastPriceValue, and to not query the subgraph more than needed.

* Moving models from contract/ -> subgraph/

* Fixing pylint

* fixing tests

* adding @enforce_types
  • Loading branch information
idiom-bytes authored Jan 4, 2024
1 parent bd01d71 commit a0244c9
Show file tree
Hide file tree
Showing 18 changed files with 813 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pdr_backend/analytics/predictoor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import polars as pl
from enforce_typing import enforce_types

from pdr_backend.contract.prediction import Prediction
from pdr_backend.subgraph.prediction import Prediction
from pdr_backend.util.csvs import get_plots_dir


Expand Down
2 changes: 1 addition & 1 deletion pdr_backend/analytics/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from pdr_backend.contract.prediction import (
from pdr_backend.subgraph.prediction import (
mock_daily_predictions,
mock_first_predictions,
mock_second_predictions,
Expand Down
21 changes: 19 additions & 2 deletions pdr_backend/lake/gql_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
get_pdr_predictions_df,
predictions_schema,
)
from pdr_backend.lake.table_pdr_subscriptions import (
get_pdr_subscriptions_df,
subscriptions_schema,
)
from pdr_backend.ppss.ppss import PPSS
from pdr_backend.subgraph.subgraph_predictions import get_all_contract_ids_by_owner
from pdr_backend.util.networkutil import get_sapphire_postfix
Expand Down Expand Up @@ -48,6 +52,13 @@ def __init__(self, ppss: PPSS):
"contract_list": contract_list,
},
},
"pdr_subscriptions": {
"fetch_fn": get_pdr_subscriptions_df,
"schema": subscriptions_schema,
"config": {
"contract_list": contract_list,
},
},
}

def get_gql_dfs(self) -> Dict[str, pl.DataFrame]:
Expand Down Expand Up @@ -164,7 +175,13 @@ def _load_parquet(self, fin_ut: int) -> Dict[str, pl.DataFrame]:
print(f" filename={filename}")

# load all data from file
df = pl.read_parquet(filename)
# check if file exists
# if file doesn't exist, return an empty dataframe with the expected schema
if os.path.exists(filename):
df = pl.read_parquet(filename)
else:
df = pl.DataFrame(schema=record["schema"])

df = df.filter(
(pl.col("timestamp") >= st_ut) & (pl.col("timestamp") <= fin_ut)
)
Expand Down Expand Up @@ -202,7 +219,7 @@ def _save_parquet(self, filename: str, df: pl.DataFrame):
if len(df) > 1:
assert (
df.head(1)["timestamp"].to_list()[0]
< df.tail(1)["timestamp"].to_list()[0]
<= df.tail(1)["timestamp"].to_list()[0]
)

if os.path.exists(filename): # "append" existing file
Expand Down
16 changes: 15 additions & 1 deletion pdr_backend/lake/plutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
from io import StringIO
from tempfile import mkdtemp
from typing import List
from typing import List, Dict

import numpy as np
import polars as pl
Expand Down Expand Up @@ -177,3 +177,17 @@ def text_to_df(s: str) -> pl.DataFrame:
df = pl.scan_csv(filename, separator="|").collect()
shutil.rmtree(tmpdir)
return df


@enforce_types
def _object_list_to_df(objects: List[object], schema: Dict) -> pl.DataFrame:
"""
@description
Convert list objects to a dataframe using their __dict__ structure.
"""
# Get all predictions into a dataframe
obj_dicts = [object.__dict__ for object in objects]
obj_df = pl.DataFrame(obj_dicts, schema=schema)
assert obj_df.schema == schema

return obj_df
18 changes: 3 additions & 15 deletions pdr_backend/lake/table_pdr_predictions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict

import polars as pl
from enforce_typing import enforce_types
Expand All @@ -8,10 +8,11 @@
FilterMode,
fetch_filtered_predictions,
)
from pdr_backend.lake.plutil import _object_list_to_df
from pdr_backend.util.networkutil import get_sapphire_postfix
from pdr_backend.util.timeutil import ms_to_seconds

# RAW_PREDICTIONS_SCHEMA
# RAW PREDICTOOR PREDICTIONS SCHEMA
predictions_schema = {
"ID": Utf8,
"pair": Utf8,
Expand All @@ -27,19 +28,6 @@
}


def _object_list_to_df(objects: List[object], schema: Dict) -> pl.DataFrame:
"""
@description
Convert list objects to a dataframe using their __dict__ structure.
"""
# Get all predictions into a dataframe
obj_dicts = [object.__dict__ for object in objects]
obj_df = pl.DataFrame(obj_dicts, schema=schema)
assert obj_df.schema == schema

return obj_df


def _transform_timestamp_to_ms(df: pl.DataFrame) -> pl.DataFrame:
df = df.with_columns(
[
Expand Down
59 changes: 59 additions & 0 deletions pdr_backend/lake/table_pdr_subscriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Dict

import polars as pl
from enforce_typing import enforce_types
from polars import Int64, Utf8, Float32

from pdr_backend.subgraph.subgraph_subscriptions import (
fetch_filtered_subscriptions,
)
from pdr_backend.lake.table_pdr_predictions import _transform_timestamp_to_ms
from pdr_backend.lake.plutil import _object_list_to_df
from pdr_backend.util.networkutil import get_sapphire_postfix
from pdr_backend.util.timeutil import ms_to_seconds


# RAW PREDICTOOR SUBSCRIPTIONS SCHEMA
subscriptions_schema = {
"ID": Utf8,
"pair": Utf8,
"timeframe": Utf8,
"source": Utf8,
"tx_id": Utf8,
"last_price_value": Float32,
"timestamp": Int64,
"user": Utf8,
}


@enforce_types
def get_pdr_subscriptions_df(
network: str, st_ut: int, fin_ut: int, config: Dict
) -> pl.DataFrame:
"""
@description
Fetch raw subscription events from predictoor subgraph
Update function for graphql query, returns raw data
+ Transforms ts into ms as required for data factory
"""
network = get_sapphire_postfix(network)

# fetch subscriptions
subscriptions = fetch_filtered_subscriptions(
ms_to_seconds(st_ut), ms_to_seconds(fin_ut), config["contract_list"], network
)

if len(subscriptions) == 0:
print(" No subscriptions fetched. Exit.")
return pl.DataFrame()

# convert subscriptions to df and transform timestamp into ms
subscriptions_df = _object_list_to_df(subscriptions, subscriptions_schema)
subscriptions_df = _transform_timestamp_to_ms(subscriptions_df)

# cull any records outside of our time range and sort them by timestamp
subscriptions_df = subscriptions_df.filter(
pl.col("timestamp").is_between(st_ut, fin_ut)
).sort("timestamp")

return subscriptions_df
8 changes: 7 additions & 1 deletion pdr_backend/lake/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import pytest

from pdr_backend.contract.prediction import mock_daily_predictions
from pdr_backend.subgraph.prediction import mock_daily_predictions
from pdr_backend.subgraph.subscription import mock_subscriptions


@pytest.fixture()
def sample_daily_predictions():
return mock_daily_predictions()


@pytest.fixture()
def sample_subscriptions():
return mock_subscriptions()
13 changes: 13 additions & 0 deletions pdr_backend/lake/test/resources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from typing import Dict

import polars as pl
from enforce_typing import enforce_types
Expand Down Expand Up @@ -44,10 +45,22 @@ def _gql_data_factory(tmpdir, feed, st_timestr=None, fin_timestr=None):
network = "sapphire-mainnet"
ppss = mock_ppss([feed], network, str(tmpdir), st_timestr, fin_timestr)
ppss.web3_pp = mock_web3_pp(network)

# setup lake
parquet_dir = str(tmpdir)
lake_ss = _lake_ss(parquet_dir, [feed], st_timestr, fin_timestr)
ppss.lake_ss = lake_ss

gql_data_factory = GQLDataFactory(ppss)
return ppss, gql_data_factory


@enforce_types
def _filter_gql_config(record_config: Dict, record_filter: str) -> Dict:
# Return a filtered version of record_config for testing
return {k: v for k, v in record_config.items() if k == record_filter}


@enforce_types
def _predictoor_ss(predict_feed, input_feeds):
return PredictoorSS(
Expand Down
101 changes: 95 additions & 6 deletions pdr_backend/lake/test/test_gql_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enforce_typing import enforce_types

from pdr_backend.lake.table_pdr_predictions import predictions_schema
from pdr_backend.lake.test.resources import _gql_data_factory
from pdr_backend.lake.test.resources import _gql_data_factory, _filter_gql_config
from pdr_backend.ppss.web3_pp import del_network_override
from pdr_backend.subgraph.subgraph_predictions import FilterMode
from pdr_backend.util.timeutil import timestr_to_ut
Expand All @@ -31,8 +31,8 @@ def test_update_gql1(
tmpdir,
sample_daily_predictions,
"2023-11-02_0:00",
"2023-11-04_0:00",
n_preds=2,
"2023-11-04_21:00",
n_preds=3,
)


Expand All @@ -52,8 +52,8 @@ def test_update_gql2(
tmpdir,
sample_daily_predictions,
"2023-11-02_0:00",
"2023-11-06_0:00",
n_preds=4,
"2023-11-06_21:00",
n_preds=5,
)


Expand Down Expand Up @@ -92,7 +92,7 @@ def test_update_gql_iteratively(

iterations = [
("2023-11-02_0:00", "2023-11-04_0:00", 2),
("2023-11-01_0:00", "2023-11-05_0:00", 3),
("2023-11-01_0:00", "2023-11-05_0:00", 3), # do not append to start
("2023-11-02_0:00", "2023-11-07_0:00", 5),
]

Expand Down Expand Up @@ -128,6 +128,11 @@ def _test_update_gql(
fin_timestr,
)

# Update predictions record only
gql_data_factory.record_config = _filter_gql_config(
gql_data_factory.record_config, pdr_predictions_record
)

# setup: filename
# everything will be inside the gql folder
filename = gql_data_factory._parquet_filename(pdr_predictions_record)
Expand Down Expand Up @@ -213,6 +218,9 @@ def test_load_and_verify_schema(
st_timestr,
fin_timestr,
)
gql_data_factory.record_config = _filter_gql_config(
gql_data_factory.record_config, pdr_predictions_record
)

fin_ut = timestr_to_ut(fin_timestr)
gql_dfs = gql_data_factory._load_parquet(fin_ut)
Expand Down Expand Up @@ -253,6 +261,12 @@ def test_get_gql_dfs_calls(
fin_timestr,
)

# Update predictions record only
default_config = gql_data_factory.record_config
gql_data_factory.record_config = _filter_gql_config(
gql_data_factory.record_config, pdr_predictions_record
)

# calculate ms locally so we can filter raw Predictions
st_ut = timestr_to_ut(st_timestr)
fin_ut = timestr_to_ut(fin_timestr)
Expand All @@ -278,3 +292,78 @@ def test_get_gql_dfs_calls(

mock_update.assert_called_once()
mock_load_parquet.assert_called_once()

# reset record config
gql_data_factory.record_config = default_config


# ====================================================================
# test loading flow when there are pdr files missing


@enforce_types
@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions")
@patch("pdr_backend.lake.table_pdr_subscriptions.fetch_filtered_subscriptions")
@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner")
def test_load_missing_parquet(
mock_get_all_contract_ids_by_owner,
mock_fetch_filtered_subscriptions,
mock_fetch_filtered_predictions,
tmpdir,
sample_daily_predictions,
monkeypatch,
):
"""Test core DataFactory functions are being called"""
del_network_override(monkeypatch)

mock_get_all_contract_ids_by_owner.return_value = ["0x123"]
mock_fetch_filtered_subscriptions.return_value = []
mock_fetch_filtered_predictions.return_value = []

st_timestr = "2023-11-02_0:00"
fin_timestr = "2023-11-04_0:00"

_, gql_data_factory = _gql_data_factory(
tmpdir,
"binanceus ETH/USDT h 5m",
st_timestr,
fin_timestr,
)

# Work 1: Fetch empty dataset
# (1) perform empty fetch
# (2) do not save to parquet
# (3) handle missing parquet file
# (4) assert we get empty dataframes with the expected schema
dfs = gql_data_factory.get_gql_dfs()

predictions_table = "pdr_predictions"
subscriptions_table = "pdr_subscriptions"

assert len(dfs[predictions_table]) == 0
assert len(dfs[subscriptions_table]) == 0

assert (
dfs[predictions_table].schema
== gql_data_factory.record_config[predictions_table]["schema"]
)
assert (
dfs[subscriptions_table].schema
== gql_data_factory.record_config[subscriptions_table]["schema"]
)

# Work 2: Fetch 1 dataset
# (1) perform 1 successful datafactory loops (predictions)
# (2) assert subscriptions parquet doesn't exist / has 0 records
_test_update_gql(
mock_fetch_filtered_predictions,
tmpdir,
sample_daily_predictions,
st_timestr,
fin_timestr,
n_preds=2,
)

dfs = gql_data_factory.get_gql_dfs()
assert len(dfs[predictions_table]) == 2
assert len(dfs[subscriptions_table]) == 0
Loading

0 comments on commit a0244c9

Please sign in to comment.