Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #471: Update ohlcv-data-factory when predictoor agent initializes #551

Merged
merged 5 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions pdr_backend/predictoor/approach3/predictoor_agent3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@

@enforce_types
class PredictoorAgent3(BasePredictoorAgent):
def __init__(self, ppss):
super().__init__(ppss)
self.get_data_components()

@enforce_types
def get_data_components(self):
# Compute aimodel_ss
lake_ss = self.ppss.lake_ss

# From lake_ss, build X/y
pq_data_factory = OhlcvDataFactory(lake_ss)
mergedohlcv_df = pq_data_factory.get_mergedohlcv_df()

return mergedohlcv_df

@enforce_types
def get_prediction(
self, timestamp: int # pylint: disable=unused-argument
) -> Tuple[bool, float]:
Expand All @@ -24,12 +40,7 @@ def get_prediction(
predval -- bool -- if True, it's predicting 'up'. If False, 'down'
stake -- int -- amount to stake, in units of Eth
"""
# Compute aimodel_ss
lake_ss = self.ppss.lake_ss

# From lake_ss, build X/y
pq_data_factory = OhlcvDataFactory(lake_ss)
mergedohlcv_df = pq_data_factory.get_mergedohlcv_df()
mergedohlcv_df = self.get_data_components()

model_data_factory = AimodelDataFactory(self.ppss.predictoor_ss)
X, y, _ = model_data_factory.create_xy(mergedohlcv_df, testshift=0)
Expand Down
27 changes: 26 additions & 1 deletion pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
from unittest.mock import patch

from pdr_backend.predictoor.approach3.predictoor_agent3 import PredictoorAgent3
from pdr_backend.predictoor.test.predictoor_agent_runner import run_agent_test
from pdr_backend.predictoor.test.predictoor_agent_runner import (
run_agent_test,
get_agent,
)


def test_predictoor_agent3(tmpdir, monkeypatch):
run_agent_test(str(tmpdir), monkeypatch, PredictoorAgent3)


@patch(
"pdr_backend.predictoor.approach3.predictoor_agent3.PredictoorAgent3.get_data_components"
)
def test_predictoor_agent3_data_component(
mock_get_data_components, tmpdir, monkeypatch
):
"""
@description
Test that PredictoorAgent3.get_data_components() is called once.
"""
# initialize agent
feed, _, _, _ = get_agent(str(tmpdir), monkeypatch, PredictoorAgent3)

# assert get_data_components() is called once during init
mock_get_data_components.assert_called_once()

# assert agent was initialized with development feed
assert "BTC/USDT|binanceus|5m" in feed.name
25 changes: 23 additions & 2 deletions pdr_backend/predictoor/test/predictoor_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@


@enforce_types
def run_agent_test(tmpdir: str, monkeypatch, predictoor_agent_class):
def get_agent(tmpdir: str, monkeypatch, predictoor_agent_class):
"""
@description
Initialize the agent, and return it along with the feed and ppss
that it uses.
"""
monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY)
feed, ppss = mock_feed_ppss("5m", "binanceus", "BTC/USDT", tmpdir=tmpdir)
feed, ppss = mock_feed_ppss(
"5m", "binanceus", "BTC/USDT", network="development", tmpdir=tmpdir
)
inplace_mock_query_feed_contracts(ppss.web3_pp, feed)

_mock_pdr_contract = inplace_mock_w3_and_contract_with_tracking(
Expand All @@ -38,6 +45,20 @@ def run_agent_test(tmpdir: str, monkeypatch, predictoor_agent_class):
# real work: initialize
agent = predictoor_agent_class(ppss)

return (feed, ppss, agent, _mock_pdr_contract)


@enforce_types
def run_agent_test(tmpdir: str, monkeypatch, predictoor_agent_class):
"""
@description
Run the agent for a while, and then do some basic sanity checks.
"""
_, ppss, agent, _mock_pdr_contract = get_agent(
tmpdir, monkeypatch, predictoor_agent_class
)
# now we're done the mocking, time for the real work!!

# real work: main iterations
for _ in range(500):
agent.take_step()
Expand Down
13 changes: 11 additions & 2 deletions system_tests/test_predictoor_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ def _test_predictoor_system(mock_feeds, mock_predictoor_contract, approach):
mock_feeds, mock_predictoor_contract
)

merged_ohlcv_df = Mock()

with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch(
"pdr_backend.publisher.publish_assets.get_address", return_value="0x1"
), patch("pdr_backend.ppss.ppss.PredictoorSS", return_value=mock_predictoor_ss):
), patch(
"pdr_backend.ppss.ppss.PredictoorSS", return_value=mock_predictoor_ss
), patch(
"pdr_backend.lake.ohlcv_data_factory.OhlcvDataFactory.get_mergedohlcv_df",
return_value=merged_ohlcv_df,
):
# Mock sys.argv
sys.argv = ["pdr", "predictoor", str(approach), "ppss.yaml", "development"]

Expand Down Expand Up @@ -65,7 +72,9 @@ def test_predictoor_approach_1_system(

@patch("pdr_backend.ppss.ppss.PPSS.verify_feed_dependencies")
def test_predictoor_approach_3_system(
mock_verify_feed_dependencies, mock_feeds, mock_predictoor_contract
mock_verify_feed_dependencies,
mock_feeds,
mock_predictoor_contract,
):
_ = mock_verify_feed_dependencies
_test_predictoor_system(mock_feeds, mock_predictoor_contract, 3)
Loading