diff --git a/.github/workflows/check_mainnet.yml b/.github/workflows/check_mainnet.yml index 59089a0e7..b5c178dd9 100644 --- a/.github/workflows/check_mainnet.yml +++ b/.github/workflows/check_mainnet.yml @@ -17,18 +17,20 @@ jobs: with: python-version: "3.11" + - name: Fetch the address file and move it to contracts directory + run: | + wget https://raw.githubusercontent.com/oceanprotocol/contracts/main/addresses/address.json + mkdir -p ~/.ocean/ocean-contracts/artifacts/ + mv address.json ~/.ocean/ocean-contracts/artifacts/ + - name: Install Dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - name: Notify Slack - env: - RPC_URL: "https://sapphire.oasis.io" - SUBGRAPH_URL: "https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph" - PRIVATE_KEY: "0xb23c44b8118eb7a7f70d21b0d20aed9b05d85d22ac6a0e57697c564da1c35554" run: | - output=$(python scripts/check_network.py 1 | grep -E 'FAIL|WARNING|error' || true) + output=$(python pdr check_network ppss.yaml sapphire-mainnet | grep -E 'FAIL|WARNING|error' || true) fact=$(curl -s https://catfact.ninja/fact | jq -r '.fact') if [ -z "$output" ]; then echo "No output, so no message will be sent to Slack" diff --git a/.github/workflows/check_testnet.yml b/.github/workflows/check_testnet.yml index 16a58d4f1..1016b22b9 100644 --- a/.github/workflows/check_testnet.yml +++ b/.github/workflows/check_testnet.yml @@ -17,18 +17,20 @@ jobs: with: python-version: "3.11" + - name: Fetch the address file and move it to contracts directory + run: | + wget https://raw.githubusercontent.com/oceanprotocol/contracts/main/addresses/address.json + mkdir -p ~/.ocean/ocean-contracts/artifacts/ + mv address.json ~/.ocean/ocean-contracts/artifacts/ + - name: Install Dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - name: Notify Slack - env: - RPC_URL: "https://testnet.sapphire.oasis.dev" - SUBGRAPH_URL: "https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph" - PRIVATE_KEY: "0xb23c44b8118eb7a7f70d21b0d20aed9b05d85d22ac6a0e57697c564da1c35554" run: | - output=$(python scripts/check_network.py 1 | grep -E 'FAIL|WARNING|error' | grep -v "1h" || true) + output=$(python pdr check_network ppss.yaml sapphire-testnet | grep -E 'FAIL|WARNING|error' | grep -v "1h" || true) joke=$(curl -s https://official-joke-api.appspot.com/jokes/general/random | jq -r '.[0].setup, .[0].punchline') if [ -z "$output" ]; then echo "No output, so no message will be sent to Slack" diff --git a/.github/workflows/cron_topup.yml b/.github/workflows/cron_topup.yml index 77863f6b9..2c7d434d3 100644 --- a/.github/workflows/cron_topup.yml +++ b/.github/workflows/cron_topup.yml @@ -3,7 +3,7 @@ name: Topup accounts on: schedule: - cron: "0 * * * *" - + jobs: topup-mainnet: runs-on: ubuntu-latest @@ -15,18 +15,20 @@ jobs: uses: actions/setup-python@v2 with: python-version: "3.11" - + - name: Fetch the address file and move it to contracts directory + run: | + wget https://raw.githubusercontent.com/oceanprotocol/contracts/main/addresses/address.json + mkdir -p ~/.ocean/ocean-contracts/artifacts/ + mv address.json ~/.ocean/ocean-contracts/artifacts/ - name: Install Dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - name: Set env variables run: | - echo "SUBGRAPH_URL=http://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph" >> $GITHUB_ENV - echo "RPC_URL=https://sapphire.oasis.io" >> $GITHUB_ENV echo "PRIVATE_KEY=${{ secrets.TOPUP_SCRIPT_PK }}" >> $GITHUB_ENV - name: Run top-up script - run: python3 scripts/topup.py + run: python3 pdr topup ppss.yaml sapphire-mainnet topup-testnet: runs-on: ubuntu-latest @@ -41,10 +43,13 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + - name: Fetch the address file and move it to contracts directory + run: | + wget https://raw.githubusercontent.com/oceanprotocol/contracts/main/addresses/address.json + mkdir -p ~/.ocean/ocean-contracts/artifacts/ + mv address.json ~/.ocean/ocean-contracts/artifacts/ - name: Set env variables run: | - echo "SUBGRAPH_URL=http://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph" >> $GITHUB_ENV - echo "RPC_URL=https://testnet.sapphire.oasis.dev" >> $GITHUB_ENV echo "PRIVATE_KEY=${{ secrets.TOPUP_SCRIPT_PK }}" >> $GITHUB_ENV - name: Run top-up script - run: python3 scripts/topup.py + run: python3 pdr topup ppss.yaml sapphire-testnet diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 126ad65f0..aa9a998f2 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -29,7 +29,6 @@ jobs: name: Checkout Barge with: repository: "oceanprotocol/barge" - ref: "main" path: "barge" - name: Run Barge @@ -57,5 +56,10 @@ jobs: - name: Test with pytest id: pytest run: | - coverage run --omit="*test*" -m pytest + coverage run --source=pdr_backend --omit=*/test/*,*/test_ganache/*,*/test_noganache/* -m pytest coverage report + coverage xml + - name: Publish code coverage + uses: paambaati/codeclimate-action@v2.7.5 + env: + CC_TEST_REPORTER_ID: ${{secrets.CC_TEST_REPORTER_ID}} diff --git a/.gitignore b/.gitignore index 53326f39a..2f2e7459f 100644 --- a/.gitignore +++ b/.gitignore @@ -164,9 +164,12 @@ cython_debug/ .test_cache/ .cache/ -# predictoor dynamic modeling +# predictoor-specific out*.txt +my_ppss.yaml csvs/ +parquet_data/ + # pdr_backend accuracy output pdr_backend/accuracy/output/*.json # pm2 configs diff --git a/.pylintrc b/.pylintrc index 11eb7b615..e97c0d636 100644 --- a/.pylintrc +++ b/.pylintrc @@ -119,7 +119,8 @@ disable=too-many-locals, consider-using-dict-items, consider-using-generator, dangerous-default-value, - unidiomatic-typecheck + unidiomatic-typecheck, + unsubscriptable-object # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/README.md b/README.md index 37cd78601..71fcf0e72 100644 --- a/README.md +++ b/README.md @@ -5,44 +5,79 @@ SPDX-License-Identifier: Apache-2.0 # pdr-backend +⚠️ As of v0.2, the CLI replaces previous `main.py` calls. Update your flows accordingly. + ## Run bots (agents) - **[Run predictoor bot](READMEs/predictoor.md)** - make predictions, make $ - **[Run trader bot](READMEs/trader.md)** - consume predictions, trade, make $ + (If you're a predictoor or trader, you can safely ignore the rest of this README.) +## Settings: PPSS + +A "ppss" yaml file, like [`ppss.yaml`](ppss.yaml), holds parameters for all bots and simulation flows. +- We follow the idiom "pp" = problem setup (what to solve), "ss" = solution strategy (how to solve). +- `PRIVATE_KEY` is an exception; it's set as an envvar. + +When you run a bot from the CLI, you specify your PPSS YAML file. + +## CLI + +(First, [install pdr-backend](READMEs/predictoor.md#install-pdr-backend-repo) first.) + +To see CLI options, in console: +```console +pdr +``` + +This will output something like: +```text +Usage: pdr sim|predictoor|trader|.. + +Main tools: + pdr sim YAML_FILE + pdr predictoor APPROACH YAML_FILE NETWORK + pdr trader APPROACH YAML_FILE NETWORK +... +``` ## Atomic READMEs - [Get tokens](READMEs/get-tokens.md): [testnet faucet](READMEs/testnet-faucet.md), [mainnet ROSE](READMEs/get-rose-on-sapphire.md) & [OCEAN](READMEs/get-ocean-on-sapphire.md) -- [Envvars](READMEs/envvars.md) -- [Predictoor subgraph](READMEs/subgraph.md) -- [Dynamic model codebase](READMEs/dynamic-model-codebase.md) -- [Static models in predictoors](READMEs/static-model.md) +- [Claim payout for predictoor bot](READMEs/payout.md) +- [Predictoor subgraph](READMEs/subgraph.md). [Subgraph filters](READMEs/filters.md) +- [Run barge locally](READMEs/barge.md) ## Flows for core team -- **Backend dev** - for `pdr-backend` itself - - [Main backend-dev README](READMEs/backend-dev.md) +- Backend-dev - for `pdr-backend` itself + - [Local dev flow](READMEs/dev.md) + - [VPS dev flow](READMEs/vps.md) - [Release process](READMEs/release-process.md) - - [Run barge locally](READMEs/barge.md) - - [Run barge remotely on VPS](READMEs/vps.md) - - [MacOS gotchas](READMEs/macos.md) wrt Docker & ports -- **[Run dfbuyer bot](READMEs/dfbuyer.md)** - runs Predictoor DF rewards -- **[Run publisher](READMEs/publisher.md)** - publish new feeds -- **[Scripts](scripts/)** for performance stats, more + - [Clean code guidelines](READMEs/clean-code.md) +- [Run dfbuyer bot](READMEs/dfbuyer.md) - runs Predictoor DF rewards +- [Run publisher](READMEs/publisher.md) - publish new feeds +- [Run trueval](READMEs/trueval.md) - run trueval bot ## Repo structure This repo implements all bots in Predictoor ecosystem. -Each bot has a directory: -- `predictoor` - submits individual predictions -- `trader` - buys aggregated predictions, then trades -- other bots: `trueval` report true values to contract, `dfbuyer` implement Predictoor Data Farming, `publisher` to publish +Each bot has a directory. Alphabetically: +- `dfbuyer` - buy feeds on behalf of Predictoor DF +- `predictoor` - submit individual predictions +- `publisher` - publish pdr data feeds +- `trader` - buy aggregated predictions, then trade +- `trueval` - report true values to contract -Other directories: -- `util` - tools for use by any agent -- `models` - classes that wrap Predictoor contracts; for setup (BaseConfig); and for data feeds (Feed) +Other directories, alphabetically: +- `accuracy` - calculates % correct, for display in predictoor.ai webapp +- `data_eng` - data engineering & modeling +- `models` - class-based data structures, and classes to wrap contracts +- `payout` - OCEAN & ROSE payout +- `ppss` - settings +- `sim` - simulation flow +- `util` - function-based tools diff --git a/READMEs/backend-dev.md b/READMEs/backend-dev.md deleted file mode 100644 index 62f50ebb7..000000000 --- a/READMEs/backend-dev.md +++ /dev/null @@ -1,93 +0,0 @@ - - -# Usage for Backend Devs - -This is for core devs to improve pdr-backend repo itself. - -## Install pdr-backend - -Follow directions to install pdr-backend in [predictoor.md](predictoor.md) - -## Local Network - -First, [install barge](barge.md#install-barge). - -Then, run barge. In barge console: -```console -# Run barge with just predictoor contracts, queryable, but no agents -./start_ocean.sh --no-provider --no-dashboard --predictoor --with-thegraph -``` - -Open a new "work" console and: -```console -# Setup virtualenv -cd pdr-backend -source venv/bin/activate - -# Set envvars -export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" -export ADDRESS_FILE="${HOME}/.ocean/ocean-contracts/artifacts/address.json" - -export RPC_URL=http://127.0.0.1:8545 -export SUBGRAPH_URL="http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph" -#OR: export SUBGRAPH_URL="http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph" -``` - -([envvars.md](envvars.md) has details.) - -### Local Usage: Testing & linting - -In work console, run tests: -```console -#(ensure envvars set as above) - -#run a single test -pytest pdr_backend/util/test/test_constants.py::test_constants1 - -#run all tests in a file -pytest pdr_backend/util/test/test_constants.py - -#run all regular tests; see details on pytest markers to select specific suites -pytest -``` - -In work console, run linting checks: -```console -#run static type-checking. By default, uses config mypy.ini. Note: pytest does dynamic type-checking. -mypy ./ - -#run linting on code style -pylint pdr_backend/* - -#auto-fix some pylint complaints -black ./ -``` - -### Local Usage: Run a custom agent - -Let's say you want to change the trader agent, and use off-the-shelf agents for everything else. Here's how. - -In barge console: -```console -# (Hit ctrl-c to stop existing barge) - -# Run all agents except trader -./start_ocean.sh --predictoor --with-thegraph --with-pdr-trueval --with-pdr-predictoor --with-pdr-publisher --with-pdr-dfbuyer -``` - -In work console: -```console -#(ensure envvars set as above) - -# run trader agent -python pdr_backend/trader/main.py -``` - -(You can track at finer resolution by writing more logs to the [code](../pdr_backend/predictoor/approach3/predictoor_agent3.py), or [querying Predictoor subgraph](subgraph.md).) - -## Remote Usage - -Combine local setup above with remote setup envvars like in [predictoor.md](predictoor.md). diff --git a/READMEs/barge-calls.md b/READMEs/barge-calls.md new file mode 100644 index 000000000..fd515b96d --- /dev/null +++ b/READMEs/barge-calls.md @@ -0,0 +1,41 @@ +### Barge flow of calls + +From getting barge going, here's how it calls specific pdr-backend components and passes arguments. + +- user calls `/barge/start_ocean.sh` to get barge going + - then, `start_ocean.sh` fills `COMPOSE_FILES` incrementally. Eg `COMPOSE_FILES+=" -f ${COMPOSE_DIR}/pdr-publisher.yml"` + - `barge/compose-files/pdr-publisher.yml` sets: + - `pdr-publisher: image: oceanprotocol/pdr-backend:${PDR_BACKEND_VERSION:-latest}` + - `pdr-publisher: command: publisher` + - `pdr-publisher: networks: backend: ipv4_address: 172.15.0.43` + - `pdr-publisher: environment:` + - `RPC_URL: ${NETWORK_RPC_URL}` (= `http://localhost:8545` via `start_ocean.sh`) + - `ADDRESS_FILE: /root/.ocean/ocean-contracts/artifacts/address.json` + - (many `PRIVATE_KEY_*`) + + - then, `start_ocean.sh` pulls the `$COMPOSE_FILES` as needed: + - `[ ${FORCEPULL} = "true" ] && eval docker-compose "$DOCKER_COMPOSE_EXTRA_OPTS" --project-name=$PROJECT_NAME "$COMPOSE_FILES" pull` + + - then, `start_ocean.sh` runs docker-compose including all `$COMPOSE_FILES`: + - `eval docker-compose "$DOCKER_COMPOSE_EXTRA_OPTS" --project-name=$PROJECT_NAME "$COMPOSE_FILES" up --remove-orphans` + - it executes each of the `"command"` entries in compose files. + - (Eg for pdr-publisher.yml, `"command" = "publisher ppss.yaml development"`) + - Which then goes to `pdr-backend/entrypoint.sh` via `"python /app/pdr_backend/pdr $@"` + - (where `@` is unpacked as eg `publisher ppss.yaml development`) [Ref](https://superuser.com/questions/1586997/what-does-symbol-mean-in-the-context-of#:). + - Then it goes through the usual CLI at `pdr-backend/pdr_backend/util/cli_module.py` + + +### How to make changes to calls + +If you made a change to pdr-backend CLI interface, then barge must call using the updated CLI command. + +How: +- change the relevant compose file's `"command"`. Eg change `barge/compose-files/pdr-publisher.yml`'s "command" value to `publisher ppss.yaml development` +- also, change envvar setup as needed. Eg in compose file, remove `RPC_URL` and `ADDRESS_FILE` entry. +- ultimately, ask: "does Docker have everything it needs to succesfully run the component?" + +### All Barge READMEs + +- [barge.md](barge.md): the main Barge README +- [barge-calls.md](barge-calls.md): order of execution from Barge and pdr-backend code +- [release-process.md](release-process.md): pdr-backend Dockerhub images get published with each push to `main`, and sometimes other branches. In turn these are used by Barge. diff --git a/READMEs/barge.md b/READMEs/barge.md index 3347e8f75..5488b0b5a 100644 --- a/READMEs/barge.md +++ b/READMEs/barge.md @@ -7,6 +7,8 @@ SPDX-License-Identifier: Apache-2.0 Barge is a Docker container to run a local Ganache network having Predictoor contracts and (optionally) local bots. This README describes how to install Barge, and provides reference on running it with various agents. +⚠️ If you're on MacOS or Windows, we recommend using a remotely-run Barge. See [vps flow](vps.md). + ## Contents Main: @@ -39,11 +41,7 @@ docker system prune -a --volumes ``` **Then, get Docker running.** To run barge, you need the Docker engine running. Here's how: -- If you're on Linux: you're good, there's nothing extra to do. -- If you're on MacOS: - - via console: `open -a Docker` - - or, via app: open Finder app, find Docker, click to open app. (You don't need to press "play" or anything else. The app being open is enough.) - - ⚠️ MacOS may give Docker issues. [Here](macos.md) are workarounds. +- If you're on Linux: you're good, there's nothing extra to do Congrats! Barge is installed and ready to be run. @@ -95,8 +93,10 @@ For each other subcomponent of Barge, you need to change its respective repo sim And for Barge core functionality, make changes to the [barge repo](https://github.com/oceanprotocol/barge) itself. -## Other READMEs +More info: [Barge flow of calls](barge-calls.md) + +## All Barge READMEs -- [Parent predictoor README: predictoor.md](./predictoor.md) -- [Parent trader README: trader.md](./trader.md) -- [Root README](../README.md) +- [barge.md](barge.md): the main Barge README +- [barge-calls.md](barge-calls.md): order of execution from Barge and pdr-backend code +- [release-process.md](release-process.md): pdr-backend Dockerhub images get published with each push to `main`, and sometimes other branches. In turn these are used by Barge. diff --git a/READMEs/clean-code.md b/READMEs/clean-code.md new file mode 100644 index 000000000..8b0ab0aef --- /dev/null +++ b/READMEs/clean-code.md @@ -0,0 +1,99 @@ + + +# Clean Code ✨ Guidelines + +Guidelines for core devs to have clean code. + +Main policy on PRs: + +> ✨ **To merge a PR, it must be clean.** ✨ (Rather than: "merge the PR, then clean up") + +Clean code enables us to proceed with maximum velocity. + +## Summary + +Clean code means: +- No DRY violations +- Great labels +- Dynamic type-checking +- Tik-tok, refactor-add +- No "TODOs" +- Passes smell test +- Have tests always. TDD +- Tests are fast +- Tests are clean too + +This ensures minimal tech debt, so we can proceed at maximum velocity. Senior engineers get to spend their time contributing to features, rather than cleaning up. + +Everyone can "up their game" by being diligent about this until it becomes second nature; and by reading books on it. + +The following sections elaborate: +- [What does clean code look like?](#what-does-clean-code-look-like) +- [Benefits of clean code](#benefits-of-clean-code) +- [Reading list](#reading-list) + +## What does clean code look like? + +**No DRY violations.** This alone avoids many complexity issues. + +**Great labels.** This makes a huge difference to complexity and DX too. +- Rule of thumb: be as specific as possible, while keeping character count sane. Eg "ohlcv_data_factory" vs "parquet_data_factory". +- In functions of 2-5 lines you can get away with super-short labels like "c" as long as it's local, and the context makes it obvious. +- Generally: "how easily can a developer understand the code?" Think of yourself that's writing for readers, where the reader is developers (including yourself). + +**Dynamic type-checking**, via @enforce_typing + type hints on variables. +- Given that bugs often show up as type violations, think of dynamic type-checking as a robot that automatically hunts down those bugs on your behalf. Doing dynamic type-checking will save you a ton of time. +- Small exception: in 2% of cases it's overkill, eg if your type is complex or if you're fighting with mock or mypy; then skip it there. + +**Tik-tok, refactor-add.** That is: for many features, it's best to spend 90% of the effort to refactor first (and merge that PR). Then the actual change for the feature itself is near-trivial, eg 10% of the effort. +- It's "tik tok", where "tik" = refactor (unit tests typically don't change), and "tok" = make change. +- Inspiration: this is Intel's approach to CPU development, where "tik" = change manufacturing process, "tok" = change CPU design. + +**No "TODOs".** If you have a "TODO" that makes sense for the PR, put it in the PR. Otherwise, create a separate issue. + +**Does it pass the smell test?** If you feel like your code smells, you have work to do. + +**Have tests always. Use TDD.** +- Coverage should be >90%. +- You should be using test-driven development (TDD), ie write the tests at the same time as the code itself in very rapid cycles of 30 s - 5 min. The outcome: module & test go together like a hand & glove. +- Anti-pattern outcome: tests are very awkward, having to jump through hoops to do tests. +- If you encounter this anti-pattern in your new code or see it in existing code, refactor to get to "hand & glove". + +**Tests run fast.** Know that if a test module publishes a feed on barge, it adds another 40 seconds to overall test runtime. So, mock aggressively. I recently did this for trader/, changing runtime from 4 min to 2 s (!). + +**Tests are clean too.** They're half the code after all! That is, for tests: no DRY violations; great labels; dynamic type-checking; no TODOs; passes "smell test" + +## Benefits of clean code + +_Aka "Why this policy?"_ + +- Helps **ensure minimal technical debt** +- Which in turn means we can **proceed at maximum velocity**. We can make changes, refactor, etc with impunity +- I was going to add the point "we're no longer under time pressure to ship features." However that should never be an excuse, because having tech debt slows us down from adding new features! Compromising quality _hurts_ speed, not helps it. Quality comes for free. +- From past experience, often "clean up later" meant "never" in practical terms. Eg sometimes it's put into a separate github issue, and that issue gets ignored. +- Senior engineers should not find themselves myself on a treadmill cleaning up after merged PRs. This is high opportunity cost of them not spending time on what they're best at (eg ML). +- What senior engineers _should_ do: use PR code reviews to show others how to clean up. And hopefully this is a learning opportunity for everyone over time too:) + +## Reading list + +To "up your game", here are great books on software engineering, in order to read them. + +- Code Complete 2, by Steve McConnell. Classic book on code construction, filled to the brim with practical tips. [Link](https://www.goodreads.com/book/show/4845.Code_Complete) +- Clean Code: A Handbook of Agile Software Craftsmanship, by Robert C. Martin. [Link](https://www.goodreads.com/book/show/3735293-clean-code) +- A Philosophy of Software Design, by John Osterhout. Best book on managing complexity. Empasizes DRY. If you've been in the coding trenches for a while, this feels like a breath of fresh air and helps you to up your game further. [Link](https://www.goodreads.com/book/show/39996759-a-philosophy-of-software-design). +- Refactoring: Improving the Design of Existing Code, by Martiwn Fowler. This book is a big "unlock" on how to apply refactoring everywhere like a ninja. [Link](https://www.goodreads.com/book/show/44936.Refactoring). +- Head First Design Patterns, by Eric Freeman et al. Every good SW engineer should have design patterns in their toolbox. This is a good first book on design patterns. [Link](https://www.goodreads.com/book/show/58128.Head_First_Design_Patterns) +- Design Patterns: Elements of Reusable Object-Oriented Software, by GOF. This is "the bible" on design patterns. It's only so-so on approachability, but nonetheless the content makes it worth it. But start with "Head First Design Patterns". [Link](https://www.goodreads.com/book/show/85009.Design_Patterns) +- The Pragmatic Programmer: From Journeyman to Master, by Andy Hunt and Dave Thomas. Some people really love this, I found it so-so. But definitely worth a read to round out your SW engineering. https://www.goodreads.com/book/show/4099.The_Pragmatic_Programmer + +A final one. In general, _when you're coding, you're writing_. Therefore, books on crisp writing are also books about coding (!). The very top of this list is [Strunk & White Elements of Style](https://www.goodreads.com/book/show/33514.The_Elements_of_Style). It's sharper than a razor blade. + + +## Recap + +Each PR should always be both "make it work" _and_ "make it good (clean)". ✨ + +It will pay off, quickly. diff --git a/READMEs/dev.md b/READMEs/dev.md new file mode 100644 index 000000000..3c0bbaebe --- /dev/null +++ b/READMEs/dev.md @@ -0,0 +1,111 @@ + + +# Usage for Backend Devs + +This is for core devs to improve pdr-backend repo itself. + +## Install pdr-backend + +Follow directions to install pdr-backend in [predictoor.md](predictoor.md) + +## Setup Barge + +**Local barge.** If you're on ubuntu, you can run barge locally. +- First, [install barge](barge.md#install-barge). +- Then, run it. In barge console: `./start_ocean.sh --no-provider --no-dashboard --predictoor --with-thegraph` + +**Or, remote barge.** If you're on MacOS or Windows, run barge on VPS. +- Follow the instructions in [vps.md](vps.md) + +### Setup dev environment + +Open a new "work" console and: +```console +# Setup virtualenv +cd pdr-backend +source venv/bin/activate + +# Set PRIVATE_KEY +export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" + +# Unit tests default to using "development" network -- a locally-run barge. +# If you need another network such as barge on VPS, then override the endpoints for the development network +``` + +All other settings are in [`ppss.yaml`](../ppss.yaml). Some of these are used in unit tests. Whereas most READMEs make a copy `my_ppss.yaml`, for development we typically want to operate directly on `ppss.yaml`. + +### Local Usage: Testing & linting + +In work console, run tests: +```console +# (ensure PRIVATE_KEY set as above) + +# run a single test. The "-s" is for more output. +# note that pytest does dynamic type-checking too:) +pytest pdr_backend/util/test_noganache/test_util_constants.py::test_util_constants -s + +# run all tests in a file +pytest pdr_backend/util/test_noganache/test_util_constants.py -s + +# run a single test that flexes network connection +pytest pdr_backend/util/test_ganache/test_contract.py::test_get_contract_filename -s + +# run all regular tests; see details on pytest markers to select specific suites +pytest +``` + +In work console, run linting checks: +```console +# mypy does static type-checking and more. Configure it via mypy.ini +mypy ./ + +# run linting on code style. Configure it via .pylintrc. +pylint pdr_backend/* + +# auto-fix some pylint complaints like whitespace +black ./ +``` + +Check code coverage: +```console +coverage run --omit="*test*" -m pytest # Run all. For subset, add eg: pdr_backend/lake +coverage report # show results +``` + +### Local Usage: Run a custom agent + +Let's say you want to change the trader agent, and use off-the-shelf agents for everything else. Here's how. + +In barge console: +```console +# (Hit ctrl-c to stop existing barge) + +# Run all agents except trader +./start_ocean.sh --predictoor --with-thegraph --with-pdr-trueval --with-pdr-predictoor --with-pdr-publisher --with-pdr-dfbuyer +``` + +In work console: +```console +#(ensure envvars set as above) + +# run trader agent, approach 1 +pdr trader 1 ppss.yaml development +# or +pdr trader 1 ppss.yaml barge-pytest +``` + +(You can track at finer resolution by writing more logs to the [code](../pdr_backend/predictoor/approach3/predictoor_agent3.py), or [querying Predictoor subgraph](subgraph.md).) + +## Remote Usage + +In the CLI, simply point to a different network: +```console +# run on testnet +pdr trader ppss.yaml sapphire-testnet + +# or, run on mainnet +pdr trader ppss.yaml sapphire-mainnet +``` diff --git a/READMEs/dfbuyer.md b/READMEs/dfbuyer.md index 9ef7ef8cb..fcf8bd4a2 100644 --- a/READMEs/dfbuyer.md +++ b/READMEs/dfbuyer.md @@ -27,27 +27,31 @@ Open a new console and: cd pdr-backend source venv/bin/activate -# Set envvars +# Set envvar export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" -export ADDRESS_FILE="${HOME}/.ocean/ocean-contracts/artifacts/address.json" - -export RPC_URL=http://127.0.0.1:8545 -export SUBGRAPH_URL="http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph" -#OR: export SUBGRAPH_URL="http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph" ``` +Copy [`ppss.yaml`](../ppss.yaml) into your own file `my_ppss.yaml` and change parameters as you see fit. The section "dfbuyer_ss" has parameters for this bot. + Then, run dfbuyer bot. In console: ```console -python pdr_backend/dfbuyer/main.py +pdr dfbuyer my_ppss.yaml development ``` -There are other environment variables that you might want to set, such as the **weekly spending limit**. To get more information about them check out the [environment variables documentation](./envvars.md). - -The bot will consume "WEEKLY_SPENDING_LIMIT" worth of assets each week. This amount is distributed equally among all DF eligible assets. +The bot will consume "weekly_spending_limit" worth of assets each week. This amount is distributed equally among all DF eligible assets. (This parameter is set in the yaml file.) ![flow](https://user-images.githubusercontent.com/25263018/269256707-566b9f5d-7e97-4549-b483-2a6700826769.png) + ## Remote Usage -Combine local setup above with remote setup envvars like in [predictoor.md](predictoor.md). +In the CLI, simply point to a different network: +```console +# run on testnet +pdr dfbuyer my_ppss.yaml sapphire-testnet + +# or, run on mainnet +pdr dfbuyer my_ppss.yaml sapphire-mainnet +``` + diff --git a/READMEs/dynamic-model-codebase.md b/READMEs/dynamic-model-codebase.md deleted file mode 100644 index 44ac24326..000000000 --- a/READMEs/dynamic-model-codebase.md +++ /dev/null @@ -1,93 +0,0 @@ - - -# About Dynamic Model Codebase - -Dynamic modeling is used in two places: - -1. Simulation of modeling & trading -> [`pdr_backend/predictoor/simulation/`](../pdr_backend/predictoor/simulation/). -2. Run predictoor bot - [`pdr_backend/predictoor/approach3/`](../pdr_backend/predictoor/approach3/) - -Contents of this README: -- [Code and Simulation](#code-and-simulation) -- [Code and Predictoor bot](#code-and-predictoor-bot) -- [Description of each file](#description-of-files) -- [HOWTO](#howtos) add new data, change model, etc - -## Code and Simulation - -The simulation flow is used by [predictoor.md](predictoor.md) and [trader.md](trader.md). - -Simulation is invoked by: `python pdr_backend/predictoor/simulation/runtrade.py` - -What `runtrade.py` does: -- Set simulation parameters. -- Grab historical price data from exchanges and stores in `csvs/` dir. It re-uses any previously saved data. -- Run through many 5min epochs. At each epoch: - - Build a model - - Predict up/down - - Trade. - - (It logs this all to screen, and to `out*.txt`.) - - Plot total profit versus time, more - -## Code and predictoor bot - -The predictoor bot flow is used by [predictoor.md](predictoor.md). - -The bot is invoked by: `python pdr_backend/predictoor/main.py 3` - -- It runs [`predictoor_agent3.py::PredictoorAgent3`](../pdr_backend/predictoor/approach3/predictoor_agent3.py) found in `pdr_backend/predictoor/approach3` -- It's configured by envvars and [`predictoor_config3.py::PredictoorConfig3`](../pdr_backend/predictoor/approach3/predictoor_config3.py) -- It predicts according to `PredictoorAgent3:get_prediction()`. - -## Description of files - -**Do simulation, including modeling & trading:** -- [`runtrade.py`](../pdr_backend/simulation/runtrade.py) - top-level file to invoke trade engine -- [`trade_engine.py`](../pdr_backend/simulation/trade_engine.py) - simple, naive trading engine - -**Build & use predictoor bot:** -- [`predictoor_agent3.py`](../pdr_backend/predictoor/approach3/predictoor_agent3.py) - main agent. Builds model -- [`predictoor_config3.py`](../pdr_backend/predictoor/approach3/predictoor_config3.py) - solution strategy parameters for the bot - -**Build & use the model:** (used by simulation and bot) -- [`model_factory.py`](../pdr_backend/model_eng/model_factory.py) - converts X/y data --> AI/ML model -- [`model_ss.py`](../pdr_backend/model_eng/model_ss.py) - solution strategy parameters for model_factory - -**Build & use data:** (used by model) -- [`data_factory.py`](../pdr_backend/data_eng/data_factory.py) - converts historical data -> historical dataframe -> X/y model data -- [`data_ss.py`](../pdr_backend/data_eng/data_ss.py) - solution strategy parameters for data_factory, ie sets what data to use - -## HOWTOs - -**On PP and SS:** -- This is a naming idiom that you'll see in in module names, class names, variable names -- "SS" = controllable by user, if in a real-world setting. "Solution Strategy" -- "PP" = uncontrollable by user "". - -**HOWTO change parameters for each flow:** -- **For running simulation flow:** change lines in [`runtrade.py`](../pdr_backend/simulation/runtrade.py). Almost every line is changeable, to change training data, model, trade parameters, and trade strategy. Details on each below. -- **For running predictoor bot flow:** change [`predictoor_config3.py`](../pdr_backend/predictoor/approach3/predictoor_config3.py) solution strategy parameters for the bot - -**HOWTO set what training data to use:** -- Change args to `data_ss.py:DataSS()` constructor. -- Includes: how far to look back historically for training samples, max # training samples, how far to look back when making a single inference. - -**HOWTO set what model to use:** -- Change args to `model_ss.py:ModelSS()` constructor. -- Includes: the model. "LIN" = linear. - -**HOWTO set trade parameters:** -- Change args to `trade_pp.py:TradePP()` constructor. -- Includes: % trading fee - -**HOWTO set trade strategy:** -- Change args to `trade_ss.py:TradeSS()` constructor. -- Includes: how much $ to trade with at each point - -**HOWTO set simulation strategy:** -- Change args to `sim_ss.py:SimSS()` constructor. -- Includes: where to log, whether to plot - diff --git a/READMEs/envvars.md b/READMEs/envvars.md deleted file mode 100644 index 81a6b7a0c..000000000 --- a/READMEs/envvars.md +++ /dev/null @@ -1,50 +0,0 @@ -# Environment Variables (Envvars) - -This page describes core envvars that are used by all agents, then envvars that are specific to each agent. - -## Core Envvars - -### Network Configuration - -- **RPC_URL**: The RPC URL of the network. - - Check out the [Sapphire Documentation](https://docs.oasis.io/dapp/sapphire/) -- **SUBGRAPH_URL**: The Ocean subgraph URL. - - **TESTNET**: https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph - - **MAINNET**: https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph -- **PRIVATE_KEY**: Private key of the wallet to use. **Must start with `0x`.** - -### Filters - -- **PAIR_FILTER**: Pairs to filter (comma-separated). Fetches all available pairs if empty. Example: `BTC/USDT,ETH/USDT` -- **TIMEFRAME_FILTER**: Timeframes to filter (comma-separated). Fetches all available timeframes if empty. Example: `5m,1h` -- **SOURCE_FILTER**: Price sources to filter (comma-separated). Fetches all available sources if empty. Example: `binance,kraken` -- **OWNER_ADDRS**: Addresses of contract deployers to filter (comma-separated). **Typically set to the address of the OPF deployer wallet.** - - **TESTNET**: `0xe02a421dfc549336d47efee85699bd0a3da7d6ff` - - **MAINNET**: `0x4ac2e51f9b1b0ca9e000dfe6032b24639b172703` - -## Agent-Specific Envvars - -These are envvars that are specific to a given agent. - -### Trueval Agent - -- **SLEEP_TIME**: The pause duration (in seconds) between batch processing. Example: `5` -- **BATCH_SIZE**: Maximum number of truevals to handle in a batch. Example: `3` - -### Trader Agent - -- **TRADER_MIN_BUFFER**: Sets a threshold (in seconds) for trade decisions. Example: if value is `180` and there's 179 seconds left, no trade. If 181 seconds left, then trade. - -### Predictoor Agent - -- **SECONDS_TILL_EPOCH_END**: Determines how soon to start predicting. Example: if value is `60` then it will start submitting predictions 60 seconds before. It will continue to periodically submit predictions until there's no time left. -- **STAKE_AMOUNT**: The amount to stake in units of Eth. - - For approach 1 stake amount is randomly determined this has no effect. - - For approach 2 stake amount is determined by: `STAKE_AMOUNT * confidence` where confidence is between 0 and 1. - - For approach 3 this is the stake amount. - -### DFBuyer Agent - -- **CONSUME_BATCH_SIZE**: Max number of consumes to process in a single transaction. Example: `10` -- **WEEKLY_SPENDING_LIMIT**: The target amount of tokens to be spent on consumes per week. Should be set to amount of Predictoor DF rewards for that week. Denominated in OCEAN. Example: `37000` -- **CONSUME_INTERVAL_SECONDS**: Time interval between each "buy", denominated in seconds. Example: `86400` (1 day) for it to consume daily. Daily is a good frequency, balancing tx cost with liveness. diff --git a/READMEs/filter.md b/READMEs/filter.md deleted file mode 100644 index a2e4f0451..000000000 --- a/READMEs/filter.md +++ /dev/null @@ -1,19 +0,0 @@ - - -# Using Predictoor Subgraph - -### Querying - -You can query [subgraph](http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph/graphql) and see [this populated data PR](https://github.com/oceanprotocol/ocean-subgraph/pull/678) here for entities. - -### Filtering - -Here are additional envvars used to filter: - -- PAIR_FILTER = if we do want to act upon only same pair, like "BTC/USDT,ETH/USDT" -- TIMEFRAME_FILTER = if we do want to act upon only same timeframes, like "5m,15m" -- SOURCE_FILTER = if we do want to act upon only same sources, like "binance,kraken" -- OWNER_ADDRS = if we do want to act upon only same publishers, like "0x123,0x124" diff --git a/READMEs/macos.md b/READMEs/macos.md deleted file mode 100644 index b1c6af816..000000000 --- a/READMEs/macos.md +++ /dev/null @@ -1,46 +0,0 @@ - - -# MacOS Gotchas - -Here are potential issues related to MacOS, and workarounds. - -### Issue: MacOS * Docker - -Summary: -- On MacOS, Docker may freeze -- Fix by reverting to Docker 4.22.1 - -Symptoms of the issue: -- it stops logging; Docker cpu usage is 0%; it hangs when you type `docker ps` in console - -More info: -- Docker 4.24.1 (Sep 28, 2023) freezes, and 4.22.1 (Aug 24, 2023) works. For us, anyway. -- [Docker releases](https://docs.docker.com/desktop/release-notes) - -To fix: detailed instructions: -- In console: `./cleanup.sh; docker system prune -a --volumes` -- Download [Docker 4.22.1](https://docs.docker.com/desktop/release-notes/#4221) -- Open the download, drag "Docker" to "Applications". Choose "Replace existing" when prompted. -- Run Docker Desktop. Confirm the version via "About". -- If you have the wrong version, then [fully uninstall Docker](https://www.makeuseof.com/how-to-uninstall-docker-desktop-mac/) and try again. - -### Issue: MacOS * Subgraph Ports - -Summary: -- The subgraph container reports a subgraph url like: `http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph` -- But agents or the browser can't see that directly, because MacOS doesn't support per-container IP addressing -- Fix for envvars: `export SUBGRAPH_URL=http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph` -- Fix for browser: open [http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph](http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph) - -Symptoms of the issue: -- If running an agent, we'll see output like: `HTTPConnectionPool(host='localhost', port=9000): Max retries exceeded with url: /subgraphs/name/oceanprotocol/ocean-subgraph (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 61] Connection refused'))` -- If loading the subgraph url in the browser, it hangs / doesn't load the page - -Details: -- https://github.com/oceanprotocol/barge/blob/main/compose-files/thegraph.yml#L6 exposes the port 8000 in the container to the host network, thus it's accessible. -- "172.15.0.15" is the internal ip of the container; other containers in the network can access it. The docker network is isolated from host network unless you expose a port, therefore you cannot access it. -- and whereas Linux can make it more convenient via "bridge", MacOS doesn't support that [[Ref]](https://docker-docs.uclv.cu/docker-for-mac/networking/#there-is-no-docker0-bridge-on-macos). -- In other words, per-container IP addressing is not possible on MacOS [[Ref]](https://docker-docs.uclv.cu/docker-for-mac/networking/#per-container-ip-addressing-is-not-possible). diff --git a/READMEs/payout.md b/READMEs/payout.md index e201459ea..fbd002362 100644 --- a/READMEs/payout.md +++ b/READMEs/payout.md @@ -19,8 +19,10 @@ Ensure you pause or stop any ongoing prediction submissions. You can use `Ctrl-C #### 2. Execute Payout -- Running locally: Simply run the python script with the command: `python pdr_backend/predictoor/main.py payout`. -- Using Container Image: Simply execute the command: `predictoor payout`. +From console: +```console +pdr claim_OCEAN ppss.yaml +``` #### 3. Completion @@ -38,8 +40,10 @@ It's good practice to run the payout module again. This ensures any failed block #### 2. Claim ROSE Rewards -- Running locally: Simply run the python script with the command: `python pdr_backend/predictoor/main.py roseclaim`. -- Using Container Image: Simply execute the command: `predictoor roseclaim`. +From console: +```console +pdr claim_ROSE ppss.yaml +``` #### 3. Completion diff --git a/READMEs/predictoor-data.md b/READMEs/predictoor-data.md deleted file mode 100644 index d617ffc01..000000000 --- a/READMEs/predictoor-data.md +++ /dev/null @@ -1,41 +0,0 @@ - - -# Get Predictoor bot performance data - -This README presents how to get some data related to the Predictoor bot performance. - -Great job on becoming a Predictoor. You should be submitting predictions and claiming some rewards by now. - -Next you might want now to see exactly how accurate yout bot is and how much you have earned, let's do it! - -### Steps to Get Predictoor Data - -#### 1. Preparation - -Ensure you have claimed payout at least once before you continue. - -#### 2. How to get predictoor data - -- Make sure you still have the `RPC_URL` env variable set. - -- Run the folowing python script with the command: `python scripts/get_predictoor_info.py WALLET_ADDRESS START_DATE END_DATE NETWORK OUTPUT_DIRECTORY`. - -- Used parameters: - - `WALLET_ADDRESS`: the wallet address used for submitting the predictions. - - `STARTE_DATE`: format yyyy-mm-dd - the date starting from which to query data. - - `END_DATE`: format yyyy-mm-dd - the date to query data until. - - `NETWORK`: mainnet | testnet - the network to get the data from. - - `OUTPUT_DIRECTORY`: where to create the csv files with the grabbed data. - -#### 3. What is the output - -- in console you are going to see the **Accuracy**, **Total Stake**, **Total Payout** and **Number of predictions** for each pair and also the mentioned values over all pairs. - -- in the specified output directory: you are going to find generate CSV files with the following file name format: **'{PAIR}{TIMEFRAME}{EXCHANGE}'**, containing: **Predicted Value, True Value, Timestamp, Stake, Payout** - - - - diff --git a/READMEs/predictoor.md b/READMEs/predictoor.md index 511c92769..d2bf5458c 100644 --- a/READMEs/predictoor.md +++ b/READMEs/predictoor.md @@ -32,6 +32,9 @@ source venv/bin/activate # Install modules in the environment pip install -r requirements.txt + +#add pwd to bash path +export PATH=$PATH:. ``` If you're running MacOS, then in console: @@ -41,29 +44,29 @@ codesign --force --deep --sign - venv/sapphirepy_bin/sapphirewrapper-arm64.dylib ## Simulate Modeling and Trading -Simulation allows us to quickly build intuition, and assess the performance of the data / model / trading strategy (backtest). +Simulation allows us to quickly build intuition, and assess the performance of the data / predicting / trading strategy (backtest). + +Copy [`ppss.yaml`](../ppss.yaml) into your own file `my_ppss.yaml` and change parameters as you see fit. Let's simulate! In console: ```console -python pdr_backend/simulation/runtrade.py +pdr sim my_ppss.yaml ``` -What `runtrade.py` does: +What it does: 1. Set simulation parameters. -1. Grab historical price data from exchanges and stores in `csvs/` dir. It re-uses any previously saved data. +1. Grab historical price data from exchanges and stores in `parquet_data/` dir. It re-uses any previously saved data. 1. Run through many 5min epochs. At each epoch: - Build a model - Predict up/down - Trade. - - Plot total profit versus time. + - Plot total profit versus time, and more. - (It logs this all to screen, and to `out*.txt`.) -The baseline settings use a linear model inputting prices of the previous 10 epochs as inputs, a simulated 0% trading fee, and a trading strategy of "buy if predict up; sell 5min later". You can play with different values in [runtrade.py](../pdr_backend/simulation/runtrade.py). +The baseline settings use a linear model inputting prices of the previous 10 epochs as inputs, a simulated 0% trading fee, and a trading strategy of "buy if predict up; sell 5min later". You can play with different values in [runsim.py](../pdr_backend/sim/sim_engine.py). Profit isn't guaranteed: fees, slippage and more eats into them. Model accuracy makes a huge difference too. -([This README](dynamic-model-codebase.md) has more info about the simulator's code structure.) - ## Run Predictoor Bot on Sapphire Testnet Predictoor contracts run on [Oasis Sapphire](https://docs.oasis.io/dapp/sapphire/) testnet and mainnet. Sapphire is a privacy-preserving EVM-compatible L1 chain. @@ -77,33 +80,20 @@ Then, copy & paste your private key as an envvar. In console: export PRIVATE_KEY= ``` -Now, set other envvars. In console: -```console -#other envvars for testnet and mainnet -export ADDRESS_FILE="${HOME}/.ocean/ocean-contracts/artifacts/address.json" -export PAIR_FILTER=BTC/USDT -export TIMEFRAME_FILTER=5m -export SOURCE_FILTER=binance - -#testnet-specific envvars -export RPC_URL=https://testnet.sapphire.oasis.dev -export SUBGRAPH_URL=https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph -export STAKE_TOKEN=0x973e69303259B0c2543a38665122b773D28405fB # (fake) OCEAN token address -export OWNER_ADDRS=0xe02a421dfc549336d47efee85699bd0a3da7d6ff # OPF deployer address -``` - -([envvars.md](envvars.md) has details.) +Update `my_ppss.yaml` as desired. Then, run a bot with modeling-on-the fly (approach 3). In console: ```console -python pdr_backend/predictoor/main.py 3 +pdr predictoor 3 my_ppss.yaml sapphire-testnet ``` Your bot is running, congrats! Sit back and watch it in action. It will loop continuously. -- It has behavior to maximize accuracy without missing submission deadlines, as follows. 60 seconds before predictions are due, it will build a model then submit a prediction. It will repeat submissions every few seconds until the deadline. -- It does this for every 5-minute epoch. -(You can track at finer resolution by writing more logs to the [code](../pdr_backend/predictoor/approach3/predictoor_agent3.py), or [querying Predictoor subgraph](subgraph.md).) +At every 5m/1h epoch, it builds & submits >1 times, to maximize accuracy without missing submission deadlines. Specifically: 60 s before predictions are due, it builds a model then submits a prediction. It repeats this until the deadline. + +The CLI has a tool to track performance. Type `pdr get_predictoor_info -h` for details. + +You can track behavior at finer resolution by writing more logs to the [code](../pdr_backend/predictoor/approach3/predictoor_agent3.py), or [querying Predictoor subgraph](subgraph.md). ## Run Predictoor Bot on Sapphire Mainnet @@ -117,21 +107,11 @@ Then, copy & paste your private key as an envvar. (You can skip this if it's sam export PRIVATE_KEY= ``` -Now, set other envvars. In console: -```console -#envvars for testnet and mainnet -#(can skip this, since same as testnet) - -#mainnet-specific envvars -export RPC_URL=https://sapphire.oasis.io -export SUBGRAPH_URL=https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph -export STAKE_TOKEN=0x39d22B78A7651A76Ffbde2aaAB5FD92666Aca520 # OCEAN token address -export OWNER_ADDRS=0x4ac2e51f9b1b0ca9e000dfe6032b24639b172703 # OPF deployer address -``` +Update `my_ppss.yaml` as desired. Then, run the bot. In console: ```console -python pdr_backend/predictoor/main.py 3 +pdr predictoor 3 my_ppss.yaml sapphire-mainnet ``` This is where there's real $ at stake. Good luck! @@ -144,11 +124,6 @@ When running predictoors on mainnet, you have the potential to earn $. **[Here](payout.md)** are instructions to claim your earnings. -## Check Performance - -After you run your bot and claimed the payout, you might want to check how your bot is perfoming before you start improving it. - -Follow **[this instructions](predictoor-data.md)** to get an overview of your performance so far. # Go Beyond @@ -164,9 +139,6 @@ Once you're familiar with the above, you can make your own model and optimize it 1. Bring your model as a Predictoor bot to testnet then mainnet. -([This README](dynamic-model-codebase.md) has more info about the simulator's code structure.) - - ## Run Many Bots at Once [PM2](https://pm2.keymetrics.io/docs/usage/quick-start/) is "a daemon process manager that will help you manage and keep your application online." diff --git a/READMEs/publisher.md b/READMEs/publisher.md index 3cfa58489..c100ef99f 100644 --- a/READMEs/publisher.md +++ b/READMEs/publisher.md @@ -21,17 +21,13 @@ Then, run barge. In barge console: ./start_ocean.sh --no-provider --no-dashboard --predictoor --with-thegraph ``` -Open a new "work" console and: +Open a new console and: ```console # Setup virtualenv cd pdr-backend source venv/bin/activate -# Set envvars - note that publisher needs WAY more private keys than others -export ADDRESS_FILE="${HOME}/.ocean/ocean-contracts/artifacts/address.json" -export RPC_URL=http://127.0.0.1:8545 -export SUBGRAPH_URL="http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph" -#OR: export SUBGRAPH_URL="http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph" +# Set envvars - note that publisher needs >>1 private keys export PREDICTOOR_PRIVATE_KEY= export PREDICTOOR2_PRIVATE_KEY= export PREDICTOOR3_PRIVATE_KEY= @@ -41,10 +37,21 @@ export PDR_WEBSOCKET_KEY= export PDR_MM_USER= """ -# publish! main.py & publish.py do all the work -python pdr_backend/publisher/main.py +Copy [`ppss.yaml`](../ppss.yaml) into your own file `my_ppss.yaml` and change parameters as you see fit. The section "publisher_ss" has parameters for this bot. + +Then, run publisher bot. In console: +```console +pdr publisher my_ppss.yaml development ``` + ## Remote Usage -Combine local setup above with remote setup envvars like in [predictoor.md](predictoor.md). \ No newline at end of file +In the CLI, simply point to a different network: +```console +# run on testnet +pdr publisher my_ppss.yaml sapphire-testnet + +# or, run on mainnet +pdr publisher my_ppss.yaml sapphire-mainnet +``` diff --git a/READMEs/release-process.md b/READMEs/release-process.md index 08809a426..830f3a9a7 100644 --- a/READMEs/release-process.md +++ b/READMEs/release-process.md @@ -51,7 +51,7 @@ To elaborate: we have an automated docker build for pdr-backend `main` branch an If you want to add Docker branches, go to https://hub.docker.com/repository/docker/oceanprotocol/pdr-backend/builds/edit -Then: on "Build rules", add your branch. Below is an example, where Alex is buildint "pdr-backend: alex" from branch "feature/alex". +Then: on "Build rules", add your branch. Below is an example, where Alex is building "pdr-backend: alex" from branch "feature/alex". ![](./images/dockerbranch.png) @@ -65,7 +65,7 @@ First, build your image locally with a custom label, eg `yaml-cli2`. ```console cd ~/code/pdr-backend -docker build -t 'oceanprotocol/pdr-backend:yaml-cli2' . +docker build . -t 'oceanprotocol/pdr-backend:yaml-cli2' . ``` Then, start barge, using the custom label: @@ -83,3 +83,10 @@ Pros of local testing: - no need of cleanups. If a PR is merged, we don't need to delete that branch from dockerhub autobuild. - no need of access to dockerhub - dockerhub should be used for production ready images only + +### All Barge READMEs + +- [barge.md](barge.md): the main Barge README +- [barge-calls.md](barge-calls.md): order of execution from Barge and pdr-backend code +- [release-process.md](release-process.md): pdr-backend Dockerhub images get published with each push to `main`, and sometimes other branches. In turn these are used by Barge. + diff --git a/READMEs/setup-remote.md b/READMEs/setup-remote.md deleted file mode 100644 index 276ede6b8..000000000 --- a/READMEs/setup-remote.md +++ /dev/null @@ -1,89 +0,0 @@ - - -# Remote Setup - -**NOTE: THIS IS NOT COMPLETE! It will need heavy revisions to work on Oasis, and proper testing.** - -Here, we do setup for Oasis Sapphire testnet (Sapptest). It's similar for Oasis Sapphire mainnet (Sappmain). - -We assume you've already [installed pdr-backend](install.md). - -For brevity, we refer to - -Here, we will: -1. Create two accounts - `REMOTE_TEST_PRIVATE_KEY1` and `2` -2. Get fake ROSE on Sapptest -3. Get fake OCEAN "" -4. Set envvars -5. Set up Alice and Bob wallets in Python - -Let's go! - -## 1. Create EVM Accounts (One-Time) - -An EVM account is singularly defined by its private key. Its address is a function of that key. Let's generate two accounts! - -In a new or existing console, run Python. -```console -python -``` - -In the Python console: - -```python -from eth_account.account import Account -account1 = Account.create() -account2 = Account.create() - -print(f""" -REMOTE_TEST_PRIVATE_KEY1={account1.key.hex()}, ADDRESS1={account1.address} -REMOTE_TEST_PRIVATE_KEY2={account2.key.hex()}, ADDRESS2={account2.address} -""") -``` - -Then, hit Ctrl-C to exit the Python console. - -Now, you have two EVM accounts (address & private key). Save them somewhere safe, like a local file or a password manager. - -These accounts will work on any EVM-based chain: production chains like Eth mainnet and Polygon, and testnets like Goerli and Sapptest. Here, we'll use them for Sapptest. - -## 2. Get (fake) ROSE on Sapptest - -We need the network's native token to pay for transactions on the network. ETH is the native token for Ethereum mainnet, ROSE is the native token for Polygon, and (fake) ROSE is the native token for Sapptest. - -To get free (fake) ROSE on Sapptest: -1. Go to the faucet (FIXME_URL) Ensure you've selected "Sapptest" network and "ROSE" token. -2. Request funds for ADDRESS1 -3. Request funds for ADDRESS2 - -You can confirm receiving funds by going to the following url, and seeing your reported ROSE balance: `FIXME_URL/` - -## 3. Get (fake) OCEAN on Sapptest - -In Predictoor, OCEAN is used as follows: -- by traders, to purchase data feeds -- by predictoors, for staking on predictions, and for earnings from predictions - -- OCEAN is an ERC20 token with a finite supply, rooted in Ethereum mainnet at address [`0x967da4048cD07aB37855c090aAF366e4ce1b9F48`](https://etherscan.io/token/0x967da4048cD07aB37855c090aAF366e4ce1b9F48). -- OCEAN on other production chains derives from the Ethereum mainnet OCEAN. OCEAN on Sappmain [`FIXME_token_address`](FIXME_URL). -- (Fake) OCEAN is on each testnet. Fake OCEAN on Sapptest is at [`FIXME_token_address`](FIXME_URL). - -To get free (fake) OCEAN on Sapptest: -1. Go to the faucet FIXME_URL -2. Request funds for ADDRESS1 -3. Request funds for ADDRESS2 - -You can confirm receiving funds by going to the following url, and seeing your reported OCEAN balance: `FIXME_URL?a=` - -## 4. Set envvars - -In your working console: -```console -export REMOTE_TEST_PRIVATE_KEY1= -export REMOTE_TEST_PRIVATE_KEY2= -``` - -Check out the [environment variables documentation](./envvars.md) to learn more about the environment variables that could be set. diff --git a/READMEs/static-model.md b/READMEs/static-model.md deleted file mode 100644 index 751188690..000000000 --- a/READMEs/static-model.md +++ /dev/null @@ -1,61 +0,0 @@ - - -# Run Static Model Predictoor - -The default flow for predictoors is approach3: dynamically building models. - -_This_ README is for approach2, which uses static models. Static models are developed and saved in a different repo. - -NOTE: this approach may be deprecated in the future. - -There are two macro steps: -1. [Develop & backtest static models](#develop-and-backtest-models) -1. [Use model in Predictoor bot](#use-model-in-predictoor-bot) - -The first step is done in a _separate_ repo. - -Let's go through each step in turn. - -## Develop and backtest models - -Normally you'd have to develop your own model. - -However to get you going, we've developed a simple model, at [`pdr-model-simple`](https://github.com/oceanprotocol/pdr-model-simple) repo. - -The second step will show how to connect this model to the bot. - -## Use model in Predictoor bot - -The bot itself will run from [`predictoor/approach2/main.py`](../pdr_backend/predictoor/approach2/main.py), using `predict.py` in the same dir. That bot needs to sees the model developed elsewhere. - -Here's how to get everything going. - -In work console: -```console -# go to a directory where you'll want to clone to. Here's one example. -cd ~/code/ - -#clone model repo -git clone https://github.com/oceanprotocol/pdr-model-simple - -#the script below needs this envvar, to know where to import model.py from -export MODELDIR=$(pwd)/pdr-model-simple/ - -#pip install anything that pdr-model-simple/model.py needs -pip install scikit-learn ta - -#run static model predictoor bot -python pdr_backend/predictoor/main.py 2 -``` - -## Your own static model - -Once you're familiar with the above, you can make your own model: fork `pdr-model-simple` and change it as you wish. Finally, link your predictoor bot to your new model repo, like shown above. - -## Other READMEs - -- [Parent predictoor README: predictoor.md](./predictoor.md) -- [Root README](../README.md) diff --git a/READMEs/subgraph.md b/READMEs/subgraph.md index a388deb04..a833b732b 100644 --- a/READMEs/subgraph.md +++ b/READMEs/subgraph.md @@ -9,13 +9,16 @@ SPDX-License-Identifier: Apache-2.0 You can query an Ocean subgraph at one of the following: -Local (= `$SUBGRAPH_URL`): +The subgraph url for each network is in the ppss yaml under "subgraph url". + +Typically, these are something like: +- Local (barge) - http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph - OR http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph - -Remote: -- Sapphire testnet, at https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph -- Sapphire mainnet, at https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph +- Sapphire testnet + - https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph +- Sapphire mainnet + - https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph ### Typical query @@ -64,9 +67,6 @@ Agents like predictoor and trader do queries via [pdr_backend/util/subgraph.py]( - They call to a subgraph, at a given url, with a particular query - and they may filter further -**Filters**, You can set some envvars to filter query results for agents. Check out the [envvar documentation](./envvars.md#filters) to learn more about these filters and how to set them. - ### Appendix - [ocean-subgraph repo](https://github.com/oceanprotocol/ocean-subgraph) -- [ocean-subgraph PR#678](https://github.com/oceanprotocol/ocean-subgraph/pull/678) lists full entities. (Things may have changed a bit since then) diff --git a/READMEs/trader.md b/READMEs/trader.md index ffccda9c0..b90fdb0f6 100644 --- a/READMEs/trader.md +++ b/READMEs/trader.md @@ -29,6 +29,9 @@ source venv/bin/activate # Install modules in the environment pip install -r requirements.txt + +#add pwd to bash path +export PATH=$PATH:. ``` If you're running MacOS, then in console: @@ -38,29 +41,29 @@ codesign --force --deep --sign - venv/sapphirepy_bin/sapphirewrapper-arm64.dylib ## Simulate Modeling and Trading -Simulation allows us to quickly build intuition, and assess the performance of the data / model / trading strategy (backtest). +Simulation allows us to quickly build intuition, and assess the performance of the data / predicting / trading strategy (backtest). + +Copy [`ppss.yaml`](../ppss.yaml) into your own file `my_ppss.yaml` and change parameters as you see fit. Let's simulate! In console: ```console -python pdr_backend/simulation/runtrade.py +pdr sim my_ppss.yaml ``` -What `runtrade.py` does: +What it does: 1. Set simulation parameters. -1. Grab historical price data from exchanges and stores in `csvs/` dir. It re-uses any previously saved data. +1. Grab historical price data from exchanges and stores in `parquet_data/` dir. It re-uses any previously saved data. 1. Run through many 5min epochs. At each epoch: - Build a model - Predict up/down - Trade. + - Plot total profit versus time, and more. - (It logs this all to screen, and to `out*.txt`.) - - Plot total profit versus time. -The baseline settings use a linear model inputting prices of the previous 10 epochs as inputs, a simulated 0% trading fee, and a trading strategy of "buy if predict up; sell 5min later". You can play with different values in [runtrade.py](../pdr_backend/simulation/runtrade.py). +The baseline settings use a linear model inputting prices of the previous 10 epochs as inputs, a simulated 0% trading fee, and a trading strategy of "buy if predict up; sell 5min later". You can play with different values in [runsim.py](../pdr_backend/sim/runsim.py). Profit isn't guaranteed: fees, slippage and more eats into them. Model accuracy makes a huge difference too. -([This README](dynamic-model-codebase.md) has more info about the simulator's code structure.) - ## Run Trader Bot on Sapphire Testnet Predictoor contracts run on [Oasis Sapphire](https://docs.oasis.io/dapp/sapphire/) testnet and mainnet. Sapphire is a privacy-preserving EVM-compatible L1 chain. @@ -74,31 +77,16 @@ Then, copy & paste your private key as an envvar. In console: export PRIVATE_KEY= ``` -Now, set other envvars. In console: -```console -#other envvars for testnet and mainnet -export ADDRESS_FILE="${HOME}/.ocean/ocean-contracts/artifacts/address.json" -export PAIR_FILTER=BTC/USDT -export TIMEFRAME_FILTER=5m -export SOURCE_FILTER=binance - -#testnet-specific envvars -export RPC_URL=https://testnet.sapphire.oasis.dev -export SUBGRAPH_URL=https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph -export STAKE_TOKEN=0x973e69303259B0c2543a38665122b773D28405fB # (fake) OCEAN token address -export OWNER_ADDRS=0xe02a421dfc549336d47efee85699bd0a3da7d6ff # OPF deployer address -``` - -([envvars.md](envvars.md) has details.) +Update `my_ppss.yaml` as desired. Then, run a simple trading bot. In console: ```console -python pdr_backend/trader/main.py +pdr trader 2 my_ppss.yaml sapphire-testnet ``` Your bot is running, congrats! Sit back and watch it in action. -(You can track at finer resolution by writing more logs to the [code](../pdr_backend/trader/trader_agent.py), or [querying Predictoor subgraph](subgraph.md).) +You can track behavior at finer resolution by writing more logs to the [code](../pdr_backend/trader/trader_agent.py), or [querying Predictoor subgraph](subgraph.md). ## Run Trader Bot on Sapphire Mainnet @@ -111,21 +99,11 @@ Then, copy & paste your private key as an envvar. (You can skip this if it's sam export PRIVATE_KEY= ``` -Now, set other envvars. In console: -```console -#envvars for testnet and mainnet -#(can skip this, since same as testnet) - -#mainnet-specific envvars -export RPC_URL=https://sapphire.oasis.io -export SUBGRAPH_URL=https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph -export STAKE_TOKEN=0x39d22B78A7651A76Ffbde2aaAB5FD92666Aca520 # OCEAN token address -export OWNER_ADDRS=0x4ac2e51f9b1b0ca9e000dfe6032b24639b172703 # OPF deployer address -``` +Update `my_ppss.yaml` as desired. Then, run the bot. In console: ```console -python pdr_backend/trader/main.py +pdr trader 2 my_ppss.yaml sapphire-mainnet ``` This is where there's real $ at stake. Good luck! @@ -145,10 +123,6 @@ Once you're familiar with the above, you can set your own trading strategy and o 1. Change trader bot code as you wish, while iterating with simulation. 1. Bring your trader bot to testnet then mainnet. -To help, here's the code structure of the bot: -- It runs [`trader_agent.py::TraderAgent`](../pdr_backend/trader/trader_agent.py) found in `pdr_backend/trader/` -- It's configured by envvars and [`trader_config.py::TraderConfig`](../pdr_backend/trader/trader_config.py) - ## Run Bots Remotely To scale up compute or run without tying up your local machine, you can run bots remotely. Get started [here](remotebot.md). diff --git a/READMEs/trueval.md b/READMEs/trueval.md new file mode 100644 index 000000000..a52b863eb --- /dev/null +++ b/READMEs/trueval.md @@ -0,0 +1,53 @@ + + +# Run a Trueval Bot + +This README describes how to run a trueval bot. + +## Install pdr-backend + +Follow directions in [predictoor.md](predictoor.md) + +## Local Network + +First, [install barge](barge.md#install-barge). + +Then, run barge. In barge console: +```console +#run barge with all bots (agents) except trueval +./start_ocean.sh --no-provider --no-dashboard --predictoor --with-thegraph --with-pdr-dfbuyer --with-pdr-predictoor --with-pdr-publisher --with-pdr-trader +``` + +Open a new console and: +``` +# Setup virtualenv +cd pdr-backend +source venv/bin/activate + +# Set envvar +export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" +``` + +Copy [`ppss.yaml`](../ppss.yaml) into your own file `my_ppss.yaml` and change parameters as you see fit. The "trueval_ss" has parameters for this bot. + +Then, run trueval bot. In console: +```console +pdr trueval my_ppss.yaml development +``` + + +## Remote Usage + +In the CLI, simply point to a different network: +```console +# run on testnet +pdr trueval my_ppss.yaml sapphire-testnet + +# or, run on mainnet +pdr trueval my_ppss.yaml sapphire-mainnet +``` + + diff --git a/READMEs/vps.md b/READMEs/vps.md index ceb048855..461559915 100644 --- a/READMEs/vps.md +++ b/READMEs/vps.md @@ -3,9 +3,12 @@ Copyright 2023 Ocean Protocol Foundation SPDX-License-Identifier: Apache-2.0 --> -# Run Barge Remotely on VPS +# VPS Backend Dev -This README shows how to run Barge on an Azure Ubuntu VPS (Virtual Private Server). This is for use in backend dev, running predictoor bot, running trader bot. +This README shows how to +- Set up an Azure Ubuntu VPS (Virtual Private Server) +- Run Barge on the VPS +- Run sim/bots or pytest using the VPS. (In fact, one VPS per flow) ## 1. Locally, install pdr-backend @@ -20,6 +23,7 @@ cd pdr-backend # Create & activate virtualenv python -m venv venv source venv/bin/activate +export PATH=$PATH:. # Install modules in the environment pip install -r requirements.txt @@ -48,8 +52,13 @@ ssh -i ~/Desktop/myKey.pem azureuser@74.234.16.165 ### In Azure Portal, Open ports of VPS Running Barge, the VPS exposes these urls: -- RPC is at http://4.245.224.119:8545 or http://74.234.16.165:8545 -- Subgraph is at http://4.245.224.119:9000/subgraphs/name/oceanprotocol/ocean-subgraph or http://74.234.16.165:9000/subgraphs/name/oceanprotocol/ocean-subgraph +- RPC is at: + - http://4.245.224.119:8545 + - or http://74.234.16.165:8545 +- Subgraph is at: + - http://4.245.224.119:9000/subgraphs/name/oceanprotocol/ocean-subgraph + - or http://74.234.16.165:9000/subgraphs/name/oceanprotocol/ocean-subgraph + - Go there, then copy & paste in the query from [subgraph.md](subgraph.md) BUT you will not be able to see these yet, because the VPS' ports are not yet open enough. Here's how: - Go to Azure Portal for your group @@ -128,10 +137,11 @@ ssh -i ~/Desktop/myKey.pem azureuser@74.234.16.165 In VPS console: ```console # cleanup past barge -rm -rf ~/.ocean cd ~/code/barge +docker stop $(docker ps -a -q) ./cleanup.sh -docker system prune -a --volumes +rm -rf ~/.ocean +docker system prune -a -f --volumes # run barge... # set ganache block time to 5 seconds, try increasing this value if barge is lagging @@ -146,26 +156,30 @@ export GANACHE_BLOCKTIME=5 ./start_ocean.sh --no-provider --no-dashboard --predictoor --with-thegraph ``` -Wait. +Track progress until the addresses are published: +- open a new console +- ssh into the VPS +- then: `docker logs -f ocean_ocean-contracts_1`. Monitor until it says it's published. +- then: `Ctrl-C`, and confirm via: `cat .ocean/ocean-contracts/artifacts/address.json |grep dev`. It should give one line. Then, copy VPS' `address.json` file to local. In local console: ```console cd # OPTION 1: for predictoor bot -scp -i ~/Desktop/myKey.pem azureuser@4.245.224.119:.ocean/ocean-contracts/artifacts/address.json MyVmBargePredictoor.address.json +scp -i ~/Desktop/myKey.pem azureuser@4.245.224.119:.ocean/ocean-contracts/artifacts/address.json barge-predictoor-bot.address.json # OR, OPTION 2: for unit testing -scp -i ~/Desktop/myKey.pem azureuser@74.234.16.165:.ocean/ocean-contracts/artifacts/address.json MyVmBargeUnitTest.address.json +scp -i ~/Desktop/myKey.pem azureuser@74.234.16.165:.ocean/ocean-contracts/artifacts/address.json barge-pytest.address.json ``` -Note how we give it a unique name, vs just "address.json". This keeps it distinct from the address file for _second_ Barge VM we run for pytesting (details below) +We give the address file a unique name, vs just "address.json". This keeps it distinct from the address file for _second_ Barge VM we run for pytest (details below). -Confirm that `MyVmBargePredictoor.address.json` has a "development" entry. In local console: +Confirm that `barge-predictoor-bot.address.json` has a "development" entry. In local console: ```console -grep development ~/MyVmBargePredictoor.address.json +grep development ~/barge-predictoor-bot.address.json # or -grep development ~/MyVmBargeUnitTest.address.json +grep development ~/barge-pytest.address.json ``` It should return: @@ -175,46 +189,44 @@ It should return: If it returns nothing, then contracts have not yet been deployed to ganache. It's either (i) you need to wait longer (ii) Barge had an issue and you need to restart it or debug. +Further debugging: +- List live docker processes: `docker ps` +- List all docker processes: `docker ps -a` +- List names of processes: `docker ps -a | cut -c 347-`. It lists `ocean_pdr-publisher_1`, `ocean_ocean-contracts_1`, ` ocean_pdr-publisher_1`, .. +- See log for publishing Ocean contracts: `docker logs ocean_ocean-contracts_1`. With realtime update: `docker logs -f ocean_ocean-contracts_1` +- See log for publishing prediction feeds: `docker logs ocean_pdr-publisher_1`. With realtime update: `docker logs -f ocean_pdr-publisher_1` +- Get detailed info about pdr-publisher image: `docker inspect ocean_pdr-publisher_1` + ## 4. Locally, Run Predictoor Bot (OPTION 1) +### Set envvars + In local console: ```console # set up virtualenv (if needed) cd ~/code/pdr-backend source venv/bin/activate +export PATH=$PATH:. -# set envvars export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" # addr for key=0xc594.. is 0xe2DD09d719Da89e5a3D0F2549c7E24566e947260 -export ADDRESS_FILE="${HOME}/MyVmBargePredictoor.address.json" # from scp to local - -export RPC_URL=http://4.245.224.119:8545 # from VPS -export SUBGRAPH_URL=http://4.245.224.119:9000/subgraphs/name/oceanprotocol/ocean-subgraph # from VPS - -# for predictoor bot. Setting to empty means no filters. -export PAIR_FILTER= -export TIMEFRAME_FILTER= -export SOURCE_FILTER= - -export OWNER_ADDRS=0xe2DD09d719Da89e5a3D0F2549c7E24566e947260 # OPF deployer address. Taken from ocean.py setup-local.md FACTORY_DEPLOYER_PRIVATE_KEY ``` -([envvars.md](envvars.md) has details.) +### Set PPSS -You also need to set the `STAKE_TOKEN` envvar to the OCEAN address in barge. In local console: +Let's configure the yaml file. In console: ```console -grep --after-context=10 development ~/address.json|grep Ocean|sed -e 's/.*0x/export STAKE_TOKEN=0x/'| sed -e 's/",//' +cp ppss.yaml my_ppss.yaml ``` -It should return something like the following. Copy that into the prompt and hit enter: -```console -export STAKE_TOKEN=0x282d8efCe846A88B159800bd4130ad77443Fa1A1 -``` +In `my_ppss.yaml` file, in `web3_pp` -> `development` section: +- change the urls and addresses as needed to reflect your VPS +- including: set the `stake_token` value to the output of the following: `grep --after-context=10 development ~/barge-predictoor-bot.address.json|grep Ocean|sed -e 's/.*0x/export STAKE_TOKEN=0x/'| sed -e 's/",//'`. (Or get the value from `~/barge-predictoor-bot.address.json`, in `"development"` -> `"Ocean"` entry.) -(Alternatively: open `~/address.json` file, find the "development" : "Ocean" entry, and paste it into prompt with `export STAKE_TOKEN=`) +### Run pdr bot Then, run a bot with modeling-on-the fly (approach 3). In console: ```console -python pdr_backend/predictoor/main.py 3 +pdr predictoor 3 my_ppss.yaml development ``` Your bot is running, congrats! Sit back and watch it in action. It will loop continuously. @@ -228,28 +240,44 @@ Your bot is running, congrats! Sit back and watch it in action. It will loop con ### Set up a second VPS / Barge -In steps 2 & 3 above, we set up a _first_ VPS & Barge, for predictoor bot. +In steps 2 & 3 above, we had set up a _first_ VPS & Barge, for predictoor bot. - Assume its IP address is 4.245.224.119 Now, repeat 2 & 3 above, to up a _second_ VPS & Barge, for local testing. - Give it the same key as the first barge. - Assume its IP address is 74.234.16.165 -- The "OR" options in sections 2 above use this second IP address. Therefore you can go through the flow easily. +- The "OR" options in sections 2 above use this second IP address. Therefore you can go through the flow with simple copy-and-pastes. -### Set Local Envars +### Set envvars -To envvars that use the second Barge. In local console: +In local console: ```console -export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" -export ADDRESS_FILE="${HOME}/MyVmBargeUnitTest.address.json" # from scp to local +# set up virtualenv (if needed) +cd ~/code/pdr-backend +source venv/bin/activate +export PATH=$PATH:. -export RPC_URL=http://74.234.16.165:8545 # from VPS -export SUBGRAPH_URL=http://74.234.16.165:9000/subgraphs/name/oceanprotocol/ocean-subgraph # from VPS +# same private key as 'run predictoor bot' +export PRIVATE_KEY="0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" # addr for key=0xc594.. is 0xe2DD09d719Da89e5a3D0F2549c7E24566e947260 ``` +### Set PPSS + +Whereas most READMEs copy `ppss.yaml` to `my_ppss.yaml`, for development we typically want to operate directly on the original one. + +In `ppss.yaml` file, in `web3_pp` -> `barge-pytest` section: (note the different barge section) +- change the urls and addresses as needed to reflect your VPS +- including: set the `stake_token` value to the output of the following: `grep --after-context=10 development ~/barge-pytest.address.json|grep Ocean|sed -e 's/.*0x/stake_token: \"0x/'| sed -e 's/",//'`. (Or get the value from `~/barge-pytest.address.json`, in `"development"` -> `"Ocean"` entry.) + + +### Run tests + In work console, run tests: ```console +# (ensure PRIVATE_KEY set as above) + # run a single test. The "-s" is for more output. +# note that pytest does dynamic type-checking too:) pytest pdr_backend/util/test_noganache/test_util_constants.py::test_util_constants -s # run all tests in a file @@ -264,12 +292,18 @@ pytest In work console, run linting checks: ```console -# run static type-checking. By default, uses config mypy.ini. Note: pytest does dynamic type-checking. +# mypy does static type-checking and more. Configure it via mypy.ini mypy ./ -# run linting on code style +# run linting on code style. Configure it via .pylintrc. pylint pdr_backend/* -# auto-fix some pylint complaints +# auto-fix some pylint complaints like whitespace black ./ ``` + +Check code coverage: +```console +coverage run --omit="*test*" -m pytest # Run all. For subset, add eg: pdr_backend/lake +coverage report # show results +``` diff --git a/entrypoint.sh b/entrypoint.sh index 049751f97..168aa1544 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,20 +1,5 @@ #!/bin/bash -MODULE_NAME=$1 -COMMAND=$2 - -if [ -z "$MODULE_NAME" ] -then - echo "No module specified. Please provide a module name as an argument." - exit 1 -fi - -if [ ! -d "/app/pdr_backend/$MODULE_NAME" ] -then - echo "Module $MODULE_NAME does not exist." - exit 1 -fi - if [ "${WAIT_FOR_CONTRACTS}" = "true" ] # Development only then @@ -40,4 +25,4 @@ echo "Delaying startup for ${DELAY} seconds.." sleep $DELAY echo "Running $MODULE_NAME..." -python /app/pdr_backend/$MODULE_NAME/main.py $COMMAND +python /app/pdr $@ diff --git a/mypy.ini b/mypy.ini index 08449830c..0bf1884c8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -44,6 +44,9 @@ ignore_missing_imports = True [mypy-pdr_backend.predictoor.examples.*] ignore_missing_imports = True +[mypy-polars.*] +ignore_missing_imports = True + [mypy-pylab.*] ignore_missing_imports = True diff --git a/pdr b/pdr new file mode 100755 index 000000000..af274c96f --- /dev/null +++ b/pdr @@ -0,0 +1,6 @@ +#!/usr/bin/env python + +from pdr_backend.cli import cli_module + +if __name__ == "__main__": + cli_module._do_main() diff --git a/pdr_backend/accuracy/app.py b/pdr_backend/accuracy/app.py index 68c5fabce..9d01e1469 100644 --- a/pdr_backend/accuracy/app.py +++ b/pdr_backend/accuracy/app.py @@ -1,16 +1,210 @@ -import threading import json +import threading from datetime import datetime, timedelta -from typing import Tuple +from typing import Any, Dict, List, Optional, Tuple + from enforce_typing import enforce_types from flask import Flask, jsonify -from pdr_backend.util.subgraph_predictions import get_all_contract_ids_by_owner -from pdr_backend.util.subgraph_slot import calculate_statistics_for_all_assets -from pdr_backend.util.subgraph_predictions import fetch_contract_id_and_spe +from pdr_backend.subgraph.subgraph_predictions import ( + fetch_contract_id_and_spe, + get_all_contract_ids_by_owner, + ContractIdAndSPE, +) +from pdr_backend.subgraph.subgraph_slot import fetch_slots_for_all_assets, PredictSlot app = Flask(__name__) JSON_FILE_PATH = "pdr_backend/accuracy/output/accuracy_data.json" +SECONDS_IN_A_DAY = 86400 + + +@enforce_types +def calculate_prediction_result( + round_sum_stakes_up: float, round_sum_stakes: float +) -> Optional[bool]: + """ + Calculates the prediction result based on the sum of stakes. + + Args: + round_sum_stakes_up: The summed stakes for the 'up' prediction. + round_sum_stakes: The summed stakes for all prediction. + + Returns: + A boolean indicating the predicted direction. + """ + + # checks for to be sure that the division is not by zero + round_sum_stakes_up_float = float(round_sum_stakes_up) + round_sum_stakes_float = float(round_sum_stakes) + + if round_sum_stakes_float == 0.0: + return None + + if round_sum_stakes_up_float == 0.0: + return False + + return (round_sum_stakes_up_float / round_sum_stakes_float) > 0.5 + + +@enforce_types +def process_single_slot( + slot: PredictSlot, end_of_previous_day_timestamp: int +) -> Optional[Tuple[float, float, int, int]]: + """ + Processes a single slot and calculates the staked amounts for yesterday and today, + as well as the count of correct predictions. + + Args: + slot: A PredictSlot TypedDict containing information about a single prediction slot. + end_of_previous_day_timestamp: The Unix timestamp marking the end of the previous day. + + Returns: + A tuple containing staked amounts for yesterday, today, and the counts of correct + predictions and slots evaluated, or None if no stakes were made today. + """ + + staked_yesterday = staked_today = 0.0 + correct_predictions_count = slots_evaluated = 0 + + if float(slot.roundSumStakes) == 0.0: + return None + + # split the id to get the slot timestamp + timestamp = int(slot.ID.split("-")[1]) # Using dot notation for attribute access + + if ( + end_of_previous_day_timestamp - SECONDS_IN_A_DAY + < timestamp + < end_of_previous_day_timestamp + ): + staked_yesterday += float(slot.roundSumStakes) + elif timestamp > end_of_previous_day_timestamp: + staked_today += float(slot.roundSumStakes) + + prediction_result = calculate_prediction_result( + slot.roundSumStakesUp, slot.roundSumStakes + ) + + if prediction_result is None: + return ( + staked_yesterday, + staked_today, + correct_predictions_count, + slots_evaluated, + ) + + true_values: List[Dict[str, Any]] = slot.trueValues or [] + true_value: Optional[bool] = true_values[0]["trueValue"] if true_values else None + + if len(true_values) > 0 and prediction_result == true_value: + correct_predictions_count += 1 + + if len(true_values) > 0 and true_value is not None: + slots_evaluated += 1 + + return staked_yesterday, staked_today, correct_predictions_count, slots_evaluated + + +@enforce_types +def aggregate_statistics( + slots: List[PredictSlot], end_of_previous_day_timestamp: int +) -> Tuple[float, float, int, int]: + """ + Aggregates statistics across all provided slots for an asset. + + Args: + slots: A list of PredictSlot TypedDicts containing information + about multiple prediction slots. + end_of_previous_day_timestamp: The Unix timestamp marking the end of the previous day. + + Returns: + A tuple containing the total staked amounts for yesterday, today, + and the total counts of correct predictions and slots evaluated. + """ + + total_staked_yesterday = ( + total_staked_today + ) = total_correct_predictions = total_slots_evaluated = 0 + for slot in slots: + slot_results = process_single_slot(slot, end_of_previous_day_timestamp) + if slot_results: + ( + staked_yesterday, + staked_today, + correct_predictions_count, + slots_evaluated, + ) = slot_results + total_staked_yesterday += staked_yesterday + total_staked_today += staked_today + total_correct_predictions += correct_predictions_count + total_slots_evaluated += slots_evaluated + return ( + total_staked_yesterday, + total_staked_today, + total_correct_predictions, + total_slots_evaluated, + ) + + +@enforce_types +def calculate_statistics_for_all_assets( + asset_ids: List[str], + contracts_list: List[ContractIdAndSPE], + start_ts_param: int, + end_ts_param: int, + network: str = "mainnet", +) -> Dict[str, Dict[str, Any]]: + """ + Calculates statistics for all provided assets based on + slot data within a specified time range. + + Args: + asset_ids: A list of asset identifiers for which statistics will be calculated. + start_ts_param: The Unix timestamp for the start of the time range. + end_ts_param: The Unix timestamp for the end of the time range. + network: The blockchain network to query ('mainnet' or 'testnet'). + + Returns: + A dictionary mapping asset IDs to another dictionary with + calculated statistics such as average accuracy and total staked amounts. + """ + + slots_by_asset = fetch_slots_for_all_assets( + asset_ids, start_ts_param, end_ts_param, network + ) + + overall_stats = {} + for asset_id, slots in slots_by_asset.items(): + ( + staked_yesterday, + staked_today, + correct_predictions_count, + slots_evaluated, + ) = aggregate_statistics(slots, end_ts_param - SECONDS_IN_A_DAY) + average_accuracy = ( + 0 + if correct_predictions_count == 0 + else (correct_predictions_count / slots_evaluated) * 100 + ) + + # filter contracts to get the contract with the current asset id + contract_item = next( + ( + contract_item + for contract_item in contracts_list + if contract_item["ID"] == asset_id + ), + None, + ) + + overall_stats[asset_id] = { + "token_name": contract_item["name"] if contract_item else None, + "average_accuracy": average_accuracy, + "total_staked_yesterday": staked_yesterday, + "total_staked_today": staked_today, + } + + return overall_stats @enforce_types @@ -77,7 +271,10 @@ def save_statistics_to_file(): "0x4ac2e51f9b1b0ca9e000dfe6032b24639b172703", network_param ) - contract_information = fetch_contract_id_and_spe(contract_addresses, network_param) + contracts_list_unfiltered = fetch_contract_id_and_spe( + contract_addresses, + network_param, + ) while True: try: @@ -85,13 +282,13 @@ def save_statistics_to_file(): for statistic_type in statistic_types: seconds_per_epoch = statistic_type["seconds_per_epoch"] - contracts = list( + contracts_list = list( filter( lambda item, spe=seconds_per_epoch: int( item["seconds_per_epoch"] ) == spe, - contract_information, + contracts_list_unfiltered, ) ) @@ -99,10 +296,14 @@ def save_statistics_to_file(): statistic_type["alias"] ) - contract_ids = [contract["id"] for contract in contracts] - # Get statistics for all contracts + contract_ids = [contract_item["ID"] for contract_item in contracts_list] + statistics = calculate_statistics_for_all_assets( - contract_ids, contracts, start_ts_param, end_ts_param, network_param + contract_ids, + contracts_list, + start_ts_param, + end_ts_param, + network_param, ) output.append( diff --git a/pdr_backend/accuracy/test/test_app.py b/pdr_backend/accuracy/test/test_app.py new file mode 100644 index 000000000..ecf175db6 --- /dev/null +++ b/pdr_backend/accuracy/test/test_app.py @@ -0,0 +1,91 @@ +from typing import List +from unittest.mock import patch + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_predictions import ContractIdAndSPE +from pdr_backend.accuracy.app import ( + calculate_prediction_result, + process_single_slot, + aggregate_statistics, + calculate_statistics_for_all_assets, +) +from pdr_backend.subgraph.subgraph_slot import PredictSlot + +# Sample data for tests +SAMPLE_PREDICT_SLOT = PredictSlot( + ID="0xAsset-12345", + slot="12345", + trueValues=[{"ID": "1", "trueValue": True}], + roundSumStakesUp=150.0, + roundSumStakes=100.0, +) + + +@enforce_types +def test_calculate_prediction_result(): + # Test the calculate_prediction_prediction_result function with expected inputs + result = calculate_prediction_result(150.0, 200.0) + assert result + + result = calculate_prediction_result(100.0, 250.0) + assert not result + + +@enforce_types +def test_process_single_slot(): + # Test the process_single_slot function + ( + staked_yesterday, + staked_today, + correct_predictions, + slots_evaluated, + ) = process_single_slot( + slot=SAMPLE_PREDICT_SLOT, end_of_previous_day_timestamp=12340 + ) + + assert staked_yesterday == 0.0 + assert staked_today == 100.0 + assert correct_predictions == 1 + assert slots_evaluated == 1 + + +@enforce_types +def test_aggregate_statistics(): + # Test the aggregate_statistics function + ( + total_staked_yesterday, + total_staked_today, + total_correct_predictions, + total_slots_evaluated, + ) = aggregate_statistics( + slots=[SAMPLE_PREDICT_SLOT], end_of_previous_day_timestamp=12340 + ) + assert total_staked_yesterday == 0.0 + assert total_staked_today == 100.0 + assert total_correct_predictions == 1 + assert total_slots_evaluated == 1 + + +@enforce_types +@patch("pdr_backend.accuracy.app.fetch_slots_for_all_assets") +def test_calculate_statistics_for_all_assets(mock_fetch_slots): + # Mocks + mock_fetch_slots.return_value = {"0xAsset": [SAMPLE_PREDICT_SLOT] * 1000} + contracts_list: List[ContractIdAndSPE] = [ + {"ID": "0xAsset", "seconds_per_epoch": 300, "name": "TEST/USDT"} + ] + + # Main work + statistics = calculate_statistics_for_all_assets( + asset_ids=["0xAsset"], + contracts_list=contracts_list, + start_ts_param=1000, + end_ts_param=2000, + network="mainnet", + ) + + print("test_calculate_statistics_for_all_assets", statistics) + # Verify + assert statistics["0xAsset"]["average_accuracy"] == 100.0 + mock_fetch_slots.assert_called_once_with(["0xAsset"], 1000, 2000, "mainnet") diff --git a/pdr_backend/aimodel/aimodel_data_factory.py b/pdr_backend/aimodel/aimodel_data_factory.py new file mode 100644 index 000000000..c0b65be32 --- /dev/null +++ b/pdr_backend/aimodel/aimodel_data_factory.py @@ -0,0 +1,135 @@ +import sys +from typing import Tuple + +import numpy as np +import pandas as pd +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.util.mathutil import fill_nans, has_nan + + +@enforce_types +class AimodelDataFactory: + """ + Roles: + - From mergedohlcv_df, create (X, y, x_df) for model building + + Where + rawohlcv files -> rawohlcv_dfs -> mergedohlcv_df, via ohlcv_data_factory + + X -- 2d array of [sample_i, var_i] : value -- inputs for model + y -- 1d array of [sample_i] -- target outputs for model + + x_df -- *pandas* DataFrame with cols like: + "binanceus:ETH-USDT:open:t-3", + "binanceus:ETH-USDT:open:t-2", + "binanceus:ETH-USDT:open:t-1", + "binanceus:ETH-USDT:high:t-3", + "binanceus:ETH-USDT:high:t-2", + "binanceus:ETH-USDT:high:t-1", + ... + (no "timestamp" or "datetime" column) + (and index = 0, 1, .. -- nothing special) + + Finally: + - "timestamp" values are ut: int is unix time, UTC, in ms (not s) + """ + + def __init__(self, ss: PredictoorSS): + self.ss = ss + + def create_xy( + self, + mergedohlcv_df: pl.DataFrame, + testshift: int, + do_fill_nans: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]: + """ + @arguments + mergedohlcv_df -- *polars* DataFrame. See class docstring + testshift -- to simulate across historical test data + do_fill_nans -- if any values are nan, fill them? (Via interpolation) + If you turn this off and mergedohlcv_df has nans, then X/y/etc gets nans + + @return -- + X -- 2d array of [sample_i, var_i] : value -- inputs for model + y -- 1d array of [sample_i] -- target outputs for model + x_df -- *pandas* DataFrame. See class docstring. + """ + # preconditions + assert isinstance(mergedohlcv_df, pl.DataFrame), pl.__class__ + assert "timestamp" in mergedohlcv_df.columns + assert "datetime" not in mergedohlcv_df.columns + + # every column should be ordered with oldest first, youngest last. + # let's verify! The timestamps should be in ascending order + uts = mergedohlcv_df["timestamp"].to_list() + assert uts == sorted(uts, reverse=False) + + # condition inputs + if do_fill_nans and has_nan(mergedohlcv_df): + mergedohlcv_df = fill_nans(mergedohlcv_df) + ss = self.ss.aimodel_ss + + # main work + x_df = pd.DataFrame() # build this up + + target_hist_cols = [ + f"{feed.exchange}:{feed.pair}:{feed.signal}" for feed in ss.feeds + ] + + for hist_col in target_hist_cols: + assert hist_col in mergedohlcv_df.columns, f"missing data col: {hist_col}" + z = mergedohlcv_df[hist_col].to_list() # [..., z(t-3), z(t-2), z(t-1)] + maxshift = testshift + ss.autoregressive_n + N_train = min(ss.max_n_train, len(z) - maxshift - 1) + if N_train <= 0: + print( + f"Too little data. len(z)={len(z)}, maxshift={maxshift}" + " (= testshift + autoregressive_n = " + f"{testshift} + {ss.autoregressive_n})\n" + "To fix: broaden time, shrink testshift, " + "or shrink autoregressive_n" + ) + sys.exit(1) + for delayshift in range(ss.autoregressive_n, 0, -1): # eg [2, 1, 0] + shift = testshift + delayshift + x_col = hist_col + f":t-{delayshift+1}" + assert (shift + N_train + 1) <= len(z) + # 1 point for test, the rest for train data + x_df[x_col] = _slice(z, -shift - N_train - 1, -shift) + + X = x_df.to_numpy() + + # y is set from yval_{exch_str, signal_str, pair_str} + # eg y = [BinEthC_-1, BinEthC_-2, ..., BinEthC_-450, BinEthC_-451] + ref_ss = self.ss + hist_col = f"{ref_ss.exchange_str}:{ref_ss.pair_str}:{ref_ss.signal_str}" + z = mergedohlcv_df[hist_col].to_list() + y = np.array(_slice(z, -testshift - N_train - 1, -testshift)) + + # postconditions + assert X.shape[0] == y.shape[0] + assert X.shape[0] <= (ss.max_n_train + 1) + assert X.shape[1] == ss.n + assert isinstance(x_df, pd.DataFrame) + + assert "timestamp" not in x_df.columns + assert "datetime" not in x_df.columns + + # return + return X, y, x_df + + +@enforce_types +def _slice(x: list, st: int, fin: int) -> list: + """Python list slice returns an empty list on x[st:fin] if st<0 and fin=0 + This overcomes that issue, for cases when st<0""" + assert st < 0 + assert fin <= 0 + assert st < fin + if fin == 0: + return x[st:] + return x[st:fin] diff --git a/pdr_backend/model_eng/model_factory.py b/pdr_backend/aimodel/aimodel_factory.py similarity index 80% rename from pdr_backend/model_eng/model_factory.py rename to pdr_backend/aimodel/aimodel_factory.py index 6133b3615..c0e230fab 100644 --- a/pdr_backend/model_eng/model_factory.py +++ b/pdr_backend/aimodel/aimodel_factory.py @@ -3,13 +3,13 @@ from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF -from pdr_backend.model_eng.model_ss import ModelSS +from pdr_backend.ppss.aimodel_ss import AimodelSS @enforce_types -class ModelFactory: - def __init__(self, model_ss: ModelSS): - self.model_ss = model_ss +class AimodelFactory: + def __init__(self, aimodel_ss: AimodelSS): + self.aimodel_ss = aimodel_ss def build(self, X_train, y_train): model = self._model() @@ -17,7 +17,7 @@ def build(self, X_train, y_train): return model def _model(self): - a = self.model_ss.model_approach + a = self.aimodel_ss.approach if a == "LIN": return linear_model.LinearRegression() if a == "GPR": diff --git a/pdr_backend/aimodel/test/conftest.py b/pdr_backend/aimodel/test/conftest.py new file mode 100644 index 000000000..ea8fb19c1 --- /dev/null +++ b/pdr_backend/aimodel/test/conftest.py @@ -0,0 +1,19 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.aimodel.aimodel_factory import AimodelFactory +from pdr_backend.ppss.aimodel_ss import AimodelSS + + +@enforce_types +@pytest.fixture() +def aimodel_factory(): + aimodel_ss = AimodelSS( + { + "approach": "LIN", + "max_n_train": 7, + "autoregressive_n": 3, + "input_feeds": ["binance BTC/USDT c"], + } + ) + return AimodelFactory(aimodel_ss) diff --git a/pdr_backend/aimodel/test/test_aimodel_data_factory.py b/pdr_backend/aimodel/test/test_aimodel_data_factory.py new file mode 100644 index 000000000..8fa35b0fb --- /dev/null +++ b/pdr_backend/aimodel/test/test_aimodel_data_factory.py @@ -0,0 +1,371 @@ +import numpy as np +import pandas as pd +import polars as pl +import pytest +from enforce_typing import enforce_types + +from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory +from pdr_backend.lake.merge_df import merge_rawohlcv_dfs +from pdr_backend.lake.test.resources import ( + BINANCE_BTC_DATA, + BINANCE_ETH_DATA, + ETHUSDT_RAWOHLCV_DFS, + KRAKEN_BTC_DATA, + KRAKEN_ETH_DATA, + _predictoor_ss, + _predictoor_ss_1feed, + _df_from_raw_data, + _mergedohlcv_df_ETHUSDT, +) +from pdr_backend.ppss.aimodel_ss import AimodelSS +from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.util.mathutil import fill_nans, has_nan + + +def test_create_xy__0(): + predictoor_ss = PredictoorSS( + { + "predict_feed": "binanceus ETH/USDT c 5m", + "bot_only": { + "s_until_epoch_end": 60, + "stake_amount": 1, + }, + "aimodel_ss": { + "input_feeds": ["binanceus ETH/USDT oc"], + "approach": "LIN", + "max_n_train": 4, + "autoregressive_n": 2, + }, + } + ) + mergedohlcv_df = pl.DataFrame( + { + # every column is ordered from youngest to oldest + "timestamp": [1, 2, 3, 4, 5, 6, 7, 8], # not used by AimodelDataFactory + # The underlying AR process is: close[t] = close[t-1] + open[t-1] + "binanceus:ETH/USDT:open": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + "binanceus:ETH/USDT:close": [2.0, 3.1, 4.2, 5.3, 6.4, 7.5, 8.6, 9.7], + } + ) + + target_X = np.array( + [ + [0.1, 0.1, 3.1, 4.2], # oldest + [0.1, 0.1, 4.2, 5.3], + [0.1, 0.1, 5.3, 6.4], + [0.1, 0.1, 6.4, 7.5], + [0.1, 0.1, 7.5, 8.6], + ] + ) # newest + target_y = np.array([5.3, 6.4, 7.5, 8.6, 9.7]) # oldest # newest + target_x_df = pd.DataFrame( + { + "binanceus:ETH/USDT:open:t-3": [0.1, 0.1, 0.1, 0.1, 0.1], + "binanceus:ETH/USDT:open:t-2": [0.1, 0.1, 0.1, 0.1, 0.1], + "binanceus:ETH/USDT:close:t-3": [3.1, 4.2, 5.3, 6.4, 7.5], + "binanceus:ETH/USDT:close:t-2": [4.2, 5.3, 6.4, 7.5, 8.6], + } + ) + + factory = AimodelDataFactory(predictoor_ss) + X, y, x_df = factory.create_xy(mergedohlcv_df, testshift=0) + + _assert_pd_df_shape(predictoor_ss.aimodel_ss, X, y, x_df) + assert np.array_equal(X, target_X) + assert np.array_equal(y, target_y) + assert x_df.equals(target_x_df) + + +@enforce_types +def test_create_xy__1exchange_1coin_1signal(tmpdir): + ss, _, aimodel_data_factory = _predictoor_ss_1feed( + tmpdir, "binanceus ETH/USDT h 5m" + ) + mergedohlcv_df = merge_rawohlcv_dfs(ETHUSDT_RAWOHLCV_DFS) + + # =========== have testshift = 0 + target_X = np.array( + [ + [11.0, 10.0, 9.0], # oldest + [10.0, 9.0, 8.0], + [9.0, 8.0, 7.0], + [8.0, 7.0, 6.0], + [7.0, 6.0, 5.0], + [6.0, 5.0, 4.0], + [5.0, 4.0, 3.0], + [4.0, 3.0, 2.0], + ] + ) # newest + + target_y = np.array( + [ + 8.0, # oldest + 7.0, + 6.0, + 5.0, + 4.0, + 3.0, + 2.0, + 1.0, # newest + ] + ) + target_x_df = pd.DataFrame( + { + "binanceus:ETH/USDT:high:t-4": [11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0], + "binanceus:ETH/USDT:high:t-3": [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0], + "binanceus:ETH/USDT:high:t-2": [9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0], + } + ) + + X, y, x_df = aimodel_data_factory.create_xy(mergedohlcv_df, testshift=0) + + _assert_pd_df_shape(ss.aimodel_ss, X, y, x_df) + assert np.array_equal(X, target_X) + assert np.array_equal(y, target_y) + assert x_df.equals(target_x_df) + + # =========== now, have testshift = 1 + target_X = np.array( + [ + [12.0, 11.0, 10.0], # oldest + [11.0, 10.0, 9.0], + [10.0, 9.0, 8.0], + [9.0, 8.0, 7.0], + [8.0, 7.0, 6.0], + [7.0, 6.0, 5.0], + [6.0, 5.0, 4.0], + [5.0, 4.0, 3.0], + ] + ) # newest + target_y = np.array( + [ + 9.0, # oldest + 8.0, + 7.0, + 6.0, + 5.0, + 4.0, + 3.0, + 2.0, # newest + ] + ) + target_x_df = pd.DataFrame( + { + "binanceus:ETH/USDT:high:t-4": [12.0, 11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0], + "binanceus:ETH/USDT:high:t-3": [11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0], + "binanceus:ETH/USDT:high:t-2": [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0], + } + ) + + X, y, x_df = aimodel_data_factory.create_xy(mergedohlcv_df, testshift=1) + + _assert_pd_df_shape(ss.aimodel_ss, X, y, x_df) + assert np.array_equal(X, target_X) + assert np.array_equal(y, target_y) + assert x_df.equals(target_x_df) + + # =========== now have a different max_n_train + target_X = np.array( + [ + [9.0, 8.0, 7.0], # oldest + [8.0, 7.0, 6.0], + [7.0, 6.0, 5.0], + [6.0, 5.0, 4.0], + [5.0, 4.0, 3.0], + [4.0, 3.0, 2.0], + ] + ) # newest + target_y = np.array([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]) # oldest # newest + target_x_df = pd.DataFrame( + { + "binanceus:ETH/USDT:high:t-4": [9.0, 8.0, 7.0, 6.0, 5.0, 4.0], + "binanceus:ETH/USDT:high:t-3": [8.0, 7.0, 6.0, 5.0, 4.0, 3.0], + "binanceus:ETH/USDT:high:t-2": [7.0, 6.0, 5.0, 4.0, 3.0, 2.0], + } + ) + + assert "max_n_train" in ss.aimodel_ss.d + ss.aimodel_ss.d["max_n_train"] = 5 + + X, y, x_df = aimodel_data_factory.create_xy(mergedohlcv_df, testshift=0) + + _assert_pd_df_shape(ss.aimodel_ss, X, y, x_df) + assert np.array_equal(X, target_X) + assert np.array_equal(y, target_y) + assert x_df.equals(target_x_df) + + +@enforce_types +def test_create_xy__2exchanges_2coins_2signals(): + rawohlcv_dfs = { + "binanceus": { + "BTC/USDT": _df_from_raw_data(BINANCE_BTC_DATA), + "ETH/USDT": _df_from_raw_data(BINANCE_ETH_DATA), + }, + "kraken": { + "BTC/USDT": _df_from_raw_data(KRAKEN_BTC_DATA), + "ETH/USDT": _df_from_raw_data(KRAKEN_ETH_DATA), + }, + } + + ss = _predictoor_ss( + "binanceus ETH/USDT h 5m", + ["binanceus BTC/USDT,ETH/USDT hl", "kraken BTC/USDT,ETH/USDT hl"], + ) + assert ss.aimodel_ss.autoregressive_n == 3 + assert ss.aimodel_ss.n == (4 + 4) * 3 + + mergedohlcv_df = merge_rawohlcv_dfs(rawohlcv_dfs) + + aimodel_data_factory = AimodelDataFactory(ss) + X, y, x_df = aimodel_data_factory.create_xy(mergedohlcv_df, testshift=0) + + _assert_pd_df_shape(ss.aimodel_ss, X, y, x_df) + found_cols = x_df.columns.tolist() + target_cols = [ + "binanceus:BTC/USDT:high:t-4", + "binanceus:BTC/USDT:high:t-3", + "binanceus:BTC/USDT:high:t-2", + "binanceus:ETH/USDT:high:t-4", + "binanceus:ETH/USDT:high:t-3", + "binanceus:ETH/USDT:high:t-2", + "binanceus:BTC/USDT:low:t-4", + "binanceus:BTC/USDT:low:t-3", + "binanceus:BTC/USDT:low:t-2", + "binanceus:ETH/USDT:low:t-4", + "binanceus:ETH/USDT:low:t-3", + "binanceus:ETH/USDT:low:t-2", + "kraken:BTC/USDT:high:t-4", + "kraken:BTC/USDT:high:t-3", + "kraken:BTC/USDT:high:t-2", + "kraken:ETH/USDT:high:t-4", + "kraken:ETH/USDT:high:t-3", + "kraken:ETH/USDT:high:t-2", + "kraken:BTC/USDT:low:t-4", + "kraken:BTC/USDT:low:t-3", + "kraken:BTC/USDT:low:t-2", + "kraken:ETH/USDT:low:t-4", + "kraken:ETH/USDT:low:t-3", + "kraken:ETH/USDT:low:t-2", + ] + assert found_cols == target_cols + + # test binanceus:ETH/USDT:high like in 1-signal + assert target_cols[3:6] == [ + "binanceus:ETH/USDT:high:t-4", + "binanceus:ETH/USDT:high:t-3", + "binanceus:ETH/USDT:high:t-2", + ] + Xa = X[:, 3:6] + assert Xa[-1, :].tolist() == [4, 3, 2] and y[-1] == 1 + assert Xa[-2, :].tolist() == [5, 4, 3] and y[-2] == 2 + assert Xa[0, :].tolist() == [11, 10, 9] and y[0] == 8 + + assert x_df.iloc[-1].tolist()[3:6] == [4, 3, 2] + assert x_df.iloc[-2].tolist()[3:6] == [5, 4, 3] + assert x_df.iloc[0].tolist()[3:6] == [11, 10, 9] + + assert x_df["binanceus:ETH/USDT:high:t-2"].tolist() == [ + 9, + 8, + 7, + 6, + 5, + 4, + 3, + 2, + ] + assert Xa[:, 2].tolist() == [9, 8, 7, 6, 5, 4, 3, 2] + + +@enforce_types +def test_create_xy__check_timestamp_order(tmpdir): + mergedohlcv_df, factory = _mergedohlcv_df_ETHUSDT(tmpdir) + + # timestamps should be descending order + uts = mergedohlcv_df["timestamp"].to_list() + assert uts == sorted(uts, reverse=False) + + # happy path + factory.create_xy(mergedohlcv_df, testshift=0) + + # failure path + bad_uts = sorted(uts, reverse=True) # bad order + bad_mergedohlcv_df = mergedohlcv_df.with_columns(pl.Series("timestamp", bad_uts)) + with pytest.raises(AssertionError): + factory.create_xy(bad_mergedohlcv_df, testshift=0) + + +@enforce_types +def test_create_xy__input_type(tmpdir): + mergedohlcv_df, aimodel_data_factory = _mergedohlcv_df_ETHUSDT(tmpdir) + + assert isinstance(mergedohlcv_df, pl.DataFrame) + assert isinstance(aimodel_data_factory, AimodelDataFactory) + + # create_xy() input should be pl + aimodel_data_factory.create_xy(mergedohlcv_df, testshift=0) + + # create_xy() inputs shouldn't be pd + with pytest.raises(AssertionError): + aimodel_data_factory.create_xy(mergedohlcv_df.to_pandas(), testshift=0) + + +@enforce_types +def test_create_xy__handle_nan(tmpdir): + # create mergedohlcv_df + _, _, aimodel_data_factory = _predictoor_ss_1feed(tmpdir, "binanceus ETH/USDT h 5m") + mergedohlcv_df = merge_rawohlcv_dfs(ETHUSDT_RAWOHLCV_DFS) + + # initial mergedohlcv_df should be ok + assert not has_nan(mergedohlcv_df) + + # now, corrupt mergedohlcv_df with NaN values + nan_indices = [1686805800000, 1686806700000, 1686808800000] + mergedohlcv_df = mergedohlcv_df.with_columns( + [ + pl.when(mergedohlcv_df["timestamp"].is_in(nan_indices)) + .then(pl.lit(None, pl.Float64)) + .otherwise(mergedohlcv_df["binanceus:ETH/USDT:high"]) + .alias("binanceus:ETH/USDT:high") + ] + ) + assert has_nan(mergedohlcv_df) + + # =========== initial testshift (0) + # run create_xy() and force the nans to stick around + # -> we want to ensure that we're building X/y with risk of nan + X, y, x_df = aimodel_data_factory.create_xy( + mergedohlcv_df, testshift=0, do_fill_nans=False + ) + assert has_nan(X) and has_nan(y) and has_nan(x_df) + + # nan approach 1: fix externally + mergedohlcv_df2 = fill_nans(mergedohlcv_df) + assert not has_nan(mergedohlcv_df2) + + # nan approach 2: explicitly tell create_xy to fill nans + X, y, x_df = aimodel_data_factory.create_xy( + mergedohlcv_df, testshift=0, do_fill_nans=True + ) + assert not has_nan(X) and not has_nan(y) and not has_nan(x_df) + + # nan approach 3: create_xy fills nans by default (best) + X, y, x_df = aimodel_data_factory.create_xy(mergedohlcv_df, testshift=0) + assert not has_nan(X) and not has_nan(y) and not has_nan(x_df) + + +# ==================================================================== +# utilities + + +@enforce_types +def _assert_pd_df_shape( + ss: AimodelSS, X: np.ndarray, y: np.ndarray, x_df: pd.DataFrame +): + assert X.shape[0] == y.shape[0] + assert X.shape[0] == (ss.max_n_train + 1) # 1 for test, rest for train + assert X.shape[1] == ss.n + + assert len(x_df) == X.shape[0] + assert len(x_df.columns) == ss.n diff --git a/pdr_backend/aimodel/test/test_aimodel_factory.py b/pdr_backend/aimodel/test/test_aimodel_factory.py new file mode 100644 index 000000000..6081c07d1 --- /dev/null +++ b/pdr_backend/aimodel/test/test_aimodel_factory.py @@ -0,0 +1,99 @@ +import warnings +from unittest.mock import Mock + +import numpy as np +import pytest +from enforce_typing import enforce_types + +from pdr_backend.aimodel.aimodel_factory import AimodelFactory +from pdr_backend.ppss.aimodel_ss import APPROACHES, AimodelSS + + +@enforce_types +def test_aimodel_factory_basic(): + for approach in APPROACHES: + aimodel_ss = AimodelSS( + { + "approach": approach, + "max_n_train": 7, + "autoregressive_n": 3, + "input_feeds": ["binance BTC/USDT c"], + } + ) + factory = AimodelFactory(aimodel_ss) + assert isinstance(factory.aimodel_ss, AimodelSS) + + (X_train, y_train, X_test, y_test) = _data() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # ignore ConvergenceWarning, more + model = factory.build(X_train, y_train) + + y_test_hat = model.predict(X_test) + assert y_test_hat.shape == y_test.shape + + +@enforce_types +def test_aimodel_accuracy_from_xy(aimodel_factory): + (X_train, y_train, X_test, y_test) = _data() + + aimodel = aimodel_factory.build(X_train, y_train) + + y_train_hat = aimodel.predict(X_train) + assert sum(abs(y_train - y_train_hat)) < 1e-10 # near-perfect since linear + + y_test_hat = aimodel.predict(X_test) + assert sum(abs(y_test - y_test_hat)) < 1e-10 + + +@enforce_types +def _data() -> tuple: + X_train = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) + y_train = f(X_train) + + X_test = np.array([[3, 5]]) + y_test = f(X_test) + + return (X_train, y_train, X_test, y_test) + + +@enforce_types +def f(X: np.ndarray) -> np.ndarray: + # y = 3 * x0 + 2 * x1 + y = 3.0 + 1.0 * X[:, 0] + 2.0 * X[:, 1] + return y + + +@enforce_types +def test_aimodel_accuracy_from_create_xy(aimodel_factory): + # This is from a test function in test_model_data_factory.py + + # The underlying AR process is: close[t] = close[t-1] + open[t-1] + X_train = np.array( + [ + [0.1, 0.1, 3.1, 4.2], # oldest + [0.1, 0.1, 4.2, 5.3], + [0.1, 0.1, 5.3, 6.4], + [0.1, 0.1, 6.4, 7.5], + [0.1, 0.1, 7.5, 8.6], + ] + ) # newest + y_train = np.array([5.3, 6.4, 7.5, 8.6, 9.7]) # oldest # newest + + aimodel = aimodel_factory.build(X_train, y_train) + + y_train_hat = aimodel.predict(X_train) + assert sum(abs(y_train - y_train_hat)) < 1e-10 # near-perfect since linear + + +@enforce_types +def test_aimodel_factory_bad_approach(): + aimodel_ss = Mock(spec=AimodelSS) + aimodel_ss.approach = "BAD" + factory = AimodelFactory(aimodel_ss) + + X_train, y_train, _, _ = _data() + + # forcefully change the model + with pytest.raises(ValueError): + factory.build(X_train, y_train) diff --git a/pdr_backend/analytics/check_network.py b/pdr_backend/analytics/check_network.py new file mode 100644 index 000000000..8788362e9 --- /dev/null +++ b/pdr_backend/analytics/check_network.py @@ -0,0 +1,185 @@ +import math +from typing import Union + +from enforce_typing import enforce_types + +from pdr_backend.cli.timeframe import s_to_timeframe_str +from pdr_backend.contract.token import Token +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.subgraph.subgraph_consume_so_far import get_consume_so_far_per_contract +from pdr_backend.util.constants import S_PER_DAY, S_PER_WEEK +from pdr_backend.util.constants_opf_addrs import get_opf_addresses +from pdr_backend.util.contract import get_address +from pdr_backend.util.mathutil import from_wei +from pdr_backend.util.timeutil import current_ut_s + +_N_FEEDS = 20 # magic number alert. FIX ME, shouldn't be hardcoded + + +@enforce_types +def print_stats(contract_dict: dict, field_name: str, threshold: float = 0.9): + n_slots = len(contract_dict["slots"]) + n_slots_with_field = sum( + 1 for slot in contract_dict["slots"] if len(slot[field_name]) > 0 + ) + if n_slots == 0: + n_slots = 1 + + status = "PASS" if n_slots_with_field / n_slots > threshold else "FAIL" + token_name = contract_dict["token"]["name"] + + s_per_epoch = int(contract_dict["secondsPerEpoch"]) + timeframe_str = s_to_timeframe_str(s_per_epoch) + print( + f"{token_name} {timeframe_str}: " + f"{n_slots_with_field}/{n_slots} {field_name} - {status}" + ) + + +@enforce_types +def check_dfbuyer( + dfbuyer_addr: str, + contract_query_result: dict, + subgraph_url: str, + token_amt: int, +): + cur_ut = current_ut_s() + start_ut = int((cur_ut // S_PER_WEEK) * S_PER_WEEK) + + contracts_sg_dict = contract_query_result["data"]["predictContracts"] + contract_addresses = [ + contract_sg_dict["id"] for contract_sg_dict in contracts_sg_dict + ] + amt_consume_so_far = get_consume_so_far_per_contract( + subgraph_url, + dfbuyer_addr, + start_ut, + contract_addresses, + ) + expect_amt_consume = get_expected_consume(cur_ut, token_amt) + print( + "Checking consume amounts (dfbuyer)" + f", expecting {expect_amt_consume} consume per contract" + ) + for addr in contract_addresses: + x = amt_consume_so_far[addr] + log_text = "PASS" if x >= expect_amt_consume else "FAIL" + print( + f" {log_text}... got {x} consume for contract: {addr}" + f", expected {expect_amt_consume}" + ) + + +@enforce_types +def get_expected_consume(for_ut: int, token_amt: int) -> Union[float, int]: + """ + @arguments + for_ut -- unix time, in ms, in UTC time zone + token_amt -- # tokens + + @return + exp_consume -- + """ + amt_per_feed_per_week = token_amt / 7 / _N_FEEDS + week_start_ut = (math.floor(for_ut / S_PER_WEEK)) * S_PER_WEEK + time_passed = for_ut - week_start_ut + n_weeks = int(time_passed / S_PER_DAY) + 1 + return n_weeks * amt_per_feed_per_week + + +@enforce_types +def check_network_main(ppss: PPSS, lookback_hours: int): + web3_pp = ppss.web3_pp + + cur_ut = current_ut_s() + start_ut = cur_ut - lookback_hours * 60 * 60 + query = """ + { + predictContracts{ + id + token{ + name + } + subscriptions(orderBy: expireTime orderDirection:desc first:10){ + user { + id + } + expireTime + } + slots(where:{slot_lt:%s, slot_gt:%s} orderBy: slot orderDirection:desc first:1000){ + slot + roundSumStakesUp + roundSumStakes + predictions(orderBy: timestamp orderDirection:desc){ + stake + user { + id + } + timestamp + payout{ + payout + predictedValue + trueValue + } + } + trueValues{ + trueValue + } + } + secondsPerEpoch + } + } + """ % ( + cur_ut, + start_ut, + ) + result = query_subgraph(web3_pp.subgraph_url, query, timeout=10.0) + + # check no of contracts + no_of_contracts = len(result["data"]["predictContracts"]) + if no_of_contracts >= 11: + print(f"Number of Predictoor contracts: {no_of_contracts} - OK") + else: + print(f"Number of Predictoor contracts: {no_of_contracts} - FAILED") + + print("-" * 60) + + # check number of predictions + print("Predictions:") + for contract in result["data"]["predictContracts"]: + print_stats(contract, "predictions") + + print() + + # Check number of truevals + print("True Values:") + for contract in result["data"]["predictContracts"]: + print_stats(contract, "trueValues") + print("\nChecking account balances") + + OCEAN_address = get_address(web3_pp, "Ocean") + OCEAN = Token(web3_pp, OCEAN_address) + + addresses = get_opf_addresses(web3_pp.network) + for name, address in addresses.items(): + ocean_bal = from_wei(OCEAN.balanceOf(address)) + native_bal = from_wei(web3_pp.web3_config.w3.eth.get_balance(address)) + + ocean_warning = ( + " WARNING LOW OCEAN BALANCE!" + if ocean_bal < 10 and name != "trueval" + else " OK " + ) + native_warning = " WARNING LOW NATIVE BALANCE!" if native_bal < 10 else " OK " + + print( + f"{name}: OCEAN: {ocean_bal:.2f}{ocean_warning}" + f", Native: {native_bal:.2f}{native_warning}" + ) + + # ---------------- dfbuyer ---------------- + + dfbuyer_addr = addresses["dfbuyer"].lower() + token_amt = 44460 + check_dfbuyer(dfbuyer_addr, result, web3_pp.subgraph_url, token_amt) diff --git a/pdr_backend/analytics/get_predictions_info.py b/pdr_backend/analytics/get_predictions_info.py new file mode 100644 index 000000000..4749bc25c --- /dev/null +++ b/pdr_backend/analytics/get_predictions_info.py @@ -0,0 +1,57 @@ +from typing import Union + +from enforce_typing import enforce_types + +from pdr_backend.analytics.predictoor_stats import get_cli_statistics +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_predictions import ( + FilterMode, + fetch_filtered_predictions, + get_all_contract_ids_by_owner, +) +from pdr_backend.util.csvs import save_analysis_csv +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.timeutil import ms_to_seconds, timestr_to_ut + + +@enforce_types +def get_predictions_info_main( + ppss: PPSS, + feed_addrs_str: Union[str, None], + start_timestr: str, + end_timestr: str, + pq_dir: str, +): + network = get_sapphire_postfix(ppss.web3_pp.network) + start_ut: int = ms_to_seconds(timestr_to_ut(start_timestr)) + end_ut: int = ms_to_seconds(timestr_to_ut(end_timestr)) + + # filter by feed contract address + feed_contract_list = get_all_contract_ids_by_owner( + owner_address=ppss.web3_pp.owner_addrs, + network=network, + ) + feed_contract_list = [f.lower() for f in feed_contract_list] + + if feed_addrs_str: + keep = feed_addrs_str.lower().split(",") + feed_contract_list = [f for f in feed_contract_list if f in keep] + + # fetch predictions + predictions = fetch_filtered_predictions( + start_ut, + end_ut, + feed_contract_list, + network, + FilterMode.CONTRACT, + payout_only=True, + trueval_only=True, + ) + + if not predictions: + print("No records found. Please adjust start and end times.") + return + + save_analysis_csv(predictions, pq_dir) + + get_cli_statistics(predictions) diff --git a/pdr_backend/analytics/get_predictoors_info.py b/pdr_backend/analytics/get_predictoors_info.py new file mode 100644 index 000000000..e40b6f347 --- /dev/null +++ b/pdr_backend/analytics/get_predictoors_info.py @@ -0,0 +1,42 @@ +from typing import Union + +from enforce_typing import enforce_types + +from pdr_backend.analytics.predictoor_stats import get_cli_statistics +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_predictions import ( + FilterMode, + fetch_filtered_predictions, +) +from pdr_backend.util.csvs import save_prediction_csv +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.timeutil import ms_to_seconds, timestr_to_ut + + +@enforce_types +def get_predictoors_info_main( + ppss: PPSS, + pdr_addrs_str: Union[str, None], + start_timestr: str, + end_timestr: str, + csv_output_dir: str, +): + network = get_sapphire_postfix(ppss.web3_pp.network) + start_ut: int = ms_to_seconds(timestr_to_ut(start_timestr)) + end_ut: int = ms_to_seconds(timestr_to_ut(end_timestr)) + + pdr_addrs_filter = [] + if pdr_addrs_str: + pdr_addrs_filter = pdr_addrs_str.lower().split(",") + + predictions = fetch_filtered_predictions( + start_ut, + end_ut, + pdr_addrs_filter, + network, + FilterMode.PREDICTOOR, + ) + + save_prediction_csv(predictions, csv_output_dir) + + get_cli_statistics(predictions) diff --git a/pdr_backend/analytics/get_traction_info.py b/pdr_backend/analytics/get_traction_info.py new file mode 100644 index 000000000..5633f2395 --- /dev/null +++ b/pdr_backend/analytics/get_traction_info.py @@ -0,0 +1,42 @@ +"""This module currently gives traction wrt predictoors. +At some point, we can expand it into traction info wrt traders & txs too. +""" + +from enforce_typing import enforce_types + +from pdr_backend.analytics.predictoor_stats import ( + get_slot_statistics, + get_traction_statistics, + plot_slot_daily_statistics, + plot_traction_cum_sum_statistics, + plot_traction_daily_statistics, +) +from pdr_backend.lake.gql_data_factory import GQLDataFactory +from pdr_backend.ppss.ppss import PPSS + + +@enforce_types +def get_traction_info_main( + ppss: PPSS, start_timestr: str, end_timestr: str, pq_dir: str +): + lake_ss = ppss.lake_ss + lake_ss.d["st_timestr"] = start_timestr + lake_ss.d["fin_timestr"] = end_timestr + + gql_data_factory = GQLDataFactory(ppss) + gql_dfs = gql_data_factory.get_gql_dfs() + + if len(gql_dfs) == 0: + print("No records found. Please adjust start and end times.") + return + + predictions_df = gql_dfs["pdr_predictions"] + + # calculate predictoor traction statistics and draw plots + stats_df = get_traction_statistics(predictions_df) + plot_traction_cum_sum_statistics(stats_df, pq_dir) + plot_traction_daily_statistics(stats_df, pq_dir) + + # calculate slot statistics and draw plots + slots_df = get_slot_statistics(predictions_df) + plot_slot_daily_statistics(slots_df, pq_dir) diff --git a/pdr_backend/analytics/predictoor_stats.py b/pdr_backend/analytics/predictoor_stats.py new file mode 100644 index 000000000..0cb007992 --- /dev/null +++ b/pdr_backend/analytics/predictoor_stats.py @@ -0,0 +1,433 @@ +import os +from typing import Dict, List, Set, Tuple, TypedDict + +import matplotlib.pyplot as plt +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.subgraph.prediction import Prediction +from pdr_backend.util.csvs import get_plots_dir + + +class PairTimeframeStat(TypedDict): + pair: str + timeframe: str + accuracy: float + stake: float + payout: float + number_of_predictions: int + + +class PredictoorStat(TypedDict): + predictoor_address: str + accuracy: float + stake: float + payout: float + number_of_predictions: int + details: Set[Tuple[str, str, str]] + + +@enforce_types +def aggregate_prediction_statistics( + all_predictions: List[Prediction], +) -> Tuple[Dict[str, Dict], int]: + """ + Aggregates statistics from a list of prediction objects. It organizes statistics + by currency pair and timeframe and predictor address. For each category, it + tallies the total number of predictions, the number of correct predictions, + and the total stakes and payouts. It also returns the total number of correct + predictions across all categories. + + Args: + all_predictions (List[Prediction]): A list of Prediction objects to aggregate. + + Returns: + Tuple[Dict[str, Dict], int]: A tuple containing a dictionary of aggregated + statistics and the total number of correct predictions. + """ + stats: Dict[str, Dict] = {"pair_timeframe": {}, "predictor": {}} + correct_predictions = 0 + + for prediction in all_predictions: + pair_timeframe_key = (prediction.pair, prediction.timeframe) + predictor_key = prediction.user + source = prediction.source + + is_correct = prediction.prediction == prediction.trueval + + if pair_timeframe_key not in stats["pair_timeframe"]: + stats["pair_timeframe"][pair_timeframe_key] = { + "correct": 0, + "total": 0, + "stake": 0, + "payout": 0.0, + } + + if predictor_key not in stats["predictor"]: + stats["predictor"][predictor_key] = { + "correct": 0, + "total": 0, + "stake": 0, + "payout": 0.0, + "details": set(), + } + + if is_correct: + correct_predictions += 1 + stats["pair_timeframe"][pair_timeframe_key]["correct"] += 1 + stats["predictor"][predictor_key]["correct"] += 1 + + stats["pair_timeframe"][pair_timeframe_key]["total"] += 1 + stats["pair_timeframe"][pair_timeframe_key]["stake"] += prediction.stake + stats["pair_timeframe"][pair_timeframe_key]["payout"] += prediction.payout + + stats["predictor"][predictor_key]["total"] += 1 + stats["predictor"][predictor_key]["stake"] += prediction.stake + stats["predictor"][predictor_key]["payout"] += prediction.payout + stats["predictor"][predictor_key]["details"].add( + (prediction.pair, prediction.timeframe, source) + ) + + return stats, correct_predictions + + +@enforce_types +def get_endpoint_statistics( + all_predictions: List[Prediction], +) -> Tuple[float, List[PairTimeframeStat], List[PredictoorStat]]: + """ + Calculates the overall accuracy of predictions, and aggregates detailed prediction + statistics by currency pair and timeframe with predictoor. + + The function first determines the overall accuracy of all given predictions. + It then organizes individual prediction statistics into two separate lists: + one for currency pair and timeframe statistics, and another for predictor statistics. + + Args: + all_predictions (List[Prediction]): A list of Prediction objects to be analyzed. + + Returns: + Tuple[float, List[Dict[str, Any]], List[Dict[str, Any]]]: A tuple containing the + overall accuracy as a float, a list of dictionaries with statistics for each + currency pair and timeframe, and a list of dictionaries with statistics for each + predictor. + """ + total_predictions = len(all_predictions) + stats, correct_predictions = aggregate_prediction_statistics(all_predictions) + + overall_accuracy = ( + correct_predictions / total_predictions * 100 if total_predictions else 0 + ) + + pair_timeframe_stats: List[PairTimeframeStat] = [] + for key, stat_pair_timeframe_item in stats["pair_timeframe"].items(): + pair, timeframe = key + accuracy = ( + stat_pair_timeframe_item["correct"] + / stat_pair_timeframe_item["total"] + * 100 + if stat_pair_timeframe_item["total"] + else 0 + ) + pair_timeframe_stat: PairTimeframeStat = { + "pair": pair, + "timeframe": timeframe, + "accuracy": accuracy, + "stake": stat_pair_timeframe_item["stake"], + "payout": stat_pair_timeframe_item["payout"], + "number_of_predictions": stat_pair_timeframe_item["total"], + } + pair_timeframe_stats.append(pair_timeframe_stat) + + predictoor_stats: List[PredictoorStat] = [] + for predictoor_addr, stat_predictoor_item in stats["predictor"].items(): + accuracy = ( + stat_predictoor_item["correct"] / stat_predictoor_item["total"] * 100 + if stat_predictoor_item["total"] + else 0 + ) + predictoor_stat: PredictoorStat = { + "predictoor_address": predictoor_addr, + "accuracy": accuracy, + "stake": stat_predictoor_item["stake"], + "payout": stat_predictoor_item["payout"], + "number_of_predictions": stat_predictoor_item["total"], + "details": set(stat_predictoor_item["details"]), + } + predictoor_stats.append(predictoor_stat) + + return overall_accuracy, pair_timeframe_stats, predictoor_stats + + +@enforce_types +def get_cli_statistics(all_predictions: List[Prediction]) -> None: + total_predictions = len(all_predictions) + + stats, correct_predictions = aggregate_prediction_statistics(all_predictions) + + if total_predictions == 0: + print("No predictions found.") + return + + if correct_predictions == 0: + print("No correct predictions found.") + return + + print(f"Overall Accuracy: {correct_predictions/total_predictions*100:.2f}%") + + for key, stat_pair_timeframe_item in stats["pair_timeframe"].items(): + pair, timeframe = key + accuracy = ( + stat_pair_timeframe_item["correct"] + / stat_pair_timeframe_item["total"] + * 100 + ) + print(f"Accuracy for Pair: {pair}, Timeframe: {timeframe}: {accuracy:.2f}%") + print(f"Total stake: {stat_pair_timeframe_item['stake']}") + print(f"Total payout: {stat_pair_timeframe_item['payout']}") + print(f"Number of predictions: {stat_pair_timeframe_item['total']}\n") + + for predictoor_addr, stat_predictoor_item in stats["predictor"].items(): + accuracy = stat_predictoor_item["correct"] / stat_predictoor_item["total"] * 100 + print(f"Accuracy for Predictoor Address: {predictoor_addr}: {accuracy:.2f}%") + print(f"Stake: {stat_predictoor_item['stake']}") + print(f"Payout: {stat_predictoor_item['payout']}") + print(f"Number of predictions: {stat_predictoor_item['total']}") + print("Details of Predictions:") + for detail in stat_predictoor_item["details"]: + print(f"Pair: {detail[0]}, Timeframe: {detail[1]}, Source: {detail[2]}") + print("\n") + + +@enforce_types +def get_traction_statistics(preds_df: pl.DataFrame) -> pl.DataFrame: + # Calculate predictoor traction statistics + # Predictoor addresses are aggregated historically + stats_df = ( + preds_df.with_columns( + [ + # use strftime(%Y-%m-%d %H:00:00) to get hourly intervals + pl.from_epoch("timestamp", time_unit="s") + .dt.strftime("%Y-%m-%d") + .alias("datetime"), + ] + ) + .group_by("datetime") + .agg( + [ + pl.col("user").unique().alias("daily_unique_predictoors"), + pl.col("user").unique().count().alias("daily_unique_predictoors_count"), + pl.lit(1).alias("index"), + ] + ) + .sort("datetime") + .with_columns( + [ + pl.col("daily_unique_predictoors") + .cumulative_eval(pl.element().explode().unique().count()) + .over("index") + .alias("cum_daily_unique_predictoors_count") + ] + ) + .select( + [ + "datetime", + "daily_unique_predictoors_count", + "cum_daily_unique_predictoors_count", + ] + ) + ) + + return stats_df + + +@enforce_types +def plot_traction_daily_statistics(stats_df: pl.DataFrame, pq_dir: str) -> None: + assert "datetime" in stats_df.columns + assert "daily_unique_predictoors_count" in stats_df.columns + + charts_dir = get_plots_dir(pq_dir) + + dates = stats_df["datetime"].to_list() + ticks = int(len(dates) / 5) if len(dates) > 5 else 2 + + # draw unique_predictoors + chart_path = os.path.join(charts_dir, "daily_unique_predictoors.png") + plt.figure(figsize=(10, 6)) + plt.plot( + stats_df["datetime"].to_pandas(), + stats_df["daily_unique_predictoors_count"], + marker="o", + linestyle="-", + ) + plt.xlabel("Date") + plt.ylabel("# Unique Predictoor Addresses") + plt.title("Daily # Unique Predictoor Addresses") + plt.xticks(range(0, len(dates), ticks), dates[::ticks], rotation=90) + plt.tight_layout() + plt.savefig(chart_path) + plt.close() + print("Chart created:", chart_path) + + +@enforce_types +def plot_traction_cum_sum_statistics(stats_df: pl.DataFrame, pq_dir: str) -> None: + assert "datetime" in stats_df.columns + assert "cum_daily_unique_predictoors_count" in stats_df.columns + + charts_dir = get_plots_dir(pq_dir) + + dates = stats_df["datetime"].to_list() + ticks = int(len(dates) / 5) if len(dates) > 5 else 2 + + # draw cum_unique_predictoors + chart_path = os.path.join(charts_dir, "daily_cumulative_unique_predictoors.png") + plt.figure(figsize=(10, 6)) + plt.plot( + stats_df["datetime"].to_pandas(), + stats_df["cum_daily_unique_predictoors_count"], + marker="o", + linestyle="-", + ) + plt.xlabel("Date") + plt.ylabel("# Unique Predictoor Addresses") + plt.title("Cumulative # Unique Predictoor Addresses") + plt.xticks(range(0, len(dates), ticks), dates[::ticks], rotation=90) + plt.tight_layout() + plt.savefig(chart_path) + plt.close() + print("Chart created:", chart_path) + + +@enforce_types +def get_slot_statistics(preds_df: pl.DataFrame) -> pl.DataFrame: + # Create a key to group predictions + slots_df = ( + preds_df.with_columns( + [ + (pl.col("pair").cast(str) + "-" + pl.col("timeframe").cast(str)).alias( + "pair_timeframe" + ), + ( + pl.col("pair").cast(str) + + "-" + + pl.col("timeframe").cast(str) + + "-" + + pl.col("slot").cast(str) + ).alias("pair_timeframe_slot"), + ] + ) + .group_by("pair_timeframe_slot") + .agg( + [ + pl.col("pair").first(), + pl.col("timeframe").first(), + pl.col("slot").first(), + pl.col("pair_timeframe").first(), + # use strftime(%Y-%m-%d %H:00:00) to get hourly intervals + pl.from_epoch("timestamp", time_unit="s") + .first() + .dt.strftime("%Y-%m-%d") + .alias("datetime"), + pl.col("user") + .unique() + .count() + .alias("n_predictoors"), # n unique predictoors + pl.col("payout").sum().alias("slot_payout"), # Sum of slot payout + pl.col("stake").sum().alias("slot_stake"), # Sum of slot stake + ] + ) + .sort(["pair", "timeframe", "slot"]) + ) + + return slots_df + + +def calculate_slot_daily_statistics( + slots_df: pl.DataFrame, +) -> pl.DataFrame: + def get_mean_slots_slots_df(slots_df: pl.DataFrame) -> pl.DataFrame: + return slots_df.select( + [ + pl.col("pair_timeframe").first(), + pl.col("datetime").first(), + pl.col("slot_stake").mean().alias("mean_stake"), + pl.col("slot_payout").mean().alias("mean_payout"), + pl.col("n_predictoors").mean().alias("mean_n_predictoors"), + ] + ) + + # for each take a sample of up-to 5 + # then for each calc daily mean_stake, mean_payout, ... + # then for each sum those numbers across all feeds + slots_daily_df = ( + slots_df.group_by(["pair_timeframe", "datetime"]) + .map_groups( + lambda df: get_mean_slots_slots_df(df.sample(5)) + if len(df) > 5 + else get_mean_slots_slots_df(df) + ) + .group_by("datetime") + .agg( + [ + pl.col("mean_stake").sum().alias("daily_average_stake"), + pl.col("mean_payout").sum().alias("daily_average_payout"), + pl.col("mean_n_predictoors") + .mean() + .alias("daily_average_predictoor_count"), + ] + ) + .sort("datetime") + ) + + return slots_daily_df + + +def plot_slot_daily_statistics(slots_df: pl.DataFrame, pq_dir: str) -> None: + assert "pair_timeframe" in slots_df.columns + assert "slot" in slots_df.columns + assert "n_predictoors" in slots_df.columns + + # calculate slot daily statistics + slots_daily_df = calculate_slot_daily_statistics(slots_df) + + charts_dir = get_plots_dir(pq_dir) + + dates = slots_daily_df["datetime"].to_list() + ticks = int(len(dates) / 5) if len(dates) > 5 else 2 + + # draw daily predictoor stake in $OCEAN + chart_path = os.path.join(charts_dir, "daily_average_stake.png") + plt.figure(figsize=(10, 6)) + plt.plot( + slots_daily_df["datetime"].to_pandas(), + slots_daily_df["daily_average_stake"], + marker="o", + linestyle="-", + ) + plt.xlabel("Date") + plt.ylabel("Average $OCEAN Staked") + plt.title("Daily average $OCEAN staked per slot, across all Feeds") + plt.xticks(range(0, len(dates), ticks), dates[::ticks], rotation=90) + plt.tight_layout() + plt.savefig(chart_path) + plt.close() + print("Chart created:", chart_path) + + # draw daily predictoor payouts in $OCEAN + chart_path = os.path.join(charts_dir, "daily_slot_average_predictoors.png") + plt.figure(figsize=(10, 6)) + plt.plot( + slots_daily_df["datetime"].to_pandas(), + slots_daily_df["daily_average_predictoor_count"], + marker="o", + linestyle="-", + ) + plt.xlabel("Date") + plt.ylabel("Average Predictoors") + plt.title("Average # Predictoors competing per slot, per feed") + plt.xticks(range(0, len(dates), ticks), dates[::ticks], rotation=90) + plt.tight_layout() + plt.savefig(chart_path) + plt.close() + print("Chart created:", chart_path) diff --git a/pdr_backend/analytics/test/conftest.py b/pdr_backend/analytics/test/conftest.py new file mode 100644 index 000000000..ceab8f803 --- /dev/null +++ b/pdr_backend/analytics/test/conftest.py @@ -0,0 +1,22 @@ +import pytest + +from pdr_backend.subgraph.prediction import ( + mock_daily_predictions, + mock_first_predictions, + mock_second_predictions, +) + + +@pytest.fixture() +def _sample_first_predictions(): + return mock_first_predictions() + + +@pytest.fixture() +def _sample_second_predictions(): + return mock_second_predictions() + + +@pytest.fixture() +def _sample_daily_predictions(): + return mock_daily_predictions() diff --git a/pdr_backend/analytics/test/test_check_network.py b/pdr_backend/analytics/test/test_check_network.py new file mode 100644 index 000000000..ef725f85c --- /dev/null +++ b/pdr_backend/analytics/test/test_check_network.py @@ -0,0 +1,170 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.analytics.check_network import ( + _N_FEEDS, + check_dfbuyer, + check_network_main, + get_expected_consume, +) +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.util.constants import S_PER_DAY, S_PER_WEEK +from pdr_backend.util.mathutil import to_wei + +PATH = "pdr_backend.analytics.check_network" + +MOCK_CUR_UT = 1702826080 + + +@enforce_types +@patch( + f"{PATH}.get_consume_so_far_per_contract", + side_effect=Mock(return_value={"0x1": 120}), +) +@patch( + f"{PATH}.get_expected_consume", + side_effect=Mock(return_value=100), +) +@patch( + f"{PATH}.current_ut_s", + side_effect=Mock(return_value=MOCK_CUR_UT), +) +def test_check_dfbuyer( # pylint: disable=unused-argument + mock_current_ut_ms, + mock_get_expected_consume_, + mock_get_consume_so_far_per_contract_, + capsys, +): + dfbuyer_addr = "0x1" + contract_query_result = {"data": {"predictContracts": [{"id": "0x1"}]}} + subgraph_url = "test_dfbuyer" + token_amt = 3 + check_dfbuyer(dfbuyer_addr, contract_query_result, subgraph_url, token_amt) + captured = capsys.readouterr() + + target_str = ( + "Checking consume amounts (dfbuyer), " + "expecting 100 consume per contract\n " + "PASS... got 120 consume for contract: 0x1, expected 100\n" + ) + assert target_str in captured.out + + cur_ut = MOCK_CUR_UT + start_ut = int((cur_ut // S_PER_WEEK) * S_PER_WEEK) + mock_get_consume_so_far_per_contract_.assert_called_once_with( + subgraph_url, dfbuyer_addr, start_ut, ["0x1"] + ) + mock_get_expected_consume_.assert_called_once_with(int(cur_ut), token_amt) + + +@enforce_types +def test_get_expected_consume(): + # Test case 1: Beginning of week + for_ut = S_PER_WEEK # Start of second week + token_amt = 140 + expected = token_amt / 7 / _N_FEEDS # Expected consume for one interval + assert get_expected_consume(for_ut, token_amt) == expected + + # Test case 2: End of first interval + for_ut = S_PER_WEEK + S_PER_DAY # Start of second day of second week + expected = 2 * (token_amt / 7 / _N_FEEDS) # Expected consume for two intervals + assert get_expected_consume(for_ut, token_amt) == expected + + # Test case 3: Middle of week + for_ut = S_PER_WEEK + 3 * S_PER_DAY # Start of fourth day of second week + expected = 4 * (token_amt / 7 / _N_FEEDS) # Expected consume for four intervals + assert get_expected_consume(for_ut, token_amt) == expected + + # Test case 4: End of week + for_ut = 2 * S_PER_WEEK - 1 # Just before end of second week + expected = 7 * (token_amt / 7 / _N_FEEDS) # Expected consume for seven intervals + assert get_expected_consume(for_ut, token_amt) == expected + + +@enforce_types +@patch(f"{PATH}.check_dfbuyer") +@patch(f"{PATH}.get_opf_addresses") +@patch(f"{PATH}.query_subgraph") +@patch(f"{PATH}.Token") +def test_check_network_main( # pylint: disable=unused-argument + mock_token, + mock_query_subgraph, + mock_get_opf_addresses, + mock_check_dfbuyer, + tmpdir, + monkeypatch, +): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_get_opf_addresses.return_value = { + "dfbuyer": "0xdfBuyerAddress", + "some_other_address": "0xSomeOtherAddress", + } + mock_query_subgraph.return_value = {"data": {"predictContracts": []}} + mock_token.return_value.balanceOf.return_value = to_wei(1000) + + mock_w3 = Mock() # pylint: disable=not-callable + mock_w3.eth.get_balance.return_value = to_wei(1000) + ppss.web3_pp.web3_config.w3 = mock_w3 + check_network_main(ppss, lookback_hours=24) + + mock_get_opf_addresses.assert_called_once_with("sapphire-mainnet") + assert mock_query_subgraph.call_count == 1 + mock_token.assert_called() + ppss.web3_pp.web3_config.w3.eth.get_balance.assert_called() + + +@enforce_types +@patch(f"{PATH}.check_dfbuyer") +@patch(f"{PATH}.get_opf_addresses") +@patch(f"{PATH}.Token") +def test_check_network_others( # pylint: disable=unused-argument + mock_token, + mock_get_opf_addresses, + mock_check_dfbuyer, + tmpdir, + monkeypatch, +): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + mock_query_subgraph = Mock() + + # test if predictoor contracts are found, iterates through them + with patch(f"{PATH}.query_subgraph") as mock_query_subgraph: + mock_query_subgraph.return_value = { + "data": { + "predictContracts": [ + { + "slots": {}, + "token": {"name": "aa"}, + "secondsPerEpoch": 86400, + }, + { + "slots": {}, + "token": {"name": "bb"}, + "secondsPerEpoch": 86400, + }, + ] + } + } + check_network_main(ppss, lookback_hours=24) + assert mock_query_subgraph.call_count == 1 + assert mock_check_dfbuyer.call_count == 1 + + +@enforce_types +@patch(f"{PATH}.check_dfbuyer") +@patch(f"{PATH}.get_opf_addresses") +@patch(f"{PATH}.Token") +def test_check_network_without_mock( # pylint: disable=unused-argument + mock_token, + mock_get_opf_addresses, + mock_check_dfbuyer, + tmpdir, + monkeypatch, +): + mock_token.balanceOf.return_value = 1000e18 + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + check_network_main(ppss, lookback_hours=1) + assert mock_check_dfbuyer.call_count == 1 diff --git a/pdr_backend/analytics/test/test_get_predictions_info.py b/pdr_backend/analytics/test/test_get_predictions_info.py new file mode 100644 index 000000000..61c4663cc --- /dev/null +++ b/pdr_backend/analytics/test/test_get_predictions_info.py @@ -0,0 +1,68 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.analytics.get_predictions_info import get_predictions_info_main +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.subgraph.subgraph_predictions import FilterMode + + +@enforce_types +def test_get_predictions_info_main_mainnet( + _sample_first_predictions, + tmpdir, +): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_getids = Mock(return_value=["0x123", "0x234"]) + mock_fetch = Mock(return_value=_sample_first_predictions) + mock_save = Mock() + mock_getstats = Mock() + + PATH = "pdr_backend.analytics.get_predictions_info" + with patch(f"{PATH}.get_all_contract_ids_by_owner", mock_getids), patch( + f"{PATH}.fetch_filtered_predictions", mock_fetch + ), patch(f"{PATH}.save_analysis_csv", mock_save), patch( + f"{PATH}.get_cli_statistics", mock_getstats + ): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_predictions_info_main( + ppss, "0x123", st_timestr, fin_timestr, "parquet_data/" + ) + + mock_fetch.assert_called_with( + 1698883200, + 1699142400, + ["0x123"], + "mainnet", + FilterMode.CONTRACT, + payout_only=True, + trueval_only=True, + ) + mock_save.assert_called() + mock_getstats.assert_called_with(_sample_first_predictions) + + +@enforce_types +def test_get_predictions_info_empty(tmpdir, capfd): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_getids = Mock(return_value=[]) + mock_fetch = Mock(return_value={}) + + PATH = "pdr_backend.analytics.get_predictions_info" + with patch(f"{PATH}.get_all_contract_ids_by_owner", mock_getids), patch( + f"{PATH}.fetch_filtered_predictions", mock_fetch + ): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_predictions_info_main( + ppss, "0x123", st_timestr, fin_timestr, "parquet_data/" + ) + + assert ( + "No records found. Please adjust start and end times" in capfd.readouterr().out + ) diff --git a/pdr_backend/analytics/test/test_get_predictoors_info.py b/pdr_backend/analytics/test/test_get_predictoors_info.py new file mode 100644 index 000000000..2b736f336 --- /dev/null +++ b/pdr_backend/analytics/test/test_get_predictoors_info.py @@ -0,0 +1,38 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.analytics.get_predictoors_info import get_predictoors_info_main +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.subgraph.subgraph_predictions import FilterMode + + +@enforce_types +def test_get_predictoors_info_main_mainnet(tmpdir): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_fetch = Mock(return_value=[]) + mock_save = Mock() + mock_getstats = Mock() + + PATH = "pdr_backend.analytics.get_predictoors_info" + with patch(f"{PATH}.fetch_filtered_predictions", mock_fetch), patch( + f"{PATH}.save_prediction_csv", mock_save + ), patch(f"{PATH}.get_cli_statistics", mock_getstats): + get_predictoors_info_main( + ppss, + "0x123", + "2023-01-01", + "2023-01-02", + "parquet_data/", + ) + + mock_fetch.assert_called_with( + 1672531200, + 1672617600, + ["0x123"], + "mainnet", + FilterMode.PREDICTOOR, + ) + mock_save.assert_called_with([], "parquet_data/") + mock_getstats.assert_called_with([]) diff --git a/pdr_backend/analytics/test/test_get_traction_info.py b/pdr_backend/analytics/test/test_get_traction_info.py new file mode 100644 index 000000000..d13ccb180 --- /dev/null +++ b/pdr_backend/analytics/test/test_get_traction_info.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock, patch + +import polars as pl +import pytest +from enforce_typing import enforce_types + +from pdr_backend.analytics.get_traction_info import get_traction_info_main +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.subgraph.subgraph_predictions import FilterMode +from pdr_backend.util.timeutil import timestr_to_ut + + +@enforce_types +def test_get_traction_info_main_mainnet( + _sample_daily_predictions, + tmpdir, +): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_traction_stat = Mock() + mock_plot_cumsum = Mock() + mock_plot_daily = Mock() + mock_getids = Mock(return_value=["0x123"]) + mock_fetch = Mock(return_value=_sample_daily_predictions) + + PATH = "pdr_backend.analytics.get_traction_info" + PATH2 = "pdr_backend.lake" + with patch(f"{PATH}.get_traction_statistics", mock_traction_stat), patch( + f"{PATH}.plot_traction_cum_sum_statistics", mock_plot_cumsum + ), patch(f"{PATH}.plot_traction_daily_statistics", mock_plot_daily), patch( + f"{PATH2}.gql_data_factory.get_all_contract_ids_by_owner", mock_getids + ), patch( + f"{PATH2}.table_pdr_predictions.fetch_filtered_predictions", mock_fetch + ): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_traction_info_main(ppss, st_timestr, fin_timestr, "parquet_data/") + + mock_fetch.assert_called_with( + 1698883200, + 1699142400, + ["0x123"], + "mainnet", + FilterMode.CONTRACT_TS, + payout_only=False, + trueval_only=False, + ) + + # calculate ms locally so we can filter raw Predictions + st_ut = timestr_to_ut(st_timestr) + fin_ut = timestr_to_ut(fin_timestr) + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # Get all predictions into a dataframe + preds = [ + x + for x in _sample_daily_predictions + if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + preds = [pred.__dict__ for pred in preds] + preds_df = pl.DataFrame(preds) + preds_df = preds_df.with_columns( + [ + pl.col("timestamp").mul(1000).alias("timestamp"), + ] + ) + + # Assert calls and values + pl.DataFrame.equals(mock_traction_stat.call_args, preds_df) + mock_plot_cumsum.assert_called() + mock_plot_daily.assert_called() + + +@enforce_types +def test_get_traction_info_empty(tmpdir, capfd): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + mock_empty = Mock(return_value=[]) + + PATH = "pdr_backend.analytics.get_traction_info" + with patch(f"{PATH}.GQLDataFactory.get_gql_dfs", mock_empty): + st_timestr = "2023-11-02" + fin_timestr = "2023-11-05" + + get_traction_info_main(ppss, st_timestr, fin_timestr, "parquet_data/") + + assert ( + "No records found. Please adjust start and end times." in capfd.readouterr().out + ) + + with patch("requests.post") as mock_post: + mock_post.return_value.status_code = 503 + # don't actually sleep in tests + with patch("time.sleep"): + with pytest.raises(Exception): + get_traction_info_main(ppss, st_timestr, fin_timestr, "parquet_data/") + + assert mock_post.call_count == 3 diff --git a/pdr_backend/analytics/test/test_predictoor_stats.py b/pdr_backend/analytics/test/test_predictoor_stats.py new file mode 100644 index 000000000..c87a37978 --- /dev/null +++ b/pdr_backend/analytics/test/test_predictoor_stats.py @@ -0,0 +1,172 @@ +from typing import List +from unittest.mock import patch + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.analytics.predictoor_stats import ( + aggregate_prediction_statistics, + calculate_slot_daily_statistics, + get_cli_statistics, + get_endpoint_statistics, + get_slot_statistics, + get_traction_statistics, + plot_slot_daily_statistics, + plot_traction_cum_sum_statistics, + plot_traction_daily_statistics, +) + + +@enforce_types +def test_aggregate_prediction_statistics(_sample_first_predictions): + stats, correct_predictions = aggregate_prediction_statistics( + _sample_first_predictions + ) + assert isinstance(stats, dict) + assert "pair_timeframe" in stats + assert "predictor" in stats + assert correct_predictions == 1 # Adjust based on your sample data + + +@enforce_types +def test_get_endpoint_statistics(_sample_first_predictions): + accuracy, pair_timeframe_stats, predictoor_stats = get_endpoint_statistics( + _sample_first_predictions + ) + assert isinstance(accuracy, float) + assert isinstance(pair_timeframe_stats, List) # List[PairTimeframeStat] + assert isinstance(predictoor_stats, List) # List[PredictoorStat] + for pair_timeframe_stat in pair_timeframe_stats: + for key in [ + "pair", + "timeframe", + "accuracy", + "stake", + "payout", + "number_of_predictions", + ]: + assert key in pair_timeframe_stat + + for predictoor_stat in predictoor_stats: + for key in [ + "predictoor_address", + "accuracy", + "stake", + "payout", + "number_of_predictions", + "details", + ]: + assert key in predictoor_stat + assert len(predictoor_stat["details"]) == 2 + + +@enforce_types +def test_get_cli_statistics(capsys, _sample_first_predictions): + get_cli_statistics(_sample_first_predictions) + captured = capsys.readouterr() + output = captured.out + assert "Overall Accuracy" in output + assert "Accuracy for Pair" in output + assert "Accuracy for Predictoor Address" in output + + get_cli_statistics([]) + assert "No predictions found" in capsys.readouterr().out + + with patch( + "pdr_backend.analytics.predictoor_stats.aggregate_prediction_statistics" + ) as mock: + mock.return_value = ({}, 0) + get_cli_statistics(_sample_first_predictions) + + assert "No correct predictions found" in capsys.readouterr().out + + +@enforce_types +@patch("matplotlib.pyplot.savefig") +def test_get_traction_statistics( + mock_savefig, _sample_first_predictions, _sample_second_predictions +): + predictions = _sample_first_predictions + _sample_second_predictions + + # Get all predictions into a dataframe + preds_dicts = [pred.__dict__ for pred in predictions] + preds_df = pl.DataFrame(preds_dicts) + + stats_df = get_traction_statistics(preds_df) + assert isinstance(stats_df, pl.DataFrame) + assert stats_df.shape == (3, 3) + assert "datetime" in stats_df.columns + assert "daily_unique_predictoors_count" in stats_df.columns + assert stats_df["cum_daily_unique_predictoors_count"].to_list() == [2, 3, 4] + + pq_dir = "parquet_data/" + plot_traction_daily_statistics(stats_df, pq_dir) + plot_traction_cum_sum_statistics(stats_df, pq_dir) + + assert mock_savefig.call_count == 2 + + +@enforce_types +def test_get_slot_statistics(_sample_first_predictions, _sample_second_predictions): + predictions = _sample_first_predictions + _sample_second_predictions + + # Get all predictions into a dataframe + preds_dicts = [pred.__dict__ for pred in predictions] + preds_df = pl.DataFrame(preds_dicts) + + # calculate slot stats + slots_df = get_slot_statistics(preds_df) + assert isinstance(slots_df, pl.DataFrame) + assert slots_df.shape == (7, 9) + + for key in [ + "datetime", + "pair", + "timeframe", + "slot", + "pair_timeframe", + "n_predictoors", + "slot_stake", + "slot_payout", + ]: + assert key in slots_df.columns + + assert slots_df["slot_payout"].to_list() == [0.0, 0.05, 0.05, 0.0, 0.0, 0.0, 0.1] + assert slots_df["slot_stake"].to_list() == [0.05, 0.05, 0.05, 0.05, 0.05, 0.0, 0.1] + + +@enforce_types +@patch("matplotlib.pyplot.savefig") +def test_plot_slot_statistics( + mock_savefig, _sample_first_predictions, _sample_second_predictions +): + predictions = _sample_first_predictions + _sample_second_predictions + + # Get all predictions into a dataframe + preds_dicts = [pred.__dict__ for pred in predictions] + preds_df = pl.DataFrame(preds_dicts) + + # calculate slot stats + slots_df = get_slot_statistics(preds_df) + slot_daily_df = calculate_slot_daily_statistics(slots_df) + + for key in [ + "datetime", + "daily_average_stake", + "daily_average_payout", + "daily_average_predictoor_count", + ]: + assert key in slot_daily_df.columns + + assert slot_daily_df["daily_average_stake"].round(2).to_list() == [0.1, 0.1, 0.15] + assert slot_daily_df["daily_average_payout"].round(2).to_list() == [0.0, 0.05, 0.15] + assert slot_daily_df["daily_average_predictoor_count"].round(2).to_list() == [ + 1.0, + 1.0, + 1.0, + ] + + pq_dir = "parquet_data/" + plot_slot_daily_statistics(slots_df, pq_dir) + + assert mock_savefig.call_count == 2 diff --git a/pdr_backend/cli/arg_exchange.py b/pdr_backend/cli/arg_exchange.py new file mode 100644 index 000000000..fc0cb670b --- /dev/null +++ b/pdr_backend/cli/arg_exchange.py @@ -0,0 +1,43 @@ +from typing import List, Union + +import ccxt + + +class ArgExchange: + def __init__(self, exchange: str): + if not exchange: + raise ValueError(exchange) + + if not hasattr(ccxt, exchange): + raise ValueError(exchange) + + self.exchange = exchange + + @property + def exchange_class(self): + return getattr(ccxt, self.exchange) + + def __str__(self): + return self.exchange + + def __eq__(self, other): + return self.exchange == str(other) + + def __hash__(self): + return hash(self.exchange) + + +class ArgExchanges(List[ArgExchange]): + def __init__(self, exchanges: Union[List[str], List[ArgExchange]]): + if not isinstance(exchanges, list): + raise TypeError("exchanges must be a list") + + converted = [ArgExchange(str(exchange)) for exchange in exchanges if exchange] + + if not converted: + raise ValueError(exchanges) + + super().__init__(converted) + + def __str__(self): + return ",".join([str(exchange) for exchange in self]) diff --git a/pdr_backend/cli/arg_feed.py b/pdr_backend/cli/arg_feed.py new file mode 100644 index 000000000..d90a20c2d --- /dev/null +++ b/pdr_backend/cli/arg_feed.py @@ -0,0 +1,221 @@ +from collections import defaultdict +from typing import List, Optional, Union + +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_exchange import ArgExchange +from pdr_backend.cli.arg_pair import ArgPair, ArgPairs +from pdr_backend.cli.timeframe import ( + Timeframe, + Timeframes, + verify_timeframes_str, +) +from pdr_backend.util.signalstr import ( + signal_to_char, + signals_to_chars, + verify_signal_str, + unpack_signalchar_str, + verify_signalchar_str, +) + + +class ArgFeed: + def __init__( + self, + exchange, + signal: Union[str, None] = None, + pair: Union[ArgPair, str, None] = None, + timeframe: Optional[Union[Timeframe, str]] = None, + ): + if signal is not None: + verify_signal_str(signal) + + if pair is None: + raise ValueError("pair cannot be None") + + self.exchange = ArgExchange(exchange) if isinstance(exchange, str) else exchange + self.pair = ArgPair(pair) if isinstance(pair, str) else pair + self.signal = signal + + if timeframe is None: + self.timeframe = None + else: + self.timeframe = ( + Timeframe(timeframe) if isinstance(timeframe, str) else timeframe + ) + + def __str__(self): + feed_str = f"{self.exchange} {self.pair}" + + if self.signal is not None: + char = signal_to_char(self.signal) + feed_str += f" {char}" + + if self.timeframe is not None: + feed_str += f" {self.timeframe}" + + return feed_str + + def __eq__(self, other): + return ( + self.exchange == other.exchange + and self.signal == other.signal + and str(self.pair) == str(other.pair) + and str(self.timeframe) == str(other.timeframe) + ) + + def __hash__(self): + return hash((self.exchange, self.signal, str(self.pair))) + + @staticmethod + def from_str(feed_str: str, do_verify: bool = True) -> "ArgFeed": + """ + @description + Unpack the string for a *single* feed: 1 exchange, 1 signal, 1 pair + + Example: Given "binance ADA-USDT o" + Return Feed("binance", "open", "BTC/USDT") + + @argument + feed_str -- eg "binance ADA/USDT o"; not eg "binance oc ADA/USDT BTC/DAI" + do_verify - typically T. Only F to avoid recursion from verify functions + + @return + Feed + """ + feeds_str = feed_str + feeds = _unpack_feeds_str(feeds_str) + + if do_verify: + if len(feeds) != 1: + raise ValueError(feed_str) + feed = feeds[0] + return feed + + +@enforce_types +def _unpack_feeds_str(feeds_str: str) -> List[ArgFeed]: + """ + @description + Unpack a *single* feeds str. It can have >1 feeds of course. + + Example: Given "binance oc ADA/USDT BTC-USDT" + Return [ + ("binance", "open", "ADA/USDT"), + ("binance", "close", "ADA/USDT"), + ("binance", "open", "BTC/USDT"), + ("binance", "close", "BTC/USDT"), + ] + + @arguments + feeds_str - " " + do_verify - typically T. Only F to avoid recursion from verify functions + + @return + feed_tups - list of (exchange_str, signal_str, pair_str) + """ + feeds_str = feeds_str.strip() + feeds_str = " ".join(feeds_str.split()) # replace multiple whitespace w/ 1 + feeds_str_split = feeds_str.split(" ") + + exchange_str = feeds_str_split[0] + + timeframe_str = feeds_str_split[-1] + offset_end = None + + if verify_timeframes_str(timeframe_str): + timeframe_str_list = Timeframes.from_str(timeframe_str) + + # last part is a valid timeframe, and we might have a signal before it + signal_char_str = feeds_str_split[-2] + + if verify_signalchar_str(signal_char_str, True): + # last part is a valid timeframe and we have a valid signal before it + signal_str_list = unpack_signalchar_str(signal_char_str) + offset_end = -2 + else: + # last part is a valid timeframe, but there is no signal before it + signal_str_list = [None] + offset_end = -1 + else: + # last part is not a valid timeframe, but it might be a signal + timeframe_str_list = [None] + signal_char_str = feeds_str_split[-1] + + if verify_signalchar_str(signal_char_str, True): + # last part is a valid signal + signal_str_list = unpack_signalchar_str(signal_char_str) + offset_end = -1 + else: + # last part is not a valid timeframe, nor a signal + signal_str_list = [None] + offset_end = None + + pairs_list_str = " ".join(feeds_str_split[1:offset_end]) + + pairs = ArgPairs.from_str(pairs_list_str) + + feeds = [ + ArgFeed(exchange_str, signal_str, pair_str, timeframe_str) + for signal_str in signal_str_list + for pair_str in pairs + for timeframe_str in timeframe_str_list + ] + + return feeds + + +@enforce_types +def _pack_feeds_str(feeds: List[ArgFeed]) -> List[str]: + """ + Returns eg set([ + "binance BTC/USDT ohl 5m", + "binance ETH/USDT ohlv 5m", + "binance DOT/USDT c 5m", + "kraken BTC/USDT c", + ]) + """ + # merge signals via dict + grouped_signals = defaultdict(set) # [(exch,pair,timeframe)] : signals + for feed in feeds: + ept_tup = (str(feed.exchange), str(feed.pair), str(feed.timeframe)) + if ept_tup not in grouped_signals: + grouped_signals[ept_tup] = {str(feed.signal)} + else: + grouped_signals[ept_tup].add(str(feed.signal)) + + # convert new dict to list of 4-tups. Sort for consistency + epts_tups = [] + for (exch, pair, timeframe), signalset in grouped_signals.items(): + fr_signalset = frozenset(sorted(signalset)) + epts_tups.append((exch, pair, timeframe, fr_signalset)) + epts_tups = sorted(epts_tups) + + # then, merge pairs via dic + grouped_pairs = defaultdict(set) # [(exch,timeframe,signals)] : pairs + for exch, pair, timeframe, fr_signalset in epts_tups: + ets_tup = (exch, timeframe, fr_signalset) + if ets_tup not in grouped_pairs: + grouped_pairs[ets_tup] = {pair} + else: + grouped_pairs[ets_tup].add(pair) + + # convert new dict to list of 4-tups. Sort for consistency + etsp_tups = [] + for (exch, timeframe, fr_signalset), pairs in grouped_pairs.items(): + fr_pairs = frozenset(sorted(pairs)) + etsp_tups.append((exch, timeframe, fr_signalset, fr_pairs)) + etsp_tups = sorted(etsp_tups) + + # convert to list of str + strs = [] + for exch, timeframe, fr_signalset, fr_pairs in etsp_tups: + s = exch + s += " " + " ".join(sorted(fr_pairs)) + if fr_signalset != frozenset({"None"}): + s += " " + signals_to_chars(list(fr_signalset)) + if timeframe != "None": + s += " " + timeframe + strs.append(s) + + return strs diff --git a/pdr_backend/cli/arg_feeds.py b/pdr_backend/cli/arg_feeds.py new file mode 100644 index 000000000..26ea92796 --- /dev/null +++ b/pdr_backend/cli/arg_feeds.py @@ -0,0 +1,75 @@ +from typing import List, Set, Union + +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_exchange import ArgExchange +from pdr_backend.cli.arg_feed import ( + ArgFeed, + _unpack_feeds_str, + _pack_feeds_str, +) +from pdr_backend.cli.arg_pair import ArgPair +from pdr_backend.cli.timeframe import Timeframe + + +class ArgFeeds(List[ArgFeed]): + @enforce_types + def __init__(self, feeds: List[ArgFeed]): + super().__init__(feeds) + + @staticmethod + def from_str(feeds_str: str) -> "ArgFeeds": + return ArgFeeds(_unpack_feeds_str(feeds_str)) + + @staticmethod + def from_strs(feeds_strs: List[str], do_verify: bool = True) -> "ArgFeeds": + if do_verify: + if not feeds_strs: + raise ValueError(feeds_strs) + + feeds = [] + for feeds_str in feeds_strs: + feeds += _unpack_feeds_str(feeds_str) + + return ArgFeeds(feeds) + + @property + def pairs(self) -> Set[str]: + return set(str(feed.pair) for feed in self) + + @property + def exchanges(self) -> Set[str]: + return set(str(feed.exchange) for feed in self) + + @property + def signals(self) -> Set[str]: + return set(str(feed.signal) for feed in self) + + @enforce_types + def contains_combination( + self, + source: Union[str, ArgExchange], + pair: Union[str, ArgPair], + timeframe: Union[str, Timeframe], + ) -> bool: + for feed in self: + if ( + feed.exchange == source + and feed.pair == pair + and (not feed.timeframe or feed.timeframe == timeframe) + ): + return True + + return False + + @enforce_types + def __eq__(self, other) -> bool: + return sorted([str(f) for f in self]) == sorted([str(f) for f in other]) + + @enforce_types + def __str__(self) -> str: + return ", ".join(self.to_strs()) + + @enforce_types + def to_strs(self) -> List[str]: + return _pack_feeds_str(self[:]) diff --git a/pdr_backend/cli/arg_pair.py b/pdr_backend/cli/arg_pair.py new file mode 100644 index 000000000..f21865023 --- /dev/null +++ b/pdr_backend/cli/arg_pair.py @@ -0,0 +1,157 @@ +import re +from typing import List, Optional, Tuple, Union + +from enforce_typing import enforce_types + +from pdr_backend.util.constants import CAND_USDCOINS + +# convention: it allows "-" and "/" as input, and always outputs "/". + +# note: the only place that "/" doesn't work is filenames. +# So it converts to "-" just-in-time. That's outside this module. + + +# don't use @enforce_types, causes problems +class ArgPair: + def __init__( + self, + pair_str: Optional[Union[str, "ArgPair"]] = None, + base_str: Optional[str] = None, + quote_str: Optional[str] = None, + ): + if not pair_str and None in [base_str, quote_str]: + raise ValueError( + "Must provide either pair_str, or both base_str and quote_str" + ) + + if isinstance(pair_str, ArgPair): + pair_str = str(pair_str) + + if pair_str is None: + pair_str = f"{base_str}/{quote_str}" + else: + pair_str = pair_str.strip() + if not re.match("[A-Z]+[-/][A-Z]+", pair_str): + raise ValueError(pair_str) + + base_str, quote_str = _unpack_pair_str(pair_str) + + _verify_base_str(base_str) + _verify_quote_str(quote_str) + + self.pair_str = pair_str + self.base_str = base_str + self.quote_str = quote_str + + def __eq__(self, other): + return self.pair_str == str(other) + + def __str__(self): + return f"{self.base_str}/{self.quote_str}" + + def __hash__(self): + return hash(self.pair_str) + + +class ArgPairs(List[ArgPair]): + def __init__(self, pairs: Union[List[str], List[ArgPair]]): + if not isinstance(pairs, list): + raise TypeError(pairs) + + if not pairs: + raise ValueError(pairs) + + pairs = [ArgPair(pair) for pair in pairs if pair] + super().__init__(pairs) + + @staticmethod + def from_str(pairs_str: str) -> "ArgPairs": + return ArgPairs(_unpack_pairs_str(pairs_str)) + + def __eq__(self, other): + return set(self) == set(other) + + @enforce_types + def __str__(self) -> str: + """ + Example: Given ArgPairs ["BTC/USDT","ETH-DAI"] + Return "BTC/USDT,ETH/DAI" + """ + return ",".join([str(pair) for pair in self]) + + +@enforce_types +def _unpack_pairs_str(pairs_str: str) -> List[str]: + """ + @description + Unpack the string for *one or more* pairs, into list of pair_str + + Example: Given 'ADA-USDT, BTC/USDT, ETH/USDT' + Return ['ADA/USDT', 'BTC/USDT', 'ETH/USDT'] + + @argument + pairs_str - '/' or 'base-quote' + + @return + pair_str_list -- List[], where all "-" are "/" + """ + pairs_str = pairs_str.strip() + pairs_str = " ".join(pairs_str.split()) # replace multiple whitespace w/ 1 + pairs_str = pairs_str.replace(", ", ",").replace(" ,", ",") + pairs_str = pairs_str.replace(" ", ",") + pairs_str = pairs_str.replace("-", "/") # ETH/USDT -> ETH-USDT. Safer files. + pair_str_list = pairs_str.split(",") + + if not any(pair_str_list): + raise ValueError(pairs_str) + + return pair_str_list + + +def _unpack_pair_str(pair_str: str) -> Tuple[str, str]: + """ + @description + Unpack the string for a *single* pair, into base_str and quote_str. + + Example: Given 'BTC/USDT' or 'BTC-USDT' + Return ('BTC', 'USDT') + + @argument + pair_str - '/' or 'base-quote' + + @return + base_str -- e.g. 'BTC' + quote_str -- e.g. 'USDT' + """ + pair_str = pair_str.replace("/", "-") + base_str, quote_str = pair_str.split("-") + + return (base_str, quote_str) + + +@enforce_types +def _verify_base_str(base_str: str): + """ + @description + Raise an error if base_str is invalid + + @argument + base_str -- e.g. 'ADA' or ' ETH ' + """ + base_str = base_str.strip() + if not re.match("[A-Z]+$", base_str): + raise ValueError(base_str) + + +@enforce_types +def _verify_quote_str(quote_str: str): + """ + @description + Raise an error if quote_str is invalid + + @argument + quote_str -- e.g. 'USDT' or ' RAI ' + """ + quote_str = quote_str.strip() + if quote_str not in CAND_USDCOINS: + raise ValueError(quote_str) diff --git a/pdr_backend/cli/cli_arguments.py b/pdr_backend/cli/cli_arguments.py new file mode 100644 index 000000000..232dd12bb --- /dev/null +++ b/pdr_backend/cli/cli_arguments.py @@ -0,0 +1,285 @@ +import sys +from argparse import ArgumentParser as ArgParser +from argparse import Namespace + +from enforce_typing import enforce_types + +HELP_LONG = """Predictoor tool + +Usage: pdr sim|predictoor|trader|.. + +Main tools: + pdr sim PPSS_FILE + pdr predictoor APPROACH PPSS_FILE NETWORK + pdr trader APPROACH PPSS_FILE NETWORK + pdr lake PPSS_FILE NETWORK + pdr claim_OCEAN PPSS_FILE + pdr claim_ROSE PPSS_FILE + +Utilities: + pdr help + pdr -h + pdr get_predictoors_info ST END PQDIR PPSS_FILE NETWORK --PDRS + pdr get_predictions_info ST END PQDIR PPSS_FILE NETWORK --FEEDS + pdr get_traction_info ST END PQDIR PPSS_FILE NETWORK --FEEDS + pdr check_network PPSS_FILE NETWORK --LOOKBACK_HOURS + +Transactions are signed with envvar 'PRIVATE_KEY`. + +Tools for core team: + pdr trueval PPSS_FILE NETWORK + pdr dfbuyer PPSS_FILE NETWORK + pdr publisher PPSS_FILE NETWORK + pdr topup PPSS_FILE NETWORK + pytest, black, mypy, pylint, .. +""" + + +# ======================================================================== +# mixins +@enforce_types +class APPROACH_Mixin: + def add_argument_APPROACH(self): + self.add_argument("APPROACH", type=int, help="1|2|..") + + +@enforce_types +class ST_Mixin: + def add_argument_ST(self): + self.add_argument("ST", type=str, help="Start date yyyy-mm-dd") + + +@enforce_types +class END_Mixin: + def add_argument_END(self): + self.add_argument("END", type=str, help="End date yyyy-mm-dd") + + +@enforce_types +class PQDIR_Mixin: + def add_argument_PQDIR(self): + self.add_argument("PQDIR", type=str, help="Parquet output dir") + + +@enforce_types +class PPSS_Mixin: + def add_argument_PPSS(self): + self.add_argument("PPSS_FILE", type=str, help="PPSS yaml settings file") + + +@enforce_types +class NETWORK_Mixin: + def add_argument_NETWORK(self): + self.add_argument( + "NETWORK", + type=str, + help="sapphire-testnet|sapphire-mainnet|development|barge-pytest|..", + ) + + +@enforce_types +class PDRS_Mixin: + def add_argument_PDRS(self): + self.add_argument( + "--PDRS", + type=str, + help="Predictoor address(es), separated by comma. If not specified, uses all.", + required=False, + ) + + +@enforce_types +class FEEDS_Mixin: + def add_argument_FEEDS(self): + self.add_argument( + "--FEEDS", + type=str, + default="", + help="Predictoor feed address(es). If not specified, uses all.", + required=False, + ) + + +@enforce_types +class LOOKBACK_Mixin: + def add_argument_LOOKBACK(self): + self.add_argument( + "--LOOKBACK_HOURS", + default=24, + type=int, + help="# hours to check back on", + required=False, + ) + + +# ======================================================================== +# argparser base classes +class CustomArgParser(ArgParser): + def add_arguments_bulk(self, command_name, arguments): + self.add_argument("command", choices=[command_name]) + + for arg in arguments: + func = getattr(self, f"add_argument_{arg}") + func() + + +@enforce_types +class _ArgParser_PPSS(CustomArgParser, PPSS_Mixin): + @enforce_types + def __init__(self, description: str, command_name: str): + super().__init__(description=description) + self.add_arguments_bulk(command_name, ["PPSS"]) + + +@enforce_types +class _ArgParser_PPSS_NETWORK(CustomArgParser, PPSS_Mixin, NETWORK_Mixin): + @enforce_types + def __init__(self, description: str, command_name: str): + super().__init__(description=description) + self.add_arguments_bulk(command_name, ["PPSS", "NETWORK"]) + + +@enforce_types +class _ArgParser_APPROACH_PPSS_NETWORK( + CustomArgParser, + APPROACH_Mixin, + PPSS_Mixin, + NETWORK_Mixin, +): + def __init__(self, description: str, command_name: str): + super().__init__(description=description) + self.add_arguments_bulk(command_name, ["APPROACH", "PPSS", "NETWORK"]) + + +@enforce_types +class _ArgParser_PPSS_NETWORK_LOOKBACK( + CustomArgParser, + PPSS_Mixin, + NETWORK_Mixin, + LOOKBACK_Mixin, +): + @enforce_types + def __init__(self, description: str, command_name: str): + super().__init__(description=description) + self.add_arguments_bulk(command_name, ["PPSS", "NETWORK", "LOOKBACK"]) + + +@enforce_types +class _ArgParser_ST_END_PQDIR_NETWORK_PPSS_PDRS( + CustomArgParser, + ST_Mixin, + END_Mixin, + PQDIR_Mixin, + PPSS_Mixin, + NETWORK_Mixin, + PDRS_Mixin, +): # pylint: disable=too-many-ancestors + @enforce_types + def __init__(self, description: str, command_name: str): + super().__init__(description=description) + self.add_arguments_bulk( + command_name, ["ST", "END", "PQDIR", "PPSS", "NETWORK", "PDRS"] + ) + + +@enforce_types +class _ArgParser_ST_END_PQDIR_NETWORK_PPSS_FEEDS( + CustomArgParser, + ST_Mixin, + END_Mixin, + PQDIR_Mixin, + PPSS_Mixin, + NETWORK_Mixin, + FEEDS_Mixin, +): # pylint: disable=too-many-ancestors + @enforce_types + def __init__(self, description: str, command_name: str): + super().__init__(description=description) + self.add_arguments_bulk( + command_name, ["ST", "END", "PQDIR", "PPSS", "NETWORK", "FEEDS"] + ) + + +# ======================================================================== +# actual arg-parser implementations are just aliases to argparser base classes +# In order of help text. + + +@enforce_types +def do_help_long(status_code=0): + print(HELP_LONG) + sys.exit(status_code) + + +@enforce_types +def print_args(arguments: Namespace): + arguments_dict = arguments.__dict__ + command = arguments_dict.pop("command", None) + + print(f"pdr {command}: Begin") + print("Arguments:") + + for arg_k, arg_v in arguments_dict.items(): + print(f"{arg_k}={arg_v}") + + +SimArgParser = _ArgParser_PPSS + +PredictoorArgParser = _ArgParser_APPROACH_PPSS_NETWORK + +TraderArgParser = _ArgParser_APPROACH_PPSS_NETWORK + +LakeArgParser = _ArgParser_PPSS_NETWORK + +ClaimOceanArgParser = _ArgParser_PPSS + +ClaimRoseArgParser = _ArgParser_PPSS + +GetPredictoorsInfoArgParser = _ArgParser_ST_END_PQDIR_NETWORK_PPSS_PDRS + +GetPredictionsInfoArgParser = _ArgParser_ST_END_PQDIR_NETWORK_PPSS_FEEDS + +GetTractionInfoArgParser = _ArgParser_ST_END_PQDIR_NETWORK_PPSS_FEEDS + +CheckNetworkArgParser = _ArgParser_PPSS_NETWORK_LOOKBACK + +TruevalArgParser = _ArgParser_PPSS_NETWORK + +DfbuyerArgParser = _ArgParser_PPSS_NETWORK + +PublisherArgParser = _ArgParser_PPSS_NETWORK + +TopupArgParser = _ArgParser_PPSS_NETWORK + +defined_parsers = { + "do_sim": SimArgParser("Run simulation", "sim"), + "do_predictoor": PredictoorArgParser("Run a predictoor bot", "predictoor"), + "do_trader": TraderArgParser("Run a trader bot", "trader"), + "do_lake": LakeArgParser("Run the lake tool", "lake"), + "do_claim_OCEAN": ClaimOceanArgParser("Claim OCEAN", "claim_OCEAN"), + "do_claim_ROSE": ClaimRoseArgParser("Claim ROSE", "claim_ROSE"), + "do_get_predictoors_info": GetPredictoorsInfoArgParser( + "For specified predictoors, report {accuracy, ..} of each predictoor", + "get_predictoors_info", + ), + "do_get_predictions_info": GetPredictionsInfoArgParser( + "For specified feeds, report {accuracy, ..} of each predictoor", + "get_predictions_info", + ), + "do_get_traction_info": GetTractionInfoArgParser( + "Get traction info: # predictoors vs time, etc", + "get_traction_info", + ), + "do_check_network": CheckNetworkArgParser("Check network", "check_network"), + "do_trueval": TruevalArgParser("Run trueval bot", "trueval"), + "do_dfbuyer": DfbuyerArgParser("Run dfbuyer bot", "dfbuyer"), + "do_publisher": PublisherArgParser("Publish feeds", "publisher"), + "do_topup": TopupArgParser("Topup OCEAN and ROSE in dfbuyer, trueval, ..", "topup"), +} + + +def get_arg_parser(func_name): + if func_name not in defined_parsers: + raise ValueError(f"Unknown function name: {func_name}") + + return defined_parsers[func_name] diff --git a/pdr_backend/cli/cli_module.py b/pdr_backend/cli/cli_module.py new file mode 100644 index 000000000..b9bedfcc2 --- /dev/null +++ b/pdr_backend/cli/cli_module.py @@ -0,0 +1,165 @@ +import sys + +from enforce_typing import enforce_types + +from pdr_backend.analytics.check_network import check_network_main +from pdr_backend.analytics.get_predictions_info import get_predictions_info_main +from pdr_backend.analytics.get_predictoors_info import get_predictoors_info_main +from pdr_backend.analytics.get_traction_info import get_traction_info_main +from pdr_backend.cli.cli_arguments import ( + do_help_long, + get_arg_parser, + print_args, +) +from pdr_backend.dfbuyer.dfbuyer_agent import DFBuyerAgent +from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory +from pdr_backend.payout.payout import do_ocean_payout, do_rose_payout +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.predictoor.approach1.predictoor_agent1 import PredictoorAgent1 +from pdr_backend.predictoor.approach3.predictoor_agent3 import PredictoorAgent3 +from pdr_backend.publisher.publish_assets import publish_assets +from pdr_backend.sim.sim_engine import SimEngine +from pdr_backend.trader.approach1.trader_agent1 import TraderAgent1 +from pdr_backend.trader.approach2.trader_agent2 import TraderAgent2 +from pdr_backend.trueval.trueval_agent import TruevalAgent +from pdr_backend.util.contract import get_address +from pdr_backend.util.fund_accounts import fund_accounts_with_OCEAN +from pdr_backend.util.topup import topup_main + + +@enforce_types +def _do_main(): + if len(sys.argv) <= 1 or sys.argv[1] == "help": + do_help_long(0) + + func_name = f"do_{sys.argv[1]}" + func = globals().get(func_name) + if func is None: + do_help_long(1) + + parser = get_arg_parser(func_name) + args = parser.parse_args() + print_args(args) + + func(args) + + +# ======================================================================== +# actual cli implementations. In order of help text. + + +@enforce_types +def do_sim(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network="development") + sim_engine = SimEngine(ppss) + sim_engine.run() + + +@enforce_types +def do_predictoor(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + + approach = args.APPROACH + if approach == 1: + agent = PredictoorAgent1(ppss) + + elif approach == 3: + agent = PredictoorAgent3(ppss) + + else: + raise ValueError(f"Unknown predictoor approach {approach}") + + agent.run() + + +@enforce_types +def do_trader(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + approach = args.APPROACH + + if approach == 1: + agent = TraderAgent1(ppss) + elif approach == 2: + agent = TraderAgent2(ppss) + else: + raise ValueError(f"Unknown trader approach {approach}") + + agent.run() + + +@enforce_types +def do_lake(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + ohlcv_data_factory = OhlcvDataFactory(ppss.lake_ss) + df = ohlcv_data_factory.get_mergedohlcv_df() + print(df) + + +# do_help() is implemented in cli_arguments and imported, so nothing needed here + + +@enforce_types +def do_claim_OCEAN(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network="sapphire-mainnet") + do_ocean_payout(ppss) + + +@enforce_types +def do_claim_ROSE(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network="sapphire-mainnet") + do_rose_payout(ppss) + + +@enforce_types +def do_get_predictoors_info(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + get_predictoors_info_main(ppss, args.PDRS, args.ST, args.END, args.PQDIR) + + +@enforce_types +def do_get_predictions_info(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + get_predictions_info_main(ppss, args.FEEDS, args.ST, args.END, args.PQDIR) + + +@enforce_types +def do_get_traction_info(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + get_traction_info_main(ppss, args.ST, args.END, args.PQDIR) + + +@enforce_types +def do_check_network(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + check_network_main(ppss, args.LOOKBACK_HOURS) + + +@enforce_types +def do_trueval(args, testing=False): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + predictoor_batcher_addr = get_address(ppss.web3_pp, "PredictoorHelper") + agent = TruevalAgent(ppss, predictoor_batcher_addr) + + agent.run(testing) + + +# @enforce_types +def do_dfbuyer(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + agent = DFBuyerAgent(ppss) + agent.run() + + +@enforce_types +def do_publisher(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + + if ppss.web3_pp.network == "development": + fund_accounts_with_OCEAN(ppss.web3_pp) + publish_assets(ppss.web3_pp, ppss.publisher_ss) + + +@enforce_types +def do_topup(args): + ppss = PPSS(yaml_filename=args.PPSS_FILE, network=args.NETWORK) + topup_main(ppss) diff --git a/pdr_backend/cli/test/test_arg_exchange.py b/pdr_backend/cli/test/test_arg_exchange.py new file mode 100644 index 000000000..34005ecb7 --- /dev/null +++ b/pdr_backend/cli/test/test_arg_exchange.py @@ -0,0 +1,52 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_exchange import ArgExchange, ArgExchanges + + +@enforce_types +def test_pack_exchange_str_list(): + assert str(ArgExchanges(["binance"])) == "binance" + assert str(ArgExchanges(["binance", "kraken"])) == "binance,kraken" + + with pytest.raises(TypeError): + ArgExchanges("") + + with pytest.raises(TypeError): + ArgExchanges(None) + + with pytest.raises(ValueError): + ArgExchange(None) + + with pytest.raises(TypeError): + ArgExchanges("") + + with pytest.raises(ValueError): + ArgExchanges([]) + + with pytest.raises(ValueError): + ArgExchanges(["adfs"]) + + with pytest.raises(ValueError): + ArgExchanges(["binance fgds"]) + + +@enforce_types +def test_verify_exchange_str(): + # ok + strs = [ + "binance", + "kraken", + ] + for exchange_str in strs: + ArgExchanges([exchange_str]) + + # not ok + strs = [ + "", + " ", + "xyz", + ] + for exchange_str in strs: + with pytest.raises(ValueError): + ArgExchanges([exchange_str]) diff --git a/pdr_backend/cli/test/test_arg_feed.py b/pdr_backend/cli/test/test_arg_feed.py new file mode 100644 index 000000000..91a8ec89d --- /dev/null +++ b/pdr_backend/cli/test/test_arg_feed.py @@ -0,0 +1,64 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed + + +@enforce_types +def test_ArgFeed_main_constructor(): + # ok + tups = [ + ("binance", "open", "BTC/USDT"), + ("kraken", "close", "BTC/DAI"), + ("kraken", "close", "BTC-DAI"), + ] + for feed_tup in tups: + ArgFeed(*feed_tup) + + # not ok - Value Error + tups = [ + ("binance", "open", ""), + ("xyz", "open", "BTC/USDT"), + ("xyz", "open", "BTC-USDT"), + ("binance", "xyz", "BTC/USDT"), + ("binance", "open", "BTC/XYZ"), + ("binance", "open"), + ] + for feed_tup in tups: + with pytest.raises(ValueError): + ArgFeed(*feed_tup) + + # not ok - Type Error + tups = [ + (), + ("binance", "open", "BTC/USDT", "", ""), + ] + for feed_tup in tups: + with pytest.raises(TypeError): + ArgFeed(*feed_tup) + + +@enforce_types +def test_ArgFeed_from_str(): + target_feed = ArgFeed("binance", "close", "BTC/USDT") + assert ArgFeed.from_str("binance BTC/USDT c") == target_feed + assert ArgFeed.from_str("binance BTC-USDT c") == target_feed + + target_feed = ArgFeed("binance", "close", "BTC/USDT", "1h") + assert ArgFeed.from_str("binance BTC/USDT c 1h") == target_feed + assert ArgFeed.from_str("binance BTC-USDT c 1h") == target_feed + + +@enforce_types +def test_ArgFeed_str(): + target_feed_str = "binance BTC/USDT o" + assert str(ArgFeed("binance", "open", "BTC/USDT")) == target_feed_str + assert str(ArgFeed("binance", "open", "BTC-USDT")) == target_feed_str + + target_feed_str = "binance BTC/USDT o 5m" + assert str(ArgFeed("binance", "open", "BTC/USDT", "5m")) == target_feed_str + assert str(ArgFeed("binance", "open", "BTC-USDT", "5m")) == target_feed_str + + target_feed_str = "binance BTC/USDT 5m" + assert str(ArgFeed("binance", None, "BTC/USDT", "5m")) == target_feed_str + assert str(ArgFeed("binance", None, "BTC-USDT", "5m")) == target_feed_str diff --git a/pdr_backend/cli/test/test_arg_feeds.py b/pdr_backend/cli/test/test_arg_feeds.py new file mode 100644 index 000000000..085099c14 --- /dev/null +++ b/pdr_backend/cli/test/test_arg_feeds.py @@ -0,0 +1,347 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.arg_feeds import ArgFeeds + + +@enforce_types +def test_ArgFeeds_from_str(): + # 1 feed + target_feeds = [ArgFeed("binance", "open", "ADA/USDT")] + assert ArgFeeds.from_str("binance ADA/USDT o") == target_feeds + assert ArgFeeds.from_str("binance ADA-USDT o") == target_feeds + + # >1 signal, so >1 feed + target_feeds = [ + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("binance", "close", "ADA/USDT"), + ] + assert ArgFeeds.from_str("binance ADA/USDT oc") == target_feeds + assert ArgFeeds.from_str("binance ADA-USDT oc") == target_feeds + + # >1 pair, so >1 feed + target_feeds = [ + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("binance", "open", "ETH/RAI"), + ] + assert ArgFeeds.from_str("binance ADA/USDT ETH/RAI o") == target_feeds + assert ArgFeeds.from_str("binance ADA-USDT ETH/RAI o") == target_feeds + assert ArgFeeds.from_str("binance ADA-USDT ETH-RAI o") == target_feeds + + # >1 signal and >1 pair, so >1 feed + target = ArgFeeds( + [ + ArgFeed("binance", "close", "ADA/USDT"), + ArgFeed("binance", "close", "BTC/USDT"), + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("binance", "open", "BTC/USDT"), + ] + ) + assert ArgFeeds.from_str("binance ADA/USDT,BTC/USDT oc") == target + assert ArgFeeds.from_str("binance ADA-USDT,BTC/USDT oc") == target + assert ArgFeeds.from_str("binance ADA-USDT,BTC-USDT oc") == target + + # >1 signal and >1 pair, so >1 feed + target = ArgFeeds( + [ + ArgFeed("binance", "close", "ADA/USDT", "1h"), + ArgFeed("binance", "close", "ADA/USDT", "5m"), + ArgFeed("binance", "close", "BTC/USDT", "1h"), + ArgFeed("binance", "close", "BTC/USDT", "5m"), + ArgFeed("binance", "open", "ADA/USDT", "1h"), + ArgFeed("binance", "open", "ADA/USDT", "5m"), + ArgFeed("binance", "open", "BTC/USDT", "1h"), + ArgFeed("binance", "open", "BTC/USDT", "5m"), + ] + ) + assert ArgFeeds.from_str("binance ADA/USDT,BTC/USDT oc 1h,5m") == target + + # unhappy paths. Verify section has way more, this is just for baseline + strs = [ + "xyz ADA/USDT o", + "binance ADA/USDT ox", + "binance ADA/X o", + ] + for feeds_str in strs: + with pytest.raises(ValueError): + ArgFeeds.from_str(feeds_str) + + targ_prs = set(["ADA/USDT", "BTC/USDT"]) + assert ArgFeeds.from_str("binance ADA/USDT BTC/USDT o").pairs == targ_prs + assert ArgFeeds.from_str("binance ADA-USDT BTC/USDT o").pairs == targ_prs + assert ArgFeeds.from_str("binance ADA-USDT BTC-USDT o").pairs == targ_prs + + targ_prs = set(["ADA/USDT", "BTC/USDT"]) + assert ArgFeeds.from_str("binance ADA/USDT,BTC/USDT oc").pairs == targ_prs + assert ArgFeeds.from_str("binance ADA-USDT,BTC/USDT oc").pairs == targ_prs + assert ArgFeeds.from_str("binance ADA-USDT,BTC-USDT oc").pairs == targ_prs + + targ_prs = set(["ADA/USDT", "BTC/USDT", "ETH/USDC", "DOT/DAI"]) + assert ( + ArgFeeds.from_str("binance ADA/USDT BTC/USDT ,ETH/USDC, DOT/DAI oc").pairs + == targ_prs + ) + + +@enforce_types +def test_ArgFeeds_from_strs_main(): + # 1 str w 1 feed, 1 feed total + target_feeds = [ArgFeed("binance", "open", "ADA/USDT")] + assert ArgFeeds.from_strs(["binance ADA/USDT o"]) == target_feeds + assert ArgFeeds.from_strs(["binance ADA-USDT o"]) == target_feeds + + target_feeds = [ArgFeed("binance", "open", "ADA/USDT", "1h")] + assert ArgFeeds.from_strs(["binance ADA-USDT o 1h"]) == target_feeds + + # 1 str w 2 feeds, 2 feeds total + target_feeds = ArgFeeds( + [ + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("binance", "high", "ADA/USDT"), + ] + ) + assert ArgFeeds.from_strs(["binance ADA/USDT oh"]) == target_feeds + assert ArgFeeds.from_strs(["binance ADA-USDT oh"]) == target_feeds + assert target_feeds.signals == set(["open", "high"]) + assert target_feeds.exchanges == set(["binance"]) + + # 2 strs each w 1 feed, 2 feeds total + target_feeds = [ + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("kraken", "high", "ADA/RAI"), + ] + feeds = ArgFeeds.from_strs( + [ + "binance ADA-USDT o", + "kraken ADA/RAI h", + ] + ) + assert feeds == target_feeds + + # 2 strs each w 1 feed, 2 feeds total, with timeframes and without signals + target_feeds = [ + ArgFeed("binance", None, "ADA/USDT", "5m"), + ArgFeed("kraken", None, "ADA/RAI", "1h"), + ] + feeds = ArgFeeds.from_strs( + [ + "binance ADA-USDT 5m", + "kraken ADA/RAI 1h", + ] + ) + assert feeds == target_feeds + + # 2 strs each w 1 feed, with timeframes 3 feeds total + target_feeds = [ + ArgFeed("binance", "open", "ADA/USDT", "5m"), + ArgFeed("binance", "open", "ADA/USDT", "1h"), + ArgFeed("kraken", "high", "ADA/RAI", "1h"), + ] + feeds = ArgFeeds.from_strs( + [ + "binance ADA-USDT o 5m,1h", + "kraken ADA/RAI h 1h", + ] + ) + assert feeds == target_feeds + + # first str has 4 feeds and second has 1 feed; 5 feeds total + target_feeds = ArgFeeds( + [ + ArgFeed("binance", "close", "ADA/USDT"), + ArgFeed("binance", "close", "BTC/USDT"), + ArgFeed("binance", "open", "ADA/USDT"), + ArgFeed("binance", "open", "BTC/USDT"), + ArgFeed("kraken", "high", "ADA/RAI"), + ] + ) + feeds = ArgFeeds.from_strs( + [ + "binance ADA-USDT BTC/USDT oc", + "kraken ADA-RAI h", + ] + ) + assert feeds == target_feeds + + # unhappy paths. Note: verify section has way more + lists = [ + [], + ["xyz ADA/USDT o"], + ["binance ADA/USDT ox"], + ["binance ADA/X o"], + ["binance ADA/X o 1h"], + ["binance ADA/X o 1h 1d"], + ["binance ADA/X o 10h"], + ] + for feeds_strs in lists: + with pytest.raises(ValueError): + ArgFeeds.from_strs(feeds_strs) + + +@enforce_types +def test_ArgFeeds_from_strs__many_inputs(): + # ok for verify_feeds_strs + lists = [ + ["binance ADA/USDT o"], + ["binance ADA-USDT o"], + ["binance ADA/USDT BTC/USDT oc", "kraken ADA/RAI h"], + ["binance ADA/USDT BTC-USDT oc", "kraken ADA/RAI h"], + [ + "binance ADA/USDT BTC-USDT oc 1h,5m", + "kraken ADA/RAI h 1h", + "binance BTC/USDT o", + ], + ] + for feeds_strs in lists: + ArgFeeds.from_strs(feeds_strs) + + # not ok for verify_feeds_strs + lists = [ + [], + [""], + ["kraken ADA/RAI xh"], + ["binance ADA/USDT BTC/USDT xoc", "kraken ADA/RAI h"], + ["binance ADA/USDT BTC/USDT xoc 1h", "kraken ADA/RAI 5m"], + ["", "kraken ADA/RAI h"], + ["", "kraken ADA/RAI h 5m"], + ] + for feeds_strs in lists: + with pytest.raises(ValueError): + ArgFeeds.from_strs(feeds_strs) + + +@enforce_types +def test_ArgFeeds_and_ArgFeed_from_str_many_inputs(): + # ok for verify_feeds_str, ok for verify_feed_str + # (well-formed 1 signal and 1 pair) + strs = [ + "binance ADA/USDT o", + "binance ADA-USDT o", + " binance ADA/USDT o", + "binance ADA/USDT o", + " binance ADA/USDT o", + " binance ADA/USDT o ", + "binance ADA/USDT", + " binance ADA/USDT ", + " binance ADA/USDT ", + " binance ADA/USDT ", + ] + for feed_str in strs: + ArgFeeds.from_str(feed_str) + for feeds_str in strs: + ArgFeeds.from_str(feeds_str) + + # not ok for verify_feed_str, ok for verify_feeds_str + # (well-formed >1 signal or >1 pair) + strs = [ + "binance ADA/USDT oh", + "binance ADA-USDT oh", + " binance ADA/USDT oh", + "binance ADA/USDT BTC/USDT oh", + " binance ADA/USDT BTC/USDT o", + "binance ADA/USDT, BTC/USDT ,ETH/USDC , DOT/DAI o", + " binance ADA/USDT, BTC/USDT ,ETH/USDC , DOT/DAI o", + " binance ADA/USDT, BTC-USDT ,ETH/USDC , DOT/DAI o", + ] + for feed_str in strs: + with pytest.raises(ValueError): + ArgFeed.from_str(feed_str) + + # not ok for verify_feed_str, not ok for verify_feeds_str + # (poorly formed) + strs = [ + "", + " ", + ",", + " , ", + " , ,", + " xyz ", + " xyz abc ", + "binance o", + "binance o ", + "binance o ,", + "o ADA/USDT", + "binance,ADA/USDT", + "binance,ADA-USDT", + "xyz ADA/USDT o", # catch non-exchanges! + "binancexyz ADA/USDT o", + "binance ADA/USDT ohx", + "binance ADA/USDT z", + "binance , ADA/USDT, o,", + "binance , ADA-USDT, o, ", + "binance,ADA/USDT,o", + "binance XYZ o", + "binance USDT o", + "binance ADA/ o", + "binance ADA- o", + "binance /USDT o", + "binance ADA:USDT o", + "binance ADA::USDT o", + "binance ADA,USDT o", + "binance ADA&USDT o", + "binance ADA/USDT XYZ o", + ] + + for feed_str in strs: + with pytest.raises(ValueError): + ArgFeed.from_str(feed_str) + + for feeds_str in strs: + with pytest.raises(ValueError): + ArgFeed.from_str(feeds_str) + + +@enforce_types +def test_ArgFeeds_contains_combination_1(): + # feeds have no timeframe so contains all timeframes + feeds = ArgFeeds( + [ArgFeed("binance", "close", "BTC/USDT"), ArgFeed("kraken", "close", "BTC/DAI")] + ) + + assert feeds.contains_combination("binance", "BTC/USDT", "1h") + assert feeds.contains_combination("kraken", "BTC/DAI", "5m") + assert not feeds.contains_combination("kraken", "BTC/USDT", "1h") + + # binance feed has a timeframe so contains just those timeframes + feeds = ArgFeeds( + [ + ArgFeed("binance", "close", "BTC/USDT", "5m"), + ArgFeed("kraken", "close", "BTC/DAI"), + ] + ) + + assert not feeds.contains_combination("binance", "BTC/USDT", "1h") + assert feeds.contains_combination("binance", "BTC/USDT", "5m") + assert feeds.contains_combination("kraken", "BTC/DAI", "5m") + assert feeds.contains_combination("kraken", "BTC/DAI", "1h") + + +@enforce_types +def test_ArgFeeds_str(): + feeds = ArgFeeds.from_strs(["binance BTC/USDT oh 5m"]) + assert str(feeds) == "binance BTC/USDT oh 5m" + + feeds = ArgFeeds.from_strs(["binance BTC/USDT oh 5m", "kraken BTC/USDT c"]) + assert str(feeds) == "binance BTC/USDT oh 5m, kraken BTC/USDT c" + + +@enforce_types +def test_ArgFeeds_to_strs(): + for feeds_strs in [ + ["binance BTC/USDT o"], + ["binance BTC/USDT oh"], + ["binance BTC/USDT 5m"], + ["binance BTC/USDT o 5m"], + ["binance BTC/USDT oh 5m"], + ["binance BTC/USDT ETH/USDT oh 5m"], + ["binance BTC/USDT ohl 5m", "binance ETH/USDT ohlv 5m"], + [ + "binance BTC/USDT ohl 5m", + "binance ETH/USDT ohlv 5m", + "binance DOT/USDT c 5m", + "kraken BTC/USDT c", + ], + ]: + feeds = ArgFeeds.from_strs(feeds_strs) + assert sorted(feeds.to_strs()) == sorted(feeds_strs) diff --git a/pdr_backend/util/test_noganache/test_pairstr.py b/pdr_backend/cli/test/test_arg_pair.py similarity index 54% rename from pdr_backend/util/test_noganache/test_pairstr.py rename to pdr_backend/cli/test/test_arg_pair.py index 7d0646d79..3a9b6cab8 100644 --- a/pdr_backend/util/test_noganache/test_pairstr.py +++ b/pdr_backend/cli/test/test_arg_pair.py @@ -1,40 +1,86 @@ -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types -from pdr_backend.util.pairstr import ( - unpack_pairs_str, - unpack_pair_str, - verify_pairs_str, - verify_pair_str, - verify_base_str, - verify_quote_str, +from pdr_backend.cli.arg_pair import ( + ArgPair, + ArgPairs, + _unpack_pairs_str, + _verify_base_str, + _verify_quote_str, ) -# ========================================================================== -# unpack..() functions + +@enforce_types +def test_arg_pair_main(): + # basic tests + p1 = ArgPair("BTC/USDT") + p2 = ArgPair(base_str="BTC", quote_str="USDT") + assert p1.pair_str == p2.pair_str == "BTC/USDT" + assert p1.base_str == p2.base_str == "BTC" + assert p1.quote_str == p2.quote_str == "USDT" + + # test __eq__ + assert p1 == p2 + assert p1 == "BTC/USDT" + + assert p1 != ArgPair("ETH/USDT") + assert p1 != "ETH/USDT" + assert p1 != 3 + + # test __str__ + assert str(p1) == "BTC/USDT" + assert str(p2) == "BTC/USDT" @enforce_types def test_unpack_pair_str(): - assert unpack_pair_str("BTC/USDT") == ("BTC", "USDT") - assert unpack_pair_str("BTC-USDT") == ("BTC", "USDT") + assert ArgPair("BTC/USDT").base_str == "BTC" + assert ArgPair("BTC/USDT").quote_str == "USDT" + assert ArgPair("BTC-USDT").base_str == "BTC" + assert ArgPair("BTC-USDT").quote_str == "USDT" @enforce_types def test_unpack_pairs_str(): - assert unpack_pairs_str("ADA-USDT BTC/USDT") == ["ADA-USDT", "BTC-USDT"] - assert unpack_pairs_str("ADA/USDT,BTC/USDT") == ["ADA-USDT", "BTC-USDT"] - assert unpack_pairs_str("ADA/USDT, BTC/USDT") == ["ADA-USDT", "BTC-USDT"] - assert unpack_pairs_str("ADA/USDT BTC/USDT,ETH-USDC, DOT/DAI") == [ - "ADA-USDT", - "BTC-USDT", - "ETH-USDC", - "DOT-DAI", + with pytest.raises(ValueError): + _unpack_pairs_str("") + + assert ArgPairs.from_str("ADA-USDT BTC/USDT") == ["ADA/USDT", "BTC/USDT"] + assert ArgPairs.from_str("ADA/USDT,BTC/USDT") == ["ADA/USDT", "BTC/USDT"] + assert ArgPairs.from_str("ADA/USDT, BTC/USDT") == ["ADA/USDT", "BTC/USDT"] + assert ArgPairs.from_str("ADA/USDT BTC/USDT,ETH-USDC, DOT/DAI") == [ + "ADA/USDT", + "BTC/USDT", + "ETH/USDC", + "DOT/DAI", ] -# ========================================================================== -# verify..() functions +@enforce_types +def test_pack_pair_str_list(): + assert str(ArgPairs(["ADA/USDT"])) == "ADA/USDT" + assert str(ArgPairs(["ADA-USDT"])) == "ADA/USDT" + assert str(ArgPairs(["ADA/USDT", "BTC/USDT"])) == "ADA/USDT,BTC/USDT" + assert str(ArgPairs(["ADA/USDT", "BTC-USDT"])) == "ADA/USDT,BTC/USDT" + assert str(ArgPairs(["ADA-USDT", "BTC-USDT"])) == "ADA/USDT,BTC/USDT" + + with pytest.raises(TypeError): + ArgPairs("") + + with pytest.raises(ValueError): + ArgPairs([]) + + with pytest.raises(TypeError): + ArgPairs(None) + + with pytest.raises(ValueError): + ArgPairs(["adfs"]) + + with pytest.raises(ValueError): + ArgPairs(["ADA-USDT fgds"]) + + pair_from_base_and_quote = ArgPair(base_str="BTC", quote_str="USDT") + assert str(ArgPair(pair_from_base_and_quote)) == "BTC/USDT" @enforce_types @@ -51,10 +97,12 @@ def test_verify_pairs_str__and__verify_pair_str(): "BTC/USDT ", " BTC/USDT ", ] - for pairs_str in strs: - verify_pairs_str(pairs_str) + for pair_str in strs: - verify_pair_str(pair_str) + ArgPair(pair_str) + + for pair_str in strs: + ArgPairs([pair_str]) # not ok for verify_pair_str, ok for verify_pairs_str # (well-formed >1 signal or >1 pair) @@ -69,10 +117,10 @@ def test_verify_pairs_str__and__verify_pair_str(): "ADA/USDT, BTC/USDT ETH-USDC , DOT/DAI", ] for pairs_str in strs: - verify_pairs_str(pairs_str) + ArgPairs.from_str(pairs_str) for pair_str in strs: with pytest.raises(ValueError): - verify_pair_str(pair_str) + ArgPair(pair_str) # not ok for verify_pair_str, not ok for verify_pairs_str # (poorly formed) @@ -114,13 +162,12 @@ def test_verify_pairs_str__and__verify_pair_str(): "ADA/USDT - BTC/USDT", "ADA/USDT / BTC/USDT", ] + for pairs_str in strs: with pytest.raises(ValueError): - verify_pairs_str(pairs_str) - - for pair_str in strs: + ArgPairs.from_str(pairs_str) with pytest.raises(ValueError): - verify_pair_str(pair_str) + ArgPair(pairs_str) @enforce_types @@ -136,7 +183,7 @@ def test_base_str(): "OCEAN", ] for base_str in strs: - verify_base_str(base_str) + _verify_base_str(base_str) # not ok strs = [ @@ -153,7 +200,7 @@ def test_base_str(): ] for base_str in strs: with pytest.raises(ValueError): - verify_base_str(base_str) + _verify_base_str(base_str) @enforce_types @@ -169,7 +216,7 @@ def test_quote_str(): "DAI", ] for quote_str in strs: - verify_quote_str(quote_str) + _verify_quote_str(quote_str) # not ok strs = [ @@ -186,4 +233,4 @@ def test_quote_str(): ] for quote_str in strs: with pytest.raises(ValueError): - verify_quote_str(quote_str) + _verify_quote_str(quote_str) diff --git a/pdr_backend/cli/test/test_cli_arguments.py b/pdr_backend/cli/test/test_cli_arguments.py new file mode 100644 index 000000000..eb5e008f0 --- /dev/null +++ b/pdr_backend/cli/test/test_cli_arguments.py @@ -0,0 +1,41 @@ +import pytest + +from pdr_backend.cli.cli_arguments import ( + CustomArgParser, + defined_parsers, + do_help_long, + get_arg_parser, + print_args, +) + + +def test_arg_parser(): + for arg in defined_parsers: + parser = get_arg_parser(arg) + assert isinstance(parser, CustomArgParser) + + with pytest.raises(ValueError): + get_arg_parser("xyz") + + +def test_do_help_long(capfd): + with pytest.raises(SystemExit): + do_help_long() + + out, _ = capfd.readouterr() + assert "Predictoor tool" in out + assert "Main tools:" in out + + +def test_print_args(capfd): + SimArgParser = defined_parsers["do_sim"] + parser = SimArgParser + args = ["sim", "ppss.yaml"] + parsed_args = parser.parse_args(args) + + print_args(parsed_args) + + out, _ = capfd.readouterr() + assert "pdr sim: Begin" in out + assert "Arguments:" in out + assert "PPSS_FILE=ppss.yaml" in out diff --git a/pdr_backend/cli/test/test_cli_module.py b/pdr_backend/cli/test/test_cli_module.py new file mode 100644 index 000000000..2303667d8 --- /dev/null +++ b/pdr_backend/cli/test/test_cli_module.py @@ -0,0 +1,314 @@ +import os +from argparse import Namespace +from unittest.mock import Mock, patch + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.cli_module import ( + _do_main, + do_check_network, + do_claim_OCEAN, + do_claim_ROSE, + do_dfbuyer, + do_get_predictions_info, + do_get_predictoors_info, + do_get_traction_info, + do_lake, + do_predictoor, + do_publisher, + do_sim, + do_topup, + do_trader, + do_trueval, +) +from pdr_backend.ppss.ppss import PPSS + + +class _APPROACH: + APPROACH = 1 + + +class _APPROACH2: + APPROACH = 2 + + +class _APPROACH3: + APPROACH = 3 + + +class _APPROACH_BAD: + APPROACH = 99 + + +class _PPSS: + PPSS_FILE = os.path.abspath("ppss.yaml") + + +class _NETWORK: + NETWORK = "development" + + +class _LOOKBACK: + LOOKBACK_HOURS = 24 + + +class _ST: + ST = "2023-06-22" + + +class _END: + END = "2023-06-24" + + +class _PQDIR: + PQDIR = "my_parquet_data/" + + +class _FEEDS: + FEEDS = "0x95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe4" + + +class _PDRS: + PDRS = "0xa5222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe4" + + +class _Base: + def __init__(self, *args, **kwargs): + pass + + +class MockArgParser_PPSS(_Base): + def parse_args(self): + class MockArgs(Namespace, _PPSS): + pass + + return MockArgs() + + +class MockArgParser_PPSS_NETWORK(_Base): + def parse_args(self): + class MockArgs(Namespace, _PPSS, _NETWORK): + pass + + return MockArgs() + + +class MockArgParser_APPROACH_PPSS_NETWORK(_Base): + def __init__(self, approach=_APPROACH): + self.approach = approach + super().__init__() + + def parse_args(self): + class MockArgs(Namespace, self.approach, _PPSS, _NETWORK): + pass + + return MockArgs() + + +class MockArgParser_PPSS_NETWORK_LOOKBACK(_Base): + def parse_args(self): + class MockArgs(Namespace, _PPSS, _NETWORK, _LOOKBACK): + pass + + return MockArgs() + + +class MockArgParser_ST_END_PQDIR_NETWORK_PPSS_PDRS(_Base): + def parse_args(self): + class MockArgs( # pylint: disable=too-many-ancestors + Namespace, _ST, _END, _PQDIR, _NETWORK, _PPSS, _PDRS + ): + pass + + return MockArgs() + + +class MockArgParser_ST_END_PQDIR_NETWORK_PPSS_FEEDS(_Base): + def parse_args(self): + class MockArgs( # pylint: disable=too-many-ancestors + Namespace, _ST, _END, _PQDIR, _NETWORK, _PPSS, _FEEDS + ): + pass + + return MockArgs() + + +@enforce_types +class MockAgent: + was_run = False + + def __init__(self, ppss: PPSS, *args, **kwargs): + pass + + def run(self, *args, **kwargs): # pylint: disable=unused-argument + self.__class__.was_run = True + + +_CLI_PATH = "pdr_backend.cli.cli_module" + + +@enforce_types +def test_do_check_network(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.check_network_main", mock_f) + + do_check_network(MockArgParser_PPSS_NETWORK_LOOKBACK().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_lake(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.OhlcvDataFactory.get_mergedohlcv_df", mock_f) + + do_lake(MockArgParser_PPSS_NETWORK().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_claim_OCEAN(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.do_ocean_payout", mock_f) + + do_claim_OCEAN(MockArgParser_PPSS().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_claim_ROSE(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.do_rose_payout", mock_f) + + do_claim_ROSE(MockArgParser_PPSS().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_dfbuyer(monkeypatch): + monkeypatch.setattr(f"{_CLI_PATH}.DFBuyerAgent", MockAgent) + + do_dfbuyer(MockArgParser_PPSS_NETWORK().parse_args()) + assert MockAgent.was_run + + +@enforce_types +def test_do_get_predictions_info(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.get_predictions_info_main", mock_f) + + do_get_predictions_info( + MockArgParser_ST_END_PQDIR_NETWORK_PPSS_FEEDS().parse_args() + ) + mock_f.assert_called() + + +@enforce_types +def test_do_get_predictoors_info(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.get_predictoors_info_main", mock_f) + + do_get_predictoors_info(MockArgParser_ST_END_PQDIR_NETWORK_PPSS_PDRS().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_get_traction_info(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.get_traction_info_main", mock_f) + + do_get_traction_info(MockArgParser_ST_END_PQDIR_NETWORK_PPSS_FEEDS().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_predictoor(monkeypatch): + monkeypatch.setattr(f"{_CLI_PATH}.PredictoorAgent1", MockAgent) + + do_predictoor(MockArgParser_APPROACH_PPSS_NETWORK().parse_args()) + assert MockAgent.was_run + + monkeypatch.setattr(f"{_CLI_PATH}.PredictoorAgent3", MockAgent) + + do_predictoor(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH3).parse_args()) + assert MockAgent.was_run + + with pytest.raises(ValueError): + do_predictoor(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH_BAD).parse_args()) + + +@enforce_types +def test_do_publisher(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.publish_assets", mock_f) + + do_publisher(MockArgParser_PPSS_NETWORK().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_topup(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.topup_main", mock_f) + + do_topup(MockArgParser_PPSS_NETWORK().parse_args()) + mock_f.assert_called() + + +@enforce_types +def test_do_trader(monkeypatch): + monkeypatch.setattr(f"{_CLI_PATH}.TraderAgent1", MockAgent) + + do_trader(MockArgParser_APPROACH_PPSS_NETWORK().parse_args()) + assert MockAgent.was_run + + monkeypatch.setattr(f"{_CLI_PATH}.TraderAgent2", MockAgent) + + do_trader(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH2).parse_args()) + assert MockAgent.was_run + + with pytest.raises(ValueError): + do_trader(MockArgParser_APPROACH_PPSS_NETWORK(_APPROACH_BAD).parse_args()) + + +@enforce_types +def test_do_trueval(monkeypatch): + monkeypatch.setattr(f"{_CLI_PATH}.TruevalAgent", MockAgent) + + do_trueval(MockArgParser_PPSS_NETWORK().parse_args()) + assert MockAgent.was_run + + +@enforce_types +def test_do_sim(monkeypatch): + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f) + + with patch("pdr_backend.sim.sim_engine.plt.show"): + do_sim(MockArgParser_PPSS_NETWORK().parse_args()) + + mock_f.assert_called() + + +@enforce_types +def test_do_main(monkeypatch, capfd): + with patch("sys.argv", ["pdr", "help"]): + with pytest.raises(SystemExit): + _do_main() + + assert "Predictoor tool" in capfd.readouterr().out + + with patch("sys.argv", ["pdr", "undefined_function"]): + with pytest.raises(SystemExit): + _do_main() + + assert "Predictoor tool" in capfd.readouterr().out + + mock_f = Mock() + monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f) + + with patch("pdr_backend.sim.sim_engine.plt.show"): + with patch("sys.argv", ["pdr", "sim", "ppss.yaml"]): + _do_main() + + assert mock_f.called diff --git a/pdr_backend/cli/test/test_timeframe.py b/pdr_backend/cli/test/test_timeframe.py new file mode 100644 index 000000000..581b53bec --- /dev/null +++ b/pdr_backend/cli/test/test_timeframe.py @@ -0,0 +1,102 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.timeframe import Timeframe, Timeframes, s_to_timeframe_str + + +@enforce_types +def test_timeframe_class_1m(): + t = Timeframe("1m") + assert t.timeframe_str == "1m" + assert t.m == 1 + assert t.s == 1 * 60 + assert t.ms == 1 * 60 * 1000 + + +@enforce_types +def test_timeframe_class_5m(): + t = Timeframe("5m") + assert t.timeframe_str == "5m" + assert t.m == 5 + assert t.s == 5 * 60 + assert t.ms == 5 * 60 * 1000 + + +@enforce_types +def test_timeframe_class_1h(): + t = Timeframe("1h") + assert t.timeframe_str == "1h" + assert t.m == 60 + assert t.s == 60 * 60 + assert t.ms == 60 * 60 * 1000 + + +@enforce_types +def test_timeframe_class_bad(): + with pytest.raises(ValueError): + Timeframe("foo") + + t = Timeframe("1h") + # forcefully change the model + t.timeframe_str = "BAD" + + with pytest.raises(ValueError): + _ = t.m + + +@enforce_types +def test_timeframe_class_eq(): + t = Timeframe("1m") + assert t == Timeframe("1m") + assert t == "1m" + + assert t != Timeframe("5m") + assert t != "5m" + + assert t != 5 + + +@enforce_types +def test_pack_timeframe_str_list(): + assert str(Timeframes([])) == "" + assert str(Timeframes(["1h"])) == "1h" + assert str(Timeframes(["1h", "5m"])) == "1h,5m" + + assert str(Timeframes.from_str("1h,5m")) == "1h,5m" + + with pytest.raises(TypeError): + Timeframes.from_str(None) + + with pytest.raises(TypeError): + Timeframes("") + + with pytest.raises(TypeError): + Timeframes(None) + + with pytest.raises(ValueError): + Timeframes(["adfs"]) + + with pytest.raises(ValueError): + Timeframes(["1h fgds"]) + + +@enforce_types +def test_verify_timeframe_str(): + Timeframe("1h") + Timeframe("1m") + + with pytest.raises(ValueError): + Timeframe("foo") + + +@enforce_types +def test_s_to_timeframe_str(): + assert s_to_timeframe_str(300) == "5m" + assert s_to_timeframe_str(3600) == "1h" + + assert s_to_timeframe_str(0) == "" + assert s_to_timeframe_str(100) == "" + assert s_to_timeframe_str(-300) == "" + + with pytest.raises(TypeError): + s_to_timeframe_str("300") diff --git a/pdr_backend/cli/timeframe.py b/pdr_backend/cli/timeframe.py new file mode 100644 index 000000000..7ad609def --- /dev/null +++ b/pdr_backend/cli/timeframe.py @@ -0,0 +1,101 @@ +from typing import List, Union + +from enforce_typing import enforce_types + +from pdr_backend.util.constants import CAND_TIMEFRAMES + + +# don't use @enforce_types, causes problems +class Timeframe: + def __init__(self, timeframe_str: str): + """ + @arguments + timeframe_str -- e.g. "5m" + """ + if timeframe_str not in CAND_TIMEFRAMES: + raise ValueError(timeframe_str) + self.timeframe_str = timeframe_str + + @property + def ms(self) -> int: + """Returns timeframe, in ms""" + return self.m * 60 * 1000 + + @property + def s(self) -> int: + """Returns timeframe, in s""" + return self.m * 60 + + @property + def m(self) -> int: + """Returns timeframe, in minutes""" + if self.timeframe_str == "1m": + return 1 + if self.timeframe_str == "5m": + return 5 + if self.timeframe_str == "1h": + return 60 + raise ValueError(f"need to support timeframe={self.timeframe_str}") + + def __eq__(self, other): + return self.timeframe_str == str(other) + + def __str__(self): + return self.timeframe_str + + +class Timeframes(List[Timeframe]): + def __init__(self, timeframes: Union[List[str], List[Timeframe]]): + if not isinstance(timeframes, list): + raise TypeError("timeframes must be a list") + + frames = [] + for timeframe in timeframes: + if isinstance(timeframe, str): + frame = Timeframe(timeframe) + + frames.append(frame) + + super().__init__(frames) + + @staticmethod + def from_str(timeframes_str: str): + """ + @description + Parses a comma-separated string of timeframes, e.g. "1h,5m" + """ + if not isinstance(timeframes_str, str): + raise TypeError("timeframes_strs must be a string") + + return Timeframes(timeframes_str.split(",")) + + def __str__(self): + if not self: + return "" + + return ",".join([str(frame) for frame in self]) + + +@enforce_types +def s_to_timeframe_str(seconds: int) -> str: + if seconds == 300: + return "5m" + if seconds == 3600: + return "1h" + return "" + + +@enforce_types +def verify_timeframes_str(signal_str: str): + """ + @description + Raise an error if signal is invalid. + + @argument + signal_str -- e.g. "close" + """ + try: + Timeframes.from_str(signal_str) + return True + except ValueError: + return False diff --git a/pdr_backend/conftest_ganache.py b/pdr_backend/conftest_ganache.py index 560ef2b16..76e925852 100644 --- a/pdr_backend/conftest_ganache.py +++ b/pdr_backend/conftest_ganache.py @@ -1,24 +1,21 @@ -import os +from unittest.mock import Mock + import pytest -from pdr_backend.models.token import Token -from pdr_backend.models.predictoor_batcher import PredictoorBatcher -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.publisher.publish import publish +from pdr_backend.contract.predictoor_batcher import PredictoorBatcher +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.contract.token import Token +from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str +from pdr_backend.publisher.publish_asset import publish_asset from pdr_backend.util.contract import get_address -from pdr_backend.util.web3_config import Web3Config -SECONDS_PER_EPOCH = 300 -os.environ["MODELDIR"] = "my_model_dir" +CHAIN_ID = 8996 +S_PER_EPOCH = 300 -@pytest.fixture(scope="session") +@pytest.fixture(scope="session") # "session" = invoke once, across all tests def chain_id(): - return _chain_id() - - -def _chain_id(): - return 8996 + return CHAIN_ID @pytest.fixture(scope="session") @@ -27,57 +24,108 @@ def web3_config(): def _web3_config(): - return Web3Config(os.getenv("RPC_URL"), os.getenv("PRIVATE_KEY")) + return _web3_pp().web3_config + + +@pytest.fixture(scope="session") +def rpc_url(): + return _web3_pp().rpc_url + + +@pytest.fixture(scope="session") +def web3_pp(): + return _web3_pp() + + +def _web3_pp(): + return _ppss().web3_pp + + +@pytest.fixture(scope="session") +def ppss(): + return _ppss() + + +def _ppss(): + s = fast_test_yaml_str() + return PPSS(yaml_str=s, network="development") @pytest.fixture(scope="session") def ocean_token() -> Token: - token_address = get_address(_chain_id(), "Ocean") - return Token(_web3_config(), token_address) + token_address = get_address(_web3_pp(), "Ocean") + return Token(_web3_pp(), token_address) -@pytest.fixture(scope="module") +@pytest.fixture(scope="module") # "module" = invoke once per test module def predictoor_contract(): - config = Web3Config(os.getenv("RPC_URL"), os.getenv("PRIVATE_KEY")) - _, _, _, _, logs = publish( - s_per_epoch=SECONDS_PER_EPOCH, - s_per_subscription=SECONDS_PER_EPOCH * 24, + w3p = _web3_pp() + w3c = w3p.web3_config + _, _, _, _, logs = publish_asset( + s_per_epoch=S_PER_EPOCH, + s_per_subscription=S_PER_EPOCH * 24, base="ETH", quote="USDT", source="kraken", timeframe="5m", - trueval_submitter_addr=config.owner, - feeCollector_addr=config.owner, + trueval_submitter_addr=w3c.owner, + feeCollector_addr=w3c.owner, rate=3, cut=0.2, - web3_config=config, + web3_pp=w3p, ) dt_addr = logs["newTokenAddress"] - return PredictoorContract(config, dt_addr) + return PredictoorContract(w3p, dt_addr) @pytest.fixture(scope="module") def predictoor_contract2(): - config = Web3Config(os.getenv("RPC_URL"), os.getenv("PRIVATE_KEY")) - _, _, _, _, logs = publish( - s_per_epoch=SECONDS_PER_EPOCH, - s_per_subscription=SECONDS_PER_EPOCH * 24, + w3p = _web3_pp() + w3c = w3p.web3_config + _, _, _, _, logs = publish_asset( + s_per_epoch=S_PER_EPOCH, + s_per_subscription=S_PER_EPOCH * 24, + base="ETH", + quote="USDT", + source="kraken", + timeframe="5m", + trueval_submitter_addr=w3c.owner, + feeCollector_addr=w3c.owner, + rate=3, + cut=0.2, + web3_pp=w3p, + ) + dt_addr = logs["newTokenAddress"] + return PredictoorContract(w3p, dt_addr) + + +@pytest.fixture(scope="module") # "module" = invoke once per test module +def predictoor_contract_empty(): + w3p = _web3_pp() + w3c = w3p.web3_config + _, _, _, _, logs = publish_asset( + s_per_epoch=S_PER_EPOCH, + s_per_subscription=S_PER_EPOCH * 24, base="ETH", quote="USDT", source="kraken", timeframe="5m", - trueval_submitter_addr=config.owner, - feeCollector_addr=config.owner, + trueval_submitter_addr=w3c.owner, + feeCollector_addr=w3c.owner, rate=3, cut=0.2, - web3_config=config, + web3_pp=w3p, ) dt_addr = logs["newTokenAddress"] - return PredictoorContract(config, dt_addr) + predictoor_c = PredictoorContract(w3p, dt_addr) + predictoor_c.get_exchanges = Mock(return_value=[]) + + return predictoor_c # pylint: disable=redefined-outer-name @pytest.fixture(scope="module") def predictoor_batcher(): - predictoor_batcher_addr = get_address(_chain_id(), "PredictoorHelper") - return PredictoorBatcher(_web3_config(), predictoor_batcher_addr) + w3p = _web3_pp() + predictoor_batcher_addr = get_address(w3p, "PredictoorHelper") + return PredictoorBatcher(w3p, predictoor_batcher_addr) diff --git a/pdr_backend/contract/base_contract.py b/pdr_backend/contract/base_contract.py new file mode 100644 index 000000000..454b29c29 --- /dev/null +++ b/pdr_backend/contract/base_contract.py @@ -0,0 +1,56 @@ +from abc import ABC + +from enforce_typing import enforce_types +from sapphirepy import wrapper + + +@enforce_types +class BaseContract(ABC): + def __init__(self, web3_pp, address: str, contract_name: str): + super().__init__() + # pylint: disable=import-outside-toplevel + from pdr_backend.ppss.web3_pp import Web3PP + + # pylint: disable=import-outside-toplevel + from pdr_backend.util.contract import get_contract_abi + + if not isinstance(web3_pp, Web3PP): + raise ValueError(f"web3_pp is {web3_pp.__class__}, not Web3PP") + self.web3_pp = web3_pp + self.config = web3_pp.web3_config # for convenience + self.contract_address = self.config.w3.to_checksum_address(address) + self.contract_instance = self.config.w3.eth.contract( + address=self.config.w3.to_checksum_address(address), + abi=get_contract_abi(contract_name, web3_pp.address_file), + ) + + def send_encrypted_tx( + self, + function_name, + args, + sender=None, + receiver=None, + pk=None, + value=0, # in wei + gasLimit=10000000, + gasCost=0, # in wei + nonce=0, + ) -> tuple: + sender = self.config.owner if sender is None else sender + receiver = self.contract_instance.address if receiver is None else receiver + pk = self.config.account.key.hex()[2:] if pk is None else pk + + data = self.contract_instance.encodeABI(fn_name=function_name, args=args) + rpc_url = self.config.rpc_url + + return wrapper.send_encrypted_sapphire_tx( + pk, + sender, + receiver, + rpc_url, + value, + gasLimit, + data, + gasCost, + nonce, + ) diff --git a/pdr_backend/models/data_nft.py b/pdr_backend/contract/data_nft.py similarity index 69% rename from pdr_backend/models/data_nft.py rename to pdr_backend/contract/data_nft.py index 1d503ffbc..3d7029b6b 100644 --- a/pdr_backend/models/data_nft.py +++ b/pdr_backend/contract/data_nft.py @@ -1,31 +1,25 @@ -from typing import Union - import hashlib import json +from typing import Union from enforce_typing import enforce_types from web3 import Web3 -from web3.types import TxReceipt, HexBytes +from web3.types import HexBytes, TxReceipt -from pdr_backend.models.base_contract import BaseContract -from pdr_backend.util.web3_config import Web3Config +from pdr_backend.contract.base_contract import BaseContract @enforce_types class DataNft(BaseContract): - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address, "ERC721Template") + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address, "ERC721Template") def set_data(self, field_label, field_value, wait_for_receipt=True): """Set key/value data via ERC725, with strings for key/value""" field_label_hash = Web3.keccak(text=field_label) # to keccak256 hash field_value_bytes = field_value.encode() # to array of bytes - # gasPrice = self.config.w3.eth.gas_price - call_params = { - "from": self.config.owner, - "gasPrice": 100000000000, - "gas": 100000, - } + + call_params = self.web3_pp.tx_call_params(gas=100000) tx = self.contract_instance.functions.setNewData( field_label_hash, field_value_bytes ).transact(call_params) @@ -34,11 +28,7 @@ def set_data(self, field_label, field_value, wait_for_receipt=True): return tx def add_erc20_deployer(self, address, wait_for_receipt=True): - # gasPrice = self.config.w3.eth.gas_price - call_params = { - "from": self.config.owner, - "gasPrice": 100000000000, - } + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.addToCreateERC20List( self.config.w3.to_checksum_address(address) ).transact(call_params) @@ -47,12 +37,9 @@ def add_erc20_deployer(self, address, wait_for_receipt=True): return tx def set_ddo(self, ddo, wait_for_receipt=True): - call_params = { - "from": self.config.owner, - "gasPrice": 100000000000, - } js = json.dumps(ddo) stored_ddo = Web3.to_bytes(text=js) + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.setMetaData( 1, "", @@ -69,10 +56,12 @@ def set_ddo(self, ddo, wait_for_receipt=True): def add_to_create_erc20_list( self, addr: str, wait_for_receipt=True ) -> Union[HexBytes, TxReceipt]: - gasPrice = self.config.w3.eth.gas_price + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.addToCreateERC20List(addr).transact( - {"from": self.config.owner, "gasPrice": gasPrice} + call_params ) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) diff --git a/pdr_backend/models/dfrewards.py b/pdr_backend/contract/dfrewards.py similarity index 54% rename from pdr_backend/models/dfrewards.py rename to pdr_backend/contract/dfrewards.py index 6d67b5dda..4878f847a 100644 --- a/pdr_backend/models/dfrewards.py +++ b/pdr_backend/contract/dfrewards.py @@ -1,29 +1,32 @@ from enforce_typing import enforce_types -from pdr_backend.models.base_contract import BaseContract -from pdr_backend.util.web3_config import Web3Config +from pdr_backend.contract.base_contract import BaseContract +from pdr_backend.util.mathutil import from_wei @enforce_types class DFRewards(BaseContract): - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address, "DFRewards") + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address, "DFRewards") def claim_rewards(self, user_addr: str, token_addr: str, wait_for_receipt=True): - gasPrice = self.config.w3.eth.gas_price + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.claimFor(user_addr, token_addr).transact( - {"from": self.config.owner, "gasPrice": gasPrice} + call_params ) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) def get_claimable_rewards(self, user_addr: str, token_addr: str) -> float: """ - Returns the amount of claimable rewards in units of ETH + @return + claimable -- # claimable rewards (in units of ETH, not wei) """ claimable_wei = self.contract_instance.functions.claimable( user_addr, token_addr ).call() - claimable_rewards = self.config.w3.from_wei(claimable_wei, "ether") - return float(claimable_rewards) + claimable = from_wei(claimable_wei) + return claimable diff --git a/pdr_backend/models/erc721_factory.py b/pdr_backend/contract/erc721_factory.py similarity index 60% rename from pdr_backend/models/erc721_factory.py rename to pdr_backend/contract/erc721_factory.py index b3fb6bfca..b26d05870 100644 --- a/pdr_backend/models/erc721_factory.py +++ b/pdr_backend/contract/erc721_factory.py @@ -1,34 +1,30 @@ from enforce_typing import enforce_types from web3.logs import DISCARD -from pdr_backend.models.base_contract import BaseContract +from pdr_backend.contract.base_contract import BaseContract from pdr_backend.util.contract import get_address -from pdr_backend.util.web3_config import Web3Config @enforce_types -class ERC721Factory(BaseContract): - def __init__(self, config: Web3Config, chain_id=None): - if not chain_id: - chain_id = config.w3.eth.chain_id - address = get_address(chain_id, "ERC721Factory") +class Erc721Factory(BaseContract): + def __init__(self, web3_pp): + address = get_address(web3_pp, "ERC721Factory") + if not address: - raise ValueError("Cannot figure out ERC721Factory address") - super().__init__(config, address, "ERC721Factory") + raise ValueError("Cannot figure out Erc721Factory address") - def createNftWithErc20WithFixedRate(self, NftCreateData, ErcCreateData, FixedData): - # gasPrice = self.config.w3.eth.gas_price - call_params = { - "from": self.config.owner, - "gasPrice": 100000000000, - } + super().__init__(web3_pp, address, "ERC721Factory") + def createNftWithErc20WithFixedRate(self, NftCreateData, ErcCreateData, FixedData): + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance.functions.createNftWithErc20WithFixedRate( NftCreateData, ErcCreateData, FixedData ).transact(call_params) receipt = self.config.w3.eth.wait_for_transaction_receipt(tx) + if receipt["status"] != 1: raise ValueError(f"createNftWithErc20WithFixedRate failed in {tx.hex()}") + # print(receipt) logs_nft = self.contract_instance.events.NFTCreated().process_receipt( receipt, errors=DISCARD diff --git a/pdr_backend/contract/fixed_rate.py b/pdr_backend/contract/fixed_rate.py new file mode 100644 index 000000000..8e929bc2a --- /dev/null +++ b/pdr_backend/contract/fixed_rate.py @@ -0,0 +1,63 @@ +from typing import Tuple + +from enforce_typing import enforce_types + +from pdr_backend.contract.base_contract import BaseContract +from pdr_backend.util.mathutil import to_wei + + +@enforce_types +class FixedRate(BaseContract): + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address, "FixedRateExchange") + + def get_dt_price(self, exchangeId) -> Tuple[int, int, int, int]: + """ + @description + # OCEAN needed to buy 1 datatoken + + @arguments + exchangeId - a unique exchange identifier. Typically a string. + + @return + baseTokenAmt_wei - # basetokens needed + oceanFeeAmt_wei - fee to Ocean Protocol Community (OPC) + publishMktFeeAmt_wei - fee to publish market + consumeMktFeeAmt_wei - fee to consume market + + @notes + Assumes consumeMktSwapFeeAmt = 0 + """ + return self.calcBaseInGivenOutDT( + exchangeId, + datatokenAmt_wei=to_wei(1), + consumeMktSwapFeeAmt_wei=0, + ) + + def calcBaseInGivenOutDT( + self, + exchangeId, + datatokenAmt_wei: int, + consumeMktSwapFeeAmt_wei: int, + ) -> Tuple[int, int, int, int]: + """ + @description + Given an exact target # datatokens, calculates # basetokens + (OCEAN) needed to get it, and fee amounts too. + + @arguments + exchangeId - a unique exchange identifier. Typically a string. + datatokenAmt_wei - # datatokens to be exchanged + consumeMktSwapFeeAmt - fee amount for consume market + + @return + baseTokenAmt_wei - # OCEAN needed, in wei + oceanFeeAmt_wei - fee to Ocean community (OPC) + publishMktFeeAmt_wei - fee to publish market + consumeMktFeeAmt_wei - fee to consume market + """ + return self.contract_instance.functions.calcBaseInGivenOutDT( + exchangeId, + datatokenAmt_wei, + consumeMktSwapFeeAmt_wei, + ).call() diff --git a/pdr_backend/contract/predictoor_batcher.py b/pdr_backend/contract/predictoor_batcher.py new file mode 100644 index 000000000..385c2eeb5 --- /dev/null +++ b/pdr_backend/contract/predictoor_batcher.py @@ -0,0 +1,89 @@ +from typing import List +from unittest.mock import Mock + +from enforce_typing import enforce_types + +from pdr_backend.contract.base_contract import BaseContract +from pdr_backend.ppss.web3_pp import Web3PP + + +class PredictoorBatcher(BaseContract): + @enforce_types + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address, "PredictoorHelper") + + @property + def web3_config(self): + return self.web3_pp.web3_config + + @property + def w3(self): + return self.web3_config.w3 + + @enforce_types + def consume_multiple( + self, + addresses: List[str], + times: List[int], + token_addr: str, + wait_for_receipt=True, + ): + call_params = self.web3_pp.tx_call_params(gas=14_000_000) + tx = self.contract_instance.functions.consumeMultiple( + addresses, times, token_addr + ).transact(call_params) + + if not wait_for_receipt: + return tx + + return self.w3.eth.wait_for_transaction_receipt(tx) + + @enforce_types + def submit_truevals_contracts( + self, + contract_addrs: List[str], + epoch_starts: List[List[int]], + trueVals: List[List[bool]], + cancelRounds: List[List[bool]], + wait_for_receipt=True, + ): + call_params = self.web3_pp.tx_call_params() + tx = self.contract_instance.functions.submitTruevalContracts( + contract_addrs, epoch_starts, trueVals, cancelRounds + ).transact(call_params) + + if not wait_for_receipt: + return tx + + return self.w3.eth.wait_for_transaction_receipt(tx) + + @enforce_types + def submit_truevals( + self, + contract_addr: str, + epoch_starts: List[int], + trueVals: List[bool], + cancelRounds: List[bool], + wait_for_receipt=True, + ): + call_params = self.web3_pp.tx_call_params() + tx = self.contract_instance.functions.submitTruevals( + contract_addr, epoch_starts, trueVals, cancelRounds + ).transact(call_params) + + if not wait_for_receipt: + return tx + + return self.w3.eth.wait_for_transaction_receipt(tx) + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_predictoor_batcher(web3_pp: Web3PP) -> PredictoorBatcher: + b = Mock(spec=PredictoorBatcher) + b.web3_pp = web3_pp + b.contract_address = "0xPdrBatcherAddr" + return b diff --git a/pdr_backend/contract/predictoor_contract.py b/pdr_backend/contract/predictoor_contract.py new file mode 100644 index 000000000..00599c4d6 --- /dev/null +++ b/pdr_backend/contract/predictoor_contract.py @@ -0,0 +1,362 @@ +from typing import List, Tuple +from unittest.mock import Mock + +from enforce_typing import enforce_types + +from pdr_backend.contract.base_contract import BaseContract +from pdr_backend.contract.fixed_rate import FixedRate +from pdr_backend.contract.token import Token +from pdr_backend.util.constants import MAX_UINT, ZERO_ADDRESS +from pdr_backend.util.mathutil import from_wei, string_to_bytes32, to_wei + + +@enforce_types +class PredictoorContract(BaseContract): # pylint: disable=too-many-public-methods + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address, "ERC20Template3") + stake_token = self.get_stake_token() + self.token = Token(web3_pp, stake_token) + self.last_allowance = 0 + + def is_valid_subscription(self): + """Does this account have a subscription to this feed yet?""" + return self.contract_instance.functions.isValidSubscription( + self.config.owner + ).call() + + def getid(self): + """Return the ID of this contract.""" + return self.contract_instance.functions.getId().call() + + def buy_and_start_subscription(self, gasLimit=None, wait_for_receipt=True): + """ + @description + Buys 1 datatoken and starts a subscription. + + @return + tx - transaction hash. Or, returns None if an error while transacting + """ + exchanges = self.get_exchanges() + if not exchanges: + raise ValueError("No exchanges available") + + (exchange_addr, exchangeId) = exchanges[0] + + # get datatoken price + exchange = FixedRate(self.web3_pp, exchange_addr) + (baseTokenAmt_wei, _, _, _) = exchange.get_dt_price(exchangeId) + print(f" Price of feed: {from_wei(baseTokenAmt_wei)} OCEAN") + + # approve + print(" Approve spend OCEAN: begin") + self.token.approve(self.contract_instance.address, baseTokenAmt_wei) + print(" Approve spend OCEAN: done") + + # buy 1 DT + call_params = self.web3_pp.tx_call_params() + orderParams = ( # OrderParams + self.config.owner, # consumer + 0, # serviceIndex + ( # providerFee, with zeroed values + ZERO_ADDRESS, + ZERO_ADDRESS, + 0, + 0, + string_to_bytes32(""), + string_to_bytes32(""), + 0, + self.config.w3.to_bytes(b""), + ), + ( # consumeMarketFee, with zeroed values + ZERO_ADDRESS, + ZERO_ADDRESS, + 0, + ), + ) + freParams = ( # FreParams + self.config.w3.to_checksum_address(exchange_addr), # exchangeContract + self.config.w3.to_bytes(exchangeId), # exchangeId + baseTokenAmt_wei, # maxBaseTokenAmount + 0, # swapMarketFee + ZERO_ADDRESS, # marketFeeAddress + ) + + if gasLimit is None: + try: + print(" Estimate gasLimit: begin") + gasLimit = self.contract_instance.functions.buyFromFreAndOrder( + orderParams, freParams + ).estimate_gas(call_params) + except Exception as e: + print( + f" Estimate gasLimit had error in estimate_gas(): {e}" + " Because of error, use get_max_gas() as workaround" + ) + gasLimit = self.config.get_max_gas() + assert gasLimit is not None, "should have non-None gasLimit by now" + print(f" Estimate gasLimit: done. gasLimit={gasLimit}") + call_params["gas"] = gasLimit + 1 + + try: + print(" buyFromFreAndOrder: begin") + tx = self.contract_instance.functions.buyFromFreAndOrder( + orderParams, freParams + ).transact(call_params) + if not wait_for_receipt: + print(" buyFromFreAndOrder: WIP, didn't wait around") + return tx + tx = self.config.w3.eth.wait_for_transaction_receipt(tx) + print(" buyFromFreAndOrder: waited around, it's done") + return tx + except Exception as e: + print(f" buyFromFreAndOrder hit an error: {e}") + return None + + def buy_many(self, n_to_buy: int, gasLimit=None, wait_for_receipt=False): + """Buys multiple subscriptions and returns tx hashes""" + if n_to_buy < 1: + return None + print(f"Purchase {n_to_buy} subscriptions for this feed: begin") + txs = [] + for i in range(n_to_buy): + print(f"Purchase access #{i+1}/{n_to_buy} for this feed") + tx = self.buy_and_start_subscription(gasLimit, wait_for_receipt) + txs.append(tx) + print(f"Purchase {n_to_buy} subscriptions for this feed: done") + return txs + + def get_exchanges(self) -> List[Tuple[str, str]]: + """ + @description + Returns the fixed-rate exchanges created for this datatoken + + @return + exchanges - list of (exchange_addr:str, exchangeId: str) + """ + return self.contract_instance.functions.getFixedRates().call() + + def get_stake_token(self): + """Returns the token used for staking & purchases. Eg OCEAN.""" + return self.contract_instance.functions.stakeToken().call() + + def get_price(self) -> int: + """ + @description + # OCEAN needed to buy 1 datatoken + + @return + baseTokenAmt_wei - # OCEAN needed, in wei + + @notes + Assumes consumeMktSwapFeeAmt = 0 + """ + exchanges = self.get_exchanges() # fixed rate exchanges + if not exchanges: + raise ValueError("No exchanges available") + (exchange_addr, exchangeId) = exchanges[0] + + exchange = FixedRate(self.web3_pp, exchange_addr) + (baseTokenAmt_wei, _, _, _) = exchange.get_dt_price(exchangeId) + return baseTokenAmt_wei + + def get_current_epoch(self) -> int: + """ + Whereas curEpoch returns the timestamp of current candle start... + *This* function returns the 'epoch number' that increases + by one each secondsPerEpoch seconds + """ + current_epoch_ts = self.get_current_epoch_ts() + seconds_per_epoch = self.get_secondsPerEpoch() + return int(current_epoch_ts / seconds_per_epoch) + + def get_current_epoch_ts(self) -> int: + """returns the current candle start timestamp""" + return self.contract_instance.functions.curEpoch().call() + + def get_secondsPerEpoch(self) -> int: + """How many seconds are in each epoch? (According to contract)""" + return self.contract_instance.functions.secondsPerEpoch().call() + + def get_agg_predval(self, timestamp: int) -> Tuple[float, float]: + """ + @description + Get aggregated prediction value. + + @arguments + timestamp - + + @return + nom - numerator = # OCEAN staked for 'up' (in units of ETH, not wei) + denom - denominator = total # OCEAN staked ("") + """ + auth = self.config.get_auth_signature() + call_params = self.web3_pp.tx_call_params() + (nom_wei, denom_wei) = self.contract_instance.functions.getAggPredval( + timestamp, auth + ).call(call_params) + return from_wei(nom_wei), from_wei(denom_wei) + + def payout_multiple(self, slots: List[int], wait_for_receipt: bool = True): + """Claims the payout for given slots""" + call_params = self.web3_pp.tx_call_params() + try: + tx = self.contract_instance.functions.payoutMultiple( + slots, self.config.owner + ).transact(call_params) + + if not wait_for_receipt: + return tx + + return self.config.w3.eth.wait_for_transaction_receipt(tx) + except Exception as e: + print(e) + return None + + def payout(self, slot, wait_for_receipt=False): + """Claims the payout for one slot""" + call_params = self.web3_pp.tx_call_params() + try: + tx = self.contract_instance.functions.payout( + slot, self.config.owner + ).transact(call_params) + + if not wait_for_receipt: + return tx + + return self.config.w3.eth.wait_for_transaction_receipt(tx) + except Exception as e: + print(e) + return None + + def soonest_timestamp_to_predict(self, timestamp: int) -> int: + """Returns the soonest epoch to predict (expressed as a timestamp)""" + return self.contract_instance.functions.soonestEpochToPredict(timestamp).call() + + def submit_prediction( + self, + predicted_value: bool, + stake_amt: float, + prediction_ts: int, + wait_for_receipt=True, + ): + """ + @description + Submits a prediction with the specified stake amount, to the contract. + + @arguments + predicted_value: The predicted value (True or False) + stake_amt: The amount of OCEAN to stake in prediction (in ETH, not wei) + prediction_ts: The prediction timestamp == start a candle. + wait_for_receipt: + If True, waits for tx receipt after submission. + If False, immediately after sending the transaction. + Default is True. + + @return: + If wait_for_receipt is True, returns the tx receipt. + If False, returns the tx hash immediately after sending. + If an exception occurs during the process, returns None. + """ + stake_amt_wei = to_wei(stake_amt) + + # Check allowance first, only approve if needed + if self.last_allowance <= 0: + self.last_allowance = self.token.allowance( + self.config.owner, self.contract_address + ) + if self.last_allowance < stake_amt_wei: + try: + self.token.approve(self.contract_address, MAX_UINT) + self.last_allowance = MAX_UINT + except Exception as e: + print("Error while approving the contract to spend tokens:", e) + return None + + call_params = self.web3_pp.tx_call_params() + try: + txhash = None + if self.config.is_sapphire: + res, txhash = self.send_encrypted_tx( + "submitPredval", [predicted_value, stake_amt_wei, prediction_ts] + ) + print("Encrypted transaction status code:", res) + else: + tx = self.contract_instance.functions.submitPredval( + predicted_value, stake_amt_wei, prediction_ts + ).transact(call_params) + txhash = tx.hex() + self.last_allowance -= stake_amt_wei + print(f"Submitted prediction, txhash: {txhash}") + + if not wait_for_receipt: + return txhash + + return self.config.w3.eth.wait_for_transaction_receipt(txhash) + except Exception as e: + print(e) + return None + + def get_trueValSubmitTimeout(self): + """Returns the timeout for submitting truevals, according to contract""" + return self.contract_instance.functions.trueValSubmitTimeout().call() + + def get_prediction(self, slot: int, address: str): + """Returns the prediction made by this account, for + the specified time slot and address.""" + auth_signature = self.config.get_auth_signature() + call_params = {"from": self.config.owner} + return self.contract_instance.functions.getPrediction( + slot, address, auth_signature + ).call(call_params) + + def submit_trueval(self, trueval, timestamp, cancel_round, wait_for_receipt=True): + """Submit true value for this feed, at the specified time. + Alternatively, cancel this epoch (round). + Can only be called by the owner. + Returns the hash of the transaction. + """ + call_params = self.web3_pp.tx_call_params() + tx = self.contract_instance.functions.submitTrueVal( + timestamp, trueval, cancel_round + ).transact(call_params) + print(f"Submit trueval: txhash={tx.hex()}") + + if wait_for_receipt: + tx = self.config.w3.eth.wait_for_transaction_receipt(tx) + + return tx + + def redeem_unused_slot_revenue(self, timestamp, wait_for_receipt=True): + """Redeem unused slot revenue.""" + call_params = self.web3_pp.tx_call_params() + try: + tx = self.contract_instance.functions.redeemUnusedSlotRevenue( + timestamp + ).transact(call_params) + + if not wait_for_receipt: + return tx + + return self.config.w3.eth.wait_for_transaction_receipt(tx) + except Exception as e: + print(e) + return None + + def erc721_addr(self) -> str: + """What's the ERC721 address from which this ERC20 feed was created?""" + return self.contract_instance.functions.getERC721Address().call() + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_predictoor_contract( + contract_address: str, + agg_predval: tuple = (1, 2), +) -> PredictoorContract: + c = Mock(spec=PredictoorContract) + c.contract_address = contract_address + c.get_agg_predval.return_value = agg_predval + return c diff --git a/pdr_backend/contract/slot.py b/pdr_backend/contract/slot.py new file mode 100644 index 000000000..ce39f95e7 --- /dev/null +++ b/pdr_backend/contract/slot.py @@ -0,0 +1,7 @@ +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + + +class Slot: + def __init__(self, slot_number: int, feed: SubgraphFeed): + self.slot_number = slot_number + self.feed = feed diff --git a/pdr_backend/models/test/conftest.py b/pdr_backend/contract/test/conftest.py similarity index 100% rename from pdr_backend/models/test/conftest.py rename to pdr_backend/contract/test/conftest.py diff --git a/pdr_backend/contract/test/test_base_contract.py b/pdr_backend/contract/test/test_base_contract.py new file mode 100644 index 000000000..aebeee946 --- /dev/null +++ b/pdr_backend/contract/test/test_base_contract.py @@ -0,0 +1,82 @@ +import os +from unittest.mock import Mock + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.contract.token import Token +from pdr_backend.util.contract import get_address + + +@pytest.fixture +def mock_send_encrypted_sapphire_tx(monkeypatch): + mock_function = Mock(return_value=(0, "dummy_tx_hash")) + monkeypatch.setattr("sapphirepy.wrapper.send_encrypted_sapphire_tx", mock_function) + return mock_function + + +@enforce_types +def test_base_contract(web3_pp, web3_config): + OCEAN_address = get_address(web3_pp, "Ocean") + + # success + Token(web3_pp, OCEAN_address) + + # catch failure + web3_config = web3_pp.web3_config + with pytest.raises(ValueError): + Token(web3_config, OCEAN_address) + + +@enforce_types +def test_send_encrypted_tx( + mock_send_encrypted_sapphire_tx, # pylint: disable=redefined-outer-name + ocean_token, + web3_pp, +): + OCEAN_address = get_address(web3_pp, "Ocean") + contract = Token(web3_pp, OCEAN_address) + + # Set up dummy return value for the mocked function + mock_send_encrypted_sapphire_tx.return_value = ( + 0, + "dummy_tx_hash", + ) + + # Sample inputs for send_encrypted_tx + function_name = "transfer" + args = [web3_pp.web3_config.owner, 100] + sender = web3_pp.web3_config.owner + receiver = web3_pp.web3_config.w3.eth.accounts[1] + rpc_url = web3_pp.rpc_url + value = 0 + gasLimit = 10000000 + gasCost = 0 + nonce = 0 + pk = os.getenv("PRIVATE_KEY") + + tx_hash, encrypted_data = contract.send_encrypted_tx( + function_name, + args, + sender, + receiver, + pk, + value, + gasLimit, + gasCost, + nonce, + ) + assert tx_hash == 0 + assert encrypted_data == "dummy_tx_hash" + + mock_send_encrypted_sapphire_tx.assert_called_once_with( + pk, + sender, + receiver, + rpc_url, + value, + gasLimit, + ocean_token.contract_instance.encodeABI(fn_name=function_name, args=args), + gasCost, + nonce, + ) diff --git a/pdr_backend/models/test/test_data_nft.py b/pdr_backend/contract/test/test_data_nft.py similarity index 72% rename from pdr_backend/models/test/test_data_nft.py rename to pdr_backend/contract/test/test_data_nft.py index 095f3b684..6c5e931c7 100644 --- a/pdr_backend/models/test/test_data_nft.py +++ b/pdr_backend/contract/test/test_data_nft.py @@ -5,15 +5,17 @@ from eth_account import Account from web3.logs import DISCARD -from pdr_backend.models.data_nft import DataNft -from pdr_backend.models.erc721_factory import ERC721Factory +from pdr_backend.contract.data_nft import DataNft +from pdr_backend.contract.erc721_factory import Erc721Factory from pdr_backend.util.constants import MAX_UINT from pdr_backend.util.contract import get_address -from pdr_backend.util.web3_config import Web3Config +from pdr_backend.util.mathutil import to_wei @enforce_types -def test_set_ddo(): +def test_set_ddo(web3_pp, web3_config): + private_key = os.getenv("PRIVATE_KEY") + path = os.path.join( os.path.dirname(__file__), "../../tests/resources/ddo_v4_sample.json" ) @@ -22,15 +24,12 @@ def test_set_ddo(): content = file_handle.read() ddo = json.loads(content) - private_key = os.getenv("PRIVATE_KEY") owner = Account.from_key( # pylint:disable=no-value-for-parameter private_key=private_key ) - rpc_url = os.getenv("RPC_URL") - web3_config = Web3Config(rpc_url, private_key) - factory = ERC721Factory(web3_config) - ocean_address = get_address(web3_config.w3.eth.chain_id, "Ocean") - fre_address = get_address(web3_config.w3.eth.chain_id, "FixedPrice") + factory = Erc721Factory(web3_pp) + ocean_address = get_address(web3_pp, "Ocean") + fre_address = get_address(web3_pp, "FixedPrice") feeCollector = owner.address @@ -49,8 +48,8 @@ def test_set_ddo(): [], ) - rate = web3_config.w3.to_wei(3, "ether") - cut = web3_config.w3.to_wei(0.2, "ether") + rate = to_wei(3) + cut = to_wei(0.2) fre_data = ( fre_address, [ocean_address, owner.address, feeCollector, owner.address], @@ -59,7 +58,7 @@ def test_set_ddo(): logs_nft, _ = factory.createNftWithErc20WithFixedRate(nft_data, erc_data, fre_data) data_nft_address = logs_nft["newTokenAddress"] print(f"Deployed NFT: {data_nft_address}") - data_nft = DataNft(web3_config, data_nft_address) + data_nft = DataNft(web3_pp, data_nft_address) tx = data_nft.set_ddo(ddo, wait_for_receipt=True) receipt = web3_config.w3.eth.wait_for_transaction_receipt(tx) diff --git a/pdr_backend/models/test/test_dfrewards.py b/pdr_backend/contract/test/test_dfrewards.py similarity index 50% rename from pdr_backend/models/test/test_dfrewards.py rename to pdr_backend/contract/test/test_dfrewards.py index f05a3567a..32a4cec89 100644 --- a/pdr_backend/models/test/test_dfrewards.py +++ b/pdr_backend/contract/test/test_dfrewards.py @@ -1,19 +1,18 @@ from enforce_typing import enforce_types -from pdr_backend.models.dfrewards import DFRewards +from pdr_backend.contract.dfrewards import DFRewards from pdr_backend.util.contract import get_address -from pdr_backend.util.web3_config import Web3Config @enforce_types -def test_dfrewards(web3_config: Web3Config): - dfrewards_addr = get_address(web3_config.w3.eth.chain_id, "DFRewards") +def test_dfrewards(web3_pp, web3_config): + dfrewards_addr = get_address(web3_pp, "DFRewards") assert isinstance(dfrewards_addr, str) - ocean_addr = get_address(web3_config.w3.eth.chain_id, "Ocean") + ocean_addr = get_address(web3_pp, "Ocean") assert isinstance(dfrewards_addr, str) - contract = DFRewards(web3_config, dfrewards_addr) + contract = DFRewards(web3_pp, dfrewards_addr) rewards = contract.get_claimable_rewards(web3_config.owner, ocean_addr) assert rewards == 0 diff --git a/pdr_backend/contract/test/test_erc721_factory.py b/pdr_backend/contract/test/test_erc721_factory.py new file mode 100644 index 000000000..183b8c206 --- /dev/null +++ b/pdr_backend/contract/test/test_erc721_factory.py @@ -0,0 +1,74 @@ +from unittest.mock import Mock, patch + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.contract.erc721_factory import Erc721Factory +from pdr_backend.util.contract import get_address +from pdr_backend.util.mathutil import to_wei + + +@enforce_types +def test_Erc721Factory(web3_pp, web3_config): + factory = Erc721Factory(web3_pp) + assert factory is not None + + ocean_address = get_address(web3_pp, "Ocean") + fre_address = get_address(web3_pp, "FixedPrice") + + rate = 3 + cut = 0.2 + + nft_data = ("TestToken", "TT", 1, "", True, web3_config.owner) + erc_data = ( + 3, + ["ERC20Test", "ET"], + [ + web3_config.owner, + web3_config.owner, + web3_config.owner, + ocean_address, + ocean_address, + ], + [2**256 - 1, 0, 300, 3000, 30000], + [], + ) + fre_data = ( + fre_address, + [ + ocean_address, + web3_config.owner, + web3_config.owner, + web3_config.owner, + ], + [ + 18, + 18, + to_wei(rate), + to_wei(cut), + 1, + ], + ) + + logs_nft, logs_erc = factory.createNftWithErc20WithFixedRate( + nft_data, erc_data, fre_data + ) + + assert len(logs_nft) > 0 + assert len(logs_erc) > 0 + + config = Mock() + receipt = {"status": 0} + config.w3.eth.wait_for_transaction_receipt.return_value = receipt + + with patch.object(factory, "config") as mock_config: + mock_config.return_value = config + with pytest.raises(ValueError): + factory.createNftWithErc20WithFixedRate(nft_data, erc_data, fre_data) + + +@enforce_types +def test_Erc721Factory_no_address(web3_pp): + with patch("pdr_backend.contract.erc721_factory.get_address", return_value=None): + with pytest.raises(ValueError): + Erc721Factory(web3_pp) diff --git a/pdr_backend/contract/test/test_fixed_rate.py b/pdr_backend/contract/test/test_fixed_rate.py new file mode 100644 index 000000000..de43a33f6 --- /dev/null +++ b/pdr_backend/contract/test/test_fixed_rate.py @@ -0,0 +1,36 @@ +from enforce_typing import enforce_types +from pytest import approx + +from pdr_backend.contract.fixed_rate import FixedRate +from pdr_backend.util.mathutil import from_wei, to_wei + + +@enforce_types +def test_FixedRate(predictoor_contract, web3_pp): + exchanges = predictoor_contract.get_exchanges() + print(exchanges) + + address = exchanges[0][0] + exchangeId = exchanges[0][1] + + # constructor + exchange = FixedRate(web3_pp, address) + + # test get_dt_price() + tup = exchange.get_dt_price(exchangeId) + ( + baseTokenAmt_wei, + oceanFeeAmt_wei, + publishMktFeeAmt_wei, + consumeMktFeeAmt_wei, + ) = tup + + assert from_wei(baseTokenAmt_wei) == approx(3.603) + + assert from_wei(oceanFeeAmt_wei) == approx(0.003) + assert from_wei(publishMktFeeAmt_wei) == approx(0.6) + assert consumeMktFeeAmt_wei == 0 + + # test calcBaseInGivenOutDT() + tup2 = exchange.calcBaseInGivenOutDT(exchangeId, to_wei(1), 0) + assert tup == tup2 diff --git a/pdr_backend/models/test/test_predictoor_batcher.py b/pdr_backend/contract/test/test_predictoor_batcher.py similarity index 64% rename from pdr_backend/models/test/test_predictoor_batcher.py rename to pdr_backend/contract/test/test_predictoor_batcher.py index 3030a1ce1..1b501ef2d 100644 --- a/pdr_backend/models/test/test_predictoor_batcher.py +++ b/pdr_backend/contract/test/test_predictoor_batcher.py @@ -1,35 +1,35 @@ +from unittest.mock import Mock + +from enforce_typing import enforce_types from web3.types import RPCEndpoint -from pdr_backend.conftest_ganache import SECONDS_PER_EPOCH -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.predictoor_batcher import PredictoorBatcher -from pdr_backend.models.data_nft import DataNft -from pdr_backend.models.token import Token -from pdr_backend.util.web3_config import Web3Config - - -def test_submit_truevals( - predictoor_contract: PredictoorContract, - web3_config: Web3Config, - predictoor_batcher: PredictoorBatcher, -): + +from pdr_backend.conftest_ganache import S_PER_EPOCH +from pdr_backend.contract.data_nft import DataNft +from pdr_backend.contract.predictoor_batcher import mock_predictoor_batcher +from pdr_backend.ppss.web3_pp import Web3PP + + +@enforce_types +def test_submit_truevals(predictoor_contract, web3_pp, predictoor_batcher): + web3_config = web3_pp.web3_config current_epoch = predictoor_contract.get_current_epoch_ts() # fast forward time - predictoor_contract.config.w3.provider.make_request( - RPCEndpoint("evm_increaseTime"), [SECONDS_PER_EPOCH * 10] + web3_config.w3.provider.make_request( + RPCEndpoint("evm_increaseTime"), [S_PER_EPOCH * 10] ) - predictoor_contract.config.w3.provider.make_request(RPCEndpoint("evm_mine"), []) + web3_config.w3.provider.make_request(RPCEndpoint("evm_mine"), []) - end_epoch = current_epoch + SECONDS_PER_EPOCH * 10 + end_epoch = current_epoch + S_PER_EPOCH * 10 # get trueval for epochs - epochs = list(range(current_epoch, end_epoch, SECONDS_PER_EPOCH)) + epochs = list(range(current_epoch, end_epoch, S_PER_EPOCH)) truevals = [True] * len(epochs) cancels = [False] * len(epochs) # add predictoor helper as ercdeployer erc721addr = predictoor_contract.erc721_addr() - datanft = DataNft(web3_config, erc721addr) + datanft = DataNft(web3_pp, erc721addr) datanft.add_to_create_erc20_list(predictoor_batcher.contract_address) truevals_before = [ @@ -51,27 +51,27 @@ def test_submit_truevals( assert trueval is True +@enforce_types def test_submit_truevals_contracts( - predictoor_contract: PredictoorContract, - predictoor_contract2: PredictoorContract, - web3_config: Web3Config, - predictoor_batcher: PredictoorBatcher, + predictoor_contract, + predictoor_contract2, + web3_pp, + web3_config, + predictoor_batcher, ): current_epoch = predictoor_contract.get_current_epoch_ts() # fast forward time - predictoor_contract.config.w3.provider.make_request( - RPCEndpoint("evm_increaseTime"), [SECONDS_PER_EPOCH * 10] + web3_config.w3.provider.make_request( + RPCEndpoint("evm_increaseTime"), [S_PER_EPOCH * 10] ) - predictoor_contract.config.w3.provider.make_request(RPCEndpoint("evm_mine"), []) + web3_config.w3.provider.make_request(RPCEndpoint("evm_mine"), []) - end_epoch = current_epoch + SECONDS_PER_EPOCH * 10 + end_epoch = current_epoch + S_PER_EPOCH * 10 # get trueval for epochs - epochs1 = list(range(current_epoch, end_epoch, SECONDS_PER_EPOCH)) - epochs2 = list( - range(current_epoch + SECONDS_PER_EPOCH * 2, end_epoch, SECONDS_PER_EPOCH) - ) + epochs1 = list(range(current_epoch, end_epoch, S_PER_EPOCH)) + epochs2 = list(range(current_epoch + S_PER_EPOCH * 2, end_epoch, S_PER_EPOCH)) epochs = [epochs1, epochs2] truevals = [[True] * len(epochs1), [True] * len(epochs2)] cancels = [[False] * len(epochs1), [False] * len(epochs2)] @@ -82,10 +82,10 @@ def test_submit_truevals_contracts( # add predictoor helper as ercdeployer erc721addr = predictoor_contract.erc721_addr() - datanft = DataNft(web3_config, erc721addr) + datanft = DataNft(web3_pp, erc721addr) datanft.add_to_create_erc20_list(predictoor_batcher.contract_address) erc721addr = predictoor_contract2.erc721_addr() - datanft = DataNft(web3_config, erc721addr) + datanft = DataNft(web3_pp, erc721addr) datanft.add_to_create_erc20_list(predictoor_batcher.contract_address) truevals_before_1 = [ @@ -126,11 +126,8 @@ def test_submit_truevals_contracts( assert trueval is True -def test_consume_multiple( - predictoor_contract: PredictoorContract, - ocean_token: Token, - predictoor_batcher: PredictoorBatcher, -): +@enforce_types +def test_consume_multiple(predictoor_contract, ocean_token, predictoor_batcher): owner = ocean_token.config.owner price = predictoor_contract.get_price() @@ -148,3 +145,11 @@ def test_consume_multiple( balance_after = ocean_token.balanceOf(owner) assert balance_after + cost == balance_before + + +@enforce_types +def test_mock_predictoor_batcher(): + web3_pp = Mock(spec=Web3PP) + b = mock_predictoor_batcher(web3_pp) + assert id(b.web3_pp) == id(web3_pp) + assert b.contract_address == "0xPdrBatcherAddr" diff --git a/pdr_backend/contract/test/test_predictoor_contract.py b/pdr_backend/contract/test/test_predictoor_contract.py new file mode 100644 index 000000000..eed413bae --- /dev/null +++ b/pdr_backend/contract/test/test_predictoor_contract.py @@ -0,0 +1,173 @@ +from unittest.mock import Mock + +import pytest +from enforce_typing import enforce_types +from pytest import approx + +from pdr_backend.conftest_ganache import S_PER_EPOCH +from pdr_backend.contract.predictoor_contract import mock_predictoor_contract +from pdr_backend.contract.token import Token +from pdr_backend.util.contract import get_address +from pdr_backend.util.mathutil import from_wei, to_wei + + +@enforce_types +def test_get_id(predictoor_contract): + id_ = predictoor_contract.getid() + assert id_ == 3 + + +@enforce_types +def test_is_valid_subscription_initially(predictoor_contract): + is_valid_sub = predictoor_contract.is_valid_subscription() + assert not is_valid_sub + + +@enforce_types +def test_buy_and_start_subscription(predictoor_contract): + receipt = predictoor_contract.buy_and_start_subscription() + assert receipt["status"] == 1 + is_valid_sub = predictoor_contract.is_valid_subscription() + assert is_valid_sub + + +@enforce_types +def test_buy_and_start_subscription_empty(predictoor_contract_empty): + with pytest.raises(ValueError): + assert predictoor_contract_empty.buy_and_start_subscription() + + +@enforce_types +def test_buy_many(predictoor_contract): + receipts = predictoor_contract.buy_many(2, None, True) + assert len(receipts) == 2 + + assert predictoor_contract.buy_many(0, None, True) is None + + +@enforce_types +def test_get_exchanges(predictoor_contract): + exchanges = predictoor_contract.get_exchanges() + assert exchanges[0][0].startswith("0x") + + +@enforce_types +def test_get_stake_token(predictoor_contract, web3_pp): + stake_token = predictoor_contract.get_stake_token() + ocean_address = get_address(web3_pp, "Ocean") + assert stake_token == ocean_address + + +@enforce_types +def test_get_price(predictoor_contract): + price = predictoor_contract.get_price() + assert price / 1e18 == approx(3.603) + + +@enforce_types +def test_get_price_no_exchanges(predictoor_contract_empty): + predictoor_contract_empty.get_exchanges = Mock(return_value=[]) + with pytest.raises(ValueError): + predictoor_contract_empty.get_price() + + +@enforce_types +def test_get_current_epoch(predictoor_contract): + current_epoch = predictoor_contract.get_current_epoch() + now = predictoor_contract.config.get_block("latest").timestamp + assert current_epoch == int(now // S_PER_EPOCH) + + +@enforce_types +def test_get_current_epoch_ts(predictoor_contract): + current_epoch = predictoor_contract.get_current_epoch_ts() + now = predictoor_contract.config.get_block("latest").timestamp + assert current_epoch == int(now // S_PER_EPOCH) * S_PER_EPOCH + + +@enforce_types +def test_get_seconds_per_epoch(predictoor_contract): + seconds_per_epoch = predictoor_contract.get_secondsPerEpoch() + assert seconds_per_epoch == S_PER_EPOCH + + +@enforce_types +def test_get_aggpredval(predictoor_contract): + current_epoch = predictoor_contract.get_current_epoch_ts() + aggpredval = predictoor_contract.get_agg_predval(current_epoch) + assert aggpredval == (0, 0) + + +@enforce_types +def test_soonest_timestamp_to_predict(predictoor_contract): + current_epoch = predictoor_contract.get_current_epoch_ts() + soonest_timestamp = predictoor_contract.soonest_timestamp_to_predict(current_epoch) + assert soonest_timestamp == current_epoch + S_PER_EPOCH * 2 + + +@enforce_types +def test_get_trueValSubmitTimeout(predictoor_contract): + trueValSubmitTimeout = predictoor_contract.get_trueValSubmitTimeout() + assert trueValSubmitTimeout == 3 * 24 * 60 * 60 + + +@enforce_types +def test_submit_prediction_trueval_payout( + predictoor_contract, + ocean_token: Token, +): + OCEAN = ocean_token + w3 = predictoor_contract.config.w3 + owner_addr = predictoor_contract.config.owner + OCEAN_before = from_wei(OCEAN.balanceOf(owner_addr)) + cur_epoch = predictoor_contract.get_current_epoch_ts() + soonest_ts = predictoor_contract.soonest_timestamp_to_predict(cur_epoch) + predval = True + stake_amt = 1.0 + receipt = predictoor_contract.submit_prediction( + predval, + stake_amt, + soonest_ts, + wait_for_receipt=True, + ) + assert receipt["status"] == 1 + + OCEAN_after = from_wei(OCEAN.balanceOf(owner_addr)) + assert (OCEAN_before - OCEAN_after) == approx(stake_amt, 1e-8) + + pred_tup = predictoor_contract.get_prediction( + soonest_ts, + predictoor_contract.config.owner, + ) + assert pred_tup[0] == predval + assert pred_tup[1] == to_wei(stake_amt) + + w3.provider.make_request("evm_increaseTime", [S_PER_EPOCH * 2]) + w3.provider.make_request("evm_mine", []) + trueval = True + receipt = predictoor_contract.submit_trueval( + trueval, + soonest_ts, + cancel_round=False, + wait_for_receipt=True, + ) + assert receipt["status"] == 1 + + receipt = predictoor_contract.payout(soonest_ts, wait_for_receipt=True) + assert receipt["status"] == 1 + OCEAN_final = from_wei(OCEAN.balanceOf(owner_addr)) + assert OCEAN_before == approx(OCEAN_final, 2.0) # + sub revenue + + +@enforce_types +def test_redeem_unused_slot_revenue(predictoor_contract): + cur_epoch = predictoor_contract.get_current_epoch_ts() - S_PER_EPOCH * 123 + receipt = predictoor_contract.redeem_unused_slot_revenue(cur_epoch, True) + assert receipt["status"] == 1 + + +@enforce_types +def test_mock_predictoor_contract(): + c = mock_predictoor_contract("0x123", (3, 4)) + assert c.contract_address == "0x123" + assert c.get_agg_predval() == (3, 4) diff --git a/pdr_backend/models/test/test_slot.py b/pdr_backend/contract/test/test_slot.py similarity index 64% rename from pdr_backend/models/test/test_slot.py rename to pdr_backend/contract/test/test_slot.py index aafffa9bd..42963bb25 100644 --- a/pdr_backend/models/test/test_slot.py +++ b/pdr_backend/contract/test/test_slot.py @@ -1,13 +1,12 @@ -from pdr_backend.models.slot import Slot -from pdr_backend.models.feed import Feed +from pdr_backend.contract.slot import Slot +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed def test_slot_initialization(): - feed = Feed( + feed = SubgraphFeed( "Contract Name", "0x12345", "test", - 300, 60, 15, "0xowner", @@ -21,4 +20,4 @@ def test_slot_initialization(): assert slot.slot_number == slot_number assert slot.feed == feed - assert isinstance(slot.feed, Feed) + assert isinstance(slot.feed, SubgraphFeed) diff --git a/pdr_backend/models/test/test_token.py b/pdr_backend/contract/test/test_token.py similarity index 53% rename from pdr_backend/models/test/test_token.py rename to pdr_backend/contract/test/test_token.py index 376a7c078..bd541e0de 100644 --- a/pdr_backend/models/test/test_token.py +++ b/pdr_backend/contract/test/test_token.py @@ -1,23 +1,23 @@ import time +from unittest.mock import patch from enforce_typing import enforce_types +from pdr_backend.contract.token import NativeToken, Token from pdr_backend.util.contract import get_address -from pdr_backend.models.token import Token @enforce_types -def test_Token(web3_config, chain_id): - token_address = get_address(chain_id, "Ocean") - token = Token(web3_config, token_address) +def test_token(web3_pp, web3_config): + token_address = get_address(web3_pp, "Ocean") + token = Token(web3_pp, token_address) accounts = web3_config.w3.eth.accounts owner_addr = web3_config.owner alice = accounts[1] - token.contract_instance.functions.mint(owner_addr, 1000000000).transact( - {"from": owner_addr, "gasPrice": web3_config.w3.eth.gas_price} - ) + call_params = web3_pp.tx_call_params() + token.contract_instance.functions.mint(owner_addr, 1000000000).transact(call_params) allowance_start = token.allowance(owner_addr, alice) token.approve(alice, allowance_start + 100, True) @@ -29,3 +29,16 @@ def test_Token(web3_config, chain_id): token.transfer(alice, 100, owner_addr) balance_end = token.balanceOf(alice) assert balance_end - balance_start == 100 + + +@enforce_types +def test_native_token(web3_pp): + token = NativeToken(web3_pp) + assert token.w3 + + owner = web3_pp.web3_config.owner + assert token.balanceOf(owner) + + with patch("web3.eth.Eth.send_transaction") as mock: + token.transfer(owner, 100, "0x123", False) + assert mock.called diff --git a/pdr_backend/contract/test/test_wrapped_token.py b/pdr_backend/contract/test/test_wrapped_token.py new file mode 100644 index 000000000..0d8b94379 --- /dev/null +++ b/pdr_backend/contract/test/test_wrapped_token.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.contract.wrapped_token import WrappedToken +from pdr_backend.util.contract import get_address + + +@enforce_types +def test_native_token(web3_pp): + token_address = get_address(web3_pp, "Ocean") + mock_wrapped_contract = Mock() + mock_transaction = Mock() + mock_transaction.transact.return_value = "mock_tx" + mock_wrapped_contract.functions.withdraw.return_value = mock_transaction + + with patch("web3.eth.Eth.contract") as mock: + mock.return_value = mock_wrapped_contract + token = WrappedToken(web3_pp, token_address) + + assert token.withdraw(100, False) == "mock_tx" diff --git a/pdr_backend/models/token.py b/pdr_backend/contract/token.py similarity index 60% rename from pdr_backend/models/token.py rename to pdr_backend/contract/token.py index c10cd9e3f..acec042cf 100644 --- a/pdr_backend/models/token.py +++ b/pdr_backend/contract/token.py @@ -1,14 +1,13 @@ from enforce_typing import enforce_types from web3.types import TxParams, Wei -from pdr_backend.models.base_contract import BaseContract -from pdr_backend.util.web3_config import Web3Config +from pdr_backend.contract.base_contract import BaseContract @enforce_types class Token(BaseContract): - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address, "ERC20Template3") + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address, "ERC20Template3") def allowance(self, account, spender): return self.contract_instance.functions.allowance(account, spender).call() @@ -17,44 +16,55 @@ def balanceOf(self, account): return self.contract_instance.functions.balanceOf(account).call() def transfer(self, to: str, amount: int, sender, wait_for_receipt=True): - gasPrice = self.config.w3.eth.gas_price + gas_price = self.web3_pp.tx_gas_price() + call_params = {"from": sender, "gasPrice": gas_price} tx = self.contract_instance.functions.transfer(to, int(amount)).transact( - {"from": sender, "gasPrice": gasPrice} + call_params ) if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) def approve(self, spender, amount, wait_for_receipt=True): - gasPrice = self.config.w3.eth.gas_price + call_params = self.web3_pp.tx_call_params() # print(f"Approving {amount} for {spender} on contract {self.contract_address}") tx = self.contract_instance.functions.approve(spender, amount).transact( - {"from": self.config.owner, "gasPrice": gasPrice} + call_params ) + if not wait_for_receipt: return tx + return self.config.w3.eth.wait_for_transaction_receipt(tx) class NativeToken: - def __init__(self, config: Web3Config): - self.config = config + @enforce_types + def __init__(self, web3_pp): + self.web3_pp = web3_pp + @property + def w3(self): + return self.web3_pp.web3_config.w3 + + @enforce_types def balanceOf(self, account): - return self.config.w3.eth.get_balance(account) + return self.w3.eth.get_balance(account) + @enforce_types def transfer(self, to: str, amount: int, sender, wait_for_receipt=True): - gasPrice = self.config.w3.eth.gas_price - params: TxParams = { + gas_price = self.web3_pp.tx_gas_price() + call_params: TxParams = { "from": sender, "gas": 25000, "value": Wei(amount), - "gasPrice": Wei(gasPrice), + "gasPrice": gas_price, "to": to, } - tx = self.config.w3.eth.send_transaction(transaction=params) + tx = self.w3.eth.send_transaction(transaction=call_params) if not wait_for_receipt: return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) + return self.w3.eth.wait_for_transaction_receipt(tx) diff --git a/pdr_backend/models/wrapped_token.py b/pdr_backend/contract/wrapped_token.py similarity index 67% rename from pdr_backend/models/wrapped_token.py rename to pdr_backend/contract/wrapped_token.py index af548ef0e..c06b62a7a 100644 --- a/pdr_backend/models/wrapped_token.py +++ b/pdr_backend/contract/wrapped_token.py @@ -1,10 +1,9 @@ -from pdr_backend.models.token import Token -from pdr_backend.util.web3_config import Web3Config +from pdr_backend.contract.token import Token class WrappedToken(Token): - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address) + def __init__(self, web3_pp, address: str): + super().__init__(web3_pp, address) abi = [ { "constant": False, @@ -16,7 +15,7 @@ def __init__(self, config: Web3Config, address: str): "type": "function", }, ] - self.contract_instance_wrapped = config.w3.eth.contract( + self.contract_instance_wrapped = self.config.w3.eth.contract( address=self.contract_address, abi=abi ) @@ -24,10 +23,9 @@ def withdraw(self, amount: int, wait_for_receipt=True): """ Converts Wrapped Token to Token, amount is in wei. """ - gas_price = self.config.w3.eth.gas_price - + call_params = self.web3_pp.tx_call_params() tx = self.contract_instance_wrapped.functions.withdraw(amount).transact( - {"from": self.config.owner, "gasPrice": gas_price} + call_params ) if not wait_for_receipt: return tx diff --git a/pdr_backend/data_eng/data_factory.py b/pdr_backend/data_eng/data_factory.py deleted file mode 100644 index e90af4519..000000000 --- a/pdr_backend/data_eng/data_factory.py +++ /dev/null @@ -1,310 +0,0 @@ -import os -import sys -from typing import Dict - -from enforce_typing import enforce_types -import numpy as np -import pandas as pd - -from pdr_backend.data_eng.constants import ( - OHLCV_COLS, - TOHLCV_COLS, - OHLCV_MULT_MIN, - OHLCV_MULT_MAX, -) -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.data_eng.pdutil import ( - initialize_df, - concat_next_df, - save_csv, - load_csv, - has_data, - oldest_ut, - newest_ut, -) -from pdr_backend.util.mathutil import has_nan, fill_nans -from pdr_backend.util.timeutil import pretty_timestr, current_ut - - -@enforce_types -class DataFactory: - def __init__(self, pp: DataPP, ss: DataSS): - self.pp = pp - self.ss = ss - - def get_hist_df(self) -> pd.DataFrame: - """ - @description - Get historical dataframe, across many exchanges & pairs. - - @return - hist_df -- df w/ cols={exchange_str}:{pair_str}:{signal}+"datetime", - and index=timestamp - """ - print("Get historical data, across many exchanges & pairs: begin.") - - # Ss_timestamp is calculated dynamically if ss.fin_timestr = "now". - # But, we don't want fin_timestamp changing as we gather data here. - # To solve, for a given call to this method, we make a constant fin_ut - fin_ut = self.ss.fin_timestamp - - print(f" Data start: {pretty_timestr(self.ss.st_timestamp)}") - print(f" Data fin: {pretty_timestr(fin_ut)}") - - self._update_csvs(fin_ut) - csv_dfs = self._load_csvs(fin_ut) - hist_df = self._merge_csv_dfs(csv_dfs) - - print("Get historical data, across many exchanges & pairs: done.") - return hist_df - - def _update_csvs(self, fin_ut: int): - print(" Update csvs.") - for exch_str, pair_str in self.ss.exchange_pair_tups: - self._update_hist_csv_at_exch_and_pair(exch_str, pair_str, fin_ut) - - def _update_hist_csv_at_exch_and_pair( - self, exch_str: str, pair_str: str, fin_ut: int - ): - pair_str = pair_str.replace("/", "-") - print(f" Update csv at exchange={exch_str}, pair={pair_str}.") - - filename = self._hist_csv_filename(exch_str, pair_str) - print(f" filename={filename}") - - st_ut = self._calc_start_ut_maybe_delete(filename) - print(f" Aim to fetch data from start time: {pretty_timestr(st_ut)}") - if st_ut > min(current_ut(), fin_ut): - print(" Given start time, no data to gather. Exit.") - return - - # Fill in df - df = initialize_df(OHLCV_COLS) - while True: - print(f" Fetch 1000 pts from {pretty_timestr(st_ut)}") - - exch = self.ss.exchs_dict[exch_str] - - # C is [sample x signal(TOHLCV)]. Row 0 is oldest - # TOHLCV = unixTime (in ms), Open, High, Low, Close, Volume - raw_tohlcv_data = exch.fetch_ohlcv( - symbol=pair_str.replace("-", "/"), # eg "BTC/USDT" - timeframe=self.pp.timeframe, # eg "5m", "1h" - since=st_ut, # timestamp of first candle - limit=1000, # max # candles to retrieve - ) - uts = [vec[0] for vec in raw_tohlcv_data] - if len(uts) > 1: - # Ideally, time between ohclv candles is always 5m or 1h - # But exchange data often has gaps. Warn about worst violations - diffs_ms = np.array(uts[1:]) - np.array(uts[:-1]) # in ms - diffs_m = diffs_ms / 1000 / 60 # in minutes - mn_thr = self.pp.timeframe_m * OHLCV_MULT_MIN - mx_thr = self.pp.timeframe_m * OHLCV_MULT_MAX - - if min(diffs_m) < mn_thr: - print(f" **WARNING: short candle time: {min(diffs_m)} min") - if max(diffs_m) > mx_thr: - print(f" **WARNING: long candle time: {max(diffs_m)} min") - - raw_tohlcv_data = [vec for vec in raw_tohlcv_data if vec[0] <= fin_ut] - next_df = pd.DataFrame(raw_tohlcv_data, columns=TOHLCV_COLS) - df = concat_next_df(df, next_df) - - if len(raw_tohlcv_data) < 1000: # no more data, we're at newest time - break - - # prep next iteration - newest_ut_value = int(df.index.values[-1]) - st_ut = newest_ut_value + self.pp.timeframe_ms - - # output to csv - save_csv(filename, df) - - def _calc_start_ut_maybe_delete(self, filename: str) -> int: - """ - @description - Calculate start timestamp, reconciling whether file exists and where - its data starts. Will delete file if it's inconvenient to re-use - - @arguments - filename - csv file with data. May or may not exist. - - @return - start_ut - timestamp (ut) to start grabbing data for - """ - if not os.path.exists(filename): - print(" No file exists yet, so will fetch all data") - return self.ss.st_timestamp - - print(" File already exists") - if not has_data(filename): - print(" File has no data, so delete it") - os.remove(filename) - return self.ss.st_timestamp - - file_ut0, file_utN = oldest_ut(filename), newest_ut(filename) - print(f" File starts at: {pretty_timestr(file_ut0)}") - print(f" File finishes at: {pretty_timestr(file_utN)}") - - if self.ss.st_timestamp >= file_ut0: - print(" User-specified start >= file start, so append file") - return file_utN + self.pp.timeframe_ms - - print(" User-specified start < file start, so delete file") - os.remove(filename) - return self.ss.st_timestamp - - def _load_csvs(self, fin_ut: int) -> Dict[str, Dict[str, pd.DataFrame]]: - """ - @arguments - fin_ut -- finish timestamp - - @return - csv_dfs -- dict of [exch_str][pair_str] : df - Where df has columns=OHLCV_COLS+"datetime", and index=timestamp - """ - print(" Load csvs.") - st_ut = self.ss.st_timestamp - - csv_dfs: Dict[str, Dict[str, pd.DataFrame]] = {} # [exch][pair] : df - for exch_str in self.ss.exchange_strs: - csv_dfs[exch_str] = {} - - for exch_str, pair_str in self.ss.exchange_pair_tups: - print(f"Load csv from exchange={exch_str}, pair={pair_str}") - filename = self._hist_csv_filename(exch_str, pair_str) - cols = [ - signal_str # cols is a subset of TOHLCV_COLS - for e, signal_str, p in self.ss.input_feed_tups - if e == exch_str and p == pair_str - ] - csv_df = load_csv(filename, cols, st_ut, fin_ut) - assert "datetime" in csv_df.columns - assert csv_df.index.name == "timestamp" - csv_dfs[exch_str][pair_str] = csv_df - - return csv_dfs - - def _merge_csv_dfs(self, csv_dfs: dict) -> pd.DataFrame: - """ - @arguments - csv_dfs -- dict [exch_str][pair_str] : df - where df has cols={signal_str}+"datetime", and index=timestamp - @return - hist_df -- df w/ cols={exch_str}:{pair_str}:{signal_str}+"datetime", - and index=timestamp - """ - print(" Merge csv DFs.") - hist_df = pd.DataFrame() - for exch_str in csv_dfs.keys(): - for pair_str, csv_df in csv_dfs[exch_str].items(): - assert "-" in pair_str, pair_str - assert "datetime" in csv_df.columns - assert csv_df.index.name == "timestamp" - - for csv_col in csv_df.columns: - if csv_col == "datetime": - if "datetime" in hist_df.columns: - continue - hist_col = csv_col - else: - signal_str = csv_col # eg "close" - hist_col = f"{exch_str}:{pair_str}:{signal_str}" - hist_df[hist_col] = csv_df[csv_col] - - assert "datetime" in hist_df.columns - assert hist_df.index.name == "timestamp" - return hist_df - - def create_xy( - self, - hist_df: pd.DataFrame, - testshift: int, - do_fill_nans: bool = True, - ): - """ - @arguments - hist_df -- df w cols={exch_str}:{pair_str}:{signal_str}+"datetime", - and index=timestamp - testshift -- to simulate across historical test data - do_fill_nans -- if any values are nan, fill them? (Via interpolation) - If you turn this off and hist_df has nans, then X/y/etc gets nans - - @return -- - X -- 2d array of [sample_i, var_i] : value - y -- 1d array of [sample_i] - x_df -- df w/ cols={exch_str}:{pair_str}:{signal}:t-{x} + "datetime" - index=0,1,.. (nothing special) - """ - if do_fill_nans and has_nan(hist_df): - hist_df = fill_nans(hist_df) - - ss = self.ss - x_df = pd.DataFrame() - - target_hist_cols = [ - f"{exch_str}:{pair_str}:{signal_str}" - for exch_str, signal_str, pair_str in ss.input_feed_tups - ] - - for hist_col in target_hist_cols: - assert hist_col in hist_df.columns, "missing a data col" - z = hist_df[hist_col].tolist() # [..., z(t-3), z(t-2), z(t-1)] - maxshift = testshift + ss.autoregressive_n - N_train = min(ss.max_n_train, len(z) - maxshift - 1) - if N_train <= 0: - print( - f"Too little data. len(z)={len(z)}, maxshift={maxshift}" - " (= testshift + autoregressive_n = " - f"{testshift} + {ss.autoregressive_n})\n" - "To fix: broaden time, shrink testshift, " - "or shrink autoregressive_n" - ) - sys.exit(1) - for delayshift in range(ss.autoregressive_n, 0, -1): # eg [2, 1, 0] - shift = testshift + delayshift - x_col = hist_col + f":t-{delayshift+1}" - assert (shift + N_train + 1) <= len(z) - # 1 point for test, the rest for train data - x_df[x_col] = _slice(z, -shift - N_train - 1, -shift) - - X = x_df.to_numpy() - - # y is set from yval_{exch_str, signal_str, pair_str} - # eg y = [BinEthC_-1, BinEthC_-2, ..., BinEthC_-450, BinEthC_-451] - pp = self.pp - hist_col = f"{pp.exchange_str}:{pp.pair_str}:{pp.signal_str}" - z = hist_df[hist_col].tolist() - y = np.array(_slice(z, -testshift - N_train - 1, -testshift)) - - # postconditions - assert X.shape[0] == y.shape[0] - assert X.shape[0] <= (ss.max_n_train + 1) - assert X.shape[1] == ss.n - - # return - return X, y, x_df - - def _hist_csv_filename(self, exch_str, pair_str) -> str: - """ - Given exch_str and pair_str (and self path), - compute csv filename - """ - pair_str = pair_str.replace("/", "-") - basename = f"{exch_str}_{pair_str}_{self.pp.timeframe}.csv" - filename = os.path.join(self.ss.csv_dir, basename) - return filename - - -@enforce_types -def _slice(x: list, st: int, fin: int) -> list: - """Python list slice returns an empty list on x[st:fin] if st<0 and fin=0 - This overcomes that issue, for cases when st<0""" - assert st < 0 - assert fin <= 0 - assert st < fin - if fin == 0: - return x[st:] - return x[st:fin] diff --git a/pdr_backend/data_eng/data_pp.py b/pdr_backend/data_eng/data_pp.py deleted file mode 100644 index 63faff2c6..000000000 --- a/pdr_backend/data_eng/data_pp.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Tuple - -from enforce_typing import enforce_types -import numpy as np - -from pdr_backend.util.constants import CAND_TIMEFRAMES -from pdr_backend.util.feedstr import unpack_feed_str, verify_feed_str -from pdr_backend.util.pairstr import unpack_pair_str - - -class DataPP: # user-uncontrollable params, at data-eng level - """ - DataPP specifies the output variable (yval), ie what to predict. - - DataPP is problem definition -> uncontrollable. - DataSS is solution strategy -> controllable. - For a given problem definition (DataPP), you can try different DataSS vals - """ - - # pylint: disable=too-many-instance-attributes - @enforce_types - def __init__( - self, - timeframe: str, # eg "1m", "1h" - predict_feed_str: str, # eg "binance c BTC/USDT", "kraken h BTC/USDT" - N_test: int, # eg 100. num pts to test on, 1 at a time (online) - ): - # preconditions - assert timeframe in CAND_TIMEFRAMES - verify_feed_str(predict_feed_str) - assert 0 < N_test < np.inf - - # save values - self.timeframe = timeframe - self.predict_feed_str = predict_feed_str - self.N_test = N_test - - @property - def timeframe_ms(self) -> int: - """Returns timeframe, in ms""" - return self.timeframe_m * 60 * 1000 - - @property - def timeframe_m(self) -> int: - """Returns timeframe, in minutes""" - if self.timeframe == "5m": - return 5 - if self.timeframe == "1h": - return 60 - raise ValueError("need to support timeframe={self.timeframe}") - - @property - def predict_feed_tup(self) -> Tuple[str, str, str]: - """ - Return (exchange_str, signal_str, pair_str) - E.g. ("binance", "close", "BTC/USDT") - """ - return unpack_feed_str(self.predict_feed_str) - - @property - def exchange_str(self) -> str: - """Return e.g. 'binance'""" - return self.predict_feed_tup[0] - - @property - def signal_str(self) -> str: - """Return e.g. 'high'""" - return self.predict_feed_tup[1] - - @property - def pair_str(self) -> str: - """Return e.g. 'ETH/USDT'""" - return self.predict_feed_tup[2] - - @property - def base_str(self) -> str: - """Return e.g. 'ETH'""" - return unpack_pair_str(self.predict_feed_tup[2])[0] - - @property - def quote_str(self) -> str: - """Return e.g. 'USDT'""" - return unpack_pair_str(self.predict_feed_tup[2])[1] - - @enforce_types - def __str__(self) -> str: - s = "DataPP={\n" - - s += f" timeframe={self.timeframe}\n" - s += f" predict_feed_str={self.predict_feed_str}\n" - s += f" N_test={self.N_test} -- # pts to test on, 1 at a time\n" - - s += "/DataPP}\n" - return s diff --git a/pdr_backend/data_eng/data_ss.py b/pdr_backend/data_eng/data_ss.py deleted file mode 100644 index aeb7dbe0b..000000000 --- a/pdr_backend/data_eng/data_ss.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -from typing import List, Set, Tuple - -import ccxt -from enforce_typing import enforce_types -import numpy as np - -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.util.feedstr import unpack_feeds_strs, verify_feeds_strs -from pdr_backend.util.timeutil import pretty_timestr, timestr_to_ut - - -class DataSS: # user-controllable params, at data-eng level - """ - DataPP specifies the output variable (yval), ie what to predict. - - DataPP is problem definition -> uncontrollable. - DataSS is solution strategy -> controllable. - For a given problem definition (DataPP), you can try different DataSS vals - - DataSS specifies the inputs, and how much training data to get - - Input vars: autoregressive_n vars for each of {all signals}x{all coins}x{all exch} - - How much trn data: time range st->fin_timestamp, bound by max_N_trn - """ - - # pylint: disable=too-many-instance-attributes - @enforce_types - def __init__( - self, - input_feeds_strs: List[str], # eg ["binance ohlcv BTC/USDT", " ", ...] - csv_dir: str, # eg "csvs". abs or rel loc'n of csvs dir - st_timestr: str, # eg "2019-09-13_04:00" (earliest), 2019-09-13" - fin_timestr: str, # eg "now", "2023-09-23_17:55", "2023-09-23" - max_n_train, # eg 50000. if inf, only limited by data available - autoregressive_n: int, # eg 10. model inputs ar_n past pts z[t-1], .., z[t-ar_n] - ): - # preconditions - if not os.path.exists(csv_dir): - print(f"Could not find csv dir, creating one at: {csv_dir}") - os.makedirs(csv_dir) - assert 0 <= timestr_to_ut(st_timestr) <= timestr_to_ut(fin_timestr) <= np.inf - assert 0 < max_n_train - assert 0 < autoregressive_n < np.inf - verify_feeds_strs(input_feeds_strs) - - # save values - self.input_feeds_strs: List[str] = input_feeds_strs - - self.csv_dir: str = csv_dir - self.st_timestr: str = st_timestr - self.fin_timestr: str = fin_timestr - - self.max_n_train: int = max_n_train - self.autoregressive_n: int = autoregressive_n - - self.exchs_dict: dict = {} # e.g. {"binance" : ccxt.binance()} - feed_tups = unpack_feeds_strs(input_feeds_strs) - for exchange_str, _, _ in feed_tups: - exchange_class = getattr(ccxt, exchange_str) - self.exchs_dict[exchange_str] = exchange_class() - - @property - def st_timestamp(self) -> int: - """ - Return start timestamp, in ut. - Calculated from self.st_timestr. - """ - return timestr_to_ut(self.st_timestr) - - @property - def fin_timestamp(self) -> int: - """ - Return fin timestamp, in ut. - Calculated from self.fin_timestr. - - ** This value will change dynamically if fin_timestr is "now". - """ - return timestr_to_ut(self.fin_timestr) - - @property - def n(self) -> int: - """Number of input dimensions == # columns in X""" - return self.n_input_feeds * self.autoregressive_n - - @property - def n_exchs(self) -> int: - return len(self.exchs_dict) - - @property - def exchange_strs(self) -> List[str]: - return sorted(self.exchs_dict.keys()) - - @property - def n_input_feeds(self) -> int: - return len(self.input_feed_tups) - - @property - def input_feed_tups(self) -> List[Tuple[str, str, str]]: - """Return list of (exchange_str, signal_str, pair_str)""" - return unpack_feeds_strs(self.input_feeds_strs) - - @property - def exchange_pair_tups(self) -> Set[Tuple[str, str]]: - """Return set of unique (exchange_str, pair_str) tuples""" - return set( - (exch_str, pair_str) for (exch_str, _, pair_str) in self.input_feed_tups - ) - - @enforce_types - def __str__(self) -> str: - s = "DataSS={\n" - - s += f" input_feeds_strs={self.input_feeds_strs}" - s += f" -> n_inputfeeds={self.n_input_feeds}" - s += " \n" - - s += f" csv_dir={self.csv_dir}\n" - s += f" st_timestr={self.st_timestr}\n" - s += f" -> st_timestamp={pretty_timestr(self.st_timestamp)}\n" - s += f" fin_timestr={self.fin_timestr}\n" - s += f" -> fin_timestamp={pretty_timestr(self.fin_timestamp)}\n" - s += " \n" - - s += f" max_n_train={self.max_n_train} -- max # pts to train on\n" - s += " \n" - - s += f" autoregressive_n={self.autoregressive_n}" - s += " -- model inputs ar_n past pts z[t-1], .., z[t-ar_n]\n" - s += " \n" - - s += f" -> n_input_feeds * ar_n = n = {self.n}" - s += "-- # input variables to model\n" - s += " \n" - - s += f" exchs_dict={self.exchs_dict}\n" - s += f" -> n_exchs={self.n_exchs}\n" - s += f" -> exchange_strs={self.exchange_strs}\n" - s += " \n" - - s += "/DataSS}\n" - return s - - @enforce_types - def copy_with_yval(self, data_pp: DataPP): - """Copy self, add data_pp's yval to new data_ss' inputs as needed""" - input_feeds_strs = self.input_feeds_strs[:] - if data_pp.predict_feed_tup not in self.input_feed_tups: - input_feeds_strs.append(data_pp.predict_feed_str) - - return DataSS( - input_feeds_strs=input_feeds_strs, - csv_dir=self.csv_dir, - st_timestr=self.st_timestr, - fin_timestr=self.fin_timestr, - max_n_train=self.max_n_train, - autoregressive_n=self.autoregressive_n, - ) diff --git a/pdr_backend/data_eng/pdutil.py b/pdr_backend/data_eng/pdutil.py deleted file mode 100644 index 5a864b4df..000000000 --- a/pdr_backend/data_eng/pdutil.py +++ /dev/null @@ -1,187 +0,0 @@ -""" -pdutil: pandas dataframe & cvs utilities. -These utilities are specific to the time-series dataframe columns we're using. -""" -import os -from typing import List - -from enforce_typing import enforce_types -import numpy as np -import pandas as pd - -from pdr_backend.data_eng.constants import ( - OHLCV_COLS, - TOHLCV_COLS, - TOHLCV_DTYPES, -) - - -@enforce_types -def initialize_df(cols: List[str]) -> pd.DataFrame: - """Start a new df, with the expected columns, index, and dtypes - It's ok whether cols has "timestamp" or not. Same for "datetime". - The return df has "timestamp" for index and "datetime" as col - """ - dtypes = { - col: pd.Series(dtype=dtype) - for col, dtype in zip(TOHLCV_COLS, TOHLCV_DTYPES) - if col in cols or col == "timestamp" - } - df = pd.DataFrame(dtypes) - df = df.set_index("timestamp") - # pylint: disable=unsupported-assignment-operation - df["datetime"] = pd.to_datetime(df.index.values, unit="ms", utc=True) - - return df - - -@enforce_types -def concat_next_df(df: pd.DataFrame, next_df: pd.DataFrame) -> pd.DataFrame: - """Add a next df to existing df, with the expected columns etc. - The existing df *should* have the 'datetime' col, and next_df should *not*. - """ - assert "datetime" in df.columns - assert "datetime" not in next_df.columns - next_df = next_df.set_index("timestamp") - next_df["datetime"] = pd.to_datetime(next_df.index.values, unit="ms", utc=True) - df = pd.concat([df, next_df]) - return df - - -@enforce_types -def save_csv(filename: str, df: pd.DataFrame): - """Append to csv file if it exists, otherwise create new one. - With header=True and index=True, it will set the index_col too - """ - # preconditions - assert df.columns.tolist() == OHLCV_COLS + ["datetime"] - - # csv column order: timestamp (index), datetime, O, H, L, C, V - columns = ["datetime"] + OHLCV_COLS - - if os.path.exists(filename): # append existing file - df.to_csv(filename, mode="a", header=False, index=True, columns=columns) - print(f" Just appended {df.shape[0]} df rows to file {filename}") - else: # write new file - df.to_csv(filename, mode="w", header=True, index=True, columns=columns) - print(f" Just saved df with {df.shape[0]} rows to new file {filename}") - - -@enforce_types -def load_csv(filename: str, cols=None, st=None, fin=None) -> pd.DataFrame: - """Load csv file as a dataframe. - - Features: - - Ensure that all dtypes are correct - - Filter to just the input columns - - Filter to just the specified start & end times - - Memory stays reasonable - - @arguments - cols -- what columns to use, eg ["open","high"]. Set to None for all cols. - st -- starting timestamp, in ut. Set to 0 or None for very beginning - fin -- ending timestamp, in ut. Set to inf or None for very end - - @return - df -- dataframe - - @notes - Don't specify "timestamp" as a column because it's the df *index* - Don't specify "datetime" as a column, as that'll get calc'd from timestamp - """ - if cols is None: - cols = OHLCV_COLS - assert "timestamp" not in cols - assert "datetime" not in cols - cols = ["timestamp"] + cols - - # set skiprows, nrows - if st in [0, None] and fin in [np.inf, None]: - skiprows, nrows = None, None - else: - df0 = pd.read_csv(filename, usecols=["timestamp"]) - timestamps = df0["timestamp"].tolist() - skiprows = [ - i + 1 for i, timestamp in enumerate(timestamps) if timestamp < st - ] # "+1" to account for header - if skiprows == []: - skiprows = None - nrows = sum( - 1 for row, timestamp in enumerate(timestamps) if st <= timestamp <= fin - ) - - # set dtypes - cand_dtypes = dict(zip(TOHLCV_COLS, TOHLCV_DTYPES)) - dtypes = {col: cand_dtypes[col] for col in cols} - - # load - df = pd.read_csv( - filename, - dtype=dtypes, - usecols=cols, - skiprows=skiprows, - nrows=nrows, - ) - - # add in datetime column - df0 = initialize_df(cols) - df = concat_next_df(df0, df) - - # postconditions, return - assert "timestamp" not in df.columns - assert df.index.name == "timestamp" and df.index.dtype == np.int64 - assert "datetime" in df.columns - return df - - -@enforce_types -def has_data(filename: str) -> bool: - """Returns True if the file has >0 data entries""" - with open(filename) as f: - for i, _ in enumerate(f): - if i >= 1: - return True - return False - - -@enforce_types -def oldest_ut(filename: str) -> int: - """ - Return the timestamp for the oldest entry in the file. - Assumes the oldest entry is the second line in the file. - (First line is header) - """ - line = _get_second_line(filename) - ut = int(line.split(",")[0]) - return ut - - -@enforce_types -def _get_second_line(filename) -> str: - """Returns the last line in a file, as a string""" - with open(filename) as f: - for i, line in enumerate(f): - if i == 1: - return line - raise ValueError(f"File {filename} has no entries") - - -@enforce_types -def newest_ut(filename: str) -> int: - """ - Return the timestamp for the youngest entry in the file. - Assumes the youngest entry is the very last line in the file. - """ - line = _get_last_line(filename) - ut = int(line.split(",")[0]) - return ut - - -@enforce_types -def _get_last_line(filename: str) -> str: - """Returns the last line in a file, as a string""" - line = None - with open(filename) as f: - for line in f: - pass - return line if line is not None else "" diff --git a/pdr_backend/data_eng/test/test_data_factory.py b/pdr_backend/data_eng/test/test_data_factory.py deleted file mode 100644 index 9da4620d1..000000000 --- a/pdr_backend/data_eng/test/test_data_factory.py +++ /dev/null @@ -1,396 +0,0 @@ -import copy -from typing import List, Tuple - -from enforce_typing import enforce_types -import numpy as np -import pandas as pd - -from pdr_backend.data_eng.constants import TOHLCV_COLS -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.data_eng.data_factory import DataFactory -from pdr_backend.data_eng.pdutil import initialize_df, concat_next_df, load_csv -from pdr_backend.util.mathutil import has_nan, fill_nans -from pdr_backend.util.timeutil import current_ut, ut_to_timestr - -MS_PER_5M_EPOCH = 300000 - -# ==================================================================== -# test csv updating - - -def test_update_csv1(tmpdir): - _test_update_csv("2023-01-01_0:00", "2023-01-01_0:00", tmpdir, n_uts=1) - - -def test_update_csv2(tmpdir): - _test_update_csv("2023-01-01_0:00", "2023-01-01_0:05", tmpdir, n_uts=2) - - -def test_update_csv3(tmpdir): - _test_update_csv("2023-01-01_0:00", "2023-01-01_0:10", tmpdir, n_uts=3) - - -def test_update_csv4(tmpdir): - _test_update_csv("2023-01-01_0:00", "2023-01-01_0:45", tmpdir, n_uts=10) - - -def test_update_csv5(tmpdir): - _test_update_csv("2023-01-01", "2023-06-21", tmpdir, n_uts=">1K") - - -@enforce_types -def _test_update_csv(st_timestr: str, fin_timestr: str, tmpdir, n_uts): - """n_uts -- expected # timestamps. Typically int. If '>1K', expect >1000""" - - # setup: base data - csvdir = str(tmpdir) - - # setup: uts helpers - def _calc_ut(since: int, i: int) -> int: - return since + i * MS_PER_5M_EPOCH - - def _uts_in_range(st_ut: int, fin_ut: int) -> List[int]: - return [ - _calc_ut(st_ut, i) - for i in range(100000) # assume <=100K epochs - if _calc_ut(st_ut, i) <= fin_ut - ] - - def _uts_from_since(cur_ut: int, since_ut: int, limit_N: int) -> List[int]: - return [ - _calc_ut(since_ut, i) - for i in range(limit_N) - if _calc_ut(since_ut, i) <= cur_ut - ] - - # setup: exchange - class FakeExchange: - def __init__(self): - self.cur_ut: int = current_ut() # fixed value, for easier testing - - # pylint: disable=unused-argument - def fetch_ohlcv(self, since, limit, *args, **kwargs) -> list: - uts: List[int] = _uts_from_since(self.cur_ut, since, limit) - return [[ut] + [1.0] * 5 for ut in uts] # 1.0 for open, high, .. - - exchange = FakeExchange() - - # setup: pp - pp = DataPP( # user-uncontrollable params - "5m", - "binanceus h ETH/USDT", - N_test=2, - ) - - # setup: ss - ss = DataSS( # user-controllable params - ["binanceus h ETH/USDT"], - csv_dir=csvdir, - st_timestr=st_timestr, - fin_timestr=fin_timestr, - max_n_train=7, - autoregressive_n=3, - ) - ss.exchs_dict["binanceus"] = exchange # override with fake exchange - - # setup: data_factory, filename - data_factory = DataFactory(pp, ss) - filename = data_factory._hist_csv_filename("binanceus", "ETH/USDT") - - def _uts_in_csv(filename: str) -> List[int]: - df = load_csv(filename) - return df.index.values.tolist() - - # work 1: new csv - data_factory._update_hist_csv_at_exch_and_pair( - "binanceus", "ETH/USDT", ss.fin_timestamp - ) - uts: List[int] = _uts_in_csv(filename) - if isinstance(n_uts, int): - assert len(uts) == n_uts - elif n_uts == ">1K": - assert len(uts) > 1000 - assert sorted(uts) == uts - assert uts[0] == ss.st_timestamp - assert uts[-1] == ss.fin_timestamp - assert uts == _uts_in_range(ss.st_timestamp, ss.fin_timestamp) - - # work 2: two more epochs at end --> it'll append existing csv - ss.fin_timestr = ut_to_timestr(ss.fin_timestamp + 2 * MS_PER_5M_EPOCH) - data_factory._update_hist_csv_at_exch_and_pair( - "binanceus", "ETH/USDT", ss.fin_timestamp - ) - uts2 = _uts_in_csv(filename) - assert uts2 == _uts_in_range(ss.st_timestamp, ss.fin_timestamp) - - # work 3: two more epochs at beginning *and* end --> it'll create new csv - ss.st_timestr = ut_to_timestr(ss.st_timestamp - 2 * MS_PER_5M_EPOCH) - ss.fin_timestr = ut_to_timestr(ss.fin_timestamp + 4 * MS_PER_5M_EPOCH) - data_factory._update_hist_csv_at_exch_and_pair( - "binanceus", "ETH/USDT", ss.fin_timestamp - ) - uts3 = _uts_in_csv(filename) - assert uts3 == _uts_in_range(ss.st_timestamp, ss.fin_timestamp) - - -# ====================================================================== -# end-to-end tests - -BINANCE_ETH_DATA = [ - # time #o #h #l #c #v - [1686805500000, 0.5, 12, 0.12, 1.1, 7.0], - [1686805800000, 0.5, 11, 0.11, 2.2, 7.0], - [1686806100000, 0.5, 10, 0.10, 3.3, 7.0], - [1686806400000, 1.1, 9, 0.09, 4.4, 1.4], - [1686806700000, 3.5, 8, 0.08, 5.5, 2.8], - [1686807000000, 4.7, 7, 0.07, 6.6, 8.1], - [1686807300000, 4.5, 6, 0.06, 7.7, 8.1], - [1686807600000, 0.6, 5, 0.05, 8.8, 8.1], - [1686807900000, 0.9, 4, 0.04, 9.9, 8.1], - [1686808200000, 2.7, 3, 0.03, 10.10, 8.1], - [1686808500000, 0.7, 2, 0.02, 11.11, 8.1], - [1686808800000, 0.7, 1, 0.01, 12.12, 8.3], -] - - -@enforce_types -def _addval(DATA: list, val: float) -> list: - DATA2 = copy.deepcopy(DATA) - for row_i, row in enumerate(DATA2): - for col_j, _ in enumerate(row): - if col_j == 0: - continue - DATA2[row_i][col_j] += val - return DATA2 - - -BINANCE_BTC_DATA = _addval(BINANCE_ETH_DATA, 10000.0) -KRAKEN_ETH_DATA = _addval(BINANCE_ETH_DATA, 0.0001) -KRAKEN_BTC_DATA = _addval(BINANCE_ETH_DATA, 10000.0 + 0.0001) - - -@enforce_types -def test_create_xy__1exchange_1coin_1signal(tmpdir): - csvdir = str(tmpdir) - - csv_dfs = {"kraken": {"ETH-USDT": _df_from_raw_data(BINANCE_ETH_DATA)}} - - pp, ss = _data_pp_ss_1exchange_1coin_1signal(csvdir) - - assert ss.n == 1 * 1 * 1 * 3 # n_exchs * n_coins * n_signals * autoregressive_n - - data_factory = DataFactory(pp, ss) - hist_df = data_factory._merge_csv_dfs(csv_dfs) - X, y, x_df = data_factory.create_xy(hist_df, testshift=0) - _assert_shapes(ss, X, y, x_df) - - assert X[-1, :].tolist() == [4, 3, 2] and y[-1] == 1 - assert X[-2, :].tolist() == [5, 4, 3] and y[-2] == 2 - assert X[0, :].tolist() == [11, 10, 9] and y[0] == 8 - - assert x_df.iloc[-1].tolist() == [4, 3, 2] - - found_cols = x_df.columns.tolist() - target_cols = [ - "kraken:ETH-USDT:high:t-4", - "kraken:ETH-USDT:high:t-3", - "kraken:ETH-USDT:high:t-2", - ] - assert found_cols == target_cols - - assert x_df["kraken:ETH-USDT:high:t-2"].tolist() == [9, 8, 7, 6, 5, 4, 3, 2] - assert X[:, 2].tolist() == [9, 8, 7, 6, 5, 4, 3, 2] - - # =========== now have a different testshift (1 not 0) - X, y, x_df = data_factory.create_xy(hist_df, testshift=1) - _assert_shapes(ss, X, y, x_df) - - assert X[-1, :].tolist() == [5, 4, 3] and y[-1] == 2 - assert X[-2, :].tolist() == [6, 5, 4] and y[-2] == 3 - assert X[0, :].tolist() == [12, 11, 10] and y[0] == 9 - - assert x_df.iloc[-1].tolist() == [5, 4, 3] - - found_cols = x_df.columns.tolist() - target_cols = [ - "kraken:ETH-USDT:high:t-4", - "kraken:ETH-USDT:high:t-3", - "kraken:ETH-USDT:high:t-2", - ] - assert found_cols == target_cols - - assert x_df["kraken:ETH-USDT:high:t-2"].tolist() == [10, 9, 8, 7, 6, 5, 4, 3] - assert X[:, 2].tolist() == [10, 9, 8, 7, 6, 5, 4, 3] - - # =========== now have a different max_n_train - ss.max_n_train = 5 - # ss.autoregressive_n = 2 - - X, y, x_df = data_factory.create_xy(hist_df, testshift=0) - _assert_shapes(ss, X, y, x_df) - - assert X.shape[0] == 5 + 1 # +1 for one test point - assert y.shape[0] == 5 + 1 - assert len(x_df) == 5 + 1 - - assert X[-1, :].tolist() == [4, 3, 2] and y[-1] == 1 - assert X[-2, :].tolist() == [5, 4, 3] and y[-2] == 2 - assert X[0, :].tolist() == [9, 8, 7] and y[0] == 6 - - -@enforce_types -def test_create_xy__2exchanges_2coins_2signals(tmpdir): - csvdir = str(tmpdir) - - csv_dfs = { - "binanceus": { - "BTC-USDT": _df_from_raw_data(BINANCE_BTC_DATA), - "ETH-USDT": _df_from_raw_data(BINANCE_ETH_DATA), - }, - "kraken": { - "BTC-USDT": _df_from_raw_data(KRAKEN_BTC_DATA), - "ETH-USDT": _df_from_raw_data(KRAKEN_ETH_DATA), - }, - } - - pp = DataPP( - "5m", - "binanceus h ETH/USDT", - N_test=2, - ) - - ss = DataSS( - ["binanceus hl BTC/USDT,ETH/USDT", "kraken hl BTC/USDT,ETH/USDT"], - csv_dir=csvdir, - st_timestr="2023-06-18", - fin_timestr="2023-06-21", - max_n_train=7, - autoregressive_n=3, - ) - - assert ss.n == 2 * 2 * 2 * 3 # n_exchs * n_coins * n_signals * autoregressive_n - - data_factory = DataFactory(pp, ss) - hist_df = data_factory._merge_csv_dfs(csv_dfs) - X, y, x_df = data_factory.create_xy(hist_df, testshift=0) - _assert_shapes(ss, X, y, x_df) - - found_cols = x_df.columns.tolist() - target_cols = [ - "binanceus:BTC-USDT:high:t-4", - "binanceus:BTC-USDT:high:t-3", - "binanceus:BTC-USDT:high:t-2", - "binanceus:ETH-USDT:high:t-4", - "binanceus:ETH-USDT:high:t-3", - "binanceus:ETH-USDT:high:t-2", - "binanceus:BTC-USDT:low:t-4", - "binanceus:BTC-USDT:low:t-3", - "binanceus:BTC-USDT:low:t-2", - "binanceus:ETH-USDT:low:t-4", - "binanceus:ETH-USDT:low:t-3", - "binanceus:ETH-USDT:low:t-2", - "kraken:BTC-USDT:high:t-4", - "kraken:BTC-USDT:high:t-3", - "kraken:BTC-USDT:high:t-2", - "kraken:ETH-USDT:high:t-4", - "kraken:ETH-USDT:high:t-3", - "kraken:ETH-USDT:high:t-2", - "kraken:BTC-USDT:low:t-4", - "kraken:BTC-USDT:low:t-3", - "kraken:BTC-USDT:low:t-2", - "kraken:ETH-USDT:low:t-4", - "kraken:ETH-USDT:low:t-3", - "kraken:ETH-USDT:low:t-2", - ] - assert found_cols == target_cols - - # test binanceus:ETH-USDT:high like in 1-signal - assert target_cols[3:6] == [ - "binanceus:ETH-USDT:high:t-4", - "binanceus:ETH-USDT:high:t-3", - "binanceus:ETH-USDT:high:t-2", - ] - Xa = X[:, 3:6] - assert Xa[-1, :].tolist() == [4, 3, 2] and y[-1] == 1 - assert Xa[-2, :].tolist() == [5, 4, 3] and y[-2] == 2 - assert Xa[0, :].tolist() == [11, 10, 9] and y[0] == 8 - - assert x_df.iloc[-1].tolist()[3:6] == [4, 3, 2] - assert x_df.iloc[-2].tolist()[3:6] == [5, 4, 3] - assert x_df.iloc[0].tolist()[3:6] == [11, 10, 9] - - assert x_df["binanceus:ETH-USDT:high:t-2"].tolist() == [9, 8, 7, 6, 5, 4, 3, 2] - assert Xa[:, 2].tolist() == [9, 8, 7, 6, 5, 4, 3, 2] - - -@enforce_types -def test_create_xy__handle_nan(tmpdir): - # create hist_df - csvdir = str(tmpdir) - csv_dfs = {"kraken": {"ETH-USDT": _df_from_raw_data(BINANCE_ETH_DATA)}} - pp, ss = _data_pp_ss_1exchange_1coin_1signal(csvdir) - data_factory = DataFactory(pp, ss) - hist_df = data_factory._merge_csv_dfs(csv_dfs) - - # corrupt hist_df with nans - all_signal_strs = set(signal_str for _, signal_str, _ in ss.input_feed_tups) - assert "high" in all_signal_strs - hist_df.at[1686805800000, "kraken:ETH-USDT:high"] = np.nan # first row - hist_df.at[1686806700000, "kraken:ETH-USDT:high"] = np.nan # middle row - hist_df.at[1686808800000, "kraken:ETH-USDT:high"] = np.nan # last row - assert has_nan(hist_df) - - # run create_xy() and force the nans to stick around - # -> we want to ensure that we're building X/y with risk of nan - X, y, x_df = data_factory.create_xy(hist_df, testshift=0, do_fill_nans=False) - assert has_nan(X) and has_nan(y) and has_nan(x_df) - - # nan approach 1: fix externally - hist_df2 = fill_nans(hist_df) - assert not has_nan(hist_df2) - - # nan approach 2: explicitly tell create_xy to fill nans - X, y, x_df = data_factory.create_xy(hist_df, testshift=0, do_fill_nans=True) - assert not has_nan(X) and not has_nan(y) and not has_nan(x_df) - - # nan approach 3: create_xy fills nans by default (best) - X, y, x_df = data_factory.create_xy(hist_df, testshift=0) - assert not has_nan(X) and not has_nan(y) and not has_nan(x_df) - - -@enforce_types -def _data_pp_ss_1exchange_1coin_1signal(csvdir: str) -> Tuple[DataPP, DataSS]: - pp = DataPP( - "5m", - "kraken h ETH/USDT", - N_test=2, - ) - - ss = DataSS( - [pp.predict_feed_str], - csv_dir=csvdir, - st_timestr="2023-06-18", - fin_timestr="2023-06-21", - max_n_train=7, - autoregressive_n=3, - ) - return pp, ss - - -@enforce_types -def _assert_shapes(ss: DataSS, X: np.ndarray, y: np.ndarray, x_df: pd.DataFrame): - assert X.shape[0] == y.shape[0] - assert X.shape[0] == (ss.max_n_train + 1) # 1 for test, rest for train - assert X.shape[1] == ss.n - - assert len(x_df) == X.shape[0] - assert len(x_df.columns) == ss.n - - -@enforce_types -def _df_from_raw_data(raw_data: list) -> pd.DataFrame: - df = initialize_df(TOHLCV_COLS) - next_df = pd.DataFrame(raw_data, columns=TOHLCV_COLS) - df = concat_next_df(df, next_df) - return df diff --git a/pdr_backend/data_eng/test/test_data_pp.py b/pdr_backend/data_eng/test/test_data_pp.py deleted file mode 100644 index de3492c42..000000000 --- a/pdr_backend/data_eng/test/test_data_pp.py +++ /dev/null @@ -1,45 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.util.constants import CAND_TIMEFRAMES - - -@enforce_types -def test_data_pp_5m(): - # construct - pp = _test_pp("5m") - - # test attributes - assert pp.timeframe == "5m" - assert pp.predict_feed_str == "kraken h ETH/USDT" - assert pp.N_test == 2 - - # test properties - assert pp.timeframe_ms == 5 * 60 * 1000 - assert pp.timeframe_m == 5 - assert pp.predict_feed_tup == ("kraken", "high", "ETH-USDT") - assert pp.exchange_str == "kraken" - assert pp.signal_str == "high" - assert pp.pair_str == "ETH-USDT" - assert pp.base_str == "ETH" - assert pp.quote_str == "USDT" - - -@enforce_types -def test_data_pp_1h(): - ss = _test_pp("1h") - - assert ss.timeframe == "1h" - assert ss.timeframe_ms == 60 * 60 * 1000 - assert ss.timeframe_m == 60 - - -@enforce_types -def _test_pp(timeframe: str) -> DataPP: - assert timeframe in CAND_TIMEFRAMES - pp = DataPP( - timeframe=timeframe, - predict_feed_str="kraken h ETH/USDT", - N_test=2, - ) - return pp diff --git a/pdr_backend/data_eng/test/test_data_ss.py b/pdr_backend/data_eng/test/test_data_ss.py deleted file mode 100644 index 7e8c4f731..000000000 --- a/pdr_backend/data_eng/test/test_data_ss.py +++ /dev/null @@ -1,97 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.util.timeutil import timestr_to_ut - - -@enforce_types -def test_data_ss_basic(tmpdir): - ss = DataSS( - ["kraken hc ETH/USDT", "binanceus h ETH/USDT,TRX/DAI"], - csv_dir=str(tmpdir), - st_timestr="2023-06-18", - fin_timestr="2023-06-21", - max_n_train=7, - autoregressive_n=3, - ) - - # test attributes - assert ss.input_feeds_strs == ["kraken hc ETH/USDT", "binanceus h ETH/USDT,TRX/DAI"] - assert ss.csv_dir == str(tmpdir) - assert ss.st_timestr == "2023-06-18" - assert ss.fin_timestr == "2023-06-21" - - assert ss.max_n_train == 7 - assert ss.autoregressive_n == 3 - - assert sorted(ss.exchs_dict.keys()) == ["binanceus", "kraken"] - - # test properties - assert ss.st_timestamp == timestr_to_ut("2023-06-18") - assert ss.fin_timestamp == timestr_to_ut("2023-06-21") - assert ss.input_feed_tups == [ - ("kraken", "high", "ETH-USDT"), - ("kraken", "close", "ETH-USDT"), - ("binanceus", "high", "ETH-USDT"), - ("binanceus", "high", "TRX-DAI"), - ] - assert ss.exchange_pair_tups == set( - [ - ("kraken", "ETH-USDT"), - ("binanceus", "ETH-USDT"), - ("binanceus", "TRX-DAI"), - ] - ) - assert len(ss.input_feed_tups) == ss.n_input_feeds == 4 - assert ss.n == 4 * 3 == 12 - assert ss.n_exchs == 2 - assert len(ss.exchange_strs) == 2 - assert "binanceus" in ss.exchange_strs - - # test str - assert "DataSS=" in str(ss) - - -@enforce_types -def test_data_ss_now(tmpdir): - ss = DataSS( - ["kraken h ETH/USDT"], - csv_dir=str(tmpdir), - st_timestr="2023-06-18", - fin_timestr="now", - max_n_train=7, - autoregressive_n=3, - ) - assert ss.fin_timestr == "now" - assert ss.fin_timestamp == timestr_to_ut("now") - - -@enforce_types -def test_data_ss_copy(tmpdir): - ss = DataSS( - ["kraken h ETH/USDT BTC/USDT"], - csv_dir=str(tmpdir), - st_timestr="2023-06-18", - fin_timestr="now", - max_n_train=7, - autoregressive_n=3, - ) - - # copy 1: don't need to append the new feed - pp = DataPP( - "5m", - "kraken h ETH/USDT", - N_test=2, - ) - ss2 = ss.copy_with_yval(pp) - assert ss2.n_input_feeds == 2 - - # copy 2: do need to append the new feed - pp = DataPP( - "5m", - "binanceus c TRX/USDC", - N_test=2, - ) - ss3 = ss.copy_with_yval(pp) - assert ss3.n_input_feeds == 3 diff --git a/pdr_backend/data_eng/test/test_pdutil.py b/pdr_backend/data_eng/test/test_pdutil.py deleted file mode 100644 index 17e7354f3..000000000 --- a/pdr_backend/data_eng/test/test_pdutil.py +++ /dev/null @@ -1,214 +0,0 @@ -import os - -from enforce_typing import enforce_types -import numpy as np -import pandas as pd -import pytest - -from pdr_backend.data_eng.constants import ( - OHLCV_COLS, - OHLCV_DTYPES, - TOHLCV_COLS, -) -from pdr_backend.data_eng.pdutil import ( - initialize_df, - concat_next_df, - save_csv, - load_csv, - has_data, - oldest_ut, - newest_ut, - _get_last_line, -) - -FOUR_ROWS_RAW_TOHLCV_DATA = [ - [1686806100000, 1648.58, 1648.58, 1646.27, 1646.64, 7.4045], - [1686806400000, 1647.05, 1647.05, 1644.61, 1644.86, 14.452], - [1686806700000, 1644.57, 1646.41, 1642.49, 1645.81, 22.8612], - [1686807000000, 1645.77, 1646.2, 1645.23, 1646.05, 8.1741], -] -ONE_ROW_RAW_TOHLCV_DATA = [[1686807300000, 1646, 1647.2, 1646.23, 1647.05, 8.1742]] - - -@enforce_types -def test_initialize_df(): - df = initialize_df(TOHLCV_COLS) - - assert isinstance(df, pd.DataFrame) - _assert_TOHLCVd_cols_and_types(df) - - df = initialize_df(OHLCV_COLS[:2]) - assert df.columns.tolist() == OHLCV_COLS[:2] + ["datetime"] - assert df.dtypes.tolist()[:-1] == OHLCV_DTYPES[:2] - - -@enforce_types -def test_concat_next_df(): - # baseline data - df = initialize_df(TOHLCV_COLS) - assert len(df) == 0 - - next_df = pd.DataFrame(FOUR_ROWS_RAW_TOHLCV_DATA, columns=TOHLCV_COLS) - assert len(next_df) == 4 - - # add 4 rows to empty df - df = concat_next_df(df, next_df) - assert len(df) == 4 - _assert_TOHLCVd_cols_and_types(df) - - # from df with 4 rows, add 1 more row - next_df = pd.DataFrame(ONE_ROW_RAW_TOHLCV_DATA, columns=TOHLCV_COLS) - assert len(next_df) == 1 - - df = concat_next_df(df, next_df) - assert len(df) == 4 + 1 - _assert_TOHLCVd_cols_and_types(df) - - -@enforce_types -def _assert_TOHLCVd_cols_and_types(df: pd.DataFrame): - assert df.columns.tolist() == OHLCV_COLS + ["datetime"] - assert df.dtypes.tolist()[:-1] == OHLCV_DTYPES - assert str(df.dtypes.tolist()[-1]) == "datetime64[ns, UTC]" - assert df.index.name == "timestamp" and df.index.dtype == np.int64 - - -def _filename(tmpdir) -> str: - return os.path.join(tmpdir, "foo.csv") - - -@enforce_types -def test_load_basic(tmpdir): - filename = _filename(tmpdir) - df = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) - save_csv(filename, df) - - # simplest specification. Don't specify cols, st or fin - df2 = load_csv(filename) - _assert_TOHLCVd_cols_and_types(df2) - assert len(df2) == 4 and str(df) == str(df2) - - # explicitly specify cols, but not st or fin - df2 = load_csv(filename, OHLCV_COLS) - _assert_TOHLCVd_cols_and_types(df2) - assert len(df2) == 4 and str(df) == str(df2) - - # explicitly specify cols, st, fin - df2 = load_csv(filename, OHLCV_COLS, st=None, fin=None) - _assert_TOHLCVd_cols_and_types(df2) - assert len(df2) == 4 and str(df) == str(df2) - - df2 = load_csv(filename, OHLCV_COLS, st=0, fin=np.inf) - _assert_TOHLCVd_cols_and_types(df2) - assert len(df2) == 4 and str(df) == str(df2) - - -@enforce_types -def test_load_append(tmpdir): - # save 4-row csv - filename = _filename(tmpdir) - df_4_rows = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) - save_csv(filename, df_4_rows) # write new file - - # append 1 row to csv - df_1_row = _df_from_raw_data(ONE_ROW_RAW_TOHLCV_DATA) - save_csv(filename, df_1_row) # will append existing file - - # test - df_5_rows = concat_next_df( - df_4_rows, pd.DataFrame(ONE_ROW_RAW_TOHLCV_DATA, columns=TOHLCV_COLS) - ) - df_5_rows_loaded = load_csv(filename) - _assert_TOHLCVd_cols_and_types(df_5_rows_loaded) - assert len(df_5_rows_loaded) == 5 - assert str(df_5_rows) == str(df_5_rows_loaded) - - -@enforce_types -def test_load_filtered(tmpdir): - # save - filename = _filename(tmpdir) - df = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) - save_csv(filename, df) - - # load with filters on rows & columns - cols = OHLCV_COLS[:2] # ["open", "high"] - timestamps = [row[0] for row in FOUR_ROWS_RAW_TOHLCV_DATA] - st = timestamps[1] # 1686806400000 - fin = timestamps[2] # 1686806700000 - df2 = load_csv(filename, cols, st, fin) - - # test entries - assert len(df2) == 2 - assert len(df2.index.values) == 2 - assert df2.index.values.tolist() == timestamps[1:3] - - # test cols and types - assert df2.columns.tolist() == OHLCV_COLS[:2] + ["datetime"] - assert df2.dtypes.tolist()[:-1] == OHLCV_DTYPES[:2] - assert str(df2.dtypes.tolist()[-1]) == "datetime64[ns, UTC]" - assert df2.index.name == "timestamp" - assert df2.index.dtype == np.int64 - - -@enforce_types -def _df_from_raw_data(raw_data: list): - df = initialize_df(OHLCV_COLS) - next_df = pd.DataFrame(raw_data, columns=TOHLCV_COLS) - df = concat_next_df(df, next_df) - return df - - -@enforce_types -def test_has_data(tmpdir): - filename0 = os.path.join(tmpdir, "f0.csv") - save_csv(filename0, _df_from_raw_data([])) - assert not has_data(filename0) - - filename1 = os.path.join(tmpdir, "f1.csv") - save_csv(filename1, _df_from_raw_data(ONE_ROW_RAW_TOHLCV_DATA)) - assert has_data(filename1) - - filename4 = os.path.join(tmpdir, "f4.csv") - save_csv(filename4, _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA)) - assert has_data(filename4) - - -@enforce_types -def test_oldest_ut_and_newest_ut__with_data(tmpdir): - filename = _filename(tmpdir) - df = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) - save_csv(filename, df) - - ut0 = oldest_ut(filename) - utN = newest_ut(filename) - assert ut0 == FOUR_ROWS_RAW_TOHLCV_DATA[0][0] - assert utN == FOUR_ROWS_RAW_TOHLCV_DATA[-1][0] - - -@enforce_types -def test_oldest_ut_and_newest_ut__no_data(tmpdir): - filename = _filename(tmpdir) - df = _df_from_raw_data([]) - save_csv(filename, df) - - with pytest.raises(ValueError): - oldest_ut(filename) - with pytest.raises(ValueError): - newest_ut(filename) - - -@enforce_types -def test_get_last_line(tmpdir): - filename = os.path.join(tmpdir, "foo.csv") - - with open(filename, "w") as f: - f.write( - """line0 boo bo bum -line1 foo bar -line2 bah bah -line3 ha ha lol""" - ) - target_last_line = "line3 ha ha lol" - found_last_line = _get_last_line(filename) - assert found_last_line == target_last_line diff --git a/pdr_backend/dfbuyer/dfbuyer_agent.py b/pdr_backend/dfbuyer/dfbuyer_agent.py index d0517b784..fe84d513c 100644 --- a/pdr_backend/dfbuyer/dfbuyer_agent.py +++ b/pdr_backend/dfbuyer/dfbuyer_agent.py @@ -1,74 +1,87 @@ import math +import os import time from typing import Dict, List, Tuple from enforce_typing import enforce_types -from pdr_backend.dfbuyer.dfbuyer_config import DFBuyerConfig -from pdr_backend.models.predictoor_batcher import PredictoorBatcher -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.token import Token + +from pdr_backend.contract.predictoor_batcher import PredictoorBatcher +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.contract.token import Token +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_consume_so_far import get_consume_so_far_per_contract +from pdr_backend.subgraph.subgraph_feed import print_feeds +from pdr_backend.subgraph.subgraph_sync import wait_until_subgraph_syncs from pdr_backend.util.constants import MAX_UINT from pdr_backend.util.contract import get_address -from pdr_backend.util.subgraph import ( - get_consume_so_far_per_contract, - wait_until_subgraph_syncs, -) +from pdr_backend.util.mathutil import from_wei WEEK = 7 * 86400 @enforce_types class DFBuyerAgent: - def __init__(self, config: DFBuyerConfig): - self.config: DFBuyerConfig = config - self.last_consume_ts = 0 - self.feeds = config.get_feeds() - self.predictoor_batcher: PredictoorBatcher = PredictoorBatcher( - self.config.web3_config, - get_address(config.web3_config.w3.eth.chain_id, "PredictoorHelper"), - ) - self.token_addr = get_address(config.web3_config.w3.eth.chain_id, "Ocean") - self.fail_counter = 0 + def __init__(self, ppss: PPSS): + # ppss + self.ppss = ppss + print("\n" + "-" * 80) + print(self.ppss) - print("-" * 80) - print("Config:") - print(self.config) + # set self.feeds + cand_feeds = ppss.web3_pp.query_feed_contracts() + print_feeds(cand_feeds, f"all feeds, owner={ppss.web3_pp.owner_addrs}") - print("\n" + "." * 80) - print("Feeds (detailed):") - for feed in self.feeds.values(): - print(f" {feed.longstr()}") + self.feeds = ppss.dfbuyer_ss.filter_feeds_from_candidates(cand_feeds) - print("\n" + "." * 80) - print("Feeds (succinct):") - for addr, feed in self.feeds.items(): - print(f" {feed}, {feed.seconds_per_epoch} s/epoch, addr={addr}") + if not self.feeds: + raise ValueError("No feeds found.") - token = Token(self.config.web3_config, self.token_addr) + # addresses + batcher_addr = get_address(ppss.web3_pp, "PredictoorHelper") + self.OCEAN_addr = get_address(ppss.web3_pp, "Ocean") + + # set attribs to track progress + self.last_consume_ts = 0 + self.predictoor_batcher: PredictoorBatcher = PredictoorBatcher( + ppss.web3_pp, + batcher_addr, + ) + self.fail_counter = 0 + self.batch_size = ppss.dfbuyer_ss.batch_size # Check allowance and approve if necessary print("Checking allowance...") - allowance = token.allowance( - self.config.web3_config.owner, self.predictoor_batcher.contract_address + OCEAN = Token(ppss.web3_pp, self.OCEAN_addr) + allowance = OCEAN.allowance( + ppss.web3_pp.web3_config.owner, + self.predictoor_batcher.contract_address, ) if allowance < MAX_UINT - 10**50: print("Approving tokens for predictoor_batcher") - tx = token.approve( + tx = OCEAN.approve( self.predictoor_batcher.contract_address, int(MAX_UINT), True ) print(f"Done: {tx['transactionHash'].hex()}") def run(self, testing: bool = False): + if not self.feeds: + return + while True: - ts = self.config.web3_config.get_block("latest")["timestamp"] + ts = self.ppss.web3_pp.web3_config.get_current_timestamp() self.take_step(ts) - if testing: + if testing or os.getenv("TEST") == "true": break def take_step(self, ts: int): + if not self.feeds: + return + print("Taking step for timestamp:", ts) - wait_until_subgraph_syncs(self.config.web3_config, self.config.subgraph_url) + wait_until_subgraph_syncs( + self.ppss.web3_pp.web3_config, self.ppss.web3_pp.subgraph_url + ) missing_consumes_amt = self._get_missing_consumes(ts) print("Missing consume amounts:", missing_consumes_amt) @@ -86,25 +99,29 @@ def take_step(self, ts: int): print("One or more consumes have failed...") self.fail_counter += 1 - if self.fail_counter > 3 and self.config.batch_size > 6: - self.config.batch_size = self.config.batch_size * 2 // 3 + batch_size = self.ppss.dfbuyer_ss.batch_size + if self.fail_counter > 3 and batch_size > 6: + self.batch_size = batch_size * 2 // 3 print( - f"Seems like we keep failing, adjusting batch size to: {self.config.batch_size}" + f"Seems like we keep failing, adjusting batch size to: {batch_size}" ) self.fail_counter = 0 print("Sleeping for a minute and trying again") time.sleep(60) return + self.fail_counter = 0 + self._sleep_until_next_consume_interval() + def _sleep_until_next_consume_interval(self): # sleep until next consume interval - ts = self.config.web3_config.get_block("latest")["timestamp"] - interval_start = ( - int(ts / self.config.consume_interval_seconds) - * self.config.consume_interval_seconds - ) - seconds_left = (interval_start + self.config.consume_interval_seconds) - ts + 60 + ts = self.ppss.web3_pp.web3_config.get_current_timestamp() + consume_interval_seconds = self.ppss.dfbuyer_ss.consume_interval_seconds + + interval_start = int(ts / consume_interval_seconds) * consume_interval_seconds + seconds_left = (interval_start + consume_interval_seconds) - ts + 60 + print( f"-- Sleeping for {seconds_left} seconds until next consume interval... --" ) @@ -122,64 +139,69 @@ def _get_missing_consumes(self, ts: int) -> Dict[str, float]: actual_consumes = self._get_consume_so_far(ts) expected_consume_per_feed = self._get_expected_amount_per_feed(ts) - missing_consumes_amt: Dict[str, float] = {} - - for address in self.feeds: - missing = expected_consume_per_feed - actual_consumes[address] - if missing > 0: - missing_consumes_amt[address] = missing - - return missing_consumes_amt + return { + address: expected_consume_per_feed - actual_consumes[address] + for address in self.feeds + if expected_consume_per_feed > actual_consumes[address] + } def _prepare_batches( self, consume_times: Dict[str, int] ) -> List[Tuple[List[str], List[int]]]: + batch_size = self.ppss.dfbuyer_ss.batch_size + max_no_of_addresses_in_batch = 3 # to avoid gas issues batches: List[Tuple[List[str], List[int]]] = [] addresses_to_consume: List[str] = [] times_to_consume: List[int] = [] + for address, times in consume_times.items(): while times > 0: current_times_to_consume = min( - times, self.config.batch_size - sum(times_to_consume) + times, batch_size - sum(times_to_consume) ) if current_times_to_consume > 0: addresses_to_consume.append( - self.config.web3_config.w3.to_checksum_address(address) + self.ppss.web3_pp.web3_config.w3.to_checksum_address(address) ) times_to_consume.append(current_times_to_consume) times -= current_times_to_consume if ( - sum(times_to_consume) == self.config.batch_size + sum(times_to_consume) == batch_size or address == list(consume_times.keys())[-1] or len(addresses_to_consume) == max_no_of_addresses_in_batch ): batches.append((addresses_to_consume, times_to_consume)) addresses_to_consume = [] times_to_consume = [] + return batches def _consume(self, addresses_to_consume, times_to_consume): - for i in range(self.config.max_request_tries): + for i in range(self.ppss.dfbuyer_ss.max_request_tries): try: tx = self.predictoor_batcher.consume_multiple( addresses_to_consume, times_to_consume, - self.token_addr, + self.OCEAN_addr, True, ) tx_hash = tx["transactionHash"].hex() + if tx["status"] != 1: print(f" Tx reverted: {tx_hash}") return False + print(f" Tx sent: {tx_hash}") return True except Exception as e: print(f" Attempt {i+1} failed with error: {e}") time.sleep(1) + if i == 4: print(" Failed to consume contracts after 5 attempts.") raise + return False def _consume_batch(self, addresses_to_consume, times_to_consume) -> bool: @@ -198,13 +220,16 @@ def _consume_batch(self, addresses_to_consume, times_to_consume) -> bool: for address, times in zip(addresses_to_consume, times_to_consume): if self._consume([address], [times]): continue # If successful, continue to the next address + # If individual consumption fails, split the consumption into two parts half_time = times // 2 + if half_time > 0: print(f" Consuming {address} for {half_time} times") if not self._consume([address], [half_time]): print("Transaction reverted again, please adjust batch size") one_or_more_failed = True + remaining_times = times - half_time if remaining_times > 0: print(f" Consuming {address} for {remaining_times} times") @@ -214,42 +239,45 @@ def _consume_batch(self, addresses_to_consume, times_to_consume) -> bool: else: print(f" Unable to consume {address} for {times} times") one_or_more_failed = True + return one_or_more_failed def _batch_txs(self, consume_times: Dict[str, int]) -> bool: batches = self._prepare_batches(consume_times) print(f"Processing {len(batches)} batches...") - one_or_more_failed = False + + failures = 0 + for addresses_to_consume, times_to_consume in batches: - failed = self._consume_batch(addresses_to_consume, times_to_consume) - if failed: - one_or_more_failed = True - return one_or_more_failed + failures += int(self._consume_batch(addresses_to_consume, times_to_consume)) + + return bool(failures) def _get_prices(self, contract_addresses: List[str]) -> Dict[str, float]: - prices: Dict[str, float] = {} - for address in contract_addresses: - rate_wei = PredictoorContract(self.config.web3_config, address).get_price() - rate_float = float(self.config.web3_config.w3.from_wei(rate_wei, "ether")) - prices[address] = rate_float - return prices + return { + address: from_wei( + PredictoorContract(self.ppss.web3_pp, address).get_price() + ) + for address in contract_addresses + } def _get_consume_so_far(self, ts: int) -> Dict[str, float]: week_start = (math.floor(ts / WEEK)) * WEEK consume_so_far = get_consume_so_far_per_contract( - self.config.subgraph_url, - self.config.web3_config.owner, + self.ppss.web3_pp.subgraph_url, + self.ppss.web3_pp.web3_config.owner, week_start, list(self.feeds.keys()), ) return consume_so_far def _get_expected_amount_per_feed(self, ts: int): - amount_per_feed_per_interval = self.config.amount_per_interval / len(self.feeds) + ss = self.ppss.dfbuyer_ss + amount_per_feed_per_interval = ss.amount_per_interval / len(self.feeds) week_start = (math.floor(ts / WEEK)) * WEEK time_passed = ts - week_start # find out how many intervals has passed - n_intervals = int(time_passed / self.config.consume_interval_seconds) + 1 + n_intervals = int(time_passed / ss.consume_interval_seconds) + 1 return n_intervals * amount_per_feed_per_interval diff --git a/pdr_backend/dfbuyer/dfbuyer_config.py b/pdr_backend/dfbuyer/dfbuyer_config.py deleted file mode 100644 index 5bdbe5e85..000000000 --- a/pdr_backend/dfbuyer/dfbuyer_config.py +++ /dev/null @@ -1,22 +0,0 @@ -from os import getenv - -from enforce_typing import enforce_types - -from pdr_backend.models.base_config import BaseConfig - - -@enforce_types -class DFBuyerConfig(BaseConfig): - def __init__(self): - super().__init__() - self.weekly_spending_limit = int(getenv("WEEKLY_SPENDING_LIMIT", "37000")) - self.consume_interval_seconds = int(getenv("CONSUME_INTERVAL_SECONDS", "86400")) - - # number of consumes to execute in a single transaction - self.batch_size = int(getenv("CONSUME_BATCH_SIZE", "20")) - - self.amount_per_interval = float( - self.weekly_spending_limit / (7 * 24 * 3600) * self.consume_interval_seconds - ) - - self.max_request_tries = 5 diff --git a/pdr_backend/dfbuyer/main.py b/pdr_backend/dfbuyer/main.py deleted file mode 100644 index 291a39f0d..000000000 --- a/pdr_backend/dfbuyer/main.py +++ /dev/null @@ -1,14 +0,0 @@ -from pdr_backend.dfbuyer.dfbuyer_agent import DFBuyerAgent -from pdr_backend.dfbuyer.dfbuyer_config import DFBuyerConfig - - -def main(): - print("Starting main loop...") - config = DFBuyerConfig() - agent = DFBuyerAgent(config) - - agent.run() - - -if __name__ == "__main__": - main() diff --git a/pdr_backend/dfbuyer/test/conftest.py b/pdr_backend/dfbuyer/test/conftest.py index bbe0a750d..11902c462 100644 --- a/pdr_backend/dfbuyer/test/conftest.py +++ b/pdr_backend/dfbuyer/test/conftest.py @@ -1,47 +1,62 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch + import pytest + +from pdr_backend.contract.predictoor_batcher import mock_predictoor_batcher from pdr_backend.dfbuyer.dfbuyer_agent import DFBuyerAgent -from pdr_backend.dfbuyer.dfbuyer_config import DFBuyerConfig -from pdr_backend.models.feed import Feed -from pdr_backend.util.constants import ( - MAX_UINT, - ZERO_ADDRESS, -) # pylint: disable=wildcard-import +from pdr_backend.ppss.ppss import mock_feed_ppss +from pdr_backend.ppss.web3_pp import inplace_mock_feedgetters +from pdr_backend.util.constants import MAX_UINT, ZERO_ADDRESS +PATH = "pdr_backend.dfbuyer.dfbuyer_agent" -def mock_feed(): - feed = Mock(spec=Feed) - feed.name = "test feed" - feed.seconds_per_epoch = 60 - return feed + +@pytest.fixture +def mock_get_address(): + with patch(f"{PATH}.get_address") as mock: + mock.return_value = ZERO_ADDRESS + yield mock @pytest.fixture() def mock_token(): - with patch("pdr_backend.dfbuyer.dfbuyer_agent.Token") as mock_token_class: + with patch(f"{PATH}.Token") as mock_token_class: mock_token_instance = MagicMock() mock_token_instance.allowance.return_value = MAX_UINT mock_token_class.return_value = mock_token_instance yield mock_token_class +_MOCK_FEED_PPSS = None # (feed, ppss) + + +def _mock_feed_ppss(): + global _MOCK_FEED_PPSS + if _MOCK_FEED_PPSS is None: + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + _MOCK_FEED_PPSS = (feed, ppss) + return _MOCK_FEED_PPSS + + @pytest.fixture -def dfbuyer_config(): - config = DFBuyerConfig() - config.get_feeds = Mock() - addresses = [ZERO_ADDRESS[: -len(str(i))] + str(i) for i in range(1, 7)] - config.get_feeds.return_value = {address: mock_feed() for address in addresses} - return config +def mock_ppss(): + _, ppss = _mock_feed_ppss() + return ppss @pytest.fixture -def mock_get_address(): - with patch("pdr_backend.dfbuyer.dfbuyer_agent.get_address") as mock: - mock.return_value = ZERO_ADDRESS +def mock_PredictoorBatcher(mock_ppss): # pylint: disable=redefined-outer-name + with patch(f"{PATH}.PredictoorBatcher") as mock: + mock.return_value = mock_predictoor_batcher(mock_ppss.web3_pp) yield mock -# pylint: disable=redefined-outer-name, unused-argument @pytest.fixture -def dfbuyer_agent(mock_get_address, mock_token, dfbuyer_config): - return DFBuyerAgent(dfbuyer_config) +def mock_dfbuyer_agent( # pylint: disable=unused-argument, redefined-outer-name + mock_get_address, + mock_token, + mock_ppss, + mock_PredictoorBatcher, +): + return DFBuyerAgent(mock_ppss) diff --git a/pdr_backend/dfbuyer/test/test_dfbuyer_agent.py b/pdr_backend/dfbuyer/test/test_dfbuyer_agent.py index e54fc8a31..dee69ad5e 100644 --- a/pdr_backend/dfbuyer/test/test_dfbuyer_agent.py +++ b/pdr_backend/dfbuyer/test/test_dfbuyer_agent.py @@ -1,123 +1,169 @@ from unittest.mock import MagicMock, call, patch -from ccxt.base.exchange import math import pytest +from ccxt.base.exchange import math +from enforce_typing import enforce_types + +from pdr_backend.contract.predictoor_batcher import PredictoorBatcher from pdr_backend.dfbuyer.dfbuyer_agent import WEEK, DFBuyerAgent +from pdr_backend.ppss.dfbuyer_ss import DFBuyerSS +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.ppss.web3_pp import Web3PP from pdr_backend.util.constants import MAX_UINT, ZERO_ADDRESS from pdr_backend.util.web3_config import Web3Config +PATH = "pdr_backend.dfbuyer.dfbuyer_agent" -@patch("pdr_backend.dfbuyer.dfbuyer_agent.get_address") -def test_new_agent(mock_get_address, mock_token, dfbuyer_config): + +@enforce_types +def test_dfbuyer_agent_constructor( # pylint: disable=unused-argument + mock_get_address, + mock_token, + mock_ppss, + mock_PredictoorBatcher, +): mock_token.return_value.allowance.return_value = 0 - mock_get_address.return_value = ZERO_ADDRESS - agent = DFBuyerAgent(dfbuyer_config) + agent = DFBuyerAgent(mock_ppss) assert len(mock_get_address.call_args_list) == 2 call1 = mock_get_address.call_args_list[0] - assert call1 == call(dfbuyer_config.web3_config.w3.eth.chain_id, "PredictoorHelper") + assert call1 == call(mock_ppss.web3_pp, "PredictoorHelper") call2 = mock_get_address.call_args_list[1] - assert call2 == call(dfbuyer_config.web3_config.w3.eth.chain_id, "Ocean") + assert call2 == call(mock_ppss.web3_pp, "Ocean") - mock_token.assert_called_with(dfbuyer_config.web3_config, agent.token_addr) + mock_token.assert_called_with(mock_ppss.web3_pp, agent.OCEAN_addr) mock_token_instance = mock_token() mock_token_instance.approve.assert_called_with( agent.predictoor_batcher.contract_address, int(MAX_UINT), True ) -def test_get_expected_amount_per_feed(dfbuyer_agent): +@enforce_types +def test_dfbuyer_agent_constructor_empty(): + # test with no feeds + mock_ppss_empty = MagicMock(spec=PPSS) + mock_ppss_empty.dfbuyer_ss = MagicMock(spec=DFBuyerSS) + mock_ppss_empty.dfbuyer_ss.filter_feeds_from_candidates.return_value = {} + mock_ppss_empty.web3_pp = MagicMock(spec=Web3PP) + mock_ppss_empty.web3_pp.query_feed_contracts.return_value = {} + + with pytest.raises(ValueError, match="No feeds found"): + DFBuyerAgent(mock_ppss_empty) + + +@enforce_types +def test_dfbuyer_agent_get_expected_amount_per_feed(mock_dfbuyer_agent): ts = 1695211135 - amount_per_feed_per_interval = dfbuyer_agent.config.amount_per_interval / len( - dfbuyer_agent.feeds + amount_per_feed_per_interval = ( + mock_dfbuyer_agent.ppss.dfbuyer_ss.amount_per_interval + / len(mock_dfbuyer_agent.feeds) ) week_start = (math.floor(ts / WEEK)) * WEEK time_passed = ts - week_start - n_intervals = int(time_passed / dfbuyer_agent.config.consume_interval_seconds) + 1 + n_intervals = ( + int(time_passed / mock_dfbuyer_agent.ppss.dfbuyer_ss.consume_interval_seconds) + + 1 + ) expected_result = n_intervals * amount_per_feed_per_interval - result = dfbuyer_agent._get_expected_amount_per_feed(ts) + result = mock_dfbuyer_agent._get_expected_amount_per_feed(ts) assert result == expected_result -def test_get_expected_amount_per_feed_hardcoded(dfbuyer_agent): +def test_dfbuyer_agent_get_expected_amount_per_feed_hardcoded(mock_dfbuyer_agent): ts = 16958592000 end = ts + WEEK - 86400 # last day just_before_new_week = ts + WEEK - 1 # 1 second before next week - amount_per_feed_per_interval = dfbuyer_agent.config.amount_per_interval / len( - dfbuyer_agent.feeds + amount_per_feed_per_interval = ( + mock_dfbuyer_agent.ppss.dfbuyer_ss.amount_per_interval + / len(mock_dfbuyer_agent.feeds) ) - result1 = dfbuyer_agent._get_expected_amount_per_feed(ts) + result1 = mock_dfbuyer_agent._get_expected_amount_per_feed(ts) assert result1 == amount_per_feed_per_interval - assert result1 * len(dfbuyer_agent.feeds) == 37000 / 7 # first day + assert result1 * len(mock_dfbuyer_agent.feeds) == 37000 / 7 # first day - result2 = dfbuyer_agent._get_expected_amount_per_feed(end) + result2 = mock_dfbuyer_agent._get_expected_amount_per_feed(end) assert result2 == amount_per_feed_per_interval * 7 assert ( - result2 * len(dfbuyer_agent.feeds) == 37000 + result2 * len(mock_dfbuyer_agent.feeds) == 37000 ) # last day, should distribute all - result3 = dfbuyer_agent._get_expected_amount_per_feed(just_before_new_week) + result3 = mock_dfbuyer_agent._get_expected_amount_per_feed(just_before_new_week) assert result3 == amount_per_feed_per_interval * 7 - assert result3 * len(dfbuyer_agent.feeds) == 37000 # still last day + assert result3 * len(mock_dfbuyer_agent.feeds) == 37000 # still last day -@patch("pdr_backend.dfbuyer.dfbuyer_agent.get_consume_so_far_per_contract") -def test_get_consume_so_far(mock_get_consume_so_far, dfbuyer_agent): +@enforce_types +@patch(f"{PATH}.get_consume_so_far_per_contract") +def test_dfbuyer_agent_get_consume_so_far(mock_get_consume_so_far, mock_dfbuyer_agent): agent = MagicMock() - agent.config.web3_config.owner = "0x123" + agent.ppss.web3_pp.web3_config.owner = "0x123" agent.feeds = {"feed1": "0x1", "feed2": "0x2"} mock_get_consume_so_far.return_value = {"0x1": 10.5} expected_result = {"0x1": 10.5} - result = dfbuyer_agent._get_consume_so_far(0) + result = mock_dfbuyer_agent._get_consume_so_far(0) assert result == expected_result -@patch("pdr_backend.dfbuyer.dfbuyer_agent.PredictoorContract") -def test_get_prices(mock_contract, dfbuyer_agent): +@enforce_types +@patch(f"{PATH}.PredictoorContract") +def test_dfbuyer_agent_get_prices(mock_contract, mock_dfbuyer_agent): mock_contract_instance = MagicMock() mock_contract.return_value = mock_contract_instance mock_contract_instance.get_price.return_value = 10000 - result = dfbuyer_agent._get_prices(["0x1", "0x2"]) + result = mock_dfbuyer_agent._get_prices(["0x1", "0x2"]) assert result["0x1"] == 10000 / 1e18 assert result["0x2"] == 10000 / 1e18 -def test_prepare_batches(dfbuyer_agent): - dfbuyer_agent.config.batch_size = 10 - +@enforce_types +def test_dfbuyer_agent_prepare_batches(mock_dfbuyer_agent): addresses = [ZERO_ADDRESS[: -len(str(i))] + str(i) for i in range(1, 7)] - consume_times = dict(zip(addresses, [5, 15, 7, 3, 12, 8])) - result = dfbuyer_agent._prepare_batches(consume_times) + consume_times = dict(zip(addresses, [10, 30, 14, 6, 24, 16])) + result = mock_dfbuyer_agent._prepare_batches(consume_times) expected_result = [ - ([addresses[0], addresses[1]], [5, 5]), - ([addresses[1]], [10]), - ([addresses[2], addresses[3]], [7, 3]), - ([addresses[4]], [10]), - ([addresses[4], addresses[5]], [2, 8]), + ([addresses[0], addresses[1]], [10, 10]), + ([addresses[1]], [20]), + ([addresses[2], addresses[3]], [14, 6]), + ([addresses[4]], [20]), + ([addresses[4], addresses[5]], [4, 16]), ] assert result == expected_result -@patch.object(DFBuyerAgent, "_get_consume_so_far") -@patch.object(DFBuyerAgent, "_get_expected_amount_per_feed") -def test_get_missing_consumes( - mock_get_expected_amount_per_feed, mock_get_consume_so_far, dfbuyer_agent +@enforce_types +def test_dfbuyer_agent_get_missing_consumes( # pylint: disable=unused-argument + mock_get_address, + mock_token, + monkeypatch, ): - ts = 0 + ppss = MagicMock(spec=PPSS) + ppss.web3_pp = MagicMock(spec=Web3PP) + addresses = [ZERO_ADDRESS[: -len(str(i))] + str(i) for i in range(1, 7)] - consume_amts = { - addresses[0]: 10, - addresses[1]: 11, - addresses[2]: 32, - addresses[3]: 24, - addresses[4]: 41, - addresses[5]: 0, - } - mock_get_consume_so_far.return_value = consume_amts - mock_get_expected_amount_per_feed.return_value = 15 + feeds = {address: MagicMock() for address in addresses} + ppss.web3_pp.query_feed_contracts = MagicMock() + ppss.web3_pp.query_feed_contracts.return_value = feeds + + ppss.dfbuyer_ss = MagicMock(spec=DFBuyerSS) + ppss.dfbuyer_ss.batch_size = 3 + ppss.dfbuyer_ss.filter_feeds_from_candidates.return_value = feeds + + batcher_class = MagicMock(spec=PredictoorBatcher) + monkeypatch.setattr(f"{PATH}.PredictoorBatcher", batcher_class) + + dfbuyer_agent = DFBuyerAgent(ppss) + + consume_amts = dict(zip(addresses, [10, 11, 32, 24, 41, 0])) + dfbuyer_agent._get_consume_so_far = MagicMock() + dfbuyer_agent._get_consume_so_far.return_value = consume_amts + + dfbuyer_agent._get_expected_amount_per_feed = MagicMock() + dfbuyer_agent._get_expected_amount_per_feed.return_value = 15 + + ts = 0 result = dfbuyer_agent._get_missing_consumes(ts) expected_consume = dfbuyer_agent._get_expected_amount_per_feed(ts) expected_result = { @@ -126,25 +172,27 @@ def test_get_missing_consumes( if expected_consume - consume_amts[address] >= 0 } assert result == expected_result - mock_get_consume_so_far.assert_called_once_with(ts) + dfbuyer_agent._get_consume_so_far.assert_called_once_with(ts) -def test_get_missing_consume_times(dfbuyer_agent): +@enforce_types +def test_dfbuyer_agent_get_missing_consume_times(mock_dfbuyer_agent): missing_consumes = {"0x1": 10.5, "0x2": 20.3, "0x3": 30.7} prices = {"0x1": 2.5, "0x2": 3.3, "0x3": 4.7} - result = dfbuyer_agent._get_missing_consume_times(missing_consumes, prices) + result = mock_dfbuyer_agent._get_missing_consume_times(missing_consumes, prices) expected_result = {"0x1": 5, "0x2": 7, "0x3": 7} assert result == expected_result -@patch("pdr_backend.dfbuyer.dfbuyer_agent.wait_until_subgraph_syncs") +@enforce_types +@patch(f"{PATH}.wait_until_subgraph_syncs") @patch("time.sleep", return_value=None) @patch.object(DFBuyerAgent, "_get_missing_consumes") @patch.object(DFBuyerAgent, "_get_prices") @patch.object(DFBuyerAgent, "_get_missing_consume_times") @patch.object(DFBuyerAgent, "_batch_txs") @patch.object(Web3Config, "get_block") -def test_take_step( +def test_dfbuyer_agent_take_step( mock_get_block, mock_batch_txs, mock_get_missing_consume_times, @@ -152,7 +200,7 @@ def test_take_step( mock_get_missing_consumes, mock_sleep, mock_subgraph_sync, # pylint: disable=unused-argument - dfbuyer_agent, + mock_dfbuyer_agent, ): ts = 0 mock_get_missing_consumes.return_value = {"0x1": 10.5, "0x2": 20.3, "0x3": 30.7} @@ -160,7 +208,7 @@ def test_take_step( mock_get_missing_consume_times.return_value = {"0x1": 5, "0x2": 7, "0x3": 7} mock_get_block.return_value = {"timestamp": 120} mock_batch_txs.return_value = False - dfbuyer_agent.take_step(ts) + mock_dfbuyer_agent.take_step(ts) mock_get_missing_consumes.assert_called_once_with(ts) mock_get_prices.assert_called_once_with( list(mock_get_missing_consumes.return_value.keys()) @@ -172,18 +220,30 @@ def test_take_step( mock_get_block.assert_called_once_with("latest") mock_sleep.assert_called_once_with(86400 - 60) + # empty feeds + mock_dfbuyer_agent.feeds = [] + assert mock_dfbuyer_agent.take_step(ts) is None + +@enforce_types @patch.object(DFBuyerAgent, "take_step") @patch.object(Web3Config, "get_block") -def test_run(mock_get_block, mock_take_step, dfbuyer_agent): +def test_dfbuyer_agent_run(mock_get_block, mock_take_step, mock_dfbuyer_agent): mock_get_block.return_value = {"timestamp": 0} - dfbuyer_agent.run(testing=True) + mock_dfbuyer_agent.run(testing=True) mock_get_block.assert_called_once_with("latest") mock_take_step.assert_called_once_with(mock_get_block.return_value["timestamp"]) + # empty feeds + mock_dfbuyer_agent.feeds = [] + assert mock_dfbuyer_agent.run(testing=True) is None + + +@enforce_types +@patch(f"{PATH}.time.sleep", return_value=None) +def test_dfbuyer_agent_consume_method(mock_sleep, mock_dfbuyer_agent): + mock_batcher = mock_dfbuyer_agent.predictoor_batcher -@patch("pdr_backend.dfbuyer.dfbuyer_agent.time.sleep", return_value=None) -def test_consume_method(mock_sleep, dfbuyer_agent): addresses_to_consume = ["0x1", "0x2"] times_to_consume = [2, 3] @@ -191,30 +251,28 @@ def test_consume_method(mock_sleep, dfbuyer_agent): failed_tx = {"transactionHash": b"some_hash", "status": 0} exception_tx = Exception("Error") - with patch.object( - dfbuyer_agent, "predictoor_batcher", autospec=True - ) as mock_predictoor_batcher: - mock_predictoor_batcher.consume_multiple.return_value = successful_tx - assert dfbuyer_agent._consume(addresses_to_consume, times_to_consume) + mock_batcher.consume_multiple.return_value = successful_tx + assert mock_dfbuyer_agent._consume(addresses_to_consume, times_to_consume) - mock_predictoor_batcher.consume_multiple.return_value = failed_tx - assert not dfbuyer_agent._consume(addresses_to_consume, times_to_consume) + mock_batcher.consume_multiple.return_value = failed_tx + assert not mock_dfbuyer_agent._consume(addresses_to_consume, times_to_consume) - mock_predictoor_batcher.consume_multiple.side_effect = exception_tx - with pytest.raises(Exception, match="Error"): - dfbuyer_agent._consume(addresses_to_consume, times_to_consume) + mock_batcher.consume_multiple.side_effect = exception_tx + with pytest.raises(Exception, match="Error"): + mock_dfbuyer_agent._consume(addresses_to_consume, times_to_consume) - assert mock_sleep.call_count == dfbuyer_agent.config.max_request_tries + assert mock_sleep.call_count == mock_dfbuyer_agent.ppss.dfbuyer_ss.max_request_tries -def test_consume_batch_method(dfbuyer_agent): +@enforce_types +def test_dfbuyer_agent_consume_batch_method(mock_dfbuyer_agent): addresses_to_consume = ["0x1", "0x2"] times_to_consume = [2, 3] with patch.object( - dfbuyer_agent, "_consume", side_effect=[False, True, False, False, True] + mock_dfbuyer_agent, "_consume", side_effect=[False, True, False, False, True] ) as mock_consume: - dfbuyer_agent._consume_batch(addresses_to_consume, times_to_consume) + mock_dfbuyer_agent._consume_batch(addresses_to_consume, times_to_consume) calls = [ call(addresses_to_consume, times_to_consume), call([addresses_to_consume[0]], [times_to_consume[0]]), @@ -226,3 +284,25 @@ def test_consume_batch_method(dfbuyer_agent): ), ] mock_consume.assert_has_calls(calls) + + +@enforce_types +def test_dfbuyer_agent_batch_txs(mock_dfbuyer_agent): + addresses = [ZERO_ADDRESS[: -len(str(i))] + str(i) for i in range(1, 7)] + consume_times = dict(zip(addresses, [10, 30, 14, 6, 24, 16])) + + with patch.object( + mock_dfbuyer_agent, + "_consume_batch", + side_effect=[False, True, False, True, True], + ): + failures = mock_dfbuyer_agent._batch_txs(consume_times) + + assert failures + + with patch.object( + mock_dfbuyer_agent, "_consume_batch", side_effect=[True, True, True, True, True] + ): + failures = mock_dfbuyer_agent._batch_txs(consume_times) + + assert failures diff --git a/pdr_backend/dfbuyer/test/test_dfbuyer_config.py b/pdr_backend/dfbuyer/test/test_dfbuyer_config.py deleted file mode 100644 index b73087d96..000000000 --- a/pdr_backend/dfbuyer/test/test_dfbuyer_config.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -from pdr_backend.dfbuyer.dfbuyer_config import DFBuyerConfig -from pdr_backend.util.env import parse_filters - - -def test_trueval_config(): - config = DFBuyerConfig() - assert config.rpc_url == os.getenv("RPC_URL") - assert config.subgraph_url == os.getenv("SUBGRAPH_URL") - assert config.private_key == os.getenv("PRIVATE_KEY") - assert config.batch_size == int(os.getenv("CONSUME_BATCH_SIZE", "20")) - assert config.weekly_spending_limit == int( - os.getenv("WEEKLY_SPENDING_LIMIT", "37000") - ) - assert config.consume_interval_seconds == int( - os.getenv("CONSUME_INTERVAL_SECONDS", "86400") - ) - - (f0, f1, f2, f3) = parse_filters() - assert config.pair_filters == f0 - assert config.timeframe_filter == f1 - assert config.source_filter == f2 - assert config.owner_addresses == f3 diff --git a/pdr_backend/dfbuyer/test/test_main.py b/pdr_backend/dfbuyer/test/test_main.py deleted file mode 100644 index d23881a89..000000000 --- a/pdr_backend/dfbuyer/test/test_main.py +++ /dev/null @@ -1,14 +0,0 @@ -from unittest.mock import patch - -from pdr_backend.dfbuyer.dfbuyer_agent import DFBuyerAgent -from pdr_backend.dfbuyer.main import main - - -@patch.object(DFBuyerAgent, "run") -@patch.object( - DFBuyerAgent, "__init__", return_value=None -) # Mock the constructor to return None -def test_main(mock_agent_init, mock_agent_run): - main() - mock_agent_init.assert_called_once() - mock_agent_run.assert_called_once() diff --git a/pdr_backend/data_eng/constants.py b/pdr_backend/lake/constants.py similarity index 60% rename from pdr_backend/data_eng/constants.py rename to pdr_backend/lake/constants.py index 48986e4c7..63f72d41f 100644 --- a/pdr_backend/data_eng/constants.py +++ b/pdr_backend/lake/constants.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl OHLCV_COLS = ["open", "high", "low", "close", "volume"] OHLCV_DTYPES = [np.float64] * len(OHLCV_COLS) @@ -6,6 +7,13 @@ TOHLCV_COLS = ["timestamp"] + OHLCV_COLS TOHLCV_DTYPES = [np.int64] + OHLCV_DTYPES +OHLCV_DTYPES_PL = [pl.Float64] * len(OHLCV_COLS) +TOHLCV_DTYPES_PL = [pl.Int64] + OHLCV_DTYPES_PL + +TOHLCV_SCHEMA_PL = dict(zip(TOHLCV_COLS, TOHLCV_DTYPES_PL)) + # warn if OHLCV_MULT_MIN * timeframe < time-between-data < OHLCV_MULT_MAX * t OHLCV_MULT_MIN = 0.5 OHLCV_MULT_MAX = 2.5 + +DEFAULT_YAML_FILE = "ppss.yaml" diff --git a/pdr_backend/lake/fetch_ohlcv.py b/pdr_backend/lake/fetch_ohlcv.py new file mode 100644 index 000000000..1b09cdd3e --- /dev/null +++ b/pdr_backend/lake/fetch_ohlcv.py @@ -0,0 +1,110 @@ +from typing import List, Union + +from enforce_typing import enforce_types +import numpy as np + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.timeframe import Timeframe +from pdr_backend.lake.constants import ( + OHLCV_MULT_MAX, + OHLCV_MULT_MIN, +) + + +@enforce_types +def safe_fetch_ohlcv( + exch, + symbol: str, + timeframe: str, + since: int, + limit: int, +) -> Union[List[tuple], None]: + """ + @description + calls ccxt.exchange.fetch_ohlcv() but if there's an error it + emits a warning and returns None, vs crashing everything + + @arguments + exch -- eg ccxt.binanceus() + symbol -- eg "BTC/USDT". NOT "BTC-USDT" + timeframe -- eg "1h", "1m" + since -- timestamp of first candle. In unix time (in ms) + limit -- max # candles to retrieve + + @return + raw_tohlcv_data -- [a TOHLCV tuple, for each timestamp]. + where row 0 is oldest + and TOHLCV = {unix time (in ms), Open, High, Low, Close, Volume} + """ + if "-" in symbol: + raise ValueError(f"Got symbol={symbol}. It must have '/' not '-'") + + try: + return exch.fetch_ohlcv( + symbol=symbol, + timeframe=timeframe, + since=since, + limit=limit, + ) + except Exception as e: + print(f" **WARNING exchange: {e}") + return None + + +@enforce_types +def clean_raw_ohlcv( + raw_tohlcv_data: Union[list, None], + feed: ArgFeed, + st_ut: int, + fin_ut: int, +) -> list: + """ + @description + From the raw data coming directly from exchange, + condition it and account for corner cases. + + @arguments + raw_tohlcv_data -- output of safe_fetch_ohlcv(), see below + feed - ArgFeed. eg Binance ETH/USDT + st_ut -- min allowed time. A timestamp, in ms, in UTC + fin_ut -- max allowed time. "" + + @return + tohlcv_data -- cleaned data + """ + tohlcv_data = raw_tohlcv_data or [] + uts = _ohlcv_to_uts(tohlcv_data) + _warn_if_uts_have_gaps(uts, feed.timeframe) + + tohlcv_data = _filter_within_timerange(tohlcv_data, st_ut, fin_ut) + + return tohlcv_data + + +@enforce_types +def _ohlcv_to_uts(tohlcv_data: list) -> list: + return [vec[0] for vec in tohlcv_data] + + +@enforce_types +def _warn_if_uts_have_gaps(uts: List[int], timeframe: Timeframe): + if len(uts) <= 1: + return + + # Ideally, time between ohclv candles is always 5m or 1h + # But exchange data often has gaps. Warn about worst violations + diffs_ms = np.array(uts[1:]) - np.array(uts[:-1]) # in ms + diffs_m = diffs_ms / 1000 / 60 # in minutes + mn_thr = timeframe.m * OHLCV_MULT_MIN + mx_thr = timeframe.m * OHLCV_MULT_MAX + + if min(diffs_m) < mn_thr: + print(f" **WARNING: short candle time: {min(diffs_m)} min") + if max(diffs_m) > mx_thr: + print(f" **WARNING: long candle time: {max(diffs_m)} min") + + +@enforce_types +def _filter_within_timerange(tohlcv_data: list, st_ut: int, fin_ut: int) -> list: + uts = _ohlcv_to_uts(tohlcv_data) + return [vec for ut, vec in zip(uts, tohlcv_data) if st_ut <= ut <= fin_ut] diff --git a/pdr_backend/lake/gql_data_factory.py b/pdr_backend/lake/gql_data_factory.py new file mode 100644 index 000000000..d2348965b --- /dev/null +++ b/pdr_backend/lake/gql_data_factory.py @@ -0,0 +1,241 @@ +import os +from typing import Callable, Dict + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.lake.plutil import has_data, newest_ut +from pdr_backend.lake.table_pdr_predictions import ( + get_pdr_predictions_df, + predictions_schema, +) +from pdr_backend.lake.table_pdr_subscriptions import ( + get_pdr_subscriptions_df, + subscriptions_schema, +) +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_predictions import get_all_contract_ids_by_owner +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.timeutil import current_ut_ms, pretty_timestr + + +@enforce_types +class GQLDataFactory: + """ + Roles: + - From each GQL API, fill >=1 gql_dfs -> parquet files data lake + - From gql_dfs, calculate other dfs and stats + - All timestamps, after fetching, are transformed into milliseconds wherever appropriate + + Finally: + - "timestamp" values are ut: int is unix time, UTC, in ms (not s) + - "datetime" values ares python datetime.datetime, UTC + """ + + def __init__(self, ppss: PPSS): + self.ppss = ppss + + # filter by feed contract address + network = get_sapphire_postfix(ppss.web3_pp.network) + contract_list = get_all_contract_ids_by_owner( + owner_address=self.ppss.web3_pp.owner_addrs, + network=network, + ) + contract_list = [f.lower() for f in contract_list] + + # configure all tables that will be recorded onto lake + self.record_config = { + "pdr_predictions": { + "fetch_fn": get_pdr_predictions_df, + "schema": predictions_schema, + "config": { + "contract_list": contract_list, + }, + }, + "pdr_subscriptions": { + "fetch_fn": get_pdr_subscriptions_df, + "schema": subscriptions_schema, + "config": { + "contract_list": contract_list, + }, + }, + } + + def get_gql_dfs(self) -> Dict[str, pl.DataFrame]: + """ + @description + Get historical dataframes across many feeds and timeframes. + + @return + predictions_df -- *polars* Dataframe. See class docstring + """ + print("Get predictions data across many feeds and timeframes.") + + # Ss_timestamp is calculated dynamically if ss.fin_timestr = "now". + # But, we don't want fin_timestamp changing as we gather data here. + # To solve, for a given call to this method, we make a constant fin_ut + fin_ut = self.ppss.lake_ss.fin_timestamp + + print(f" Data start: {pretty_timestr(self.ppss.lake_ss.st_timestamp)}") + print(f" Data fin: {pretty_timestr(fin_ut)}") + + self._update(fin_ut) + gql_dfs = self._load_parquet(fin_ut) + + print("Get historical data across many subgraphs. Done.") + + # postconditions + assert len(gql_dfs.values()) > 0 + for df in gql_dfs.values(): + assert isinstance(df, pl.DataFrame) + + return gql_dfs + + def _update(self, fin_ut: int): + """ + @description + Iterate across all gql queries and update their parquet files: + - Predictoors + - Slots + - Claims + + Improve this by: + 1. Break out raw data from any transformed/cleaned data + 2. Integrate other queries and summaries + 3. Integrate config/pp if needed + @arguments + fin_ut -- a timestamp, in ms, in UTC + """ + + for k, record in self.record_config.items(): + filename = self._parquet_filename(k) + print(f" filename={filename}") + + st_ut = self._calc_start_ut(filename) + print(f" Aim to fetch data from start time: {pretty_timestr(st_ut)}") + if st_ut > min(current_ut_ms(), fin_ut): + print(" Given start time, no data to gather. Exit.") + continue + + # to satisfy mypy, get an explicit function pointer + do_fetch: Callable[[str, int, int, Dict], pl.DataFrame] = record["fetch_fn"] + + # call the function + print(f" Fetching {k}") + df = do_fetch(self.ppss.web3_pp.network, st_ut, fin_ut, record["config"]) + + # postcondition + if len(df) > 0: + assert df.schema == record["schema"] + + # save to parquet + self._save_parquet(filename, df) + + def _calc_start_ut(self, filename: str) -> int: + """ + @description + Calculate start timestamp, reconciling whether file exists and where + its data starts. If file exists, you can only append to end. + + @arguments + filename - parquet file with data. May or may not exist. + + @return + start_ut - timestamp (ut) to start grabbing data for (in ms) + """ + if not os.path.exists(filename): + print(" No file exists yet, so will fetch all data") + return self.ppss.lake_ss.st_timestamp + + print(" File already exists") + if not has_data(filename): + print(" File has no data, so delete it") + os.remove(filename) + return self.ppss.lake_ss.st_timestamp + + file_utN = newest_ut(filename) + return file_utN + 1000 + + def _load_parquet(self, fin_ut: int) -> Dict[str, pl.DataFrame]: + """ + @arguments + fin_ut -- finish timestamp + + @return + gql_dfs -- dict of [gql_filename] : df + Where df has columns=GQL_COLS+"datetime", and index=timestamp + """ + print(" Load parquet.") + st_ut = self.ppss.lake_ss.st_timestamp + + dfs: Dict[str, pl.DataFrame] = {} # [parquet_filename] : df + + for k, record in self.record_config.items(): + filename = self._parquet_filename(k) + print(f" filename={filename}") + + # load all data from file + # check if file exists + # if file doesn't exist, return an empty dataframe with the expected schema + if os.path.exists(filename): + df = pl.read_parquet(filename) + else: + df = pl.DataFrame(schema=record["schema"]) + + df = df.filter( + (pl.col("timestamp") >= st_ut) & (pl.col("timestamp") <= fin_ut) + ) + + # postcondition + assert df.schema == record["schema"] + dfs[k] = df + + return dfs + + def _parquet_filename(self, filename_str: str) -> str: + """ + @description + Computes the lake-path for the parquet file. + + @arguments + filename_str -- eg "subgraph_predictions" + + @return + parquet_filename -- name for parquet file. + """ + basename = f"{filename_str}.parquet" + filename = os.path.join(self.ppss.lake_ss.parquet_dir, basename) + return filename + + @enforce_types + def _save_parquet(self, filename: str, df: pl.DataFrame): + """write to parquet file + parquet only supports appending via the pyarrow engine + """ + + # precondition + assert "timestamp" in df.columns and df["timestamp"].dtype == pl.Int64 + assert len(df) > 0 + if len(df) > 1: + assert ( + df.head(1)["timestamp"].to_list()[0] + <= df.tail(1)["timestamp"].to_list()[0] + ) + + if os.path.exists(filename): # "append" existing file + cur_df = pl.read_parquet(filename) + df = pl.concat([cur_df, df]) + + # check for duplicates and throw error if any found + duplicate_rows = df.filter(pl.struct("ID").is_duplicated()) + if len(duplicate_rows) > 0: + raise Exception( + f"Not saved. Duplicate rows found. {len(duplicate_rows)} rows: {duplicate_rows}" + ) + + df.write_parquet(filename) + n_new = df.shape[0] - cur_df.shape[0] + print(f" Just appended {n_new} df rows to file {filename}") + else: # write new file + df.write_parquet(filename) + print(f" Just saved df with {df.shape[0]} rows to new file {filename}") diff --git a/pdr_backend/lake/merge_df.py b/pdr_backend/lake/merge_df.py new file mode 100644 index 000000000..922947669 --- /dev/null +++ b/pdr_backend/lake/merge_df.py @@ -0,0 +1,126 @@ +from typing import List, Union + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.lake.plutil import set_col_values + + +@enforce_types +def merge_rawohlcv_dfs(rawohlcv_dfs: dict) -> pl.DataFrame: + """ + @arguments + rawohlcv_dfs -- see class docstring + + @return + mergedohlcv_df -- see class docstring + """ + # preconditions + raw_dfs = rawohlcv_dfs + _verify_pair_strs(raw_dfs) + + # initialize merged_df with all timestamps seen + all_uts_set: set = set() + for exch_str in raw_dfs.keys(): + for raw_df in raw_dfs[exch_str].values(): + all_uts_set = all_uts_set.union(raw_df["timestamp"].to_list()) + all_uts: list = sorted(all_uts_set) + merged_df = pl.DataFrame({"timestamp": all_uts}) + + # merge in data from each raw_df. It can handle inconsistent # rows. + for exch_str in raw_dfs.keys(): + for pair_str, raw_df in rawohlcv_dfs[exch_str].items(): + for raw_col in raw_df.columns: + if raw_col == "timestamp": + continue + signal_str = raw_col # eg "close" + merged_col = f"{exch_str}:{pair_str}:{signal_str}" + merged_df = _add_df_col(merged_df, merged_col, raw_df, raw_col) + + # order the columns + merged_df = merged_df.select(_ordered_cols(merged_df.columns)) # type: ignore + # postconditions, return + _verify_df_cols(merged_df) + return merged_df + + +@enforce_types +def _add_df_col( + merged_df: Union[pl.DataFrame, None], + merged_col: str, # eg "binance:BTC/USDT:close" + raw_df: pl.DataFrame, + raw_col: str, # eg "close" +) -> pl.DataFrame: + """ + Does polars equivalent of: merged_df[merged_col] = raw_df[raw_col]. + Tuned for this factory, by keeping "timestamp" + """ + # if raw_df has no rows, then give it many rows with null value + # (this is needed to avoid issues in df.join() below) + if raw_df.shape[0] == 0: # empty + assert merged_df is not None + timestamps = merged_df["timestamp"].to_list() + d = {"timestamp": timestamps, raw_col: [None] * len(timestamps)} + raw_df = pl.DataFrame(d) + + # newraw_df = copy of raw_df, with raw_col renamed -> merged_col, +timestamp + newraw_df = raw_df.with_columns( + pl.col(raw_col).alias(merged_col), + ) + newraw_df = newraw_df.select(["timestamp", merged_col]) + + # now join the cols of newraw_df into merged_df + if merged_df is None: + merged_df = newraw_df + else: + merged_df = merged_df.join(newraw_df, on="timestamp", how="outer") + merged_df = merge_cols(merged_df, "timestamp", "timestamp_right") + + # re-order merged_df's columns + merged_df = merged_df.select(_ordered_cols(merged_df.columns)) # type: ignore + + # postconditions, return + _verify_df_cols(merged_df) + return merged_df + + +@enforce_types +def merge_cols(df: pl.DataFrame, col1: str, col2: str) -> pl.DataFrame: + """Keep the non-null versions of col1 & col2, in col1. Drop col2.""" + assert col1 in df + if col2 not in df: + return df + n_rows = df.shape[0] + new_vals = [df[col1][i] or df[col2][i] for i in range(n_rows)] + df = set_col_values(df, col1, new_vals) + df = df.drop(col2) + return df + + +@enforce_types +def _ordered_cols(merged_cols: List[str]) -> List[str]: + """Returns in order ["timestamp", item1, item2, item3, ...]""" + assert "timestamp" in merged_cols + assert len(set(merged_cols)) == len(merged_cols) + + ordered_cols = [] + ordered_cols += ["timestamp"] + ordered_cols += [col for col in merged_cols if col != "timestamp"] + return ordered_cols + + +@enforce_types +def _verify_df_cols(df: pl.DataFrame): + assert "timestamp" in df.columns + assert "datetime" not in df.columns + for col in df.columns: + assert "_right" not in col + assert "_left" not in col + assert df.columns == _ordered_cols(df.columns) + + +@enforce_types +def _verify_pair_strs(rawohlcv_dfs): + for exch_str in rawohlcv_dfs.keys(): + for pair_str in rawohlcv_dfs[exch_str].keys(): + assert "/" in str(pair_str), f"pair_str={pair_str} needs '/'" diff --git a/pdr_backend/lake/ohlcv_data_factory.py b/pdr_backend/lake/ohlcv_data_factory.py new file mode 100644 index 000000000..b2ab30f35 --- /dev/null +++ b/pdr_backend/lake/ohlcv_data_factory.py @@ -0,0 +1,243 @@ +import os +from typing import Dict + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.timeframe import Timeframe +from pdr_backend.lake.constants import ( + TOHLCV_SCHEMA_PL, + TOHLCV_COLS, +) +from pdr_backend.lake.fetch_ohlcv import clean_raw_ohlcv, safe_fetch_ohlcv +from pdr_backend.lake.merge_df import merge_rawohlcv_dfs +from pdr_backend.lake.plutil import ( + concat_next_df, + has_data, + initialize_rawohlcv_df, + load_rawohlcv_file, + newest_ut, + oldest_ut, + save_rawohlcv_file, +) +from pdr_backend.ppss.lake_ss import LakeSS +from pdr_backend.util.timeutil import current_ut_ms, pretty_timestr + + +@enforce_types +class OhlcvDataFactory: + """ + Roles: + - From each CEX API, fill >=1 rawohlcv_dfs -> rawohlcv files data lake + - From rawohlcv_dfs, fill 1 mergedohlcv_df -- all data across all CEXes + + Where: + rawohlcv_dfs -- dict of [exch_str][pair_str] : df + And df has columns of: "timestamp", "open", "high", .., "volume" + And NOT "datetime" column + Where pair_str must have '/' not '-', to avoid key issues + + + mergedohlcv_df -- polars DataFrame with cols like: + "timestamp", + "binanceus:ETH-USDT:open", + "binanceus:ETH-USDT:high", + "binanceus:ETH-USDT:low", + "binanceus:ETH-USDT:close", + "binanceus:ETH-USDT:volume", + ... + (NOT "datetime") + + #For each column: oldest first, newest at the end + + Finally: + - "timestamp" values are ut: int is unix time, UTC, in ms (not s) + """ + + def __init__(self, ss: LakeSS): + self.ss = ss + + def get_mergedohlcv_df(self) -> pl.DataFrame: + """ + @description + Get dataframe of all ohlcv data: merged from many exchanges & pairs. + + @return + mergedohlcv_df -- *polars* Dataframe. See class docstring + """ + print("Get historical data, across many exchanges & pairs: begin.") + + # Ss_timestamp is calculated dynamically if ss.fin_timestr = "now". + # But, we don't want fin_timestamp changing as we gather data here. + # To solve, for a given call to this method, we make a constant fin_ut + fin_ut = self.ss.fin_timestamp + + print(f" Data start: {pretty_timestr(self.ss.st_timestamp)}") + print(f" Data fin: {pretty_timestr(fin_ut)}") + + self._update_rawohlcv_files(fin_ut) + rawohlcv_dfs = self._load_rawohlcv_files(fin_ut) + mergedohlcv_df = merge_rawohlcv_dfs(rawohlcv_dfs) + + print("Get historical data, across many exchanges & pairs: done.") + + # postconditions + assert isinstance(mergedohlcv_df, pl.DataFrame) + return mergedohlcv_df + + def _update_rawohlcv_files(self, fin_ut: int): + print(" Update all rawohlcv files: begin") + for feed in self.ss.feeds: + self._update_rawohlcv_files_at_feed(feed, fin_ut) + + print() + print(" Update all rawohlcv files: done") + + def _update_rawohlcv_files_at_feed(self, feed: ArgFeed, fin_ut: int): + """ + @arguments + feed -- ArgFeed + fin_ut -- a timestamp, in ms, in UTC + """ + pair_str = str(feed.pair) + exch_str = str(feed.exchange) + assert "/" in str(pair_str), f"pair_str={pair_str} needs '/'" + print() + print( + f" Update rawohlcv file at exchange={exch_str}, pair={pair_str}: begin" + ) + + filename = self._rawohlcv_filename(feed) + print(f" filename={filename}") + + assert feed.timeframe + st_ut = self._calc_start_ut_maybe_delete(feed.timeframe, filename) + print(f" Aim to fetch data from start time: {pretty_timestr(st_ut)}") + if st_ut > min(current_ut_ms(), fin_ut): + print(" Given start time, no data to gather. Exit.") + return + + # empty ohlcv df + df = initialize_rawohlcv_df() + while True: + limit = 1000 + print(f" Fetch up to {limit} pts from {pretty_timestr(st_ut)}") + exch = feed.exchange.exchange_class() + raw_tohlcv_data = safe_fetch_ohlcv( + exch, + symbol=str(pair_str).replace("-", "/"), + timeframe=str(feed.timeframe), + since=st_ut, + limit=limit, + ) + tohlcv_data = clean_raw_ohlcv(raw_tohlcv_data, feed, st_ut, fin_ut) + + # concat both TOHLCV data + next_df = pl.DataFrame(tohlcv_data, schema=TOHLCV_SCHEMA_PL) + df = concat_next_df(df, next_df) + + if len(tohlcv_data) < limit: # no more data, we're at newest time + break + + # prep next iteration + newest_ut_value = df.tail(1)["timestamp"][0] + + print(f" newest_ut_value: {newest_ut_value}") + st_ut = newest_ut_value + feed.timeframe.ms + + # output to file + save_rawohlcv_file(filename, df) + + # done + print(f" Update rawohlcv file at exchange={exch_str}, pair={pair_str}: done") + + def _calc_start_ut_maybe_delete(self, timeframe: Timeframe, filename: str) -> int: + """ + @description + Calculate start timestamp, reconciling whether file exists and where + its data starts. Will delete file if it's inconvenient to re-use + + @arguments + timeframe - Timeframe + filename - csv file with data. May or may not exist. + + @return + start_ut - timestamp (ut) to start grabbing data for + """ + if not os.path.exists(filename): + print(" No file exists yet, so will fetch all data") + return self.ss.st_timestamp + + print(" File already exists") + if not has_data(filename): + print(" File has no data, so delete it") + os.remove(filename) + return self.ss.st_timestamp + + file_ut0, file_utN = oldest_ut(filename), newest_ut(filename) + print(f" File starts at: {pretty_timestr(file_ut0)}") + print(f" File finishes at: {pretty_timestr(file_utN)}") + + if self.ss.st_timestamp >= file_ut0: + print(" User-specified start >= file start, so append file") + return file_utN + timeframe.ms + + print(" User-specified start < file start, so delete file") + os.remove(filename) + return self.ss.st_timestamp + + def _load_rawohlcv_files(self, fin_ut: int) -> Dict[str, Dict[str, pl.DataFrame]]: + """ + @arguments + fin_ut -- finish timestamp + + @return + rawohlcv_dfs -- dict of [exch_str][pair_str] : ohlcv_df + Where df has columns: TOHLCV_COLS + And pair_str is eg "BTC/USDT", *not* "BTC-USDT" + """ + print(" Load rawohlcv file.") + st_ut = self.ss.st_timestamp + + rawohlcv_dfs: Dict[str, Dict[str, pl.DataFrame]] = {} # [exch][pair] : df + for exch_str in self.ss.exchange_strs: + rawohlcv_dfs[exch_str] = {} + + for feed in self.ss.feeds: + pair_str = str(feed.pair) + exch_str = str(feed.exchange) + assert "/" in str(pair_str), f"pair_str={pair_str} needs '/'" + filename = self._rawohlcv_filename(feed) + cols = TOHLCV_COLS + rawohlcv_df = load_rawohlcv_file(filename, cols, st_ut, fin_ut) + + assert "timestamp" in rawohlcv_df.columns + assert "datetime" not in rawohlcv_df.columns + + rawohlcv_dfs[exch_str][pair_str] = rawohlcv_df + + # rawohlcv_dfs["kraken"] is a DF, with proper cols, and 0 rows + + return rawohlcv_dfs + + def _rawohlcv_filename(self, feed: ArgFeed) -> str: + """ + @description + Computes a filename for the rawohlcv data. + + @arguments + feed -- ArgFeed + + @return + rawohlcv_filename -- + + @notes + If pair_str has '/', it will become '-' in the filename. + """ + pair_str = str(feed.pair) + assert "/" in str(pair_str) or "-" in pair_str, pair_str + pair_str = str(pair_str).replace("/", "-") # filesystem needs "-" + basename = f"{feed.exchange}_{pair_str}_{feed.timeframe}.parquet" + filename = os.path.join(self.ss.parquet_dir, basename) + return filename diff --git a/pdr_backend/lake/plutil.py b/pdr_backend/lake/plutil.py new file mode 100644 index 000000000..8d2e0aa71 --- /dev/null +++ b/pdr_backend/lake/plutil.py @@ -0,0 +1,193 @@ +""" +plutil: polars dataframe & csv/parquet utilities. +These utilities are specific to the time-series dataframe columns we're using. +""" +import os +import shutil +from io import StringIO +from tempfile import mkdtemp +from typing import List, Dict + +import numpy as np +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.lake.constants import TOHLCV_COLS, TOHLCV_SCHEMA_PL + + +@enforce_types +def initialize_rawohlcv_df(cols: List[str] = []) -> pl.DataFrame: + """Start an empty df with the expected columns and schema + Applies transform to get columns + """ + df = pl.DataFrame(data=[], schema=TOHLCV_SCHEMA_PL) + df = df.select(cols if cols else "*") + return df + + +@enforce_types +def set_col_values(df: pl.DataFrame, col: str, new_vals: list) -> pl.DataFrame: + """Equivalent to: df[col] = new_vals""" + return df.with_columns(pl.Series(new_vals).alias(col)) + + +@enforce_types +def concat_next_df(df: pl.DataFrame, next_df: pl.DataFrame) -> pl.DataFrame: + """Add a next_df to existing df, with the expected columns etc. + Makes sure that both schemas match before concatenating. + """ + assert df.schema == next_df.schema + df = pl.concat([df, next_df]) + return df + + +@enforce_types +def save_rawohlcv_file(filename: str, df: pl.DataFrame): + """write to parquet file + parquet only supports appending via the pyarrow engine + """ + # preconditions + assert df.columns[:6] == TOHLCV_COLS + assert "datetime" not in df.columns + + # parquet column order: timestamp, O, H, L, C, V + columns = TOHLCV_COLS + + df = df.select(columns) + + if os.path.exists(filename): # append existing file + cur_df = pl.read_parquet(filename) + df = pl.concat([cur_df, df]) + df.write_parquet(filename) + n_new = df.shape[0] - cur_df.shape[0] + print(f" Just appended {n_new} df rows to file {filename}") + else: # write new file + df.write_parquet(filename) + print(f" Just saved df with {df.shape[0]} rows to new file {filename}") + + +@enforce_types +def load_rawohlcv_file(filename: str, cols=None, st=None, fin=None) -> pl.DataFrame: + """Load parquet file as a dataframe. + + Features: + - Ensure that all dtypes are correct + - Filter to just the input columns + - Filter to just the specified start & end times + - Memory stays reasonable + + @arguments + cols -- what columns to use, eg ["open","high"]. Set to None for all cols. + st -- starting timestamp, in ut. Set to 0 or None for very beginning + fin -- ending timestamp, in ut. Set to inf or None for very end + + @return + df -- dataframe + + @notes + Polars does not have an index. "timestamp" is a regular col + """ + # handle cols + if cols is None: + cols = TOHLCV_COLS + if "timestamp" not in cols: + cols = ["timestamp"] + cols + assert "datetime" not in cols + + # set st, fin + st = st if st is not None else 0 + fin = fin if fin is not None else np.inf + + # load tohlcv + df = pl.read_parquet( + filename, + columns=cols, + ) + df = df.filter((pl.col("timestamp") >= st) & (pl.col("timestamp") <= fin)) + + # initialize df and enforce schema + df0 = initialize_rawohlcv_df(cols) + df = concat_next_df(df0, df) + + # postconditions, return + assert "timestamp" in df.columns and df["timestamp"].dtype == pl.Int64 + assert "datetime" not in df.columns + + return df + + +@enforce_types +def has_data(filename: str) -> bool: + """Returns True if the file has >0 data entries""" + df = pl.read_parquet(filename, n_rows=1) + return not df.is_empty() + + +@enforce_types +def newest_ut(filename: str) -> int: + """ + Return the timestamp for the youngest entry in the file. + The latest date should be the tail (row = n), or last entry in the file/dataframe + """ + df = _get_tail_df(filename, n=1) + ut = int(df["timestamp"][0]) + return ut + + +@enforce_types +def _get_tail_df(filename: str, n: int = 5) -> pl.DataFrame: + """Returns the last record in a parquet file, as a list""" + + df = pl.read_parquet(filename) + tail_df = df.tail(n) + if not tail_df.is_empty(): + return tail_df + raise ValueError(f"File {filename} has no entries") + + +@enforce_types +def oldest_ut(filename: str) -> int: + """ + Return the timestamp for the oldest entry in the parquet file. + The oldest date should be the head (row = 0), or the first entry in the file/dataframe + """ + df = _get_head_df(filename, n=1) + ut = int(df["timestamp"][0]) + return ut + + +@enforce_types +def _get_head_df(filename: str, n: int = 5) -> pl.DataFrame: + """Returns the head of parquet file, as a df""" + df = pl.read_parquet(filename) + head_df = df.head(n) + if not head_df.is_empty(): + return head_df + raise ValueError(f"File {filename} has no entries") + + +@enforce_types +def text_to_df(s: str) -> pl.DataFrame: + tmpdir = mkdtemp() + filename = os.path.join(tmpdir, "df.psv") + s = StringIO(s) # type: ignore + with open(filename, "w") as f: + for line in s: + f.write(line) + df = pl.scan_csv(filename, separator="|").collect() + shutil.rmtree(tmpdir) + return df + + +@enforce_types +def _object_list_to_df(objects: List[object], schema: Dict) -> pl.DataFrame: + """ + @description + Convert list objects to a dataframe using their __dict__ structure. + """ + # Get all predictions into a dataframe + obj_dicts = [object.__dict__ for object in objects] + obj_df = pl.DataFrame(obj_dicts, schema=schema) + assert obj_df.schema == schema + + return obj_df diff --git a/pdr_backend/lake/table_pdr_predictions.py b/pdr_backend/lake/table_pdr_predictions.py new file mode 100644 index 000000000..eac364e4b --- /dev/null +++ b/pdr_backend/lake/table_pdr_predictions.py @@ -0,0 +1,76 @@ +from typing import Dict + +import polars as pl +from enforce_typing import enforce_types +from polars import Boolean, Float64, Int64, Utf8 + +from pdr_backend.subgraph.subgraph_predictions import ( + FilterMode, + fetch_filtered_predictions, +) +from pdr_backend.lake.plutil import _object_list_to_df +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.timeutil import ms_to_seconds + +# RAW PREDICTOOR PREDICTIONS SCHEMA +predictions_schema = { + "ID": Utf8, + "pair": Utf8, + "timeframe": Utf8, + "prediction": Boolean, + "stake": Float64, + "trueval": Boolean, + "timestamp": Int64, + "source": Utf8, + "payout": Float64, + "slot": Int64, + "user": Utf8, +} + + +def _transform_timestamp_to_ms(df: pl.DataFrame) -> pl.DataFrame: + df = df.with_columns( + [ + pl.col("timestamp").mul(1000).alias("timestamp"), + ] + ) + return df + + +@enforce_types +def get_pdr_predictions_df( + network: str, st_ut: int, fin_ut: int, config: Dict +) -> pl.DataFrame: + """ + @description + Fetch raw predictions from predictoor subgraph + Update function for graphql query, returns raw data + + Transforms ts into ms as required for data factory + """ + network = get_sapphire_postfix(network) + + # fetch predictions + predictions = fetch_filtered_predictions( + ms_to_seconds(st_ut), + ms_to_seconds(fin_ut), + config["contract_list"], + network, + FilterMode.CONTRACT_TS, + payout_only=False, + trueval_only=False, + ) + + if len(predictions) == 0: + print(" No predictions to fetch. Exit.") + return pl.DataFrame() + + # convert predictions to df and transform timestamp into ms + predictions_df = _object_list_to_df(predictions, predictions_schema) + predictions_df = _transform_timestamp_to_ms(predictions_df) + + # cull any records outside of our time range and sort them by timestamp + predictions_df = predictions_df.filter( + pl.col("timestamp").is_between(st_ut, fin_ut) + ).sort("timestamp") + + return predictions_df diff --git a/pdr_backend/lake/table_pdr_subscriptions.py b/pdr_backend/lake/table_pdr_subscriptions.py new file mode 100644 index 000000000..b44d23ea2 --- /dev/null +++ b/pdr_backend/lake/table_pdr_subscriptions.py @@ -0,0 +1,59 @@ +from typing import Dict + +import polars as pl +from enforce_typing import enforce_types +from polars import Int64, Utf8, Float32 + +from pdr_backend.subgraph.subgraph_subscriptions import ( + fetch_filtered_subscriptions, +) +from pdr_backend.lake.table_pdr_predictions import _transform_timestamp_to_ms +from pdr_backend.lake.plutil import _object_list_to_df +from pdr_backend.util.networkutil import get_sapphire_postfix +from pdr_backend.util.timeutil import ms_to_seconds + + +# RAW PREDICTOOR SUBSCRIPTIONS SCHEMA +subscriptions_schema = { + "ID": Utf8, + "pair": Utf8, + "timeframe": Utf8, + "source": Utf8, + "tx_id": Utf8, + "last_price_value": Float32, + "timestamp": Int64, + "user": Utf8, +} + + +@enforce_types +def get_pdr_subscriptions_df( + network: str, st_ut: int, fin_ut: int, config: Dict +) -> pl.DataFrame: + """ + @description + Fetch raw subscription events from predictoor subgraph + Update function for graphql query, returns raw data + + Transforms ts into ms as required for data factory + """ + network = get_sapphire_postfix(network) + + # fetch subscriptions + subscriptions = fetch_filtered_subscriptions( + ms_to_seconds(st_ut), ms_to_seconds(fin_ut), config["contract_list"], network + ) + + if len(subscriptions) == 0: + print(" No subscriptions fetched. Exit.") + return pl.DataFrame() + + # convert subscriptions to df and transform timestamp into ms + subscriptions_df = _object_list_to_df(subscriptions, subscriptions_schema) + subscriptions_df = _transform_timestamp_to_ms(subscriptions_df) + + # cull any records outside of our time range and sort them by timestamp + subscriptions_df = subscriptions_df.filter( + pl.col("timestamp").is_between(st_ut, fin_ut) + ).sort("timestamp") + + return subscriptions_df diff --git a/pdr_backend/lake/test/conftest.py b/pdr_backend/lake/test/conftest.py new file mode 100644 index 000000000..1b9ff717d --- /dev/null +++ b/pdr_backend/lake/test/conftest.py @@ -0,0 +1,14 @@ +import pytest + +from pdr_backend.subgraph.prediction import mock_daily_predictions +from pdr_backend.subgraph.subscription import mock_subscriptions + + +@pytest.fixture() +def sample_daily_predictions(): + return mock_daily_predictions() + + +@pytest.fixture() +def sample_subscriptions(): + return mock_subscriptions() diff --git a/pdr_backend/lake/test/resources.py b/pdr_backend/lake/test/resources.py new file mode 100644 index 000000000..a61f38510 --- /dev/null +++ b/pdr_backend/lake/test/resources.py @@ -0,0 +1,184 @@ +import copy +from typing import Dict + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory +from pdr_backend.lake.constants import TOHLCV_COLS, TOHLCV_SCHEMA_PL +from pdr_backend.lake.gql_data_factory import GQLDataFactory +from pdr_backend.lake.merge_df import merge_rawohlcv_dfs +from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory +from pdr_backend.lake.plutil import concat_next_df, initialize_rawohlcv_df, text_to_df +from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.ppss.lake_ss import LakeSS +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.ppss.web3_pp import mock_web3_pp + + +@enforce_types +def _mergedohlcv_df_ETHUSDT(tmpdir): + _, _, aimodel_data_factory = _predictoor_ss_1feed(tmpdir, "binanceus ETH/USDT h 5m") + mergedohlcv_df = merge_rawohlcv_dfs(ETHUSDT_RAWOHLCV_DFS) + return mergedohlcv_df, aimodel_data_factory + + +@enforce_types +def _predictoor_ss_1feed(tmpdir, feed): + predictoor_ss = _predictoor_ss(feed, [feed]) + lake_ss = _lake_ss(tmpdir, [feed]) + ohlcv_data_factory = OhlcvDataFactory(lake_ss) + aimodel_data_factory = AimodelDataFactory(predictoor_ss) + return predictoor_ss, ohlcv_data_factory, aimodel_data_factory + + +@enforce_types +def _lake_ss_1feed(tmpdir, feed, st_timestr=None, fin_timestr=None): + parquet_dir = str(tmpdir) + ss = _lake_ss(parquet_dir, [feed], st_timestr, fin_timestr) + ohlcv_data_factory = OhlcvDataFactory(ss) + return ss, ohlcv_data_factory + + +@enforce_types +def _gql_data_factory(tmpdir, feed, st_timestr=None, fin_timestr=None): + network = "sapphire-mainnet" + ppss = mock_ppss([feed], network, str(tmpdir), st_timestr, fin_timestr) + ppss.web3_pp = mock_web3_pp(network) + + # setup lake + parquet_dir = str(tmpdir) + lake_ss = _lake_ss(parquet_dir, [feed], st_timestr, fin_timestr) + ppss.lake_ss = lake_ss + + gql_data_factory = GQLDataFactory(ppss) + return ppss, gql_data_factory + + +def _filter_gql_config(record_config: Dict, record_filter: str) -> Dict: + # Return a filtered version of record_config for testing + return {k: v for k, v in record_config.items() if k == record_filter} + + +@enforce_types +def _predictoor_ss(feed, feeds): + return PredictoorSS( + { + "predict_feed": feed, + "timeframe": "5m", + "bot_only": {"s_until_epoch_end": 60, "stake_amount": 1}, + "aimodel_ss": { + "input_feeds": feeds, + "approach": "LIN", + "max_n_train": 7, + "autoregressive_n": 3, + }, + } + ) + + +@enforce_types +def _lake_ss(parquet_dir, feeds, st_timestr=None, fin_timestr=None): + return LakeSS( + { + "feeds": feeds, + "parquet_dir": parquet_dir, + "st_timestr": st_timestr or "2023-06-18", + "fin_timestr": fin_timestr or "2023-06-21", + "timeframe": "5m", + } + ) + + +# ================================================================== + + +@enforce_types +def _df_from_raw_data(raw_data: list) -> pl.DataFrame: + """Return a df for use in rawohlcv_dfs""" + df = initialize_rawohlcv_df(TOHLCV_COLS) + + next_df = pl.DataFrame(raw_data, schema=TOHLCV_SCHEMA_PL) + + df = concat_next_df(df, next_df) + + return df + + +BINANCE_ETH_DATA = [ # oldest first, newest on the bottom (at the end) + # time #o #h #l #c #v + [1686805500000, 0.5, 12, 0.12, 1.1, 7.0], + [1686805800000, 0.5, 11, 0.11, 2.2, 7.0], + [1686806100000, 0.5, 10, 0.10, 3.3, 7.0], + [1686806400000, 1.1, 9, 0.09, 4.4, 1.4], + [1686806700000, 3.5, 8, 0.08, 5.5, 2.8], + [1686807000000, 4.7, 7, 0.07, 6.6, 8.1], + [1686807300000, 4.5, 6, 0.06, 7.7, 8.1], + [1686807600000, 0.6, 5, 0.05, 8.8, 8.1], + [1686807900000, 0.9, 4, 0.04, 9.9, 8.1], + [1686808200000, 2.7, 3, 0.03, 10.10, 8.1], + [1686808500000, 0.7, 2, 0.02, 11.11, 8.1], + [1686808800000, 0.7, 1, 0.01, 12.12, 8.3], +] + + +@enforce_types +def _addval(DATA: list, val: float) -> list: + DATA2 = copy.deepcopy(DATA) + for row_i, row in enumerate(DATA2): + for col_j, _ in enumerate(row): + if col_j == 0: + continue + DATA2[row_i][col_j] += val + return DATA2 + + +BINANCE_BTC_DATA = _addval(BINANCE_ETH_DATA, 10000.0) +KRAKEN_ETH_DATA = _addval(BINANCE_ETH_DATA, 0.0001) +KRAKEN_BTC_DATA = _addval(BINANCE_ETH_DATA, 10000.0 + 0.0001) + +ETHUSDT_RAWOHLCV_DFS = { + "binanceus": { + "ETH/USDT": _df_from_raw_data(BINANCE_ETH_DATA), + } +} + +# ================================================================== + +RAW_DF1 = text_to_df( # binance BTC/USDT + """timestamp|open|close +0|10.0|11.0 +1|10.1|11.1 +3|10.3|11.3 +4|10.4|11.4 +""" +) # does not have: "2|10.2|11.2" to simulate missing vals from exchanges + +RAW_DF2 = text_to_df( # binance ETH/USDT + """timestamp|open|close +0|20.0|21.0 +1|20.1|21.1 +2|20.2|21.2 +3|20.3|21.3 +""" +) # does *not* have: "4|20.4|21.4" to simulate missing vals from exchanges + +RAW_DF3 = text_to_df( # kraken BTC/USDT + """timestamp|open|close +0|30.0|31.0 +1|30.1|31.1 +2|30.2|31.2 +3|30.3|31.3 +4|30.4|31.4 +""" +) + +RAW_DF4 = text_to_df( # kraken ETH/USDT + """timestamp|open|close +0|40.0|41.0 +1|40.1|41.1 +2|40.2|41.2 +3|40.3|41.3 +4|40.4|41.4 +""" +) diff --git a/pdr_backend/lake/test/test_fetch_ohlcv.py b/pdr_backend/lake/test/test_fetch_ohlcv.py new file mode 100644 index 000000000..93f6a8302 --- /dev/null +++ b/pdr_backend/lake/test/test_fetch_ohlcv.py @@ -0,0 +1,93 @@ +import ccxt +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.lake.fetch_ohlcv import safe_fetch_ohlcv, clean_raw_ohlcv +from pdr_backend.util.timeutil import timestr_to_ut + + +MPE = 300000 # ms per 5min epoch +T4, T5, T6, T7, T8, T10 = 4 * MPE, 5 * MPE, 6 * MPE, 7 * MPE, 8 * MPE, 10 * MPE + +# ut #o #h #l #c #v +RAW5 = [T5, 0.5, 12, 0.12, 1.1, 7.0] +RAW6 = [T6, 0.5, 11, 0.11, 2.2, 7.0] +RAW7 = [T7, 0.5, 10, 0.10, 3.3, 7.0] +RAW8 = [T8, 0.5, 9, 0.09, 4.4, 7.0] + + +@enforce_types +def test_clean_raw_ohlcv(): + feed = ArgFeed("binanceus", None, "ETH/USDT", "5m") + + assert clean_raw_ohlcv(None, feed, 0, 0) == [] + assert clean_raw_ohlcv([], feed, 0, 0) == [] + + # RAW5v is the shape of "raw_tohlcv_data" with just one candle + RAW5v = [RAW5] + assert clean_raw_ohlcv(RAW5v, feed, 0, 0) == [] + assert clean_raw_ohlcv(RAW5v, feed, 0, T4) == [] + assert clean_raw_ohlcv(RAW5v, feed, T6, T10) == [] + assert clean_raw_ohlcv(RAW5v, feed, T5, T5) == RAW5v + assert clean_raw_ohlcv(RAW5v, feed, 0, T10) == RAW5v + assert clean_raw_ohlcv(RAW5v, feed, 0, T5) == RAW5v + assert clean_raw_ohlcv(RAW5v, feed, T5, T10) == RAW5v + + # RAW5v is the shape of "raw_tohlcv_data" with four candles + RAW5678v = [RAW5, RAW6, RAW7, RAW8] + assert clean_raw_ohlcv(RAW5678v, feed, 0, 0) == [] + assert clean_raw_ohlcv(RAW5678v, feed, 0, T10) == RAW5678v + assert clean_raw_ohlcv(RAW5678v, feed, T5, T5) == [RAW5] + assert clean_raw_ohlcv(RAW5678v, feed, T6, T6) == [RAW6] + assert clean_raw_ohlcv(RAW5678v, feed, T5, T6) == [RAW5, RAW6] + assert clean_raw_ohlcv(RAW5678v, feed, T5, T8) == RAW5678v + assert clean_raw_ohlcv(RAW5678v, feed, T7, T8) == [RAW7, RAW8] + assert clean_raw_ohlcv(RAW5678v, feed, T8, T8) == [RAW8] + + +@enforce_types +@pytest.mark.parametrize("exch", [ccxt.binanceus(), ccxt.kraken()]) +def test_safe_fetch_ohlcv(exch): + since = timestr_to_ut("2023-06-18") + symbol, timeframe, limit = "ETH/USDT", "5m", 1000 + + # happy path + raw_tohlc_data = safe_fetch_ohlcv(exch, symbol, timeframe, since, limit) + assert_raw_tohlc_data_ok(raw_tohlc_data) + + # catch bad (but almost good) symbol + with pytest.raises(ValueError): + raw_tohlc_data = safe_fetch_ohlcv(exch, "ETH-USDT", timeframe, since, limit) + + # it will catch type errors, except for exch. Test an example of this. + with pytest.raises(TypeError): + raw_tohlc_data = safe_fetch_ohlcv(exch, 11, timeframe, since, limit) + with pytest.raises(TypeError): + raw_tohlc_data = safe_fetch_ohlcv(exch, symbol, 11, since, limit) + with pytest.raises(TypeError): + raw_tohlc_data = safe_fetch_ohlcv(exch, symbol, timeframe, "f", limit) + with pytest.raises(TypeError): + raw_tohlc_data = safe_fetch_ohlcv(exch, symbol, timeframe, since, "f") + + # should not crash, just give warning + safe_fetch_ohlcv("bad exch", symbol, timeframe, since, limit) + safe_fetch_ohlcv(exch, "bad symbol", timeframe, since, limit) + safe_fetch_ohlcv(exch, symbol, "bad timeframe", since, limit) + safe_fetch_ohlcv(exch, symbol, timeframe, -5, limit) + safe_fetch_ohlcv(exch, symbol, timeframe, since, -5) + + # ensure a None is returned when warning + v = safe_fetch_ohlcv("bad exch", symbol, timeframe, since, limit) + assert v is None + + +@enforce_types +def assert_raw_tohlc_data_ok(raw_tohlc_data): + assert raw_tohlc_data, raw_tohlc_data + assert isinstance(raw_tohlc_data, list) + for item in raw_tohlc_data: + assert len(item) == (6) + assert isinstance(item[0], int) + for val in item[1:]: + assert isinstance(val, float) diff --git a/pdr_backend/lake/test/test_gql_data_factory.py b/pdr_backend/lake/test/test_gql_data_factory.py new file mode 100644 index 000000000..eb2b39734 --- /dev/null +++ b/pdr_backend/lake/test/test_gql_data_factory.py @@ -0,0 +1,354 @@ +from typing import List +from unittest.mock import patch + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.lake.table_pdr_predictions import predictions_schema +from pdr_backend.lake.test.resources import _gql_data_factory, _filter_gql_config +from pdr_backend.subgraph.subgraph_predictions import FilterMode +from pdr_backend.util.timeutil import timestr_to_ut + +# ==================================================================== +# test parquet updating +pdr_predictions_record = "pdr_predictions" + + +@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql1( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, + "2023-11-02_0:00", + "2023-11-04_21:00", + n_preds=3, + ) + + +@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql2( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, + "2023-11-02_0:00", + "2023-11-06_21:00", + n_preds=5, + ) + + +@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql3( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, + "2023-11-01_0:00", + "2023-11-07_0:00", + n_preds=6, + ) + + +@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql_iteratively( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + + iterations = [ + ("2023-11-02_0:00", "2023-11-04_0:00", 2), + ("2023-11-01_0:00", "2023-11-05_0:00", 3), # do not append to start + ("2023-11-02_0:00", "2023-11-07_0:00", 5), + ] + + for st_timestr, fin_timestr, n_preds in iterations: + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, + st_timestr, + fin_timestr, + n_preds=n_preds, + ) + + +@enforce_types +def _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_predictions, + st_timestr: str, + fin_timestr: str, + n_preds, +): + """ + @arguments + n_preds -- expected # predictions. Typically int. If '>1K', expect >1000 + """ + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + + # Update predictions record only + gql_data_factory.record_config = _filter_gql_config( + gql_data_factory.record_config, pdr_predictions_record + ) + + # setup: filename + # everything will be inside the gql folder + filename = gql_data_factory._parquet_filename(pdr_predictions_record) + assert ".parquet" in filename + + fin_ut = timestr_to_ut(fin_timestr) + st_ut = gql_data_factory._calc_start_ut(filename) + + # calculate ms locally so we can filter raw Predictions + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # filter preds that will be returned from subgraph to client + target_preds = [ + x for x in sample_predictions if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + mock_fetch_filtered_predictions.return_value = target_preds + + # work 1: update parquet + gql_data_factory._update(fin_ut) + + # assert params + mock_fetch_filtered_predictions.assert_called_with( + st_ut_sec, + fin_ut_sec, + ["0x123"], + "mainnet", + FilterMode.CONTRACT_TS, + payout_only=False, + trueval_only=False, + ) + + # read parquet and columns + def _preds_in_parquet(filename: str) -> List[int]: + df = pl.read_parquet(filename) + assert df.schema == predictions_schema + return df["timestamp"].to_list() + + # assert expected length of preds in parquet + preds: List[int] = _preds_in_parquet(filename) + if isinstance(n_preds, int): + assert len(preds) == n_preds + elif n_preds == ">1K": + assert len(preds) > 1000 + + # preds may not match start or end time + assert preds[0] != st_ut + assert preds[-1] != fin_ut + + # assert all target_preds are registered in parquet + target_preds_ts = [pred.__dict__["timestamp"] for pred in target_preds] + for target_pred in target_preds_ts: + assert target_pred * 1000 in preds + + +@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_load_and_verify_schema( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, +): + st_timestr = "2023-11-02_0:00" + fin_timestr = "2023-11-07_0:00" + + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, + st_timestr, + fin_timestr, + n_preds=5, + ) + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + gql_data_factory.record_config = _filter_gql_config( + gql_data_factory.record_config, pdr_predictions_record + ) + + fin_ut = timestr_to_ut(fin_timestr) + gql_dfs = gql_data_factory._load_parquet(fin_ut) + + assert len(gql_dfs) == 1 + assert len(gql_dfs[pdr_predictions_record]) == 5 + assert gql_dfs[pdr_predictions_record].schema == predictions_schema + + +# ==================================================================== +# test if appropriate calls are made + + +@enforce_types +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +@patch("pdr_backend.lake.gql_data_factory.GQLDataFactory._update") +@patch("pdr_backend.lake.gql_data_factory.GQLDataFactory._load_parquet") +def test_get_gql_dfs_calls( + mock_load_parquet, + mock_update, + mock_get_all_contract_ids_by_owner, + tmpdir, + sample_daily_predictions, +): + """Test core DataFactory functions are being called""" + + st_timestr = "2023-11-02_0:00" + fin_timestr = "2023-11-07_0:00" + + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + + # Update predictions record only + default_config = gql_data_factory.record_config + gql_data_factory.record_config = _filter_gql_config( + gql_data_factory.record_config, pdr_predictions_record + ) + + # calculate ms locally so we can filter raw Predictions + st_ut = timestr_to_ut(st_timestr) + fin_ut = timestr_to_ut(fin_timestr) + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # mock_load_parquet should return the values from a simple code block + mock_load_parquet.return_value = { + pdr_predictions_record: pl.DataFrame( + [ + x.__dict__ + for x in sample_daily_predictions + if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + ).with_columns([pl.col("timestamp").mul(1000).alias("timestamp")]) + } + + # call and assert + gql_dfs = gql_data_factory.get_gql_dfs() + assert isinstance(gql_dfs, dict) + assert isinstance(gql_dfs[pdr_predictions_record], pl.DataFrame) + assert len(gql_dfs[pdr_predictions_record]) == 5 + + mock_update.assert_called_once() + mock_load_parquet.assert_called_once() + + # reset record config + gql_data_factory.record_config = default_config + + +# ==================================================================== +# test loading flow when there are pdr files missing + + +@enforce_types +@patch("pdr_backend.lake.table_pdr_predictions.fetch_filtered_predictions") +@patch("pdr_backend.lake.table_pdr_subscriptions.fetch_filtered_subscriptions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_load_missing_parquet( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_subscriptions, + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, +): + """Test core DataFactory functions are being called""" + + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + mock_fetch_filtered_subscriptions.return_value = [] + mock_fetch_filtered_predictions.return_value = [] + + st_timestr = "2023-11-02_0:00" + fin_timestr = "2023-11-04_0:00" + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + + # Work 1: Fetch empty dataset + # (1) perform empty fetch + # (2) do not save to parquet + # (3) handle missing parquet file + # (4) assert we get empty dataframes with the expected schema + dfs = gql_data_factory.get_gql_dfs() + + predictions_table = "pdr_predictions" + subscriptions_table = "pdr_subscriptions" + + assert len(dfs[predictions_table]) == 0 + assert len(dfs[subscriptions_table]) == 0 + + assert ( + dfs[predictions_table].schema + == gql_data_factory.record_config[predictions_table]["schema"] + ) + assert ( + dfs[subscriptions_table].schema + == gql_data_factory.record_config[subscriptions_table]["schema"] + ) + + # Work 2: Fetch 1 dataset + # (1) perform 1 successful datafactory loops (predictions) + # (2) assert subscriptions parquet doesn't exist / has 0 records + _test_update_gql( + mock_fetch_filtered_predictions, + tmpdir, + sample_daily_predictions, + st_timestr, + fin_timestr, + n_preds=2, + ) + + dfs = gql_data_factory.get_gql_dfs() + assert len(dfs[predictions_table]) == 2 + assert len(dfs[subscriptions_table]) == 0 diff --git a/pdr_backend/data_eng/test/test_data_eng_constants.py b/pdr_backend/lake/test/test_lake_constants.py similarity index 81% rename from pdr_backend/data_eng/test/test_data_eng_constants.py rename to pdr_backend/lake/test/test_lake_constants.py index b8785746d..d2619af06 100644 --- a/pdr_backend/data_eng/test/test_data_eng_constants.py +++ b/pdr_backend/lake/test/test_lake_constants.py @@ -1,13 +1,14 @@ -from enforce_typing import enforce_types import numpy as np +from enforce_typing import enforce_types -from pdr_backend.data_eng.constants import ( +from pdr_backend.lake.constants import ( OHLCV_COLS, OHLCV_DTYPES, + OHLCV_MULT_MAX, + OHLCV_MULT_MIN, TOHLCV_COLS, TOHLCV_DTYPES, - OHLCV_MULT_MIN, - OHLCV_MULT_MAX, + TOHLCV_SCHEMA_PL, ) @@ -26,4 +27,7 @@ def test_constants(): assert np.float64 in TOHLCV_DTYPES assert np.int64 in TOHLCV_DTYPES + assert isinstance(TOHLCV_SCHEMA_PL, dict) + assert TOHLCV_COLS[0] in TOHLCV_SCHEMA_PL + assert 0 < OHLCV_MULT_MIN <= OHLCV_MULT_MAX < np.inf diff --git a/pdr_backend/lake/test/test_merge_df.py b/pdr_backend/lake/test/test_merge_df.py new file mode 100644 index 000000000..69b27aac0 --- /dev/null +++ b/pdr_backend/lake/test/test_merge_df.py @@ -0,0 +1,209 @@ +import polars as pl +import pytest +from enforce_typing import enforce_types + +from pdr_backend.lake.merge_df import ( + _add_df_col, + _ordered_cols, + merge_cols, + merge_rawohlcv_dfs, +) +from pdr_backend.lake.test.resources import ( + ETHUSDT_RAWOHLCV_DFS, + RAW_DF1, + RAW_DF2, + RAW_DF3, + RAW_DF4, +) + + +@enforce_types +def test_mergedohlcv_df_shape(): + mergedohlcv_df = merge_rawohlcv_dfs(ETHUSDT_RAWOHLCV_DFS) + assert isinstance(mergedohlcv_df, pl.DataFrame) + assert mergedohlcv_df.columns == [ + "timestamp", + "binanceus:ETH/USDT:open", + "binanceus:ETH/USDT:high", + "binanceus:ETH/USDT:low", + "binanceus:ETH/USDT:close", + "binanceus:ETH/USDT:volume", + ] + assert mergedohlcv_df.shape == (12, 6) + assert len(mergedohlcv_df["timestamp"]) == 12 + assert ( # pylint: disable=unsubscriptable-object + mergedohlcv_df["timestamp"][0] == 1686805500000 + ) + + +@enforce_types +def test_merge_rawohlcv_dfs_equal_dfs(): + raw_dfs = { + "binance": {"BTC/USDT": RAW_DF1, "ETH/USDT": RAW_DF2}, + "kraken": {"BTC/USDT": RAW_DF3, "ETH/USDT": RAW_DF4}, + } + + merged_df = merge_rawohlcv_dfs(raw_dfs) + + assert merged_df.columns == [ + "timestamp", + "binance:BTC/USDT:open", + "binance:BTC/USDT:close", + "binance:ETH/USDT:open", + "binance:ETH/USDT:close", + "kraken:BTC/USDT:open", + "kraken:BTC/USDT:close", + "kraken:ETH/USDT:open", + "kraken:ETH/USDT:close", + ] + assert merged_df["timestamp"][1] == 1 + assert merged_df["binance:BTC/USDT:close"][3] == 11.3 + assert merged_df["kraken:BTC/USDT:close"][3] == 31.3 + assert merged_df["kraken:ETH/USDT:open"][4] == 40.4 + + +@enforce_types +def test_merge_rawohlcv__empty_and_nonempty_df(): + df_btc = pl.DataFrame( + { + "timestamp": [], + "close": [], + } + ) + df_eth = pl.DataFrame( + { + "timestamp": [1, 2, 3], + "close": [5, 6, 7], + } + ) + + merged_df = merge_rawohlcv_dfs({"kraken": {"BTC/USDT": df_btc, "ETH/USDT": df_eth}}) + target_df = pl.DataFrame( + { + "timestamp": [1, 2, 3], + "kraken:BTC/USDT:close": [None, None, None], + "kraken:ETH/USDT:close": [5, 6, 7], + } + ) + assert merged_df.equals(target_df) + + +@enforce_types +def test_merge_rawohlcv__dfs_with_different_timestamps(): + df_eth = pl.DataFrame( + { + "timestamp": [1, 2, 3], + "close": [5, 6, 7], + } + ) + df_dot = pl.DataFrame( + { + "timestamp": [1, 5], + "close": [8, 9], + } + ) + merged_df = merge_rawohlcv_dfs({"kraken": {"ETH/USDT": df_eth, "DOT/USDT": df_dot}}) + target_df = pl.DataFrame( + { + "timestamp": [1, 2, 3, 5], + "kraken:ETH/USDT:close": [5, 6, 7, None], + "kraken:DOT/USDT:close": [8, None, None, 9], + } + ) + assert merged_df.equals(target_df) + + +@enforce_types +def test_add_df_col_unequal_dfs(): + # basic sanity test that floats are floats + assert isinstance(RAW_DF1["close"][1], float) + + # add a first RAW_DF + merged_df = _add_df_col(None, "binance:BTC/USDT:close", RAW_DF1, "close") + assert merged_df.columns == ["timestamp", "binance:BTC/USDT:close"] + assert merged_df.shape == (4, 2) + assert merged_df["timestamp"][1] == 1 + assert merged_df["binance:BTC/USDT:close"][3] == 11.4 + + # add a second RAW_DF + merged_df = _add_df_col(merged_df, "binance:ETH/USDT:open", RAW_DF2, "open") + assert merged_df.columns == [ + "timestamp", + "binance:BTC/USDT:close", + "binance:ETH/USDT:open", + ] + assert merged_df.shape == (5, 3) + assert merged_df["timestamp"][1] == 1 + assert merged_df["binance:BTC/USDT:close"][3] == 11.3 + assert merged_df["binance:ETH/USDT:open"][3] == 20.3 + + +@enforce_types +def test_add_df_col_equal_dfs(): + # basic sanity test that floats are floats + assert isinstance(RAW_DF3["close"][1], float) + + # add a first RAW_DF + merged_df = _add_df_col(None, "kraken:BTC/USDT:close", RAW_DF3, "close") + assert merged_df.columns == [ + "timestamp", + "kraken:BTC/USDT:close", + ] + assert merged_df.shape == (5, 2) + assert merged_df["timestamp"][1] == 1 + assert merged_df["kraken:BTC/USDT:close"][3] == 31.3 + + # add a second RAW_DF + merged_df = _add_df_col(merged_df, "kraken:ETH/USDT:open", RAW_DF4, "open") + assert merged_df.columns == [ + "timestamp", + "kraken:BTC/USDT:close", + "kraken:ETH/USDT:open", + ] + assert merged_df.shape == (5, 3) + assert merged_df["timestamp"][1] == 1 + assert merged_df["kraken:BTC/USDT:close"][3] == 31.3 + assert merged_df["kraken:ETH/USDT:open"][4] == 40.4 + + +@enforce_types +def test_merge_cols(): + df = pl.DataFrame( + { + "a": [1, 2, 3, 4], + "b": [5, 6, 7, None], + "c": [9, 9, 9, 9], + } + ) # None will become pl.Null + + df = df.select(["a", "b", "c"]) + assert df.columns == ["a", "b", "c"] + + # merge b into a + df2 = merge_cols(df, "a", "b") + assert df2.columns == ["a", "c"] + assert df2["a"].to_list() == [1, 2, 3, 4] + + # merge a into b + df3 = merge_cols(df, "b", "a") + assert df3.columns == ["b", "c"] + assert df3["b"].to_list() == [5, 6, 7, 4] # the 4 comes from "a" + + +@enforce_types +def test_ordered_cols(): + assert _ordered_cols(["timestamp"]) == ["timestamp"] + assert _ordered_cols(["a", "c", "b", "timestamp"]) == [ + "timestamp", + "a", + "c", + "b", + ] + + for bad_cols in [ + ["a", "c", "b"], # missing timestamp + ["a", "c", "b", "b", "timestamp"], # duplicates + ["a", "c", "b", "timestamp", "timestamp"], # duplicates + ]: + with pytest.raises(AssertionError): + _ordered_cols(bad_cols) diff --git a/pdr_backend/lake/test/test_ohlcv_data_factory.py b/pdr_backend/lake/test/test_ohlcv_data_factory.py new file mode 100644 index 000000000..589821382 --- /dev/null +++ b/pdr_backend/lake/test/test_ohlcv_data_factory.py @@ -0,0 +1,314 @@ +import os +import time +from typing import List +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types +import numpy as np +import polars as pl +import pytest + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.lake.constants import TOHLCV_SCHEMA_PL +from pdr_backend.lake.merge_df import merge_rawohlcv_dfs +from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory +from pdr_backend.lake.plutil import ( + concat_next_df, + initialize_rawohlcv_df, + load_rawohlcv_file, + save_rawohlcv_file, +) +from pdr_backend.lake.test.resources import _lake_ss_1feed, _lake_ss +from pdr_backend.util.constants import S_PER_MIN +from pdr_backend.util.mathutil import all_nan, has_nan +from pdr_backend.util.timeutil import current_ut_ms, ut_to_timestr + +MS_PER_5M_EPOCH = 300000 + + +# ==================================================================== +# test update of rawohlcv files + + +@enforce_types +@pytest.mark.parametrize( + "st_timestr, fin_timestr, n_uts", + [ + ("2023-01-01_0:00", "2023-01-01_0:00", 1), + ("2023-01-01_0:00", "2023-01-01_0:05", 2), + ("2023-01-01_0:00", "2023-01-01_0:10", 3), + ("2023-01-01_0:00", "2023-01-01_0:45", 10), + ("2023-01-01", "2023-06-21", ">1K"), + ], +) +def test_update_rawohlcv_files(st_timestr: str, fin_timestr: str, n_uts, tmpdir): + """ + @arguments + n_uts -- expected # timestamps. Typically int. If '>1K', expect >1000 + """ + + # setup: uts helpers + def _calc_ut(since: int, i: int) -> int: + """Return a ut : unix time, in ms, in UTC time zone""" + return since + i * MS_PER_5M_EPOCH + + def _uts_in_range(st_ut: int, fin_ut: int, limit_N=100000) -> List[int]: + return [ + _calc_ut(st_ut, i) for i in range(limit_N) if _calc_ut(st_ut, i) <= fin_ut + ] + + # setup: exchange + class FakeExchange: + def __init__(self): + self.cur_ut: int = current_ut_ms() # fixed value, for easier testing + + # pylint: disable=unused-argument + def fetch_ohlcv(self, since, limit, *args, **kwargs) -> list: + uts: List[int] = _uts_in_range(since, self.cur_ut, limit) + return [[ut] + [1.0] * 5 for ut in uts] # 1.0 for open, high, .. + + ss, factory = _lake_ss_1feed( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + + # setup: filename + # it's ok for input pair_str to have '/' or '-', it handles it + # but the output filename should not have '/' its pairstr part + feed = ArgFeed("binanceus", None, "ETH/USDT", "5m") + filename = factory._rawohlcv_filename(feed) + feed = ArgFeed("binanceus", None, "ETH/USDT", "5m") + filename2 = factory._rawohlcv_filename(feed) + assert filename == filename2 + assert "ETH-USDT" in filename and "ETH/USDT" not in filename + + # work 1: new rawohlcv file + feed = ArgFeed("binanceus", None, "ETH/USDT", "5m") + with patch("pdr_backend.cli.arg_exchange.ArgExchange.exchange_class") as mock: + mock.return_value = FakeExchange() + factory._update_rawohlcv_files_at_feed(feed, ss.fin_timestamp) + + def _uts_in_rawohlcv_file(filename: str) -> List[int]: + df = load_rawohlcv_file(filename) + return df["timestamp"].to_list() + + uts: List[int] = _uts_in_rawohlcv_file(filename) + if isinstance(n_uts, int): + assert len(uts) == n_uts + elif n_uts == ">1K": + assert len(uts) > 1000 + assert sorted(uts) == uts + assert uts[0] == ss.st_timestamp + assert uts[-1] == ss.fin_timestamp + assert uts == _uts_in_range(ss.st_timestamp, ss.fin_timestamp) + + # work 2: two more epochs at end --> it'll append existing file + ss.d["fin_timestr"] = ut_to_timestr(ss.fin_timestamp + 2 * MS_PER_5M_EPOCH) + with patch("pdr_backend.cli.arg_exchange.ArgExchange.exchange_class") as mock: + mock.return_value = FakeExchange() + factory._update_rawohlcv_files_at_feed(feed, ss.fin_timestamp) + uts2 = _uts_in_rawohlcv_file(filename) + assert uts2 == _uts_in_range(ss.st_timestamp, ss.fin_timestamp) + + # work 3: two more epochs at beginning *and* end --> it'll create new file + ss.d["st_timestr"] = ut_to_timestr(ss.st_timestamp - 2 * MS_PER_5M_EPOCH) + ss.d["fin_timestr"] = ut_to_timestr(ss.fin_timestamp + 4 * MS_PER_5M_EPOCH) + with patch("pdr_backend.cli.arg_exchange.ArgExchange.exchange_class") as mock: + mock.return_value = FakeExchange() + factory._update_rawohlcv_files_at_feed(feed, ss.fin_timestamp) + uts3 = _uts_in_rawohlcv_file(filename) + assert uts3 == _uts_in_range(ss.st_timestamp, ss.fin_timestamp) + + +# ==================================================================== +# test behavior of get_mergedohlcv_df() + + +@enforce_types +def test_get_mergedohlcv_df_happypath(tmpdir): + """Is get_mergedohlcv_df() executing e2e correctly? + Includes actual calls to the exchange API, eg binance or kraken, via ccxt. + + It may fail if the exchange is temporarily misbehaving, which + shows up as a FileNotFoundError. + So give it a few tries if needed. + """ + n_tries = 5 + for try_i in range(n_tries - 1): + try: + _test_get_mergedohlcv_df_happypath_onetry(tmpdir) + return # success + + except FileNotFoundError: + print(f"test_get_mergedohlcv_df_happypath try #{try_i+1}, file not found") + time.sleep(2) + + # last chance + _test_get_mergedohlcv_df_happypath_onetry(tmpdir) + + +@enforce_types +def _test_get_mergedohlcv_df_happypath_onetry(tmpdir): + parquet_dir = str(tmpdir) + + ss = _lake_ss( + parquet_dir, + ["binanceus BTC-USDT,ETH/USDT h 5m", "kraken BTC/USDT h 5m"], + st_timestr="2023-06-18", + fin_timestr="2023-06-19", + ) + factory = OhlcvDataFactory(ss) + + # call and assert + mergedohlcv_df = factory.get_mergedohlcv_df() + + # 289 records created + assert len(mergedohlcv_df) == 289 + + # binanceus is returning valid data + assert not has_nan(mergedohlcv_df["binanceus:BTC/USDT:high"]) + assert not has_nan(mergedohlcv_df["binanceus:ETH/USDT:high"]) + + # kraken is returning nans + assert has_nan(mergedohlcv_df["kraken:BTC/USDT:high"]) + + # assert head is oldest + head_timestamp = mergedohlcv_df.head(1)["timestamp"].to_list()[0] + tail_timestamp = mergedohlcv_df.tail(1)["timestamp"].to_list()[0] + assert head_timestamp < tail_timestamp + + +@enforce_types +def test_mergedohlcv_df__low_vs_high_level__1_no_nan(tmpdir): + _test_mergedohlcv_df__low_vs_high_level(tmpdir, ohlcv_val=12.1) + + +@enforce_types +def test_mergedohlcv_df__low_vs_high_level__2_all_nan(tmpdir): + _test_mergedohlcv_df__low_vs_high_level(tmpdir, ohlcv_val=np.nan) + + +@enforce_types +def _test_mergedohlcv_df__low_vs_high_level(tmpdir, ohlcv_val): + """Does high-level behavior of mergedohlcv_df() align with low-level implement'n? + Should work whether no nans, or all nans (as set by ohlcv_val) + """ + + # setup + _, factory = _lake_ss_1feed(tmpdir, "binanceus BTC/USDT h 5m") + filename = factory._rawohlcv_filename( + ArgFeed("binanceus", "high", "BTC/USDT", "5m") + ) + st_ut = factory.ss.st_timestamp + fin_ut = factory.ss.fin_timestamp + + # mock + n_pts = 20 + + def mock_update(*args, **kwargs): # pylint: disable=unused-argument + s_per_epoch = S_PER_MIN * 5 + raw_tohlcv_data = [ + [st_ut + s_per_epoch * i] + [ohlcv_val] * 5 for i in range(n_pts) + ] + df = initialize_rawohlcv_df() + next_df = pl.DataFrame(raw_tohlcv_data, schema=TOHLCV_SCHEMA_PL) + df = concat_next_df(df, next_df) + save_rawohlcv_file(filename, df) + + factory._update_rawohlcv_files_at_feed = mock_update + + # test 1: get mergedohlcv_df via several low-level instrs, as get_mergedohlcv_df() does + factory._update_rawohlcv_files(fin_ut) + assert os.path.getsize(filename) > 500 + + df0 = pl.read_parquet(filename, columns=["high"]) + df1 = load_rawohlcv_file(filename, ["high"], st_ut, fin_ut) + rawohlcv_dfs = ( # pylint: disable=assignment-from-no-return + factory._load_rawohlcv_files(fin_ut) + ) + mergedohlcv_df = merge_rawohlcv_dfs(rawohlcv_dfs) + + assert len(df0) == len(df1) == len(df1["high"]) == len(mergedohlcv_df) == n_pts + if np.isnan(ohlcv_val): + assert all_nan(df0) + assert all_nan(df1["high"]) + assert all_nan(mergedohlcv_df["binanceus:BTC/USDT:high"]) + else: + assert not has_nan(df0) + assert not has_nan(df1["high"]) + assert not has_nan(mergedohlcv_df["binanceus:BTC/USDT:high"]) + + # cleanup for test 2 + os.remove(filename) + + # test 2: get mergedohlcv_df via a single high-level instr + mergedohlcv_df = factory.get_mergedohlcv_df() + assert os.path.getsize(filename) > 500 + assert len(mergedohlcv_df) == n_pts + if np.isnan(ohlcv_val): + assert all_nan(mergedohlcv_df["binanceus:BTC/USDT:high"]) + else: + assert not has_nan(mergedohlcv_df["binanceus:BTC/USDT:high"]) + + +@enforce_types +def test_exchange_hist_overlap(tmpdir): + """DataFactory get_mergedohlcv_df() and concat is executing e2e correctly""" + _, factory = _lake_ss_1feed( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr="2023-06-18", + fin_timestr="2023-06-19", + ) + + # call and assert + mergedohlcv_df = factory.get_mergedohlcv_df() + + # 289 records created + assert len(mergedohlcv_df) == 289 + + # assert head is oldest + head_timestamp = mergedohlcv_df.head(1)["timestamp"].to_list()[0] + tail_timestamp = mergedohlcv_df.tail(1)["timestamp"].to_list()[0] + assert head_timestamp < tail_timestamp + + # let's get more data from exchange with overlap + _, factory2 = _lake_ss_1feed( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr="2023-06-18", # same + fin_timestr="2023-06-20", # different + ) + mergedohlcv_df2 = factory2.get_mergedohlcv_df() + + # assert on expected values + # another 288 records appended + # head (index = 0) still points to oldest date with tail (index = n) being the latest date + assert len(mergedohlcv_df2) == 289 + 288 == 577 + assert ( + mergedohlcv_df2.head(1)["timestamp"].to_list()[0] + < mergedohlcv_df2.tail(1)["timestamp"].to_list()[0] + ) + + +@enforce_types +@patch("pdr_backend.lake.ohlcv_data_factory.merge_rawohlcv_dfs") +def test_get_mergedohlcv_df_calls( + mock_merge_rawohlcv_dfs, + tmpdir, +): + mock_merge_rawohlcv_dfs.return_value = Mock(spec=pl.DataFrame) + _, factory = _lake_ss_1feed(tmpdir, "binanceus ETH/USDT h 5m") + + factory._update_rawohlcv_files = Mock(return_value=None) + factory._load_rawohlcv_files = Mock(return_value=None) + + mergedohlcv_df = factory.get_mergedohlcv_df() + + assert isinstance(mergedohlcv_df, pl.DataFrame) + + factory._update_rawohlcv_files.assert_called() + factory._load_rawohlcv_files.assert_called() + mock_merge_rawohlcv_dfs.assert_called() diff --git a/pdr_backend/lake/test/test_plutil.py b/pdr_backend/lake/test/test_plutil.py new file mode 100644 index 000000000..97cbdbb5f --- /dev/null +++ b/pdr_backend/lake/test/test_plutil.py @@ -0,0 +1,306 @@ +import os + +import numpy as np +import polars as pl +import pytest +from enforce_typing import enforce_types + +from pdr_backend.lake.constants import ( + OHLCV_COLS, + OHLCV_DTYPES_PL, + TOHLCV_COLS, + TOHLCV_DTYPES_PL, +) +from pdr_backend.lake.plutil import ( + _get_tail_df, + concat_next_df, + has_data, + initialize_rawohlcv_df, + load_rawohlcv_file, + newest_ut, + oldest_ut, + save_rawohlcv_file, + set_col_values, + text_to_df, +) + +FOUR_ROWS_RAW_TOHLCV_DATA = [ + [1686806100000, 1648.58, 1648.58, 1646.27, 1646.64, 7.4045], + [1686806400000, 1647.05, 1647.05, 1644.61, 1644.86, 14.452], + [1686806700000, 1644.57, 1646.41, 1642.49, 1645.81, 22.8612], + [1686807000000, 1645.77, 1646.2, 1645.23, 1646.05, 8.1741], +] + + +ONE_ROW_RAW_TOHLCV_DATA = [[1686807300000, 1646, 1647.2, 1646.23, 1647.05, 8.1742]] + + +@enforce_types +def test_initialize_rawohlcv_df(): + df = initialize_rawohlcv_df() + assert isinstance(df, pl.DataFrame) + assert list(df.schema.values()) == TOHLCV_DTYPES_PL + + _assert_TOHLCVd_cols_and_types(df) + + # test with just 2 cols + df = initialize_rawohlcv_df(OHLCV_COLS[:2]) + assert df.columns == OHLCV_COLS[:2] + assert list(df.schema.values())[:2] == OHLCV_DTYPES_PL[:2] + + # test with just ut + 2 cols + df = initialize_rawohlcv_df(TOHLCV_COLS[:3]) + assert df.columns == TOHLCV_COLS[:3] + assert list(df.schema.values()) == TOHLCV_DTYPES_PL[:3] + + +@enforce_types +def test_set_col_values(): + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + } + ) + + df2 = set_col_values(df, "a", [7, 8, 9]) + assert df2["a"].to_list() == [7, 8, 9] + + df2 = set_col_values(df, "a", [7.1, 8.1, 9.1]) + assert df2["a"].to_list() == [7.1, 8.1, 9.1] + + with pytest.raises(pl.exceptions.ShapeError): + set_col_values(df, "a", [7, 8]) + + +@enforce_types +def test_concat_next_df(): + # baseline data + df = initialize_rawohlcv_df(TOHLCV_COLS) + assert len(df) == 0 + + cand_dtypes = dict(zip(TOHLCV_COLS, TOHLCV_DTYPES_PL)) + schema = {col: cand_dtypes[col] for col in TOHLCV_COLS} + + next_df = pl.DataFrame(FOUR_ROWS_RAW_TOHLCV_DATA, schema=schema) + assert len(next_df) == 4 + + # add 4 rows to empty df + df = concat_next_df(df, next_df) + assert len(df) == 4 + + _assert_TOHLCVd_cols_and_types(df) + + # assert 1 more row + next_df = pl.DataFrame(ONE_ROW_RAW_TOHLCV_DATA, schema=schema) + assert len(next_df) == 1 + + # assert that concat verifies schemas match + next_df = pl.DataFrame(ONE_ROW_RAW_TOHLCV_DATA, schema=schema) + assert len(next_df) == 1 + assert "datetime" not in next_df.columns + + # add 1 row to existing 4 rows + df = concat_next_df(df, next_df) + assert len(df) == 4 + 1 + + # assert concat order + assert df.head(1)["timestamp"].to_list()[0] == FOUR_ROWS_RAW_TOHLCV_DATA[0][0] + assert df.tail(1)["timestamp"].to_list()[0] == ONE_ROW_RAW_TOHLCV_DATA[0][0] + _assert_TOHLCVd_cols_and_types(df) + + +@enforce_types +def _assert_TOHLCVd_cols_and_types(df: pl.DataFrame): + assert df.columns == TOHLCV_COLS + assert list(df.schema.values()) == TOHLCV_DTYPES_PL + assert "timestamp" in df.columns and df.schema["timestamp"] == pl.Int64 + + +def _filename(tmpdir) -> str: + return os.path.join(tmpdir, "foo.csv") + + +@enforce_types +def test_load_basic(tmpdir): + filename = _filename(tmpdir) + df = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) + save_rawohlcv_file(filename, df) + + # simplest specification. Don't specify cols, st or fin + df2 = load_rawohlcv_file(filename) + _assert_TOHLCVd_cols_and_types(df2) + assert len(df2) == 4 and str(df) == str(df2) + + # explicitly specify cols, but not st or fin + df2 = load_rawohlcv_file(filename, OHLCV_COLS) + _assert_TOHLCVd_cols_and_types(df2) + assert len(df2) == 4 and str(df) == str(df2) + + # explicitly specify cols, st, fin + df2 = load_rawohlcv_file(filename, OHLCV_COLS, st=None, fin=None) + _assert_TOHLCVd_cols_and_types(df2) + assert len(df2) == 4 and str(df) == str(df2) + + df2 = load_rawohlcv_file(filename, OHLCV_COLS, st=0, fin=np.inf) + _assert_TOHLCVd_cols_and_types(df2) + assert len(df2) == 4 and str(df) == str(df2) + + +@enforce_types +def test_load_append(tmpdir): + # save 4-row parquet to new file + filename = _filename(tmpdir) + df_4_rows = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) + save_rawohlcv_file(filename, df_4_rows) + + # append 1 row to existing file + df_1_row = _df_from_raw_data(ONE_ROW_RAW_TOHLCV_DATA) + save_rawohlcv_file(filename, df_1_row) + + # verify: doing a manual concat is the same as the load + schema = dict(zip(TOHLCV_COLS, TOHLCV_DTYPES_PL)) + df_1_row = pl.DataFrame(ONE_ROW_RAW_TOHLCV_DATA, schema=schema) + df_5_rows = concat_next_df(df_4_rows, df_1_row) + df_5_rows_loaded = load_rawohlcv_file(filename) + + _assert_TOHLCVd_cols_and_types(df_5_rows_loaded) + + assert len(df_5_rows_loaded) == 5 + assert str(df_5_rows) == str(df_5_rows_loaded) + + +@enforce_types +def test_load_filtered(tmpdir): + # save + filename = _filename(tmpdir) + df = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) + save_rawohlcv_file(filename, df) + + # load with filters on rows & columns + cols = OHLCV_COLS[:2] # ["open", "high"] + timestamps = [row[0] for row in FOUR_ROWS_RAW_TOHLCV_DATA] + st = timestamps[1] # 1686806400000 + fin = timestamps[2] # 1686806700000 + df2 = load_rawohlcv_file(filename, cols, st, fin) + + # test entries + assert len(df2) == 2 + assert "timestamp" in df2.columns + assert len(df2["timestamp"]) == 2 + assert df2["timestamp"].to_list() == timestamps[1:3] + + # test cols and types + assert df2["timestamp"].dtype == pl.Int64 + assert list(df2.columns) == TOHLCV_COLS[:3] + assert list(df2.schema.values()) == TOHLCV_DTYPES_PL[:3] + + +@enforce_types +def _df_from_raw_data(raw_data: list) -> pl.DataFrame: + df = initialize_rawohlcv_df(TOHLCV_COLS) + + schema = dict(zip(TOHLCV_COLS, TOHLCV_DTYPES_PL)) + next_df = pl.DataFrame(raw_data, schema=schema) + + df = concat_next_df(df, next_df) + return df + + +@enforce_types +def test_has_data(tmpdir): + filename0 = os.path.join(tmpdir, "f0.parquet") + save_rawohlcv_file(filename0, _df_from_raw_data([])) + assert not has_data(filename0) + + filename1 = os.path.join(tmpdir, "f1.parquet") + save_rawohlcv_file(filename1, _df_from_raw_data(ONE_ROW_RAW_TOHLCV_DATA)) + assert has_data(filename1) + + filename4 = os.path.join(tmpdir, "f4.parquet") + save_rawohlcv_file(filename4, _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA)) + assert has_data(filename4) + + +@enforce_types +def test_oldest_ut_and_newest_ut__with_data(tmpdir): + filename = _filename(tmpdir) + + # write out four rows + df = _df_from_raw_data(FOUR_ROWS_RAW_TOHLCV_DATA) + save_rawohlcv_file(filename, df) + + # assert head == oldest and tail == latest + ut0 = oldest_ut(filename) + utN = newest_ut(filename) + + assert ut0 == FOUR_ROWS_RAW_TOHLCV_DATA[0][0] + assert utN == FOUR_ROWS_RAW_TOHLCV_DATA[-1][0] + + # append and check newest/oldest + df = _df_from_raw_data(ONE_ROW_RAW_TOHLCV_DATA) + save_rawohlcv_file(filename, df) + + ut0 = oldest_ut(filename) + utN = newest_ut(filename) + + assert ut0 == FOUR_ROWS_RAW_TOHLCV_DATA[0][0] + assert utN == ONE_ROW_RAW_TOHLCV_DATA[0][0] + + +@enforce_types +def test_oldest_ut_and_newest_ut__no_data(tmpdir): + filename = _filename(tmpdir) + df = _df_from_raw_data([]) + save_rawohlcv_file(filename, df) + + with pytest.raises(ValueError): + oldest_ut(filename) + with pytest.raises(ValueError): + newest_ut(filename) + + +@enforce_types +def test_parquet_tail_records(tmpdir): + df = pl.DataFrame( + { + "timestamp": [0, 1, 2, 3], + "open": [100, 101, 102, 103], + "high": [100, 101, 102, 103], + "low": [100, 101, 102, 103], + "close": [100, 101, 102, 103], + "volume": [100, 101, 102, 103], + } + ) + + filename = os.path.join(tmpdir, "foo.parquet") + df.write_parquet(filename) + + target_tail_df = pl.DataFrame( + { + "timestamp": [3], + "open": [103], + "high": [103], + "low": [103], + "close": [103], + "volume": [103], + } + ) + + tail_df = _get_tail_df(filename, n=1) + assert tail_df.equals(target_tail_df) + + +@enforce_types +def test_text_to_df(): + df = text_to_df( + """timestamp|open|close +0|10.0|11.0 +1|10.1|11.1 +""" + ) + assert df.columns == ["timestamp", "open", "close"] + assert df.shape == (2, 3) + assert df["timestamp"][0] == 0 + assert df["open"][1] == 10.1 + assert isinstance(df["open"][1], float) diff --git a/pdr_backend/lake/test/test_table_subscriptions.py b/pdr_backend/lake/test/test_table_subscriptions.py new file mode 100644 index 000000000..d2e82d5e5 --- /dev/null +++ b/pdr_backend/lake/test/test_table_subscriptions.py @@ -0,0 +1,182 @@ +from typing import List +from unittest.mock import patch + +import polars as pl +from enforce_typing import enforce_types + +from pdr_backend.lake.table_pdr_subscriptions import subscriptions_schema +from pdr_backend.lake.test.resources import _gql_data_factory, _filter_gql_config +from pdr_backend.util.timeutil import timestr_to_ut + +# ==================================================================== +pdr_subscriptions_record = "pdr_subscriptions" + + +@patch("pdr_backend.lake.table_pdr_subscriptions.fetch_filtered_subscriptions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql1( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + _test_update_gql( + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, + "2023-11-02_0:00", + "2023-11-04_17:00", + n_subs=4, + ) + + +@patch("pdr_backend.lake.table_pdr_subscriptions.fetch_filtered_subscriptions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_update_gql_iteratively( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + iterations = [ + ("2023-11-02_0:00", "2023-11-04_17:00", 4), + ("2023-11-01_0:00", "2023-11-04_17:00", 4), # does not append to beginning + ("2023-11-01_0:00", "2023-11-05_17:00", 6), + ("2023-11-01_0:00", "2023-11-06_17:00", 7), + ] + + for st_timestr, fin_timestr, n_subs in iterations: + _test_update_gql( + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, + st_timestr, + fin_timestr, + n_subs=n_subs, + ) + + +@enforce_types +def _test_update_gql( + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, + st_timestr: str, + fin_timestr: str, + n_subs, +): + """ + @arguments + n_subs -- expected # subscriptions. Typically int. If '>1K', expect >1000 + """ + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + + # Update subscriptions record only + default_config = gql_data_factory.record_config + gql_data_factory.record_config = _filter_gql_config( + gql_data_factory.record_config, pdr_subscriptions_record + ) + + # setup: filename + # everything will be inside the gql folder + filename = gql_data_factory._parquet_filename(pdr_subscriptions_record) + assert ".parquet" in filename + + fin_ut = timestr_to_ut(fin_timestr) + st_ut = gql_data_factory._calc_start_ut(filename) + + # calculate ms locally so we can filter raw subscriptions + st_ut_sec = st_ut // 1000 + fin_ut_sec = fin_ut // 1000 + + # filter subs that will be returned from subgraph to client + target_subs = [ + x for x in sample_subscriptions if st_ut_sec <= x.timestamp <= fin_ut_sec + ] + mock_fetch_filtered_subscriptions.return_value = target_subs + + # work 1: update parquet + gql_data_factory._update(fin_ut) + + # assert params + mock_fetch_filtered_subscriptions.assert_called_with( + st_ut_sec, + fin_ut_sec, + ["0x123"], + "mainnet", + ) + + # read parquet and columns + def _subs_in_parquet(filename: str) -> List[int]: + df = pl.read_parquet(filename) + assert df.schema == subscriptions_schema + return df["timestamp"].to_list() + + # assert expected length of subs in parquet + subs: List[int] = _subs_in_parquet(filename) + if isinstance(n_subs, int): + assert len(subs) == n_subs + elif n_subs == ">1K": + assert len(subs) > 1000 + + # subs may not match start or end time + assert subs[0] != st_ut + assert subs[-1] != fin_ut + + # assert all target_subs are registered in parquet + target_subs_ts = [sub.__dict__["timestamp"] for sub in target_subs] + for target_sub in target_subs_ts: + assert target_sub * 1000 in subs + + # reset record config + gql_data_factory.record_config = default_config + + +@patch("pdr_backend.lake.table_pdr_subscriptions.fetch_filtered_subscriptions") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +def test_load_and_verify_schema( + mock_get_all_contract_ids_by_owner, + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, +): + mock_get_all_contract_ids_by_owner.return_value = ["0x123"] + st_timestr = "2023-11-01_0:00" + fin_timestr = "2023-11-07_0:00" + + _test_update_gql( + mock_fetch_filtered_subscriptions, + tmpdir, + sample_subscriptions, + st_timestr, + fin_timestr, + n_subs=8, + ) + + _, gql_data_factory = _gql_data_factory( + tmpdir, + "binanceus ETH/USDT h 5m", + st_timestr, + fin_timestr, + ) + + # Update subscriptions record only + gql_data_factory.record_config = _filter_gql_config( + gql_data_factory.record_config, pdr_subscriptions_record + ) + + fin_ut = timestr_to_ut(fin_timestr) + gql_dfs = gql_data_factory._load_parquet(fin_ut) + + assert len(gql_dfs) == 1 + assert len(gql_dfs[pdr_subscriptions_record]) == 8 + assert round(gql_dfs[pdr_subscriptions_record]["last_price_value"].sum(), 2) == 24.0 + assert gql_dfs[pdr_subscriptions_record].schema == subscriptions_schema diff --git a/pdr_backend/data_eng/test/test_timeblock.py b/pdr_backend/lake/test/test_timeblock.py similarity index 91% rename from pdr_backend/data_eng/test/test_timeblock.py rename to pdr_backend/lake/test/test_timeblock.py index c74edd203..dc57bd3e1 100644 --- a/pdr_backend/data_eng/test/test_timeblock.py +++ b/pdr_backend/lake/test/test_timeblock.py @@ -1,8 +1,7 @@ -from enforce_typing import enforce_types - import numpy as np +from enforce_typing import enforce_types -from pdr_backend.data_eng import timeblock +from pdr_backend.lake import timeblock @enforce_types diff --git a/pdr_backend/data_eng/timeblock.py b/pdr_backend/lake/timeblock.py similarity index 100% rename from pdr_backend/data_eng/timeblock.py rename to pdr_backend/lake/timeblock.py index 5e4b0b61d..fbcbac55d 100644 --- a/pdr_backend/data_eng/timeblock.py +++ b/pdr_backend/lake/timeblock.py @@ -1,6 +1,6 @@ -from enforce_typing import enforce_types import numpy as np import pandas as pd +from enforce_typing import enforce_types @enforce_types diff --git a/pdr_backend/model_eng/model_ss.py b/pdr_backend/model_eng/model_ss.py deleted file mode 100644 index b92062bc9..000000000 --- a/pdr_backend/model_eng/model_ss.py +++ /dev/null @@ -1,14 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.util.strutil import StrMixin - -APPROACHES = ["LIN", "GPR", "SVR", "NuSVR", "LinearSVR"] - - -@enforce_types -class ModelSS(StrMixin): - def __init__(self, model_approach: str): - if model_approach not in APPROACHES: - raise ValueError(model_approach) - - self.model_approach = model_approach diff --git a/pdr_backend/model_eng/test/conftest.py b/pdr_backend/model_eng/test/conftest.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pdr_backend/model_eng/test/test_model_factory.py b/pdr_backend/model_eng/test/test_model_factory.py deleted file mode 100644 index 7ddfd34e3..000000000 --- a/pdr_backend/model_eng/test/test_model_factory.py +++ /dev/null @@ -1,35 +0,0 @@ -import warnings - -from enforce_typing import enforce_types -import numpy as np - -from pdr_backend.model_eng.model_factory import ModelFactory -from pdr_backend.model_eng.model_ss import APPROACHES, ModelSS - - -@enforce_types -def test_model_factory(): - for approach in APPROACHES: - model_ss = ModelSS(approach) - factory = ModelFactory(model_ss) - assert isinstance(factory.model_ss, ModelSS) - - X, y, Xtest = _data() - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # ignore ConvergenceWarning, more - model = factory.build(X, y) - - ytest = model.predict(Xtest) - assert len(ytest) == 1 - - -def _data(): - X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) - - # y = 1 * x_0 + 2 * x_1 + 3 - y = np.dot(X, np.array([1, 2])) + 3 - - Xtest = np.array([[3, 5]]) - - return X, y, Xtest diff --git a/pdr_backend/model_eng/test/test_model_ss.py b/pdr_backend/model_eng/test/test_model_ss.py deleted file mode 100644 index a50b349ed..000000000 --- a/pdr_backend/model_eng/test/test_model_ss.py +++ /dev/null @@ -1,23 +0,0 @@ -from enforce_typing import enforce_types -import pytest - -from pdr_backend.model_eng.model_ss import APPROACHES, ModelSS - - -@enforce_types -def test_model_ss1(): - ss = ModelSS("LIN") - assert ss.model_approach == "LIN" - - assert "ModelSS" in str(ss) - assert "model_approach" in str(ss) - - -@enforce_types -def test_model_ss2(): - for approach in APPROACHES: - ss = ModelSS(approach) - assert approach in str(ss) - - with pytest.raises(ValueError): - ModelSS("foo_approach") diff --git a/pdr_backend/models/base_config.py b/pdr_backend/models/base_config.py deleted file mode 100644 index 5e9414c12..000000000 --- a/pdr_backend/models/base_config.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Dict, List - -from enforce_typing import enforce_types - -from pdr_backend.models.feed import dictToFeed, Feed -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.slot import Slot -from pdr_backend.util.env import getenv_or_exit, parse_filters -from pdr_backend.util.strutil import StrMixin -from pdr_backend.util.subgraph import get_pending_slots, query_feed_contracts -from pdr_backend.util.web3_config import Web3Config - - -@enforce_types -class BaseConfig(StrMixin): - # pylint: disable=too-many-instance-attributes - def __init__(self): - self.rpc_url: str = getenv_or_exit("RPC_URL") # type: ignore - self.subgraph_url: str = getenv_or_exit("SUBGRAPH_URL") # type: ignore - self.private_key: str = getenv_or_exit("PRIVATE_KEY") # type: ignore - - (f0, f1, f2, f3) = parse_filters() - self.pair_filters: [List[str]] = f0 - self.timeframe_filter: [List[str]] = f1 - self.source_filter: [List[str]] = f2 - self.owner_addresses: [List[str]] = f3 - - self.web3_config = Web3Config(self.rpc_url, self.private_key) - - def get_pending_slots(self, timestamp: int) -> List[Slot]: - return get_pending_slots( - self.subgraph_url, - timestamp, - self.owner_addresses, - self.pair_filters, - self.timeframe_filter, - self.source_filter, - ) - - def get_feeds(self) -> Dict[str, Feed]: - """Return dict of [feed_addr] : {"name":.., "pair":.., ..}""" - feed_dicts = query_feed_contracts( - self.subgraph_url, - ",".join(self.pair_filters), - ",".join(self.timeframe_filter), - ",".join(self.source_filter), - ",".join(self.owner_addresses), - ) - feeds = {addr: dictToFeed(feed_dict) for addr, feed_dict in feed_dicts.items()} - return feeds - - def get_contracts(self, feed_addrs: List[str]) -> Dict[str, PredictoorContract]: - """Return dict of [feed_addr] : PredictoorContract}""" - contracts = {} - for addr in feed_addrs: - contracts[addr] = PredictoorContract(self.web3_config, addr) - return contracts diff --git a/pdr_backend/models/base_contract.py b/pdr_backend/models/base_contract.py deleted file mode 100644 index 6a580cb9c..000000000 --- a/pdr_backend/models/base_contract.py +++ /dev/null @@ -1,17 +0,0 @@ -from abc import ABC -from enforce_typing import enforce_types - -from pdr_backend.util.contract import get_contract_abi -from pdr_backend.util.web3_config import Web3Config - - -@enforce_types -class BaseContract(ABC): - def __init__(self, config: Web3Config, address: str, name: str): - super().__init__() - self.config = config - self.contract_address = config.w3.to_checksum_address(address) - self.contract_instance = config.w3.eth.contract( - address=config.w3.to_checksum_address(address), - abi=get_contract_abi(name), - ) diff --git a/pdr_backend/models/feed.py b/pdr_backend/models/feed.py deleted file mode 100644 index 519a80684..000000000 --- a/pdr_backend/models/feed.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any, Dict - -from enforce_typing import enforce_types - -from pdr_backend.util.pairstr import unpack_pair_str -from pdr_backend.util.strutil import StrMixin - - -class Feed(StrMixin): # pylint: disable=too-many-instance-attributes - @enforce_types - def __init__( - self, - name: str, - address: str, - symbol: str, - seconds_per_epoch: int, - seconds_per_subscription: int, - trueval_submit_timeout: int, - owner: str, - pair: str, - timeframe: str, - source: str, - ): - self.name = name - self.address = address - self.symbol = symbol - self.seconds_per_epoch = seconds_per_epoch - self.seconds_per_subscription = seconds_per_subscription - self.trueval_submit_timeout = trueval_submit_timeout - self.owner = owner - self.pair = pair - self.timeframe = timeframe - self.source = source - - @property - def base(self): - return unpack_pair_str(self.pair)[0] - - @property - def quote(self): - return unpack_pair_str(self.pair)[1] - - @enforce_types - def shortstr(self) -> str: - return ( - f"[Feed {self.address[:7]} {self.pair}" f"|{self.source}|{self.timeframe}]" - ) - - @enforce_types - def __str__(self) -> str: - return self.shortstr() - - -@enforce_types -def dictToFeed(feed_dict: Dict[str, Any]): - """ - @description - Convert a feed_dict into Feed format - - @arguments - feed_dict -- dict with values for "name", "address", etc - - @return - feed -- Feed - """ - d = feed_dict - feed = Feed( - name=d["name"], - address=d["address"], - symbol=d["symbol"], - seconds_per_epoch=int(d["seconds_per_epoch"]), - seconds_per_subscription=int(d["seconds_per_subscription"]), - trueval_submit_timeout=int(d["trueval_submit_timeout"]), - owner=d["owner"], - pair=d["pair"], - timeframe=d["timeframe"], - source=d["source"], - ) - return feed diff --git a/pdr_backend/models/fixed_rate.py b/pdr_backend/models/fixed_rate.py deleted file mode 100644 index 9ccc9340e..000000000 --- a/pdr_backend/models/fixed_rate.py +++ /dev/null @@ -1,15 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.models.base_contract import BaseContract -from pdr_backend.util.web3_config import Web3Config - - -@enforce_types -class FixedRate(BaseContract): - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address, "FixedRateExchange") - - def get_dt_price(self, exchangeId): - return self.contract_instance.functions.calcBaseInGivenOutDT( - exchangeId, self.config.w3.to_wei("1", "ether"), 0 - ).call() diff --git a/pdr_backend/models/prediction.py b/pdr_backend/models/prediction.py deleted file mode 100644 index a6d1d1a3a..000000000 --- a/pdr_backend/models/prediction.py +++ /dev/null @@ -1,23 +0,0 @@ -class Prediction: - # pylint: disable=too-many-instance-attributes - def __init__( - self, - pair, - timeframe, - prediction, - stake, - trueval, - timestamp, - source, - payout, - user, - ) -> None: - self.pair = pair - self.timeframe = timeframe - self.prediction = prediction - self.stake = stake - self.trueval = trueval - self.timestamp = timestamp - self.source = source - self.payout = payout - self.user = user diff --git a/pdr_backend/models/predictoor_batcher.py b/pdr_backend/models/predictoor_batcher.py deleted file mode 100644 index f3b16b0d4..000000000 --- a/pdr_backend/models/predictoor_batcher.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List -from enforce_typing import enforce_types -from eth_typing import ChecksumAddress -from pdr_backend.models.base_contract import BaseContract -from pdr_backend.util.web3_config import Web3Config - - -@enforce_types -class PredictoorBatcher(BaseContract): - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address, "PredictoorHelper") - - def consume_multiple( - self, - addresses: List[ChecksumAddress], - times: List[int], - token_addr: str, - wait_for_receipt=True, - ): - gasPrice = self.config.w3.eth.gas_price - tx = self.contract_instance.functions.consumeMultiple( - addresses, times, token_addr - ).transact({"from": self.config.owner, "gasPrice": gasPrice, "gas": 14_000_000}) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - - def submit_truevals_contracts( - self, - contract_addrs: List[ChecksumAddress], - epoch_starts: List[List[int]], - trueVals: List[List[bool]], - cancelRounds: List[List[bool]], - wait_for_receipt=True, - ): - gasPrice = self.config.w3.eth.gas_price - tx = self.contract_instance.functions.submitTruevalContracts( - contract_addrs, epoch_starts, trueVals, cancelRounds - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - - def submit_truevals( - self, - contract_addr: ChecksumAddress, - epoch_starts: List[int], - trueVals: List[bool], - cancelRounds: List[bool], - wait_for_receipt=True, - ): - gasPrice = self.config.w3.eth.gas_price - tx = self.contract_instance.functions.submitTruevals( - contract_addr, epoch_starts, trueVals, cancelRounds - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) diff --git a/pdr_backend/models/predictoor_contract.py b/pdr_backend/models/predictoor_contract.py deleted file mode 100644 index 3e0b3bfd6..000000000 --- a/pdr_backend/models/predictoor_contract.py +++ /dev/null @@ -1,348 +0,0 @@ -from typing import List, Tuple - -from enforce_typing import enforce_types -from eth_keys import KeyAPI -from eth_keys.backends import NativeECCBackend - -from pdr_backend.models.fixed_rate import FixedRate -from pdr_backend.models.token import Token -from pdr_backend.models.base_contract import BaseContract -from pdr_backend.util.constants import ZERO_ADDRESS, MAX_UINT -from pdr_backend.util.networkutil import is_sapphire_network, send_encrypted_tx -from pdr_backend.util.web3_config import Web3Config - -_KEYS = KeyAPI(NativeECCBackend) - - -@enforce_types -class PredictoorContract(BaseContract): # pylint: disable=too-many-public-methods - def __init__(self, config: Web3Config, address: str): - super().__init__(config, address, "ERC20Template3") - stake_token = self.get_stake_token() - self.token = Token(config, stake_token) - self.last_allowance = 0 - - def is_valid_subscription(self): - return self.contract_instance.functions.isValidSubscription( - self.config.owner - ).call() - - def getid(self): - return self.contract_instance.functions.getId().call() - - def get_empty_provider_fee(self): - return { - "providerFeeAddress": ZERO_ADDRESS, - "providerFeeToken": ZERO_ADDRESS, - "providerFeeAmount": 0, - "v": 0, - "r": 0, - "s": 0, - "validUntil": 0, - "providerData": 0, - } - - def string_to_bytes32(self, data): - if len(data) > 32: - myBytes32 = data[:32] - else: - myBytes32 = data.ljust(32, "0") - return bytes(myBytes32, "utf-8") - - def get_auth_signature(self): - valid_until = self.config.get_block("latest").timestamp + 3600 - message_hash = self.config.w3.solidity_keccak( - ["address", "uint256"], - [self.config.owner, valid_until], - ) - pk = _KEYS.PrivateKey(self.config.account.key) - prefix = "\x19Ethereum Signed Message:\n32" - signable_hash = self.config.w3.solidity_keccak( - ["bytes", "bytes"], - [ - self.config.w3.to_bytes(text=prefix), - self.config.w3.to_bytes(message_hash), - ], - ) - signed = _KEYS.ecdsa_sign(message_hash=signable_hash, private_key=pk) - auth = { - "userAddress": self.config.owner, - "v": (signed.v + 27) if signed.v <= 1 else signed.v, - "r": self.config.w3.to_hex( - self.config.w3.to_bytes(signed.r).rjust(32, b"\0") - ), - "s": self.config.w3.to_hex( - self.config.w3.to_bytes(signed.s).rjust(32, b"\0") - ), - "validUntil": valid_until, - } - return auth - - def get_max_gas(self): - """Returns max block gas""" - block = self.config.get_block( - self.config.w3.eth.block_number, full_transactions=False - ) - return int(block["gasLimit"] * 0.99) - - def buy_and_start_subscription(self, gasLimit=None, wait_for_receipt=True): - """Buys 1 datatoken and starts a subscription""" - fixed_rates = self.get_exchanges() - if not fixed_rates: - return None - - (fixed_rate_address, exchange_str) = fixed_rates[0] - - # get datatoken price - exchange = FixedRate(self.config, fixed_rate_address) - (baseTokenAmount, _, _, _) = exchange.get_dt_price(exchange_str) - - # approve - self.token.approve(self.contract_instance.address, baseTokenAmount) - - # buy 1 DT - gasPrice = self.config.w3.eth.gas_price - provider_fees = self.get_empty_provider_fee() - try: - orderParams = ( - self.config.owner, - 0, - ( - ZERO_ADDRESS, - ZERO_ADDRESS, - 0, - 0, - self.string_to_bytes32(""), - self.string_to_bytes32(""), - provider_fees["validUntil"], - self.config.w3.to_bytes(b""), - ), - (ZERO_ADDRESS, ZERO_ADDRESS, 0), - ) - freParams = ( - self.config.w3.to_checksum_address(fixed_rate_address), - self.config.w3.to_bytes(exchange_str), - baseTokenAmount, - 0, - ZERO_ADDRESS, - ) - call_params = { - "from": self.config.owner, - "gasPrice": gasPrice, - # 'nonce': self.config.w3.eth.get_transaction_count(self.config.owner), - } - if gasLimit is None: - try: - gasLimit = self.contract_instance.functions.buyFromFreAndOrder( - orderParams, freParams - ).estimate_gas(call_params) - except Exception as e: - print("Estimate gas failed") - print(e) - gasLimit = self.get_max_gas() - call_params["gas"] = gasLimit + 1 - tx = self.contract_instance.functions.buyFromFreAndOrder( - orderParams, freParams - ).transact(call_params) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - except Exception as e: - print(e) - return None - - def buy_many(self, how_many, gasLimit=None, wait_for_receipt=False): - """Buys multiple accesses and returns tx hashes""" - txs = [] - if how_many < 1: - return None - print(f"Buying {how_many} accesses....") - for _ in range(0, how_many): - txs.append(self.buy_and_start_subscription(gasLimit, wait_for_receipt)) - return txs - - def get_exchanges(self): - return self.contract_instance.functions.getFixedRates().call() - - def get_stake_token(self): - return self.contract_instance.functions.stakeToken().call() - - def get_price(self) -> int: - fixed_rates = self.get_exchanges() - if not fixed_rates: - return 0 - (fixed_rate_address, exchange_str) = fixed_rates[0] - # get datatoken price - exchange = FixedRate(self.config, fixed_rate_address) - (baseTokenAmount, _, _, _) = exchange.get_dt_price(exchange_str) - return baseTokenAmount - - def get_current_epoch(self) -> int: - # curEpoch returns the timestamp of current candle start - # this function returns the "epoch number" that increases - # by one each secondsPerEpoch seconds - current_epoch_ts = self.get_current_epoch_ts() - seconds_per_epoch = self.get_secondsPerEpoch() - return int(current_epoch_ts / seconds_per_epoch) - - def get_current_epoch_ts(self) -> int: - """returns the current candle start timestamp""" - return self.contract_instance.functions.curEpoch().call() - - def get_secondsPerEpoch(self) -> int: - return self.contract_instance.functions.secondsPerEpoch().call() - - def get_agg_predval(self, timestamp) -> Tuple[float, float]: - auth = self.get_auth_signature() - (nom_wei, denom_wei) = self.contract_instance.functions.getAggPredval( - timestamp, auth - ).call({"from": self.config.owner}) - nom = float(self.config.w3.from_wei(nom_wei, "ether")) - denom = float(self.config.w3.from_wei(denom_wei, "ether")) - return nom, denom - - def payout_multiple(self, slots: List[int], wait_for_receipt=True): - """Claims the payout for given slots""" - gasPrice = self.config.w3.eth.gas_price - try: - tx = self.contract_instance.functions.payoutMultiple( - slots, self.config.owner - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - except Exception as e: - print(e) - return None - - def payout(self, slot, wait_for_receipt=False): - """Claims the payout for a slot""" - gasPrice = self.config.w3.eth.gas_price - try: - tx = self.contract_instance.functions.payout( - slot, self.config.owner - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - except Exception as e: - print(e) - return None - - def soonest_timestamp_to_predict(self, timestamp): - return self.contract_instance.functions.soonestEpochToPredict(timestamp).call() - - def submit_prediction( - self, - predicted_value: bool, - stake_amount: float, - prediction_ts: int, - wait_for_receipt=True, - ): - """ - Submits a prediction with the specified stake amount, to the contract. - - @param predicted_value: The predicted value (True or False) - @param stake_amount: The amount of ETH to be staked on the prediction - @param prediction_ts: The prediction timestamp == start a candle. - @param wait_for_receipt: - If True, waits for tx receipt after submission. - If False, immediately after sending the transaction. - Default is True. - - @return: - If wait_for_receipt is True, returns the tx receipt. - If False, returns the tx hash immediately after sending. - If an exception occurs during the process, returns None. - """ - amount_wei = self.config.w3.to_wei(str(stake_amount), "ether") - - # Check allowance first, only approve if needed - if self.last_allowance <= 0: - self.last_allowance = self.token.allowance( - self.config.owner, self.contract_address - ) - if self.last_allowance < amount_wei: - try: - self.token.approve(self.contract_address, MAX_UINT) - self.last_allowance = MAX_UINT - except Exception as e: - print("Error while approving the contract to spend tokens:", e) - return None - - gasPrice = self.config.w3.eth.gas_price - try: - txhash = None - if is_sapphire_network(self.config.w3.eth.chain_id): - self.contract_instance.encodeABI( - fn_name="submitPredval", - args=[predicted_value, amount_wei, prediction_ts], - ) - sender = self.config.owner - receiver = self.contract_instance.address - pk = self.config.account.key.hex()[2:] - res, txhash = send_encrypted_tx( - self.contract_instance, - "submitPredval", - [predicted_value, amount_wei, prediction_ts], - pk, - sender, - receiver, - self.config.rpc_url, - 0, - 1000000, - 0, - 0, - ) - print("Encrypted transaction status code:", res) - else: - tx = self.contract_instance.functions.submitPredval( - predicted_value, amount_wei, prediction_ts - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - txhash = tx.hex() - self.last_allowance -= amount_wei - print(f"Submitted prediction, txhash: {txhash}") - if not wait_for_receipt: - return txhash - return self.config.w3.eth.wait_for_transaction_receipt(txhash) - except Exception as e: - print(e) - return None - - def get_trueValSubmitTimeout(self): - return self.contract_instance.functions.trueValSubmitTimeout().call() - - def get_prediction(self, slot: int, address: str): - auth_signature = self.get_auth_signature() - return self.contract_instance.functions.getPrediction( - slot, address, auth_signature - ).call({"from": self.config.owner}) - - def submit_trueval(self, trueval, timestamp, cancel_round, wait_for_receipt=True): - gasPrice = self.config.w3.eth.gas_price - tx = self.contract_instance.functions.submitTrueVal( - timestamp, trueval, cancel_round - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - print(f"Submitted trueval, txhash: {tx.hex()}") - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - - def redeem_unused_slot_revenue(self, timestamp, wait_for_receipt=True): - gasPrice = self.config.w3.eth.gas_price - try: - tx = self.contract_instance.functions.redeemUnusedSlotRevenue( - timestamp - ).transact({"from": self.config.owner, "gasPrice": gasPrice}) - if not wait_for_receipt: - return tx - return self.config.w3.eth.wait_for_transaction_receipt(tx) - except Exception as e: - print(e) - return None - - def get_block(self, block): - return self.config.get_block(block) - - def erc721_addr(self) -> str: - return self.contract_instance.functions.getERC721Address().call() diff --git a/pdr_backend/models/slot.py b/pdr_backend/models/slot.py deleted file mode 100644 index 6257bab12..000000000 --- a/pdr_backend/models/slot.py +++ /dev/null @@ -1,7 +0,0 @@ -from pdr_backend.models.feed import Feed - - -class Slot: - def __init__(self, slot_number: int, feed: Feed): - self.slot_number = slot_number - self.feed = feed diff --git a/pdr_backend/models/test/test_base_config.py b/pdr_backend/models/test/test_base_config.py deleted file mode 100644 index efba917d8..000000000 --- a/pdr_backend/models/test/test_base_config.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -from unittest.mock import patch, Mock - -from enforce_typing import enforce_types - -from pdr_backend.models.base_config import BaseConfig - -PRIV_KEY = os.getenv("PRIVATE_KEY") - -ADDR = "0xe8933f2950aec1080efad1ca160a6bb641ad245d" # predictoor contract addr - -FEED_DICT = { # info inside a predictoor contract - "name": "Contract Name", - "address": ADDR, - "symbol": "test", - "seconds_per_epoch": 300, - "seconds_per_subscription": 60, - "trueval_submit_timeout": 15, - "owner": "0xowner", - "pair": "BTC-ETH", - "timeframe": "1h", - "source": "binance", -} - - -@enforce_types -def test_base_config_with_filters(monkeypatch): - _setenvs(monkeypatch, have_filters=True) - c = BaseConfig() - - assert c.rpc_url == "http://foo" - assert c.subgraph_url == "http://bar" - assert c.private_key == PRIV_KEY - - assert c.pair_filters == ["BTC/USDT", "ETH/USDT"] - assert c.timeframe_filter == ["5m", "15m"] - assert c.source_filter == ["binance", "kraken"] - assert c.owner_addresses == ["0x123", "0x124"] - - assert c.web3_config is not None - - -@enforce_types -def test_base_config_no_filters(monkeypatch): - _setenvs(monkeypatch, have_filters=False) - c = BaseConfig() - assert c.pair_filters == [] - assert c.timeframe_filter == [] - assert c.source_filter == [] - assert c.owner_addresses == [] - - -@enforce_types -def test_base_config_pending_slots(monkeypatch): - _setenvs(monkeypatch, have_filters=False) - c = BaseConfig() - - # test get_pending_slots() - def _mock_get_pending_slots(*args): - timestamp = args[1] - return [f"1_{timestamp}", f"2_{timestamp}"] - - with patch( - "pdr_backend.models.base_config.get_pending_slots", _mock_get_pending_slots - ): - slots = c.get_pending_slots(6789) - assert slots == ["1_6789", "2_6789"] - - -@enforce_types -def test_base_config_feeds_contracts(monkeypatch): - _setenvs(monkeypatch, have_filters=False) - c = BaseConfig() - - # test get_feeds() - def _mock_query_feed_contracts(*args, **kwargs): # pylint: disable=unused-argument - feed_dicts = {ADDR: FEED_DICT} - return feed_dicts - - with patch( - "pdr_backend.models.base_config.query_feed_contracts", - _mock_query_feed_contracts, - ): - feeds = c.get_feeds() - feed_addrs = list(feeds.keys()) - assert feed_addrs == [ADDR] - - # test get_contracts(). Uses results from get_feeds - def _mock_contract(*args, **kwarg): # pylint: disable=unused-argument - m = Mock() - m.contract_address = ADDR - return m - - with patch("pdr_backend.models.base_config.PredictoorContract", _mock_contract): - contracts = c.get_contracts(feed_addrs) - assert list(contracts.keys()) == feed_addrs - assert contracts[ADDR].contract_address == ADDR - - -@enforce_types -def _setenvs(monkeypatch, have_filters: bool): - monkeypatch.setenv("RPC_URL", "http://foo") - monkeypatch.setenv("SUBGRAPH_URL", "http://bar") - monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) - - monkeypatch.setenv("SECONDS_TILL_EPOCH_END", "60") - monkeypatch.setenv("STAKE_AMOUNT", "30000") - - if have_filters: - monkeypatch.setenv("PAIR_FILTER", "BTC/USDT,ETH/USDT") - monkeypatch.setenv("TIMEFRAME_FILTER", "5m,15m") - monkeypatch.setenv("SOURCE_FILTER", "binance,kraken") - monkeypatch.setenv("OWNER_ADDRS", "0x123,0x124") diff --git a/pdr_backend/models/test/test_erc721_factory.py b/pdr_backend/models/test/test_erc721_factory.py deleted file mode 100644 index 28009d694..000000000 --- a/pdr_backend/models/test/test_erc721_factory.py +++ /dev/null @@ -1,54 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.models.erc721_factory import ERC721Factory -from pdr_backend.util.contract import get_address - - -@enforce_types -def test_ERC721Factory(web3_config): - factory = ERC721Factory(web3_config) - assert factory is not None - - ocean_address = get_address(web3_config.w3.eth.chain_id, "Ocean") - fre_address = get_address(web3_config.w3.eth.chain_id, "FixedPrice") - - rate = 3 - cut = 0.2 - - nft_data = ("TestToken", "TT", 1, "", True, web3_config.owner) - erc_data = ( - 3, - ["ERC20Test", "ET"], - [ - web3_config.owner, - web3_config.owner, - web3_config.owner, - ocean_address, - ocean_address, - ], - [2**256 - 1, 0, 300, 3000, 30000], - [], - ) - fre_data = ( - fre_address, - [ - ocean_address, - web3_config.owner, - web3_config.owner, - web3_config.owner, - ], - [ - 18, - 18, - web3_config.w3.to_wei(rate, "ether"), - web3_config.w3.to_wei(cut, "ether"), - 1, - ], - ) - - logs_nft, logs_erc = factory.createNftWithErc20WithFixedRate( - nft_data, erc_data, fre_data - ) - - assert len(logs_nft) > 0 - assert len(logs_erc) > 0 diff --git a/pdr_backend/models/test/test_feed.py b/pdr_backend/models/test/test_feed.py deleted file mode 100644 index 0486bd766..000000000 --- a/pdr_backend/models/test/test_feed.py +++ /dev/null @@ -1,82 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.models.feed import dictToFeed, Feed - - -@enforce_types -def test_feed__construct_directly(): - feed = Feed( - "Contract Name", - "0x12345", - "SYM:TEST", - 300, - 60, - 15, - "0xowner", - "BTC-USDT", - "1h", - "binance", - ) - - assert feed.name == "Contract Name" - assert feed.address == "0x12345" - assert feed.symbol == "SYM:TEST" - assert feed.seconds_per_epoch == 300 - assert feed.seconds_per_subscription == 60 - assert feed.trueval_submit_timeout == 15 - assert feed.owner == "0xowner" - assert feed.pair == "BTC-USDT" - assert feed.timeframe == "1h" - assert feed.source == "binance" - assert feed.quote == "USDT" - assert feed.base == "BTC" - - -@enforce_types -def test_feed__construct_via_dictToFeed(): - feed_dict = { - "name": "Contract Name", - "address": "0x12345", - "symbol": "SYM:TEST", - "seconds_per_epoch": "300", - "seconds_per_subscription": "60", - "trueval_submit_timeout": "15", - "owner": "0xowner", - "pair": "BTC-USDT", - "timeframe": "1h", - "source": "binance", - } - feed = dictToFeed(feed_dict) - assert isinstance(feed, Feed) - - assert feed.name == "Contract Name" - assert feed.address == "0x12345" - assert feed.symbol == "SYM:TEST" - assert feed.seconds_per_epoch == 300 - assert feed.seconds_per_subscription == 60 - assert feed.trueval_submit_timeout == 15 - assert feed.owner == "0xowner" - assert feed.pair == "BTC-USDT" - assert feed.timeframe == "1h" - assert feed.source == "binance" - assert feed.base == "BTC" - assert feed.quote == "USDT" - - # test pair with "/" (versus "-") - feed_dict["pair"] = "BTC/USDT" - feed = dictToFeed(feed_dict) - assert feed.base == "BTC" - assert feed.quote == "USDT" - - # test where ints are int, not str (should work for either) - feed_dict.update( - { - "seconds_per_epoch": 301, - "seconds_per_subscription": 61, - "trueval_submit_timeout": 16, - } - ) - feed = dictToFeed(feed_dict) - assert feed.seconds_per_epoch == 301 - assert feed.seconds_per_subscription == 61 - assert feed.trueval_submit_timeout == 16 diff --git a/pdr_backend/models/test/test_fixed_rate.py b/pdr_backend/models/test/test_fixed_rate.py deleted file mode 100644 index b21752851..000000000 --- a/pdr_backend/models/test/test_fixed_rate.py +++ /dev/null @@ -1,14 +0,0 @@ -from enforce_typing import enforce_types -from pytest import approx - -from pdr_backend.models.fixed_rate import FixedRate - - -@enforce_types -def test_FixedRate(predictoor_contract, web3_config): - exchanges = predictoor_contract.get_exchanges() - address = exchanges[0][0] - id_ = exchanges[0][1] - print(exchanges) - rate = FixedRate(web3_config, address) - assert rate.get_dt_price(id_)[0] / 1e18 == approx(3.603) diff --git a/pdr_backend/models/test/test_predictoor_contract.py b/pdr_backend/models/test/test_predictoor_contract.py deleted file mode 100644 index 5e670ddc2..000000000 --- a/pdr_backend/models/test/test_predictoor_contract.py +++ /dev/null @@ -1,169 +0,0 @@ -from enforce_typing import enforce_types -import pytest -from pytest import approx - -from pdr_backend.conftest_ganache import SECONDS_PER_EPOCH -from pdr_backend.models.token import Token -from pdr_backend.util.contract import get_address - - -@enforce_types -def test_get_id(predictoor_contract): - id_ = predictoor_contract.getid() - assert id_ == 3 - - -@enforce_types -def test_is_valid_subscription_initially(predictoor_contract): - is_valid_sub = predictoor_contract.is_valid_subscription() - assert not is_valid_sub - - -@enforce_types -def test_auth_signature(predictoor_contract): - auth_sig = predictoor_contract.get_auth_signature() - assert "v" in auth_sig - assert "r" in auth_sig - assert "s" in auth_sig - - -@enforce_types -def test_max_gas_limit(predictoor_contract): - max_gas_limit = predictoor_contract.get_max_gas() - # You'll have access to the config object if required, using predictoor_contract.config - expected_limit = int(predictoor_contract.config.get_block("latest").gasLimit * 0.99) - assert max_gas_limit == expected_limit - - -@enforce_types -def test_buy_and_start_subscription(predictoor_contract): - receipt = predictoor_contract.buy_and_start_subscription() - assert receipt["status"] == 1 - is_valid_sub = predictoor_contract.is_valid_subscription() - assert is_valid_sub - - -@enforce_types -def test_buy_many(predictoor_contract): - receipts = predictoor_contract.buy_many(2, None, True) - assert len(receipts) == 2 - - -@enforce_types -def test_get_exchanges(predictoor_contract): - exchanges = predictoor_contract.get_exchanges() - assert exchanges[0][0].startswith("0x") - - -@enforce_types -def test_get_stake_token(predictoor_contract, web3_config): - stake_token = predictoor_contract.get_stake_token() - ocean_address = get_address(web3_config.w3.eth.chain_id, "Ocean") - assert stake_token == ocean_address - - -@enforce_types -def test_get_price(predictoor_contract): - price = predictoor_contract.get_price() - assert price / 1e18 == approx(3.603) - - -@enforce_types -def test_get_current_epoch(predictoor_contract): - current_epoch = predictoor_contract.get_current_epoch() - now = predictoor_contract.config.get_block("latest").timestamp - assert current_epoch == int(now // SECONDS_PER_EPOCH) - - -def test_get_current_epoch_ts(predictoor_contract): - current_epoch = predictoor_contract.get_current_epoch_ts() - now = predictoor_contract.config.get_block("latest").timestamp - assert current_epoch == int(now // SECONDS_PER_EPOCH) * SECONDS_PER_EPOCH - - -@enforce_types -def test_get_seconds_per_epoch(predictoor_contract): - seconds_per_epoch = predictoor_contract.get_secondsPerEpoch() - assert seconds_per_epoch == SECONDS_PER_EPOCH - - -@enforce_types -def test_get_aggpredval(predictoor_contract): - current_epoch = predictoor_contract.get_current_epoch_ts() - aggpredval = predictoor_contract.get_agg_predval(current_epoch) - assert aggpredval == (0, 0) - - -@enforce_types -def test_soonest_timestamp_to_predict(predictoor_contract): - current_epoch = predictoor_contract.get_current_epoch_ts() - soonest_timestamp = predictoor_contract.soonest_timestamp_to_predict(current_epoch) - assert soonest_timestamp == current_epoch + SECONDS_PER_EPOCH * 2 - - -@enforce_types -def test_get_trueValSubmitTimeout(predictoor_contract): - trueValSubmitTimeout = predictoor_contract.get_trueValSubmitTimeout() - assert trueValSubmitTimeout == 3 * 24 * 60 * 60 - - -@enforce_types -def test_get_block(predictoor_contract): - block = predictoor_contract.get_block(0) - assert block.number == 0 - - -@enforce_types -def test_submit_prediction_aggpredval_payout(predictoor_contract, ocean_token: Token): - owner_addr = predictoor_contract.config.owner - balance_before = ocean_token.balanceOf(owner_addr) - current_epoch = predictoor_contract.get_current_epoch_ts() - soonest_timestamp = predictoor_contract.soonest_timestamp_to_predict(current_epoch) - receipt = predictoor_contract.submit_prediction(True, 1, soonest_timestamp, True) - assert receipt["status"] == 1 - - balance_after = ocean_token.balanceOf(owner_addr) - assert balance_before - balance_after == 1e18 - - prediction = predictoor_contract.get_prediction( - soonest_timestamp, predictoor_contract.config.owner - ) - assert prediction[0] - assert prediction[1] == 1e18 - - predictoor_contract.config.w3.provider.make_request( - "evm_increaseTime", [SECONDS_PER_EPOCH * 2] - ) - predictoor_contract.config.w3.provider.make_request("evm_mine", []) - receipt = predictoor_contract.submit_trueval(True, soonest_timestamp, False, True) - assert receipt["status"] == 1 - - receipt = predictoor_contract.payout(soonest_timestamp, True) - assert receipt["status"] == 1 - balance_final = ocean_token.balanceOf(owner_addr) - assert balance_before / 1e18 == approx(balance_final / 1e18) # + sub revenue - - -@enforce_types -def test_redeem_unused_slot_revenue(predictoor_contract): - current_epoch = predictoor_contract.get_current_epoch_ts() - SECONDS_PER_EPOCH * 123 - receipt = predictoor_contract.redeem_unused_slot_revenue(current_epoch, True) - assert receipt["status"] == 1 - - -@pytest.mark.parametrize( - "input_data,expected_output", - [ - ("short", b"short" + b"0" * 27), - ("this is exactly 32 chars", b"this is exactly 32 chars00000000"), - ( - "this is a very long string which is more than 32 chars", - b"this is a very long string which", - ), - ], -) -def test_string_to_bytes32(input_data, expected_output, predictoor_contract): - result = predictoor_contract.string_to_bytes32(input_data) - assert ( - result == expected_output - ), f"For {input_data}, expected {expected_output}, but got {result}" diff --git a/pdr_backend/payout/payout.py b/pdr_backend/payout/payout.py new file mode 100644 index 000000000..f4790c088 --- /dev/null +++ b/pdr_backend/payout/payout.py @@ -0,0 +1,105 @@ +import time +from typing import Any, List + +from enforce_typing import enforce_types + +from pdr_backend.contract.dfrewards import DFRewards +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.contract.wrapped_token import WrappedToken +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_pending_payouts import query_pending_payouts +from pdr_backend.subgraph.subgraph_sync import wait_until_subgraph_syncs +from pdr_backend.util.constants import SAPPHIRE_MAINNET_CHAINID + + +@enforce_types +def batchify(data: List[Any], batch_size: int): + return [data[i : i + batch_size] for i in range(0, len(data), batch_size)] + + +@enforce_types +def request_payout_batches( + predictoor_contract: PredictoorContract, batch_size: int, timestamps: List[int] +): + batches = batchify(timestamps, batch_size) + for batch in batches: + retries = 0 + success = False + + while retries < 5 and not success: + try: + wait_for_receipt = True + predictoor_contract.payout_multiple(batch, wait_for_receipt) + print(".", end="", flush=True) + success = True + except Exception as e: + retries += 1 + print(f"Error: {e}. Retrying... {retries}/5", flush=True) + time.sleep(1) + + if not success: + print("\nFailed after 5 attempts. Moving to next batch.", flush=True) + + print("\nBatch completed") + + +@enforce_types +def do_ocean_payout(ppss: PPSS, check_network: bool = True): + web3_config = ppss.web3_pp.web3_config + subgraph_url: str = ppss.web3_pp.subgraph_url + + if check_network: + assert ppss.web3_pp.network == "sapphire-mainnet" + assert web3_config.w3.eth.chain_id == SAPPHIRE_MAINNET_CHAINID + + print("Starting payout") + wait_until_subgraph_syncs(web3_config, subgraph_url) + print("Finding pending payouts") + pending_payouts = query_pending_payouts(subgraph_url, web3_config.owner) + total_timestamps = sum(len(timestamps) for timestamps in pending_payouts.values()) + print(f"Found {total_timestamps} slots") + + for pdr_contract_addr in pending_payouts: + print(f"Claiming payouts for {pdr_contract_addr}") + pdr_contract = PredictoorContract(ppss.web3_pp, pdr_contract_addr) + request_payout_batches( + pdr_contract, ppss.payout_ss.batch_size, pending_payouts[pdr_contract_addr] + ) + + print("Payout done") + + +@enforce_types +def do_rose_payout(ppss: PPSS, check_network: bool = True): + web3_config = ppss.web3_pp.web3_config + + if check_network: + assert ppss.web3_pp.network == "sapphire-mainnet" + assert web3_config.w3.eth.chain_id == SAPPHIRE_MAINNET_CHAINID + + dfrewards_addr = "0xc37F8341Ac6e4a94538302bCd4d49Cf0852D30C0" + wROSE_addr = "0x8Bc2B030b299964eEfb5e1e0b36991352E56D2D3" + + dfrewards_contract = DFRewards(ppss.web3_pp, dfrewards_addr) + claimable_rewards = dfrewards_contract.get_claimable_rewards( + web3_config.owner, wROSE_addr + ) + print(f"Found {claimable_rewards} wROSE available to claim") + + if claimable_rewards > 0: + print("Claiming wROSE rewards...") + dfrewards_contract.claim_rewards(web3_config.owner, wROSE_addr) + else: + print("No rewards available to claim") + + print("Converting wROSE to ROSE") + time.sleep(10) + wROSE = WrappedToken(ppss.web3_pp, wROSE_addr) + wROSE_bal = wROSE.balanceOf(web3_config.owner) + if wROSE_bal == 0: + print("wROSE balance is 0") + else: + print(f"Found {wROSE_bal/1e18} wROSE, converting to ROSE...") + wROSE.withdraw(wROSE_bal) + + print("ROSE reward claim done") diff --git a/pdr_backend/predictoor/test/test_payout.py b/pdr_backend/payout/test/test_payout.py similarity index 56% rename from pdr_backend/predictoor/test/test_payout.py rename to pdr_backend/payout/test/test_payout.py index b9df3f6e3..91d2cb865 100644 --- a/pdr_backend/predictoor/test/test_payout.py +++ b/pdr_backend/payout/test/test_payout.py @@ -1,19 +1,21 @@ from unittest.mock import Mock, call, patch import pytest +from enforce_typing import enforce_types -from pdr_backend.models.dfrewards import DFRewards -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.wrapped_token import WrappedToken -from pdr_backend.predictoor.payout import ( +from pdr_backend.contract.dfrewards import DFRewards +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.contract.wrapped_token import WrappedToken +from pdr_backend.payout.payout import ( batchify, - do_payout, + do_ocean_payout, do_rose_payout, request_payout_batches, ) -from pdr_backend.util.web3_config import Web3Config +from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str +@enforce_types def test_batchify(): assert batchify([1, 2, 3, 4, 5], 2) == [[1, 2], [3, 4], [5]] assert batchify([], 2) == [] @@ -23,6 +25,7 @@ def test_batchify(): batchify("12345", 2) +@enforce_types def test_request_payout_batches(): mock_contract = Mock(spec=PredictoorContract) mock_contract.payout_multiple = Mock() @@ -43,13 +46,10 @@ def test_request_payout_batches(): assert mock_contract.payout_multiple.call_count == 3 -def test_do_payout(): - mock_config = Mock() - mock_config.subgraph_url = "" - mock_config.web3_config = Mock(spec=Web3Config) - mock_config.web3_config.owner = "mock_owner" - - mock_batch_size = "5" +@enforce_types +def test_do_ocean_payout(tmpdir): + ppss = _ppss(tmpdir) + ppss.payout_ss.set_batch_size(5) mock_pending_payouts = { "address_1": [1, 2, 3], @@ -59,17 +59,13 @@ def test_do_payout(): mock_contract = Mock(spec=PredictoorContract) mock_contract.payout_multiple = Mock() - with patch( - "pdr_backend.predictoor.payout.BaseConfig", return_value=mock_config - ), patch("pdr_backend.predictoor.payout.wait_until_subgraph_syncs"), patch( - "os.getenv", return_value=mock_batch_size - ), patch( - "pdr_backend.predictoor.payout.query_pending_payouts", + with patch("pdr_backend.payout.payout.wait_until_subgraph_syncs"), patch( + "pdr_backend.payout.payout.query_pending_payouts", return_value=mock_pending_payouts, ), patch( - "pdr_backend.predictoor.payout.PredictoorContract", return_value=mock_contract + "pdr_backend.payout.payout.PredictoorContract", return_value=mock_contract ): - do_payout() + do_ocean_payout(ppss, check_network=False) print(mock_contract.payout_multiple.call_args_list) call_args = mock_contract.payout_multiple.call_args_list assert call_args[0] == call([1, 2, 3], True) @@ -79,13 +75,10 @@ def test_do_payout(): assert len(call_args) == 4 -def test_do_rose_payout(): - mock_config = Mock() - mock_config.subgraph_url = "" - mock_config.web3_config = Mock(spec=Web3Config) - mock_config.web3_config.w3 = Mock() - mock_config.web3_config.w3.eth.chain_id = 23294 - mock_config.web3_config.owner = "mock_owner" +@enforce_types +def test_do_rose_payout(tmpdir): + ppss = _ppss(tmpdir) + web3_config = ppss.web3_pp.web3_config mock_contract = Mock(spec=DFRewards) mock_contract.get_claimable_rewards = Mock() @@ -97,16 +90,19 @@ def test_do_rose_payout(): mock_wrose.balanceOf.return_value = 100 mock_wrose.withdraw = Mock() - with patch("pdr_backend.predictoor.payout.time"), patch( - "pdr_backend.predictoor.payout.BaseConfig", return_value=mock_config - ), patch( - "pdr_backend.predictoor.payout.WrappedToken", return_value=mock_wrose - ), patch( - "pdr_backend.predictoor.payout.DFRewards", return_value=mock_contract - ): - do_rose_payout() + with patch("pdr_backend.payout.payout.time"), patch( + "pdr_backend.payout.payout.WrappedToken", return_value=mock_wrose + ), patch("pdr_backend.payout.payout.DFRewards", return_value=mock_contract): + do_rose_payout(ppss, check_network=False) mock_contract.claim_rewards.assert_called_with( - "mock_owner", "0x8Bc2B030b299964eEfb5e1e0b36991352E56D2D3" + web3_config.owner, "0x8Bc2B030b299964eEfb5e1e0b36991352E56D2D3" ) mock_wrose.balanceOf.assert_called() mock_wrose.withdraw.assert_called_with(100) + + +@enforce_types +def _ppss(tmpdir): + s = fast_test_yaml_str(tmpdir) + ppss = PPSS(yaml_str=s, network="sapphire-mainnet") + return ppss diff --git a/pdr_backend/ppss/aimodel_ss.py b/pdr_backend/ppss/aimodel_ss.py new file mode 100644 index 000000000..649c65f2c --- /dev/null +++ b/pdr_backend/ppss/aimodel_ss.py @@ -0,0 +1,59 @@ +import copy + +import numpy as np +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import MultiFeedMixin +from pdr_backend.util.strutil import StrMixin + +APPROACHES = ["LIN", "GPR", "SVR", "NuSVR", "LinearSVR"] + + +@enforce_types +class AimodelSS(MultiFeedMixin, StrMixin): + __STR_OBJDIR__ = ["d"] + FEEDS_KEY = "input_feeds" + + def __init__(self, d: dict): + super().__init__( + d, assert_feed_attributes=["signal"] + ) # yaml_dict["aimodel_ss"] + + # test inputs + if self.approach not in APPROACHES: + raise ValueError(self.approach) + + assert 0 < self.max_n_train + assert 0 < self.autoregressive_n < np.inf + + # -------------------------------- + # yaml properties + + @property + def approach(self) -> str: + return self.d["approach"] # eg "LIN" + + @property + def max_n_train(self) -> int: + return self.d["max_n_train"] # eg 50000. S.t. what data is available + + @property + def autoregressive_n(self) -> int: + return self.d[ + "autoregressive_n" + ] # eg 10. model inputs ar_n past pts z[t-1], .., z[t-ar_n] + + # input feeds defined in base + + # -------------------------------- + # derivative properties + @property + def n(self) -> int: + """Number of input dimensions == # columns in X""" + return self.n_feeds * self.autoregressive_n + + @enforce_types + def copy(self): + d2 = copy.deepcopy(self.d) + + return AimodelSS(d2) diff --git a/pdr_backend/ppss/base_ss.py b/pdr_backend/ppss/base_ss.py new file mode 100644 index 000000000..b32c11d24 --- /dev/null +++ b/pdr_backend/ppss/base_ss.py @@ -0,0 +1,180 @@ +import copy +from typing import Dict, List, Optional, Set, Tuple, Union + +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.cli.arg_pair import ArgPair +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + + +class MultiFeedMixin: + FEEDS_KEY = "" + + @enforce_types + def __init__(self, d: dict, assert_feed_attributes: Optional[List] = None): + assert self.__class__.FEEDS_KEY + self.d = d + feeds = ArgFeeds.from_strs(self.feeds_strs) + + if assert_feed_attributes: + missing_attributes = [] + for attr in assert_feed_attributes: + for feed in feeds: + if not getattr(feed, attr): + missing_attributes.append(attr) + + if missing_attributes: + raise AssertionError( + f"Missing attributes {missing_attributes} for some feeds." + ) + + # -------------------------------- + # yaml properties + @property + def feeds_strs(self) -> List[str]: + nested_attrs = self.__class__.FEEDS_KEY.split(".") + lookup = copy.deepcopy(self.d) + + # Iterate over each attribute in the nesting + for attr in nested_attrs: + try: + # Attempt to access the next level in the dict + lookup = lookup[attr] + except KeyError as exc: + raise ValueError( + f"Could not find nested attribute {attr} in {nested_attrs}" + ) from exc + + assert isinstance(lookup, list) + return lookup # eg ["binance BTC/USDT ohlcv",..] + + # -------------------------------- + + @property + def n_exchs(self) -> int: + return len(self.exchange_strs) + + @property + def exchange_strs(self) -> Set[str]: + return set(str(feed.exchange) for feed in self.feeds) + + @property + def n_feeds(self) -> int: + return len(self.feeds) + + @property + def feeds(self) -> ArgFeeds: + """Return list of ArgFeed(exchange_str, signal_str, pair_str)""" + return ArgFeeds.from_strs(self.feeds_strs) + + @property + def exchange_pair_tups(self) -> Set[Tuple[str, str]]: + """Return set of unique (exchange_str, pair_str) tuples""" + return set((feed.exchange, str(feed.pair)) for feed in self.feeds) + + @enforce_types + def filter_feeds_from_candidates( + self, cand_feeds: Dict[str, SubgraphFeed] + ) -> Dict[str, SubgraphFeed]: + result: Dict[str, SubgraphFeed] = {} + + allowed_tups = [ + (str(feed.exchange), str(feed.pair), str(feed.timeframe)) + for feed in self.feeds + ] + + for sg_key, sg_feed in cand_feeds.items(): + assert isinstance(sg_feed, SubgraphFeed) + + if (sg_feed.source, sg_feed.pair, sg_feed.timeframe) in allowed_tups: + result[sg_key] = sg_feed + + return result + + +class SingleFeedMixin: + FEED_KEY = "" + + def __init__(self, d: dict, assert_feed_attributes: Optional[List] = None): + assert self.__class__.FEED_KEY + self.d = d + if assert_feed_attributes: + for attr in assert_feed_attributes: + assert getattr(self.feed, attr) + + # -------------------------------- + # yaml properties + @property + def feed(self) -> ArgFeed: + """Which feed to use for predictions. Eg "feed1".""" + return ArgFeed.from_str(self.d[self.__class__.FEED_KEY]) + + # -------------------------------- + + @property + def pair_str(self) -> ArgPair: + """Return e.g. 'ETH/USDT'. Only applicable when 1 feed.""" + return self.feed.pair + + @property + def exchange_str(self) -> str: + """Return e.g. 'binance'. Only applicable when 1 feed.""" + return str(self.feed.exchange) + + @property + def exchange_class(self) -> str: + return self.feed.exchange.exchange_class + + @property + def signal_str(self) -> str: + """Return e.g. 'high'. Only applicable when 1 feed.""" + return str(self.feed.signal) if self.feed.signal else "" + + @property + def base_str(self) -> str: + """Return e.g. 'ETH'. Only applicable when 1 feed.""" + return ArgPair(self.pair_str).base_str or "" + + @property + def quote_str(self) -> str: + """Return e.g. 'USDT'. Only applicable when 1 feed.""" + return ArgPair(self.pair_str).quote_str or "" + + @property + def timeframe(self) -> str: + return str(self.feed.timeframe) + + @property + def timeframe_ms(self) -> int: + """Returns timeframe, in ms""" + return self.feed.timeframe.ms if self.feed.timeframe else 0 + + @property + def timeframe_s(self) -> int: + """Returns timeframe, in s""" + return self.feed.timeframe.s if self.feed.timeframe else 0 + + @property + def timeframe_m(self) -> int: + """Returns timeframe, in minutes""" + return self.feed.timeframe.m if self.feed.timeframe else 0 + + @enforce_types + def get_feed_from_candidates( + self, cand_feeds: Dict[str, SubgraphFeed] + ) -> Union[None, SubgraphFeed]: + allowed_tup = ( + self.timeframe, + self.feed.exchange, + self.feed.pair, + ) + + for feed in cand_feeds.values(): + assert isinstance(feed, SubgraphFeed) + feed_tup = (feed.timeframe, feed.source, feed.pair) + if feed_tup == allowed_tup: + return feed + + return None diff --git a/pdr_backend/ppss/dfbuyer_ss.py b/pdr_backend/ppss/dfbuyer_ss.py new file mode 100644 index 000000000..b4a356ed5 --- /dev/null +++ b/pdr_backend/ppss/dfbuyer_ss.py @@ -0,0 +1,54 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import MultiFeedMixin +from pdr_backend.util.strutil import StrMixin + + +class DFBuyerSS(MultiFeedMixin, StrMixin): + __STR_OBJDIR__ = ["d"] + FEEDS_KEY = "feeds" + + @enforce_types + def __init__(self, d: dict): + # yaml_dict["dfbuyer_ss"] + super().__init__(d, assert_feed_attributes=["timeframe"]) + + # -------------------------------- + # yaml properties + @property + def weekly_spending_limit(self) -> int: + """ + Target consume amount in OCEAN per week + """ + return self.d["weekly_spending_limit"] + + @property + def batch_size(self) -> int: + """ + Number of pairs to consume in a batch + """ + return self.d["batch_size"] + + @property + def consume_interval_seconds(self) -> int: + """ + Frequency of consume in seconds + """ + return self.d["consume_interval_seconds"] + + @property + def max_request_tries(self) -> int: + """ + Number of times to retry the request, hardcoded to 5 + """ + return 5 + + # feeds defined in base + + # -------------------------------- + # derived values + @property + def amount_per_interval(self): + return float( + self.weekly_spending_limit / (7 * 24 * 3600) * self.consume_interval_seconds + ) diff --git a/pdr_backend/ppss/lake_ss.py b/pdr_backend/ppss/lake_ss.py new file mode 100644 index 000000000..0e50e060a --- /dev/null +++ b/pdr_backend/ppss/lake_ss.py @@ -0,0 +1,91 @@ +import copy +import os + +import numpy as np +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import MultiFeedMixin +from pdr_backend.util.timeutil import pretty_timestr, timestr_to_ut + + +class LakeSS(MultiFeedMixin): + FEEDS_KEY = "feeds" + + @enforce_types + def __init__(self, d: dict): + # yaml_dict["lake_ss"] + super().__init__(d, assert_feed_attributes=["timeframe"]) + + # handle parquet_dir + assert self.parquet_dir == os.path.abspath(self.parquet_dir) + if not os.path.exists(self.parquet_dir): + print(f"Could not find parquet dir, creating one at: {self.parquet_dir}") + os.makedirs(self.parquet_dir) + + # test inputs + assert ( + 0 + <= timestr_to_ut(self.st_timestr) + <= timestr_to_ut(self.fin_timestr) + <= np.inf + ) + + # -------------------------------- + # yaml properties + @property + def parquet_dir(self) -> str: + s = self.d["parquet_dir"] + if s != os.path.abspath(s): # rel path given; needs an abs path + return os.path.abspath(s) + # abs path given + return s + + @property + def st_timestr(self) -> str: + return self.d["st_timestr"] # eg "2019-09-13_04:00" (earliest) + + @property + def fin_timestr(self) -> str: + return self.d["fin_timestr"] # eg "now","2023-09-23_17:55","2023-09-23" + + # feeds defined in base + + # -------------------------------- + # derivative properties + @property + def st_timestamp(self) -> int: + """ + Return start timestamp, in ut: unix time, in ms, in UTC time zone + Calculated from self.st_timestr. + """ + return timestr_to_ut(self.st_timestr) + + @property + def fin_timestamp(self) -> int: + """ + Return fin timestamp, in ut: unix time, in ms, in UTC time zone + Calculated from self.fin_timestr. + + ** This value will change dynamically if fin_timestr is "now". + """ + return timestr_to_ut(self.fin_timestr) + + @enforce_types + def __str__(self) -> str: + s = "LakeSS:\n" + s += f"feeds_strs={self.feeds_strs}" + s += f" -> n_inputfeeds={self.n_feeds}\n" + s += f"st_timestr={self.st_timestr}" + s += f" -> st_timestamp={pretty_timestr(self.st_timestamp)}\n" + s += f"fin_timestr={self.fin_timestr}" + s += f" -> fin_timestamp={pretty_timestr(self.fin_timestamp)}\n" + s += f" -> n_exchs={self.n_exchs}\n" + s += f"parquet_dir={self.parquet_dir}\n" + s += "-" * 10 + "\n" + return s + + @enforce_types + def copy(self): + d2 = copy.deepcopy(self.d) + + return LakeSS(d2) diff --git a/pdr_backend/ppss/payout_ss.py b/pdr_backend/ppss/payout_ss.py new file mode 100644 index 000000000..5828c6b33 --- /dev/null +++ b/pdr_backend/ppss/payout_ss.py @@ -0,0 +1,22 @@ +from enforce_typing import enforce_types + +from pdr_backend.util.strutil import StrMixin + + +class PayoutSS(StrMixin): + __STR_OBJDIR__ = ["d"] + + @enforce_types + def __init__(self, d: dict): + self.d = d # yaml_dict["payout_ss"] + + # -------------------------------- + # yaml properties + @property + def batch_size(self) -> int: + return self.d["batch_size"] + + # -------------------------------- + # setters + def set_batch_size(self, batch_size: int): + self.d["batch_size"] = batch_size diff --git a/pdr_backend/ppss/ppss.py b/pdr_backend/ppss/ppss.py new file mode 100644 index 000000000..00b79e1b2 --- /dev/null +++ b/pdr_backend/ppss/ppss.py @@ -0,0 +1,229 @@ +import os +import tempfile +from typing import List, Optional, Tuple + +import yaml +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.ppss.dfbuyer_ss import DFBuyerSS +from pdr_backend.ppss.lake_ss import LakeSS +from pdr_backend.ppss.payout_ss import PayoutSS +from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.ppss.publisher_ss import PublisherSS +from pdr_backend.ppss.sim_ss import SimSS +from pdr_backend.ppss.trader_ss import TraderSS +from pdr_backend.ppss.trueval_ss import TruevalSS +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed, mock_feed + + +@enforce_types +class PPSS: # pylint: disable=too-many-instance-attributes + def __init__( + self, + yaml_filename: Optional[str] = None, + yaml_str: Optional[str] = None, + network: Optional[str] = None, # eg "development", "sapphire-testnet" + ): + # preconditions + assert ( + yaml_filename or yaml_str and not (yaml_filename and yaml_str) + ), "need to set yaml_filename_ or yaml_str but not both" + + # get d + if yaml_filename is not None: + with open(yaml_filename, "r") as f: + d = yaml.safe_load(f) + else: + d = yaml.safe_load(str(yaml_str)) + + # fill attributes from d. Same order as ppss.yaml, to help reading + self.lake_ss = LakeSS(d["lake_ss"]) + self.predictoor_ss = PredictoorSS(d["predictoor_ss"]) + self.trader_ss = TraderSS(d["trader_ss"]) + self.sim_ss = SimSS(d["sim_ss"]) + self.publisher_ss = PublisherSS(network, d["publisher_ss"]) + self.trueval_ss = TruevalSS(d["trueval_ss"]) + self.dfbuyer_ss = DFBuyerSS(d["dfbuyer_ss"]) + self.payout_ss = PayoutSS(d["payout_ss"]) + self.web3_pp = Web3PP(d["web3_pp"], network) + + self.verify_feed_dependencies() + + def verify_feed_dependencies(self): + """Raise ValueError if a feed dependency is violated""" + lake_fs = self.lake_ss.feeds + predict_f = self.predictoor_ss.feed + aimodel_fs = self.predictoor_ss.aimodel_ss.feeds + + # is predictoor_ss.predict_feed in lake feeds? + # - check for matching {exchange, pair, timeframe} but not {signal} + # because lake holds all signals o,h,l,c,v + if not lake_fs.contains_combination( + predict_f.exchange, predict_f.pair, predict_f.timeframe + ): + s = "predictoor_ss.predict_feed not in lake_ss.feeds" + s += f"\n lake_ss.feeds = {lake_fs} (ohlcv)" + s += f"\n predictoor_ss.predict_feed = {predict_f}" + raise ValueError(s) + + # do all aimodel_ss input feeds conform to predict feed timeframe? + for aimodel_f in aimodel_fs: + if aimodel_f.timeframe != predict_f.timeframe: + s = "at least one ai_model_ss.input_feeds' timeframe incorrect" + s += f"\n target={predict_f.timeframe}, in predictoor_ss.feed" + s += f"\n found={aimodel_f.timeframe}, in this aimodel feed:" + s += f" {aimodel_f}" + raise ValueError(s) + + # is each predictoor_ss.aimodel_ss.input_feeds in lake feeds? + # - check for matching {exchange, pair, timeframe} but not {signal} + for aimodel_f in aimodel_fs: + if not lake_fs.contains_combination( + aimodel_f.exchange, aimodel_f.pair, aimodel_f.timeframe + ): + s = "at least one aimodel_ss.input_feeds not in lake_ss.feeds" + s += f"\n lake_ss.feeds = {lake_fs} (ohlcv)" + s += f"\n predictoor_ss.ai_model.input_feeds = {aimodel_fs}" + s += f"\n (input_feed not found: {aimodel_f})" + raise ValueError(s) + + # is predictoor_ss.predict_feed in aimodel_ss.input_feeds? + # - check for matching {exchange, pair, timeframe AND signal} + if predict_f not in aimodel_fs: + s = "predictoor_ss.predict_feed not in aimodel_ss.input_feeds" + s += " (accounting for signal too)" + s += f"\n predictoor_ss.ai_model.input_feeds = {aimodel_fs}" + s += f"\n predictoor_ss.predict_feed = {predict_f}" + raise ValueError(s) + + def __str__(self): + s = "" + s += f"lake_ss={self.lake_ss}\n" + s += f"dfbuyer_ss={self.dfbuyer_ss}\n" + s += f"payout_ss={self.payout_ss}\n" + s += f"predictoor_ss={self.predictoor_ss}\n" + s += f"trader_ss={self.trader_ss}\n" + s += f"sim_ss={self.sim_ss}\n" + s += f"trueval_ss={self.trueval_ss}\n" + s += f"web3_pp={self.web3_pp}\n" + return s + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_feed_ppss( + timeframe, + exchange, + pair, + network: Optional[str] = None, + tmpdir=None, +) -> Tuple[SubgraphFeed, PPSS]: + feed = mock_feed(timeframe, exchange, pair) + ppss = mock_ppss( + [f"{exchange} {pair} c {timeframe}"], + network, + tmpdir, + ) + return (feed, ppss) + + +@enforce_types +def mock_ppss( + feeds: List[str], # eg ["binance BTC/USDT ETH/USDT c 5m", "kraken ..."] + network: Optional[str] = None, + tmpdir: Optional[str] = None, + st_timestr: Optional[str] = "2023-06-18", + fin_timestr: Optional[str] = "2023-06-21", +) -> PPSS: + network = network or "development" + yaml_str = fast_test_yaml_str(tmpdir) + + ppss = PPSS(yaml_str=yaml_str, network=network) + + if tmpdir is None: + tmpdir = tempfile.mkdtemp() + + arg_feeds: ArgFeeds = ArgFeeds.from_strs(feeds) + onefeed = str(arg_feeds[0]) # eg "binance BTC/USDT c 5m" + + assert hasattr(ppss, "lake_ss") + ppss.lake_ss = LakeSS( + { + "feeds": feeds, + "parquet_dir": os.path.join(tmpdir, "parquet_data"), + "st_timestr": st_timestr, + "fin_timestr": fin_timestr, + } + ) + + assert hasattr(ppss, "predictoor_ss") + ppss.predictoor_ss = PredictoorSS( + { + "predict_feed": onefeed, + "bot_only": {"s_until_epoch_end": 60, "stake_amount": 1}, + "aimodel_ss": { + "input_feeds": feeds, + "approach": "LIN", + "max_n_train": 7, + "autoregressive_n": 3, + }, + } + ) + + assert hasattr(ppss, "trader_ss") + ppss.trader_ss = TraderSS( + { + "feed": onefeed, + "sim_only": { + "buy_amt": "10 USD", + }, + "bot_only": {"min_buffer": 30, "max_tries": 10, "position_size": 3}, + } + ) + + assert hasattr(ppss, "trueval_ss") + assert "feeds" in ppss.trueval_ss.d + ppss.trueval_ss.d["feeds"] = feeds + + assert hasattr(ppss, "dfbuyer_ss") + ppss.dfbuyer_ss = DFBuyerSS( + { + "feeds": feeds, + "batch_size": 20, + "consume_interval_seconds": 86400, + "weekly_spending_limit": 37000, + } + ) + return ppss + + +_CACHED_YAML_FILE_S = None + + +@enforce_types +def fast_test_yaml_str(tmpdir=None): + """Use this for testing. It has fast runtime.""" + global _CACHED_YAML_FILE_S + if _CACHED_YAML_FILE_S is None: + filename = os.path.abspath("ppss.yaml") + with open(filename) as f: + _CACHED_YAML_FILE_S = f.read() + + s = _CACHED_YAML_FILE_S + + if tmpdir is not None: + assert "parquet_dir: parquet_data" in s + s = s.replace( + "parquet_dir: parquet_data", + f"parquet_dir: {os.path.join(tmpdir, 'parquet_data')}", + ) + + assert "log_dir: logs" in s + s = s.replace("log_dir: logs", f"log_dir: {os.path.join(tmpdir, 'logs')}") + + return s diff --git a/pdr_backend/ppss/predictoor_ss.py b/pdr_backend/ppss/predictoor_ss.py new file mode 100644 index 000000000..0c967032c --- /dev/null +++ b/pdr_backend/ppss/predictoor_ss.py @@ -0,0 +1,30 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import SingleFeedMixin +from pdr_backend.ppss.aimodel_ss import AimodelSS +from pdr_backend.util.strutil import StrMixin + + +class PredictoorSS(SingleFeedMixin, StrMixin): + __STR_OBJDIR__ = ["d"] + FEED_KEY = "predict_feed" + + @enforce_types + def __init__(self, d: dict): + super().__init__(d, assert_feed_attributes=["timeframe"]) + self.aimodel_ss = AimodelSS(d["aimodel_ss"]) + + # -------------------------------- + # yaml properties + @property + def s_until_epoch_end(self) -> int: + return self.d["bot_only"]["s_until_epoch_end"] + + @property + def stake_amount(self) -> int: + return self.d["bot_only"]["stake_amount"] + + # feed defined in base + + # -------------------------------- + # derivative properties diff --git a/pdr_backend/ppss/publisher_ss.py b/pdr_backend/ppss/publisher_ss.py new file mode 100644 index 000000000..921ed368f --- /dev/null +++ b/pdr_backend/ppss/publisher_ss.py @@ -0,0 +1,53 @@ +from typing import Dict + +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import MultiFeedMixin +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.util.strutil import StrMixin + + +class PublisherSS(MultiFeedMixin, StrMixin): + @enforce_types + def __init__(self, network: str, d: dict): + self.network = network # e.g. "sapphire-testnet", "sapphire-mainnet" + self.__class__.FEEDS_KEY = network + ".feeds" + super().__init__( + d, assert_feed_attributes=["signal", "timeframe"] + ) # yaml_dict["publisher_ss"] + + # -------------------------------- + # yaml properties + @property + def fee_collector_address(self) -> str: + """ + Returns the address of FeeCollector of the current network + """ + return self.d[self.network]["fee_collector_address"] + + @enforce_types + def filter_feeds_from_candidates( + self, cand_feeds: Dict[str, SubgraphFeed] + ) -> Dict[str, SubgraphFeed]: + raise NotImplementedError("PublisherSS should not filter subgraph feeds.") + + +@enforce_types +def mock_publisher_ss(network) -> PublisherSS: + if network in ["development", "barge-pytest", "barge-predictoor-bot"]: + feeds = ["binance BTC/USDT ETH/USDT XRP/USDT c 5m"] + else: + # sapphire-testnet, sapphire-mainnet + feeds = [ + "binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT" + " ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT" + " c 5m,1h" + ] + + d = { + network: { + "fee_collector_address": "0x1", + "feeds": feeds, + } + } + return PublisherSS(network, d) diff --git a/pdr_backend/ppss/sim_ss.py b/pdr_backend/ppss/sim_ss.py new file mode 100644 index 000000000..578ff34ca --- /dev/null +++ b/pdr_backend/ppss/sim_ss.py @@ -0,0 +1,41 @@ +import os + +import numpy as np +from enforce_typing import enforce_types + +from pdr_backend.util.strutil import StrMixin + + +@enforce_types +class SimSS(StrMixin): + __STR_OBJDIR__ = ["d"] + + def __init__(self, d: dict): + self.d = d # yaml_dict["sim_ss"] + + # handle log_dir + assert self.log_dir == os.path.abspath(self.log_dir) + if not os.path.exists(self.log_dir): + print(f"Could not find log dir, creating one at: {self.log_dir}") + os.makedirs(self.log_dir) + + if not (0 < int(self.test_n) < np.inf): # pylint: disable=superfluous-parens + raise ValueError(f"test_n={self.test_n}, must be an int >0 and bool: + return self.d["do_plot"] + + @property + def log_dir(self) -> str: + s = self.d["log_dir"] + if s != os.path.abspath(s): # rel path given; needs an abs path + return os.path.abspath(s) + # abs path given + return s + + @property + def test_n(self) -> int: + return self.d["test_n"] # eg 200 diff --git a/pdr_backend/ppss/test/test_aimodel_ss.py b/pdr_backend/ppss/test/test_aimodel_ss.py new file mode 100644 index 000000000..cc7d2f722 --- /dev/null +++ b/pdr_backend/ppss/test/test_aimodel_ss.py @@ -0,0 +1,80 @@ +import re + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.ppss.aimodel_ss import APPROACHES, AimodelSS + + +@enforce_types +def test_aimodel_ss_happy1(): + d = { + "approach": "LIN", + "max_n_train": 7, + "autoregressive_n": 3, + "input_feeds": ["kraken ETH/USDT hc", "binanceus ETH/USDT,TRX/DAI h"], + } + ss = AimodelSS(d) + assert isinstance(ss.copy(), AimodelSS) + + # yaml properties + assert ss.feeds_strs == ["kraken ETH/USDT hc", "binanceus ETH/USDT,TRX/DAI h"] + assert ss.approach == "LIN" + assert ss.max_n_train == 7 + assert ss.autoregressive_n == 3 + + # derivative properties + assert ss.feeds == ArgFeeds( + [ + ArgFeed("kraken", "high", "ETH/USDT"), + ArgFeed("kraken", "close", "ETH/USDT"), + ArgFeed("binanceus", "high", "ETH/USDT"), + ArgFeed("binanceus", "high", "TRX/DAI"), + ] + ) + + # str + assert "AimodelSS" in str(ss) + assert "approach" in str(ss) + + +@enforce_types +def test_aimodel_ss_happy2(): + for approach in APPROACHES: + ss = AimodelSS( + { + "approach": approach, + "max_n_train": 7, + "autoregressive_n": 3, + "input_feeds": ["binance BTC/USDT c"], + } + ) + assert approach in str(ss) + + with pytest.raises(ValueError): + AimodelSS( + { + "approach": "foo_approach", + "max_n_train": 7, + "autoregressive_n": 3, + "input_feeds": ["binance BTC/USDT c"], + } + ) + + +@enforce_types +def test_aimodel_ss_unhappy1(): + d = { + "approach": "LIN", + "max_n_train": 7, + "autoregressive_n": 3, + "input_feeds": ["kraken ETH/USDT"], # missing a signal like "c" + } + + # it should complain that it's missing a signal in input feeds + with pytest.raises( + AssertionError, match=re.escape("Missing attributes ['signal'] for some feeds") + ): + AimodelSS(d) diff --git a/pdr_backend/ppss/test/test_base_ss.py b/pdr_backend/ppss/test/test_base_ss.py new file mode 100644 index 000000000..2c933250d --- /dev/null +++ b/pdr_backend/ppss/test/test_base_ss.py @@ -0,0 +1,286 @@ +import ccxt +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.ppss.base_ss import MultiFeedMixin, SingleFeedMixin +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + + +class MultiFeedMixinTest(MultiFeedMixin): + FEEDS_KEY = "feeds" + + +class NestedMultiFeedMixinTest(MultiFeedMixin): + FEEDS_KEY = "abc.xyz.feeds" + + +class SingleFeedMixinTest(SingleFeedMixin): + FEED_KEY = "feed" + + +@enforce_types +def test_multi_feed(): + d = { + "feeds": ["kraken ETH/USDT hc", "binanceus ETH/USDT,TRX/DAI h"], + } + ss = MultiFeedMixinTest(d) + MultiFeedMixinTest(d, assert_feed_attributes=[]) + MultiFeedMixinTest(d, assert_feed_attributes=["exchange", "pair", "signal"]) + + with pytest.raises(AssertionError): + MultiFeedMixinTest(d, assert_feed_attributes=["timeframe"]) + + assert ss.feeds_strs == ["kraken ETH/USDT hc", "binanceus ETH/USDT,TRX/DAI h"] + assert ss.n_exchs == 2 + assert ss.exchange_strs == {"kraken", "binanceus"} + assert ss.n_feeds == 4 + assert ss.feeds == ArgFeeds( + [ + ArgFeed("kraken", "high", "ETH/USDT"), + ArgFeed("kraken", "close", "ETH/USDT"), + ArgFeed("binanceus", "high", "ETH/USDT"), + ArgFeed("binanceus", "high", "TRX/DAI"), + ] + ) + assert ss.exchange_pair_tups == { + ("kraken", "ETH/USDT"), + ("binanceus", "ETH/USDT"), + ("binanceus", "TRX/DAI"), + } + + +@enforce_types +def test_nested_multi_feed(): + d = { + "abc": { + "xyz": { + "feeds": ["kraken ETH/USDT hc", "binanceus ETH/USDT,TRX/DAI h"], + } + } + } + ss = NestedMultiFeedMixinTest(d) + assert ss.n_feeds == 4 + + wrong_d = { + "abc": { + "feeds": ["kraken ETH/USDT hc", "binanceus ETH/USDT,TRX/DAI h"], + } + } + + with pytest.raises(ValueError, match="Could not find nested attribute xyz"): + ss = NestedMultiFeedMixinTest(wrong_d) + + +@enforce_types +def test_multi_feed_filter(): + d = { + "feeds": ["kraken ETH/USDT 5m", "binanceus ETH/USDT,TRX/DAI 1h"], + } + ss = MultiFeedMixinTest(d) + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "BTC/ETH", + "1h", + "binance", + ) + } + + assert ss.filter_feeds_from_candidates(cand_feeds) == {} + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "5m", + "kraken", + ) + } + + assert ss.filter_feeds_from_candidates(cand_feeds) == cand_feeds + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "5m", + "kraken", + ), + "0x67890": SubgraphFeed( + "", + "0x67890", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "1h", + "binanceus", + ), + } + + assert ss.filter_feeds_from_candidates(cand_feeds) == cand_feeds + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "5m", + "kraken", + ), + "0x67890": SubgraphFeed( + "", + "0x67890", + "test", + 60, + 15, + "0xowner", + "ETH/DAI", + "1h", + "binanceus", + ), + } + + assert ss.filter_feeds_from_candidates(cand_feeds) == { + "0x12345": cand_feeds["0x12345"] + } + + +@enforce_types +def test_single_feed(): + d = {"feed": "kraken ETH/USDT h"} + ss = SingleFeedMixinTest(d) + SingleFeedMixinTest(d, assert_feed_attributes=[]) + SingleFeedMixinTest(d, assert_feed_attributes=["exchange", "pair", "signal"]) + + with pytest.raises(AssertionError): + SingleFeedMixinTest(d, assert_feed_attributes=["timeframe"]) + + d = {"feed": "kraken ETH/USDT 1h"} + ss = SingleFeedMixinTest(d) + assert ss.feed == ArgFeed("kraken", None, "ETH/USDT", "1h") + assert ss.pair_str == "ETH/USDT" + assert ss.exchange_str == "kraken" + assert ss.exchange_class == ccxt.kraken + assert ss.signal_str == "" + assert ss.base_str == "ETH" + assert ss.quote_str == "USDT" + assert ss.timeframe == "1h" + assert ss.timeframe_ms == 3600000 + assert ss.timeframe_s == 3600 + assert ss.timeframe_m == 60 + + +@enforce_types +def test_single_feed_filter(): + d = {"feed": "kraken ETH/USDT 5m"} + ss = SingleFeedMixinTest(d) + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "BTC/ETH", + "1h", + "binance", + ) + } + + assert ss.get_feed_from_candidates(cand_feeds) is None + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "5m", + "kraken", + ) + } + + assert ss.get_feed_from_candidates(cand_feeds) == cand_feeds["0x12345"] + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "5m", + "kraken", + ), + "0x67890": SubgraphFeed( + "", + "0x67890", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "1h", + "binanceus", + ), + } + + assert ss.get_feed_from_candidates(cand_feeds) == cand_feeds["0x12345"] + + cand_feeds = { + "0x12345": SubgraphFeed( + "", + "0x12345", + "test", + 60, + 15, + "0xowner", + "ETH/USDT", + "5m", + "kraken", + ), + "0x67890": SubgraphFeed( + "", + "0x67890", + "test", + 60, + 15, + "0xowner", + "ETH/DAI", + "1h", + "binanceus", + ), + } + + assert ss.get_feed_from_candidates(cand_feeds) == cand_feeds["0x12345"] diff --git a/pdr_backend/ppss/test/test_dfbuyer_ss.py b/pdr_backend/ppss/test/test_dfbuyer_ss.py new file mode 100644 index 000000000..cdcf9832c --- /dev/null +++ b/pdr_backend/ppss/test/test_dfbuyer_ss.py @@ -0,0 +1,16 @@ +from pdr_backend.ppss.dfbuyer_ss import DFBuyerSS + + +def test_trueval_config(): + ss = DFBuyerSS( + { + "batch_size": 42, + "consume_interval_seconds": 42, + "weekly_spending_limit": 42 * 7 * 24 * 3600, + "feeds": ["binance BTC/USDT c 5m"], + } + ) + assert ss.batch_size == 42 + assert ss.weekly_spending_limit == 42 * 7 * 24 * 3600 + assert ss.consume_interval_seconds == 42 + assert ss.amount_per_interval == 42 * 42 diff --git a/pdr_backend/ppss/test/test_lake_ss.py b/pdr_backend/ppss/test/test_lake_ss.py new file mode 100644 index 000000000..4bf256f07 --- /dev/null +++ b/pdr_backend/ppss/test/test_lake_ss.py @@ -0,0 +1,85 @@ +import copy +import os + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_feed import ArgFeed +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.ppss.lake_ss import LakeSS +from pdr_backend.util.timeutil import timestr_to_ut + +_D = { + "feeds": ["kraken ETH/USDT 5m", "binanceus ETH/USDT,TRX/DAI 1h"], + "parquet_dir": "parquet_data", + "st_timestr": "2023-06-18", + "fin_timestr": "2023-06-21", +} + + +@enforce_types +def test_lake_ss_basic(): + ss = LakeSS(_D) + + # yaml properties + assert "parquet_data" in ss.parquet_dir + assert ss.st_timestr == "2023-06-18" + assert ss.fin_timestr == "2023-06-21" + + assert ss.exchange_strs == set(["binanceus", "kraken"]) + + # derivative properties + assert ss.st_timestamp == timestr_to_ut("2023-06-18") + assert ss.fin_timestamp == timestr_to_ut("2023-06-21") + assert ss.feeds == ArgFeeds( + [ + ArgFeed("kraken", None, "ETH/USDT", "5m"), + ArgFeed("binanceus", None, "ETH/USDT", "1h"), + ArgFeed("binanceus", None, "TRX/DAI", "1h"), + ] + ) + assert ss.exchange_pair_tups == set( + [ + ("kraken", "ETH/USDT"), + ("binanceus", "ETH/USDT"), + ("binanceus", "TRX/DAI"), + ] + ) + assert len(ss.feeds) == ss.n_feeds == 3 + assert ss.n_exchs == 2 + assert len(ss.exchange_strs) == 2 + assert "binanceus" in ss.exchange_strs + + # test str + assert "LakeSS" in str(ss) + + assert isinstance(ss.copy(), LakeSS) + + +@enforce_types +def test_lake_ss_now(): + d = copy.deepcopy(_D) + d["fin_timestr"] = "now" + ss = LakeSS(d) + + assert ss.fin_timestr == "now" + + ut2 = timestr_to_ut("now") + assert ss.fin_timestamp / 1000 == pytest.approx(ut2 / 1000, 1.0) + + +@enforce_types +def test_parquet_dir(tmpdir): + # rel path given; needs an abs path + d = copy.deepcopy(_D) + d["parquet_dir"] = "parquet_data" + ss = LakeSS(d) + target_parquet_dir = os.path.abspath("parquet_data") + assert ss.parquet_dir == target_parquet_dir + + # abs path given + d = copy.deepcopy(_D) + d["parquet_dir"] = os.path.join(tmpdir, "parquet_data") + ss = LakeSS(d) + target_parquet_dir = os.path.join(tmpdir, "parquet_data") + assert ss.parquet_dir == target_parquet_dir diff --git a/pdr_backend/ppss/test/test_payout_ss.py b/pdr_backend/ppss/test/test_payout_ss.py new file mode 100644 index 000000000..f33a5e8f9 --- /dev/null +++ b/pdr_backend/ppss/test/test_payout_ss.py @@ -0,0 +1,13 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.payout_ss import PayoutSS + + +@enforce_types +def test_payout_ss(): + d = {"batch_size": 50} + ss = PayoutSS(d) + assert ss.batch_size == 50 + + ss.set_batch_size(5) + assert ss.batch_size == 5 diff --git a/pdr_backend/ppss/test/test_ppss.py b/pdr_backend/ppss/test/test_ppss.py new file mode 100644 index 000000000..5244ecb59 --- /dev/null +++ b/pdr_backend/ppss/test/test_ppss.py @@ -0,0 +1,186 @@ +from copy import deepcopy +import os + +from enforce_typing import enforce_types +import pytest + +from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str, mock_feed_ppss, mock_ppss + + +@enforce_types +def test_ppss_main_from_file(tmpdir): + yaml_str = fast_test_yaml_str(tmpdir) + yaml_filename = os.path.join(tmpdir, "ppss.yaml") + with open(yaml_filename, "a") as f: + f.write(yaml_str) + + _test_ppss(yaml_filename=yaml_filename, network="development") + + +@enforce_types +def test_ppss_main_from_str(tmpdir): + yaml_str = fast_test_yaml_str(tmpdir) + _test_ppss(yaml_str=yaml_str, network="development") + + +@enforce_types +def _test_ppss(yaml_filename=None, yaml_str=None, network=None): + # construct + ppss = PPSS(yaml_filename, yaml_str, network) + + # yaml properties - test lightly, since each *_pp and *_ss has its tests + # - so just do one test for each of this class's pp/ss attribute + assert ppss.trader_ss.timeframe in ["5m", "1h"] + assert isinstance(ppss.lake_ss.st_timestr, str) + assert ppss.dfbuyer_ss.weekly_spending_limit >= 0 + assert ppss.predictoor_ss.aimodel_ss.approach == "LIN" + assert ppss.payout_ss.batch_size >= 0 + assert 1 <= ppss.predictoor_ss.s_until_epoch_end <= 120 + assert isinstance(ppss.sim_ss.do_plot, bool) + assert 0.0 <= ppss.trader_ss.fee_percent <= 0.99 + assert "USD" in ppss.trader_ss.buy_amt_str + assert ppss.trueval_ss.batch_size >= 0 + assert isinstance(ppss.web3_pp.address_file, str) + + # str + s = str(ppss) + assert "lake_ss" in s + assert "dfbuyer_ss" in s + assert "payout_ss" in s + assert "predictoor_ss" in s + assert "sim_ss" in s + assert "trader_ss" in s + assert "trueval_ss" in s + assert "web3_pp" in s + + +@enforce_types +def test_mock_feed_ppss(): + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT", "sapphire-mainnet") + + assert feed.timeframe == "5m" + assert feed.source == "binance" + assert feed.pair == "BTC/USDT" + + assert ppss.predictoor_ss.timeframe == "5m" + assert str(ppss.predictoor_ss.feed) == "binance BTC/USDT c 5m" + assert ppss.lake_ss.feeds_strs == ["binance BTC/USDT c 5m"] + assert ppss.web3_pp.network == "sapphire-mainnet" + + +@enforce_types +def test_mock_ppss_simple(): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet") + assert ppss.web3_pp.network == "sapphire-mainnet" + + +@enforce_types +def test_mock_ppss_default_network_development(): + ppss = mock_ppss(["binance BTC/USDT c 5m"]) + assert ppss.web3_pp.network == "development" + + +@enforce_types +@pytest.mark.parametrize( + "feed_str", + [ + "binance BTC/USDT c 5m", + "binance ETH/USDT c 5m", + "binance BTC/USDT o 5m", + "binance BTC/USDT c 1h", + "kraken ETH/USDT c 5m", + ], +) +def test_mock_ppss_onefeed1(feed_str): + """Thorough test that the 1-feed arg is used everywhere""" + + ppss = mock_ppss([feed_str], "sapphire-mainnet") + + assert ppss.lake_ss.d["feeds"] == [feed_str] + assert ppss.predictoor_ss.d["predict_feed"] == feed_str + assert ppss.predictoor_ss.aimodel_ss.d["input_feeds"] == [feed_str] + assert ppss.trader_ss.d["feed"] == feed_str + assert ppss.trueval_ss.d["feeds"] == [feed_str] + assert ppss.dfbuyer_ss.d["feeds"] == [feed_str] + + ppss.verify_feed_dependencies() + + +@enforce_types +def test_mock_ppss_manyfeed(): + """Thorough test that the many-feed arg is used everywhere""" + + feed_strs = ["binance BTC/USDT ETH/USDT c 5m", "kraken BTC/USDT c 5m"] + feed_str = "binance BTC/USDT c 5m" # must be the first in feed_strs + ppss = mock_ppss(feed_strs, "sapphire-mainnet") + + assert ppss.lake_ss.d["feeds"] == feed_strs + assert ppss.predictoor_ss.d["predict_feed"] == feed_str + assert ppss.predictoor_ss.aimodel_ss.d["input_feeds"] == feed_strs + assert ppss.trader_ss.d["feed"] == feed_str + assert ppss.trueval_ss.d["feeds"] == feed_strs + assert ppss.dfbuyer_ss.d["feeds"] == feed_strs + + ppss.verify_feed_dependencies() + + +@enforce_types +def test_verify_feed_dependencies(): + ppss = mock_ppss( + ["binance BTC/USDT c 5m", "kraken ETH/USDT c 5m"], + "sapphire-mainnet", + ) + ppss.verify_feed_dependencies() + + # don't fail if aimodel needs more ohlcv feeds for same exchange/pair/time + ppss2 = deepcopy(ppss) + ppss2.predictoor_ss.aimodel_ss.d["input_feeds"] = ["binance BTC/USDT ohlcv 5m"] + ppss2.verify_feed_dependencies() + + # fail check: is predictoor_ss.predict_feed in lake feeds? + # - check for matching {exchange, pair, timeframe} but not {signal} + assert "predict_feed" in ppss.predictoor_ss.d + for wrong_feed in [ + "binance BTC/USDT o 5m", + "binance ETH/USDT c 5m", + "binance BTC/USDT c 1h", + "kraken BTC/USDT c 5m", + ]: + ppss2 = deepcopy(ppss) + ppss2.predictoor_ss.d["predict_feed"] = wrong_feed + with pytest.raises(ValueError): + ppss2.verify_feed_dependencies() + + # fail check: do all aimodel_ss input feeds conform to predict feed timeframe? + ppss2 = deepcopy(ppss) + ppss2.predictoor_ss.aimodel_ss.d["input_feeds"] = [ + "binance BTC/USDT c 5m", + "binance BTC/USDT c 1h", + ] # 0th ok, 1st bad + with pytest.raises(ValueError): + ppss2.verify_feed_dependencies() + + # fail check: is each predictoor_ss.aimodel_ss.input_feeds in lake feeds? + # - check for matching {exchange, pair, timeframe} but not {signal} + for wrong_feed in [ + "kraken BTC/USDT c 5m", + "binance ETH/USDT c 5m", + "binance BTC/USDT c 1h", + ]: + ppss2 = deepcopy(ppss) + ppss2.predictoor_ss.aimodel_ss.d["input_feeds"] = [wrong_feed] + with pytest.raises(ValueError): + ppss2.verify_feed_dependencies() + + # fail check: is predictoor_ss.predict_feed in aimodel_ss.input_feeds? + # - check for matching {exchange, pair, timeframe AND signal} + for wrong_feed in [ + "mexc BTC/USDT c 5m", + "binance DOT/USDT c 5m", + "binance BTC/USDT c 1h", + "binance BTC/USDT o 5m", + ]: + ppss2 = deepcopy(ppss) + ppss2.predictoor_ss.d["predict_feed"] = wrong_feed + with pytest.raises(ValueError): + ppss2.verify_feed_dependencies() diff --git a/pdr_backend/ppss/test/test_predictoor_ss.py b/pdr_backend/ppss/test/test_predictoor_ss.py new file mode 100644 index 000000000..a5e2db9ec --- /dev/null +++ b/pdr_backend/ppss/test/test_predictoor_ss.py @@ -0,0 +1,28 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.predictoor_ss import PredictoorSS + + +@enforce_types +def test_predictoor_ss(): + d = { + "predict_feed": "binance BTC/USDT c 5m", + "bot_only": { + "s_until_epoch_end": 60, + "stake_amount": 1, + }, + "aimodel_ss": { + "input_feeds": ["binance BTC/USDT c"], + "approach": "LIN", + "max_n_train": 7, + "autoregressive_n": 3, + }, + } + ss = PredictoorSS(d) + + # yaml properties + assert ss.s_until_epoch_end == 60 + assert ss.stake_amount == 1 + + # str + assert "PredictoorSS" in str(ss) diff --git a/pdr_backend/ppss/test/test_publisher_ss.py b/pdr_backend/ppss/test/test_publisher_ss.py new file mode 100644 index 000000000..e4a276c08 --- /dev/null +++ b/pdr_backend/ppss/test/test_publisher_ss.py @@ -0,0 +1,39 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.ppss.publisher_ss import PublisherSS, mock_publisher_ss + + +@enforce_types +def test_publisher_ss(): + sapphire_feeds = [ + "binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT" + " ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT" + " c 5m,1h" + ] + + d = { + "sapphire-mainnet": { + "fee_collector_address": "0x1", + "feeds": sapphire_feeds, + }, + "sapphire-testnet": { + "fee_collector_address": "0x2", + "feeds": sapphire_feeds, + }, + } + + ss1 = PublisherSS("sapphire-mainnet", d) + assert ss1.fee_collector_address == "0x1" + + ss2 = PublisherSS("sapphire-testnet", d) + assert ss2.fee_collector_address == "0x2" + + with pytest.raises(NotImplementedError): + ss1.filter_feeds_from_candidates({}) + + +@enforce_types +def test_mock_publisher_ss(): + publisher_ss = mock_publisher_ss("development") + assert isinstance(publisher_ss, PublisherSS) diff --git a/pdr_backend/ppss/test/test_sim_ss.py b/pdr_backend/ppss/test/test_sim_ss.py new file mode 100644 index 000000000..0497dc85e --- /dev/null +++ b/pdr_backend/ppss/test/test_sim_ss.py @@ -0,0 +1,54 @@ +import copy +import os +import pytest + +from enforce_typing import enforce_types + +from pdr_backend.ppss.sim_ss import SimSS + +_D = {"do_plot": False, "log_dir": "logs", "test_n": 2} + + +@enforce_types +def test_sim_ss(): + ss = SimSS(_D) + + # yaml properties + assert not ss.do_plot + assert "logs" in ss.log_dir + assert ss.test_n == 2 + + # str + assert "SimSS" in str(ss) + + +@enforce_types +def test_sim_ss_bad(): + bad = copy.deepcopy(_D) + bad["test_n"] = -3 + + with pytest.raises(ValueError): + SimSS(bad) + + bad = copy.deepcopy(_D) + bad["test_n"] = "lit" + + with pytest.raises(ValueError): + SimSS(bad) + + +@enforce_types +def test_log_dir(tmpdir): + # rel path given; needs an abs path + d = copy.deepcopy(_D) + d["log_dir"] = "logs" + ss = SimSS(d) + target_log_dir = os.path.abspath("logs") + assert ss.log_dir == target_log_dir + + # abs path given + d = copy.deepcopy(_D) + d["log_dir"] = os.path.join(tmpdir, "logs") + ss = SimSS(d) + target_log_dir = os.path.join(tmpdir, "logs") + assert ss.log_dir == target_log_dir diff --git a/pdr_backend/ppss/test/test_trader_ss.py b/pdr_backend/ppss/test/test_trader_ss.py new file mode 100644 index 000000000..9b7473115 --- /dev/null +++ b/pdr_backend/ppss/test/test_trader_ss.py @@ -0,0 +1,59 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.trader_ss import TraderSS, inplace_make_trader_fast + +_D = { + "sim_only": { + "buy_amt": "10 USD", + "fee_percent": 0.01, + "init_holdings": ["10000.0 USDT", "0 BTC"], + }, + "bot_only": {"min_buffer": 60, "max_tries": 10, "position_size": 3}, + "feed": "kraken ETH/USDT h 5m", +} + + +@enforce_types +def test_trader_ss(): + ss = TraderSS(_D) + + # yaml properties + assert ss.buy_amt_str == "10 USD" + assert ss.min_buffer == 60 + assert ss.max_tries == 10 + assert ss.position_size == 3 + assert str(ss.feed) == "kraken ETH/USDT h 5m" + assert ss.fee_percent == 0.01 + assert ss.init_holdings_strs == ["10000.0 USDT", "0 BTC"] + + assert ss.signal_str == "high" + assert ss.pair_str == "ETH/USDT" + assert ss.base_str == "ETH" + assert ss.quote_str == "USDT" + + # derivative properties + assert ss.buy_amt_usd == 10.0 + assert ss.init_holdings["USDT"] == 10000.0 + + # setters + ss.set_max_tries(12) + assert ss.max_tries == 12 + + ss.set_min_buffer(59) + assert ss.min_buffer == 59 + + ss.set_position_size(15) + assert ss.position_size == 15 + + # str + assert "TraderSS" in str(ss) + + +@enforce_types +def test_inplace_make_trader_fast(): + ss = TraderSS(_D) + inplace_make_trader_fast(ss) + + assert ss.max_tries == 10 + assert ss.position_size == 10.0 + assert ss.min_buffer == 20 diff --git a/pdr_backend/ppss/test/test_trueval_ss.py b/pdr_backend/ppss/test/test_trueval_ss.py new file mode 100644 index 000000000..5b37b4304 --- /dev/null +++ b/pdr_backend/ppss/test/test_trueval_ss.py @@ -0,0 +1,15 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.trueval_ss import TruevalSS + + +@enforce_types +def test_trueval_ss(): + d = { + "sleep_time": 30, + "batch_size": 50, + "feeds": ["binance BTC/USDT c 5m"], + } + ss = TruevalSS(d) + assert ss.sleep_time == 30 + assert ss.batch_size == 50 diff --git a/pdr_backend/ppss/test/test_web3_pp.py b/pdr_backend/ppss/test/test_web3_pp.py new file mode 100644 index 000000000..7043a2f20 --- /dev/null +++ b/pdr_backend/ppss/test/test_web3_pp.py @@ -0,0 +1,200 @@ +import os +from unittest.mock import Mock, patch + +import pytest +from enforce_typing import enforce_types +from eth_account.signers.local import LocalAccount +from web3 import Web3 + +from pdr_backend.contract.predictoor_contract import mock_predictoor_contract +from pdr_backend.ppss.web3_pp import ( + Web3PP, + inplace_mock_feedgetters, + inplace_mock_get_contracts, + inplace_mock_query_feed_contracts, + mock_web3_pp, +) +from pdr_backend.subgraph.subgraph_feed import mock_feed +from pdr_backend.util.web3_config import Web3Config + +PRIV_KEY = os.getenv("PRIVATE_KEY") + +_D1 = { + "address_file": "address.json 1", + "rpc_url": "rpc url 1", + "subgraph_url": "subgraph url 1", + "owner_addrs": "0xOwner1", +} +_D2 = { + "address_file": "address.json 2", + "rpc_url": "rpc url 2", + "subgraph_url": "subgraph url 2", + "owner_addrs": "0xOwner2", +} +_D = { + "network1": _D1, + "network2": _D2, +} + + +@enforce_types +def test_web3_pp__bad_network(): + with pytest.raises(ValueError): + Web3PP(_D, "bad network") + + +@enforce_types +def test_web3_pp__yaml_dict(): + pp = Web3PP(_D, "network1") + + assert pp.network == "network1" + assert pp.dn == _D1 + assert pp.address_file == "address.json 1" + assert pp.rpc_url == "rpc url 1" + assert pp.subgraph_url == "subgraph url 1" + assert pp.owner_addrs == "0xOwner1" + assert isinstance(pp.account, LocalAccount) + + # network2 + pp2 = Web3PP(_D, "network2") + assert pp2.network == "network2" + assert pp2.dn == _D2 + assert pp2.address_file == "address.json 2" + + +@enforce_types +def test_web3_pp__JIT_cached_properties(monkeypatch): + monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) + web3_pp = Web3PP(_D, "network1") + + # test web3_config + assert web3_pp._web3_config is None + + c = web3_pp.web3_config # calcs & caches web3_pp._web3_config + assert isinstance(c, Web3Config) + assert id(c) == id(web3_pp.web3_config) + assert c.rpc_url == web3_pp.rpc_url + assert c.private_key == PRIV_KEY + assert isinstance(c.w3, Web3) + + # test derived properties + assert web3_pp.private_key == PRIV_KEY + assert isinstance(web3_pp.w3, Web3) + assert id(web3_pp.w3) == id(c.w3) + + # test setter + web3_pp.set_web3_config("foo") + assert web3_pp.web3_config == "foo" + + # str + assert "Web3PP=" in str(web3_pp) + + +@enforce_types +def test_web3_pp__get_pending_slots(monkeypatch): + monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) + web3_pp = Web3PP(_D, "network1") + + def _mock_get_pending_slots(*args, **kwargs): + if len(args) >= 2: + timestamp = args[1] + else: + timestamp = kwargs["timestamp"] + return [f"1_{timestamp}", f"2_{timestamp}"] + + with patch("pdr_backend.ppss.web3_pp.get_pending_slots", _mock_get_pending_slots): + slots = web3_pp.get_pending_slots(6789) + assert slots == ["1_6789", "2_6789"] + + +@enforce_types +def test_web3_pp__query_feed_contracts__get_contracts(monkeypatch): + # test get_feeds() & get_contracts() at once, because one flows into other + monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) + web3_pp = Web3PP(_D, "network1") + + feed = mock_feed("5m", "binance", "BTC/USDT") + + # test get_feeds(). Uses results from get_feeds + def _mock_subgraph_query_feed_contracts( + *args, **kwargs + ): # pylint: disable=unused-argument + return {feed.address: feed} + + with patch( + "pdr_backend.ppss.web3_pp.query_feed_contracts", + _mock_subgraph_query_feed_contracts, + ): + feeds = web3_pp.query_feed_contracts() + + assert list(feeds.keys()) == [feed.address] + + # test get_contracts(). Uses results from get_feeds + def _mock_contract(*args, **kwarg): # pylint: disable=unused-argument + m = Mock() + m.contract_address = feed.address + return m + + with patch( + "pdr_backend.contract.predictoor_contract.PredictoorContract", + _mock_contract, + ): + contracts = web3_pp.get_contracts([feed.address]) + assert list(contracts.keys()) == [feed.address] + assert contracts[feed.address].contract_address == feed.address + + +# ========================================================================= +# test utilities for testing + + +@enforce_types +def test_mock_web3_pp(): + web3_pp = mock_web3_pp("development") + assert isinstance(web3_pp, Web3PP) + assert web3_pp.network == "development" + + web3_pp = mock_web3_pp("sapphire-mainnet") + assert web3_pp.network == "sapphire-mainnet" + + +@enforce_types +def test_inplace_mocks(): + web3_pp = mock_web3_pp("development") + feed = mock_feed("5m", "binance", "BTC/USDT") + + # basic sanity test: can we call it without a fail? + inplace_mock_feedgetters(web3_pp, feed) + inplace_mock_query_feed_contracts(web3_pp, feed) + + c = mock_predictoor_contract(feed.address) + inplace_mock_get_contracts(web3_pp, feed, c) + + +@enforce_types +def test_tx_gas_price__and__tx_call_params(): + web3_pp = mock_web3_pp("sapphire-testnet") + eth_mock = Mock() + eth_mock.gas_price = 12 + web3_pp.web3_config.w3.eth = eth_mock + web3_pp.web3_config.owner = "0xowner" + + web3_pp.network = "sapphire-testnet" + assert web3_pp.tx_gas_price() == 12 + assert web3_pp.tx_call_params() == {"from": "0xowner", "gasPrice": 12} + + web3_pp.network = "sapphire-mainnet" + assert web3_pp.tx_gas_price() == 12 + + web3_pp.network = "development" + assert web3_pp.tx_gas_price() == 0 + assert web3_pp.tx_call_params() == {"from": "0xowner", "gasPrice": 0} + + web3_pp.network = "barge-pytest" + assert web3_pp.tx_gas_price() == 0 + + web3_pp.network = "foo" + with pytest.raises(ValueError): + web3_pp.tx_gas_price() + with pytest.raises(ValueError): + web3_pp.tx_call_params() diff --git a/pdr_backend/ppss/trader_ss.py b/pdr_backend/ppss/trader_ss.py new file mode 100644 index 000000000..c7190b0c8 --- /dev/null +++ b/pdr_backend/ppss/trader_ss.py @@ -0,0 +1,92 @@ +from typing import Dict, List, Union + +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import SingleFeedMixin +from pdr_backend.util.strutil import StrMixin + + +class TraderSS(SingleFeedMixin, StrMixin): + __STR_OBJDIR__ = ["d"] + FEED_KEY = "feed" + + @enforce_types + def __init__(self, d: dict): + super().__init__( + d, assert_feed_attributes=["timeframe"] + ) # yaml_dict["trader_ss"] + + # -------------------------------- + # yaml properties: sim only + @property + def buy_amt_str(self) -> Union[int, float]: + """How much to buy. Eg 10.""" + return self.d["sim_only"]["buy_amt"] + + @property + def fee_percent(self) -> str: + return self.d["sim_only"]["fee_percent"] # Eg 0.001 is 0.1%.Trading fee + + @property + def init_holdings_strs(self) -> List[str]: + return self.d["sim_only"]["init_holdings"] # eg ["1000 USDT", ..] + + # feed defined in base + + # -------------------------------- + # yaml properties: bot only + @property + def min_buffer(self) -> int: + """Only trade if there's > this time left. Denominated in s.""" + return self.d["bot_only"]["min_buffer"] + + @property + def max_tries(self) -> int: + """Max no. attempts to process a feed. Eg 10""" + return self.d["bot_only"]["max_tries"] + + @property + def position_size(self) -> Union[int, float]: + """Trading size. Eg 10""" + return self.d["bot_only"]["position_size"] + + # -------------------------------- + # setters (add as needed) + @enforce_types + def set_max_tries(self, max_tries): + self.d["bot_only"]["max_tries"] = max_tries + + @enforce_types + def set_min_buffer(self, min_buffer): + self.d["bot_only"]["min_buffer"] = min_buffer + + @enforce_types + def set_position_size(self, position_size): + self.d["bot_only"]["position_size"] = position_size + + # -------------------------------- + # derivative properties + @property + def buy_amt_usd(self): + amt_s, _ = self.buy_amt_str.split() + return float(amt_s) + + @property + def init_holdings(self) -> Dict[str, float]: + d = {} + for s in self.init_holdings_strs: + amt_s, coin = s.split() + amt = float(amt_s) + d[coin] = amt + return d + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def inplace_make_trader_fast(trader_ss: TraderSS): + trader_ss.set_max_tries(10) + trader_ss.set_position_size(10.0) + trader_ss.set_min_buffer(20) diff --git a/pdr_backend/ppss/trueval_ss.py b/pdr_backend/ppss/trueval_ss.py new file mode 100644 index 000000000..01042a963 --- /dev/null +++ b/pdr_backend/ppss/trueval_ss.py @@ -0,0 +1,24 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.base_ss import MultiFeedMixin +from pdr_backend.util.strutil import StrMixin + + +@enforce_types +class TruevalSS(MultiFeedMixin, StrMixin): + __STR_OBJDIR__ = ["d"] + FEEDS_KEY = "feeds" + + # -------------------------------- + # yaml properties + @property + def sleep_time(self) -> int: + """# seconds to wait between batches""" + return self.d["sleep_time"] + + @property + def batch_size(self) -> int: + """# slots to process in a batch""" + return self.d["batch_size"] + + # feeds defined in base diff --git a/pdr_backend/ppss/web3_pp.py b/pdr_backend/ppss/web3_pp.py new file mode 100644 index 000000000..fa5edd649 --- /dev/null +++ b/pdr_backend/ppss/web3_pp.py @@ -0,0 +1,302 @@ +import random +from os import getenv +from typing import Any, Dict, List, Optional +from unittest.mock import Mock + +from enforce_typing import enforce_types +from eth_account.signers.local import LocalAccount +from web3 import Web3 + +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.contract.slot import Slot +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.subgraph.subgraph_feed_contracts import query_feed_contracts +from pdr_backend.subgraph.subgraph_pending_slots import get_pending_slots +from pdr_backend.util.strutil import StrMixin +from pdr_backend.util.web3_config import Web3Config + + +class Web3PP(StrMixin): + __STR_OBJDIR__ = ["network", "d"] + + @enforce_types + def __init__(self, d: dict, network: str): + if network not in d: + raise ValueError(f"network '{network}' not found in dict") + + self.network = network # e.g. "sapphire-testnet", "sapphire-mainnet" + self.d = d # yaml_dict["web3_pp"] + + self._web3_config: Optional[Web3Config] = None + + # -------------------------------- + # JIT cached properties - only do the work if requested + # (and therefore don't complain if missing envvar) + + @property + def web3_config(self) -> Web3Config: + if self._web3_config is None: + rpc_url = self.rpc_url + private_key = getenv("PRIVATE_KEY") + self._web3_config = Web3Config(rpc_url, private_key) + return self._web3_config # type: ignore[return-value] + + # -------------------------------- + # yaml properties + @property + def dn(self) -> str: # "d at network". Compact on purpose. + return self.d[self.network] + + @property + def address_file(self) -> str: + return self.dn["address_file"] # type: ignore[index] + + @property + def rpc_url(self) -> str: + return self.dn["rpc_url"] # type: ignore[index] + + @property + def subgraph_url(self) -> str: + return self.dn["subgraph_url"] # type: ignore[index] + + @property + def owner_addrs(self) -> str: + return self.dn["owner_addrs"] # type: ignore[index] + + # -------------------------------- + # setters (add as needed) + @enforce_types + def set_web3_config(self, web3_config): + self._web3_config = web3_config + + # -------------------------------- + # derived properties + @property + def private_key(self) -> Optional[str]: + return self.web3_config.private_key + + @property + def account(self) -> Optional[LocalAccount]: + return self.web3_config.account + + @property + def w3(self) -> Optional[Web3]: + return self.web3_config.w3 + + # -------------------------------- + # onchain feed data + @enforce_types + def query_feed_contracts(self) -> Dict[str, SubgraphFeed]: + """ + @description + Gets all feeds, only filtered by self.owner_addrs + + @return + feeds -- dict of [feed_addr] : SubgraphFeed + """ + feeds = query_feed_contracts( + subgraph_url=self.subgraph_url, + owners_string=self.owner_addrs, + ) + # postconditions + for feed in feeds.values(): + assert isinstance(feed, SubgraphFeed) + return feeds + + @enforce_types + def get_contracts(self, feed_addrs: List[str]) -> Dict[str, Any]: + """ + @description + Get contracts for specified feeds + + @arguments + feed_addrs -- which feeds we want + + @return + contracts -- dict of [feed_addr] : PredictoorContract + """ + # pylint: disable=import-outside-toplevel + from pdr_backend.contract.predictoor_contract import PredictoorContract + + contracts = {} + for addr in feed_addrs: + contracts[addr] = PredictoorContract(self, addr) + return contracts + + @enforce_types + def get_pending_slots( + self, + timestamp: int, + allowed_feeds: Optional[ArgFeeds] = None, + ) -> List[Slot]: + """ + @description + Query chain to get Slots that have status "Pending". + + @return + pending_slots -- List[Slot] + """ + return get_pending_slots( + subgraph_url=self.subgraph_url, + timestamp=timestamp, + owner_addresses=[self.owner_addrs] if self.owner_addrs else None, + allowed_feeds=allowed_feeds, + ) + + @enforce_types + def tx_call_params(self, gas=None) -> dict: + call_params = { + "from": self.web3_config.owner, + "gasPrice": self.tx_gas_price(), + } + if gas is not None: + call_params["gas"] = gas + return call_params + + @enforce_types + def tx_gas_price(self) -> int: + """Return gas price for use in call_params of transaction calls.""" + network = self.network + if network in ["sapphire-testnet", "sapphire-mainnet"]: + return self.web3_config.w3.eth.gas_price + # return 100000000000 + if network in ["development", "barge-predictoor-bot", "barge-pytest"]: + return 0 + raise ValueError(f"Unknown network {network}") + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_web3_pp(network: str) -> Web3PP: + D1 = { + "address_file": "address.json 1", + "rpc_url": "http://example.com/rpc", + "subgraph_url": "http://example.com/subgraph", + "owner_addrs": "0xOwner1", + } + D = { + network: D1, + } + return Web3PP(D, network) + + +@enforce_types +def inplace_mock_feedgetters(web3_pp, feed: SubgraphFeed): + # pylint: disable=import-outside-toplevel + from pdr_backend.contract.predictoor_contract import mock_predictoor_contract + + inplace_mock_query_feed_contracts(web3_pp, feed) + + c = mock_predictoor_contract(feed.address) + inplace_mock_get_contracts(web3_pp, feed, c) + + +@enforce_types +def inplace_mock_query_feed_contracts(web3_pp: Web3PP, feed: SubgraphFeed): + web3_pp.query_feed_contracts = Mock() + web3_pp.query_feed_contracts.return_value = {feed.address: feed} + + +@enforce_types +def inplace_mock_get_contracts( + web3_pp: Web3PP, feed: SubgraphFeed, predictoor_contract +): + # pylint: disable=import-outside-toplevel + from pdr_backend.contract.predictoor_contract import PredictoorContract + + assert isinstance(predictoor_contract, PredictoorContract) + web3_pp.get_contracts = Mock() + web3_pp.get_contracts.return_value = {feed.address: predictoor_contract} + + +@enforce_types +class _MockEthWithTracking: + def __init__(self, init_timestamp: int, init_block_number: int): + self.timestamp: int = init_timestamp + self.block_number: int = init_block_number + self._timestamps_seen: List[int] = [init_timestamp] + + def get_block( + self, block_number: int, full_transactions: bool = False + ): # pylint: disable=unused-argument + mock_block = {"timestamp": self.timestamp} + return mock_block + + +@enforce_types +class _MockPredictoorContractWithTracking: + def __init__(self, w3, s_per_epoch: int, contract_address: str): + self._w3 = w3 + self.s_per_epoch = s_per_epoch + self.contract_address: str = contract_address + self._prediction_slots: List[int] = [] + + def get_current_epoch(self) -> int: + """Returns an epoch number""" + return self.get_current_epoch_ts() // self.s_per_epoch + + def get_current_epoch_ts(self) -> int: + """Returns a timestamp""" + return self._w3.eth.timestamp // self.s_per_epoch * self.s_per_epoch + + def get_secondsPerEpoch(self) -> int: + return self.s_per_epoch + + def submit_prediction( + self, + predicted_value: bool, + stake_amt: float, + prediction_ts: int, + wait_for_receipt: bool = True, + ): # pylint: disable=unused-argument + assert stake_amt <= 3 + if prediction_ts in self._prediction_slots: + print(f" (Replace prev pred at time slot {prediction_ts})") + self._prediction_slots.append(prediction_ts) + + +@enforce_types +def inplace_mock_w3_and_contract_with_tracking( + web3_pp: Web3PP, + init_timestamp: int, + init_block_number: int, + timeframe_s: int, + feed_address: str, + monkeypatch, +): + """ + Updates web3_pp.web3_config.w3 with a mock. + Includes a mock of time.sleep(), which advances the (mock) blockchain + Includes a mock of web3_pp.PredictoorContract(); returns it for convenience + """ + mock_w3 = Mock() # pylint: disable=not-callable + mock_w3.eth = _MockEthWithTracking(init_timestamp, init_block_number) + _mock_pdr_contract = _MockPredictoorContractWithTracking( + mock_w3, + timeframe_s, + feed_address, + ) + + mock_contract_func = Mock() + mock_contract_func.return_value = _mock_pdr_contract + monkeypatch.setattr( + "pdr_backend.contract.predictoor_contract.PredictoorContract", + mock_contract_func, + ) + + def advance_func(*args, **kwargs): # pylint: disable=unused-argument + do_advance_block = random.random() < 0.40 + if do_advance_block: + mock_w3.eth.timestamp += random.randint(3, 12) + mock_w3.eth.block_number += 1 + mock_w3.eth._timestamps_seen.append(mock_w3.eth.timestamp) + + monkeypatch.setattr("time.sleep", advance_func) + + assert hasattr(web3_pp.web3_config, "w3") + web3_pp.web3_config.w3 = mock_w3 + + return _mock_pdr_contract diff --git a/pdr_backend/predictoor/README.md b/pdr_backend/predictoor/README.md deleted file mode 100644 index 603601440..000000000 --- a/pdr_backend/predictoor/README.md +++ /dev/null @@ -1 +0,0 @@ -See [READMEs/predictoor.md](../../READMEs/predictoor.md). diff --git a/pdr_backend/predictoor/approach1/predictoor_agent1.py b/pdr_backend/predictoor/approach1/predictoor_agent1.py index b8594c9e4..4467d4fdc 100644 --- a/pdr_backend/predictoor/approach1/predictoor_agent1.py +++ b/pdr_backend/predictoor/approach1/predictoor_agent1.py @@ -4,22 +4,18 @@ from enforce_typing import enforce_types from pdr_backend.predictoor.base_predictoor_agent import BasePredictoorAgent -from pdr_backend.predictoor.approach1.predictoor_config1 import PredictoorConfig1 @enforce_types class PredictoorAgent1(BasePredictoorAgent): - predictoor_config_class = PredictoorConfig1 - def get_prediction( - self, addr: str, timestamp: int # pylint: disable=unused-argument + self, timestamp: int # pylint: disable=unused-argument ) -> Tuple[bool, float]: """ @description - Given a feed, let's predict for a given timestamp. + Predict for a given timestamp. @arguments - addr -- str -- address of the trading pair. Info in self.feeds[addr] timestamp -- int -- when to make prediction for (unix time) @return diff --git a/pdr_backend/predictoor/approach1/predictoor_config1.py b/pdr_backend/predictoor/approach1/predictoor_config1.py deleted file mode 100644 index f41823619..000000000 --- a/pdr_backend/predictoor/approach1/predictoor_config1.py +++ /dev/null @@ -1,8 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.predictoor.base_predictoor_config import BasePredictoorConfig - - -@enforce_types -class PredictoorConfig1(BasePredictoorConfig): - pass diff --git a/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py b/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py index 0b6e82eb9..f8dbc8d7b 100644 --- a/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py +++ b/pdr_backend/predictoor/approach1/test/test_predictoor_agent1.py @@ -1,167 +1,35 @@ -import os -import random -from typing import List -from unittest.mock import Mock +from unittest.mock import MagicMock +import pytest from enforce_typing import enforce_types -from pdr_backend.predictoor.approach1.predictoor_config1 import PredictoorConfig1 +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.ppss.web3_pp import Web3PP from pdr_backend.predictoor.approach1.predictoor_agent1 import PredictoorAgent1 -from pdr_backend.util.constants import S_PER_MIN, S_PER_DAY +from pdr_backend.predictoor.test.predictoor_agent_runner import run_agent_test -PRIV_KEY = os.getenv("PRIVATE_KEY") -ADDR = "0xe8933f2950aec1080efad1ca160a6bb641ad245d" +def test_predictoor_agent1(tmpdir, monkeypatch): + run_agent_test(str(tmpdir), monkeypatch, PredictoorAgent1) -SOURCE = "binance" -PAIR = "BTC-USDT" -TIMEFRAME, S_PER_EPOCH = "5m", 5 * S_PER_MIN # must change both at once -SECONDS_TILL_EPOCH_END = 60 # how soon to start making predictions? -FEED_S = f"{PAIR}|{SOURCE}|{TIMEFRAME}" -S_PER_SUBSCRIPTION = 1 * S_PER_DAY -FEED_DICT = { # info inside a predictoor contract - "name": f"Feed of {FEED_S}", - "address": ADDR, - "symbol": f"FEED:{FEED_S}", - "seconds_per_epoch": S_PER_EPOCH, - "seconds_per_subscription": S_PER_SUBSCRIPTION, - "trueval_submit_timeout": 15, - "owner": "0xowner", - "pair": PAIR, - "timeframe": TIMEFRAME, - "source": SOURCE, -} -INIT_TIMESTAMP = 107 -INIT_BLOCK_NUMBER = 13 +def test_run(): + mock_predictoor_agent1 = MagicMock(spec=PredictoorAgent1) + take_step = mock_predictoor_agent1.take_step + take_step.return_value = None -@enforce_types -def test_predictoor_agent1(monkeypatch): - _setenvs(monkeypatch) - - # mock query_feed_contracts() - def mock_query_feed_contracts(*args, **kwargs): # pylint: disable=unused-argument - feed_dicts = {ADDR: FEED_DICT} - return feed_dicts - - monkeypatch.setattr( - "pdr_backend.models.base_config.query_feed_contracts", - mock_query_feed_contracts, - ) - - # mock w3.eth.block_number, w3.eth.get_block() - @enforce_types - class MockEth: - def __init__(self): - self.timestamp = INIT_TIMESTAMP - self.block_number = INIT_BLOCK_NUMBER - self._timestamps_seen: List[int] = [INIT_TIMESTAMP] - - def get_block( - self, block_number: int, full_transactions: bool = False - ): # pylint: disable=unused-argument - mock_block = {"timestamp": self.timestamp} - return mock_block - - mock_w3 = Mock() # pylint: disable=not-callable - mock_w3.eth = MockEth() - - # mock PredictoorContract - @enforce_types - def toEpochStart(timestamp: int) -> int: - return timestamp // S_PER_EPOCH * S_PER_EPOCH - - @enforce_types - class MockContract: - def __init__(self, w3): - self._w3 = w3 - self.contract_address: str = ADDR - self._prediction_slots: List[int] = [] - - def get_current_epoch(self) -> int: # returns an epoch number - return self.get_current_epoch_ts() // S_PER_EPOCH - - def get_current_epoch_ts(self) -> int: # returns a timestamp - curEpoch_ts = toEpochStart(self._w3.eth.timestamp) - return curEpoch_ts - - def get_secondsPerEpoch(self) -> int: - return S_PER_EPOCH - - def submit_prediction( - self, predval: bool, stake: float, timestamp: int, wait: bool = True - ): # pylint: disable=unused-argument - assert stake <= 3 - if timestamp in self._prediction_slots: - print(f" (Replace prev pred at time slot {timestamp})") - self._prediction_slots.append(timestamp) - - mock_contract = MockContract(mock_w3) + mock_predictoor_agent1.run() - def mock_contract_func(*args, **kwargs): # pylint: disable=unused-argument - return mock_contract - monkeypatch.setattr( - "pdr_backend.models.base_config.PredictoorContract", mock_contract_func - ) - - # mock time.sleep() - def advance_func(*args, **kwargs): # pylint: disable=unused-argument - do_advance_block = random.random() < 0.40 - if do_advance_block: - mock_w3.eth.timestamp += random.randint(3, 12) - mock_w3.eth.block_number += 1 - mock_w3.eth._timestamps_seen.append(mock_w3.eth.timestamp) - - monkeypatch.setattr("time.sleep", advance_func) - - # now we're done the mocking, time for the real work!! - - # real work: initialize - c = PredictoorConfig1() - agent = PredictoorAgent1(c) - - # last bit of mocking - agent.config.web3_config.w3 = mock_w3 - - # real work: main iterations - for _ in range(1000): - agent.take_step() - - # log some final results for debubbing / inspection - print("\n" + "=" * 80) - print("Done iterations") - print( - f"init block_number = {INIT_BLOCK_NUMBER}" - f", final = {mock_w3.eth.block_number}" - ) - print() - print(f"init timestamp = {INIT_TIMESTAMP}, final = {mock_w3.eth.timestamp}") - print(f"all timestamps seen = {mock_w3.eth._timestamps_seen}") - print() - print( - "unique prediction_slots = " f"{sorted(set(mock_contract._prediction_slots))}" - ) - print(f"all prediction_slots = {mock_contract._prediction_slots}") - - # relatively basic sanity tests - assert mock_contract._prediction_slots - assert (mock_w3.eth.timestamp + 2 * S_PER_EPOCH) >= max( - mock_contract._prediction_slots - ) - - -def _setenvs(monkeypatch): - # envvars handled by PredictoorConfig1 - monkeypatch.setenv("SECONDS_TILL_EPOCH_END", "60") - monkeypatch.setenv("STAKE_AMOUNT", "30000") - - # envvars handled by BaseConfig - monkeypatch.setenv("RPC_URL", "http://foo") - monkeypatch.setenv("SUBGRAPH_URL", "http://bar") - monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) - - monkeypatch.setenv("PAIR_FILTER", PAIR.replace("-", "/")) - monkeypatch.setenv("TIMEFRAME_FILTER", TIMEFRAME) - monkeypatch.setenv("SOURCE_FILTER", SOURCE) - monkeypatch.setenv("OWNER_ADDRS", FEED_DICT["owner"]) +@enforce_types +def test_agent_constructor_empty(): + # test with no feeds + mock_ppss_empty = MagicMock(spec=PPSS) + mock_ppss_empty.predictoor_ss = MagicMock(spec=PredictoorSS) + mock_ppss_empty.predictoor_ss.get_feed_from_candidates.return_value = None + mock_ppss_empty.web3_pp = MagicMock(spec=Web3PP) + mock_ppss_empty.web3_pp.query_feed_contracts.return_value = {} + + with pytest.raises(ValueError, match="No feeds found"): + PredictoorAgent1(mock_ppss_empty) diff --git a/pdr_backend/predictoor/approach1/test/test_predictoor_config1.py b/pdr_backend/predictoor/approach1/test/test_predictoor_config1.py deleted file mode 100644 index b05d6bee4..000000000 --- a/pdr_backend/predictoor/approach1/test/test_predictoor_config1.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -from enforce_typing import enforce_types - -from pdr_backend.predictoor.approach1.predictoor_config1 import PredictoorConfig1 - -ADDR = "0xe8933f2950aec1080efad1ca160a6bb641ad245d" # predictoor contract addr -PRIV_KEY = os.getenv("PRIVATE_KEY") - - -@enforce_types -def test_predictoor_config_basic(monkeypatch): - _setenvs(monkeypatch) - c = PredictoorConfig1() - - # values handled by PredictoorConfig1 - assert c.s_until_epoch_end == 60 - assert c.stake_amount == 30000 - - # values handled by BaseConfig - assert c.rpc_url == "http://foo" - assert c.subgraph_url == "http://bar" - assert c.private_key == PRIV_KEY - - assert c.pair_filters == ["BTC/USDT", "ETH/USDT"] - assert c.timeframe_filter == ["5m", "15m"] - assert c.source_filter == ["binance", "kraken"] - assert c.owner_addresses == ["0x123", "0x124"] - - assert c.web3_config is not None - - -def _setenvs(monkeypatch): - # envvars handled by PredictoorConfig1 - monkeypatch.setenv("SECONDS_TILL_EPOCH_END", "60") - monkeypatch.setenv("STAKE_AMOUNT", "30000") - - # envvars handled by BaseConfig - monkeypatch.setenv("RPC_URL", "http://foo") - monkeypatch.setenv("SUBGRAPH_URL", "http://bar") - monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) - - monkeypatch.setenv("PAIR_FILTER", "BTC/USDT,ETH/USDT") - monkeypatch.setenv("TIMEFRAME_FILTER", "5m,15m") - monkeypatch.setenv("SOURCE_FILTER", "binance,kraken") - monkeypatch.setenv("OWNER_ADDRS", "0x123,0x124") diff --git a/pdr_backend/predictoor/approach2/main2.py b/pdr_backend/predictoor/approach2/main2.py deleted file mode 100644 index 4fe31d84b..000000000 --- a/pdr_backend/predictoor/approach2/main2.py +++ /dev/null @@ -1,270 +0,0 @@ -import csv -import os -from os import getenv -import sys -import time -from typing import List - -import ccxt -import numpy as np -import pandas as pd - -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.predictoor.approach2.predict import predict_function -from pdr_backend.util.env import getenv_or_exit -from pdr_backend.util.subgraph import query_feed_contracts -from pdr_backend.util.web3_config import Web3Config - -# set envvar model MODELDIR before calling main.py. eg ~/code/pdr-model-simple/ -# then, the pickled trained models live in $MODELDIR/trained_models/ -# and, OceanModel module lives in $MODELDIR/model.py -model_dir: str = getenv_or_exit("MODELDIR") -trained_models_dir = os.path.join(model_dir, "trained_models") -sys.path.append(model_dir) -from model import OceanModel # type: ignore # fmt: off # pylint: disable=wrong-import-order, wrong-import-position - -rpc_url = getenv_or_exit("RPC_URL") -subgraph_url = getenv_or_exit("SUBGRAPH_URL") -private_key = getenv_or_exit("PRIVATE_KEY") -pair_filters = getenv("PAIR_FILTER") -timeframe_filter = getenv("TIMEFRAME_FILTER") -source_filter = getenv("SOURCE_FILTER") -owner_addresses = getenv("OWNER_ADDRS") - -exchange_str = "binance" -pair = "BTC/USDT" -timeframe = "5m" - -# =================== -# done imports and constants. Now start running... - -last_block_time = 0 -topics: List[dict] = [] - -exchange_class = getattr(ccxt, exchange_str) -exchange_ccxt = exchange_class({"timeout": 30000}) - -web3_config = Web3Config(rpc_url, private_key) -owner = web3_config.owner - -models = [ - OceanModel(exchange_str, pair, timeframe), -] - - -def process_block(block, model, main_pd): - """ - Process each contract. - If needed, get a prediction, submit it and claim revenue for past epoch - """ - global topics - if not topics: - topics = query_feed_contracts( - subgraph_url, - pair_filters, - timeframe_filter, - source_filter, - owner_addresses, - ) - - print(f"Got new block: {block['number']} with {len(topics)} topics") - - for address in topics: - topic = topics[address] - predictoor_contract = PredictoorContract(web3_config, address) - epoch = predictoor_contract.get_current_epoch() - seconds_per_epoch = predictoor_contract.get_secondsPerEpoch() - seconds_till_epoch_end = ( - epoch * seconds_per_epoch + seconds_per_epoch - block["timestamp"] - ) - print( - f"\t{topic['name']} (at address {topic['address']} is at " - f"epoch {epoch}, seconds_per_epoch: {seconds_per_epoch}" - f", seconds_till_epoch_end: {seconds_till_epoch_end}" - ) - - if seconds_till_epoch_end <= int(getenv("SECONDS_TILL_EPOCH_END", "60")): - # Timestamp of prediction - target_time = (epoch + 2) * seconds_per_epoch - - # Fetch the prediction - (predicted_value, predicted_confidence) = predict_function( - topic, target_time, model, main_pd - ) - - if predicted_value is not None and predicted_confidence > 0: - # We have a prediction, let's submit it - stake_amount = ( - float(getenv("STAKE_AMOUNT", "1")) * predicted_confidence - ) # TO DO have a customizable function to handle this - print( - f"Contract:{predictoor_contract.contract_address} - " - f"Submitting prediction for slot:{target_time}" - ) - predictoor_contract.submit_prediction( - predicted_value, stake_amount, target_time, True - ) - topics[address]["last_submited_epoch"] = epoch - return predicted_value - - print( - "We do not submit, prediction function returned " - f"({predicted_value}, {predicted_confidence})" - ) - return None - - -def log_loop(blockno, model, main_pd): - global last_block_time - block = web3_config.get_block(blockno, full_transactions=False) - if block: - last_block_time = block["timestamp"] - prediction = process_block(block, model, main_pd) - if prediction is not None: - return prediction - return None - - -def do_main2(): # pylint: disable=too-many-statements - print("Starting main loop...") - - ts_now = int(time.time()) - - results_path = "results" - if not os.path.exists(results_path): - os.makedirs(results_path) - - results_csv_name = ( - "./" - + results_path - + "/" - + exchange_str - + "_" - + models[0].pair - + "_" - + models[0].timeframe - + "_" - + str(ts_now) - + ".csv" - ) - - columns_short = ["datetime", "open", "high", "low", "close", "volume"] - - columns_models = [] - for model in models: - model.unpickle_model(trained_models_dir) - columns_models.append(model.model_name) # prediction column. 0 or 1 - - all_columns = columns_short + columns_models - - # write csv header for results - size = 0 - try: - files_stats = os.stat(results_csv_name) - size = files_stats.st_size - except: # pylint: disable=bare-except - pass - if size == 0: - with open(results_csv_name, "a") as f: - writer = csv.writer(f) - writer.writerow(all_columns) - - # read initial set of candles - candles = exchange_ccxt.fetch_ohlcv(pair, "5m") - # load past data - main_pd = pd.DataFrame(columns=all_columns) - for ohl in candles: - ohlc = { - "timestamp": int(ohl[0] / 1000), - "open": float(ohl[1]), - "close": float(ohl[4]), - "low": float(ohl[3]), - "high": float(ohl[2]), - "volume": float(ohl[5]), - } - main_pd.loc[ohlc["timestamp"]] = ohlc - main_pd["datetime"] = pd.to_datetime(main_pd.index.values, unit="s", utc=True) - - lastblock = 0 - last_finalized_timestamp = 0 - while True: - candles = exchange_ccxt.fetch_ohlcv(pair, "5m") - - # update last two candles - for ohl in candles[-2:]: - t = int(ohl[0] / 1000) - main_pd.loc[t, ["datetime"]] = pd.to_datetime(t, unit="s", utc=True) - main_pd.loc[t, ["open"]] = float(ohl[1]) - main_pd.loc[t, ["close"]] = float(ohl[4]) - main_pd.loc[t, ["low"]] = float(ohl[3]) - main_pd.loc[t, ["high"]] = float(ohl[2]) - main_pd.loc[t, ["volume"]] = float(ohl[5]) - - timestamp = main_pd.index.values[-2] - - block = web3_config.w3.eth.block_number - if block > lastblock: - lastblock = block - - # #we have a new candle - if last_finalized_timestamp < timestamp: - last_finalized_timestamp = timestamp - - should_write = False - for model in models: - prediction = main_pd.iloc[-2][model.model_name] - if not np.isnan(prediction): - should_write = True - - if should_write: - with open(results_csv_name, "a") as f: - writer = csv.writer(f) - row = [ - main_pd.index.values[-2], - main_pd.iloc[-2]["datetime"], - main_pd.iloc[-2]["open"], - main_pd.iloc[-2]["high"], - main_pd.iloc[-2]["low"], - main_pd.iloc[-2]["close"], - main_pd.iloc[-2]["volume"], - ] - for model in models: - row.append(main_pd.iloc[-2][model.model_name]) - writer.writerow(row) - - for model in models: - index = main_pd.index.values[-1] - current_prediction = main_pd.iloc[-1][model.model_name] - if np.isnan(current_prediction): - max_retries = 5 - for attempt in range(max_retries): - try: - prediction = log_loop( - block, - model, - main_pd.drop(columns_models + ["datetime"], axis=1), - ) - if prediction is not None: - main_pd.loc[index, [model.model_name]] = float( - prediction - ) - break - except Exception as e: - if attempt < max_retries - 1: - print(f"Attempt {attempt + 1} failed. Retrying...") - continue - print(f"Attempt {attempt + 1} failed. No more retries.") - raise e - - print( - main_pd.loc[ - :, ~main_pd.columns.isin(["volume", "open", "high", "low"]) - ].tail(15) - ) - - else: - time.sleep(1) - - -if __name__ == "__main__": - do_main2() diff --git a/pdr_backend/predictoor/approach2/predict.py b/pdr_backend/predictoor/approach2/predict.py deleted file mode 100644 index 69f36cee8..000000000 --- a/pdr_backend/predictoor/approach2/predict.py +++ /dev/null @@ -1,38 +0,0 @@ -def predict_function(topic, estimated_time, model, main_pd): - """Given a topic, let's predict - Topic object looks like: - - { - "name":"ETH-USDT", - "address":"0x54b5ebeed85f4178c6cb98dd185067991d058d55", - "symbol":"ETH-USDT", - "blocks_per_epoch":"60", - "blocks_per_subscription":"86400", - "last_submited_epoch":0, - "pair":"eth-usdt", - "base":"eth", - "quote":"usdt", - "source":"kraken", - "timeframe":"5m" - } - - """ - print( - f" We were asked to predict {topic['name']} " - f"(contract: {topic['address']}) value " - f"at estimated timestamp: {estimated_time}" - ) - predicted_confidence = None - predicted_value = None - - try: - predicted_value, predicted_confidence = model.predict(main_pd) - predicted_value = bool(predicted_value) - print( - f"Predicting {predicted_value} with a confidence of {predicted_confidence}" - ) - - except Exception as e: - print(e) - - return (predicted_value, predicted_confidence) diff --git a/pdr_backend/predictoor/approach2/test/conftest.py b/pdr_backend/predictoor/approach2/test/conftest.py deleted file mode 100644 index fc9ed034c..000000000 --- a/pdr_backend/predictoor/approach2/test/conftest.py +++ /dev/null @@ -1 +0,0 @@ -from pdr_backend.conftest_ganache import * # pylint: disable=wildcard-import diff --git a/pdr_backend/predictoor/approach2/test/test_predictoor_approach2_predict.py b/pdr_backend/predictoor/approach2/test/test_predictoor_approach2_predict.py deleted file mode 100644 index b8f1f3446..000000000 --- a/pdr_backend/predictoor/approach2/test/test_predictoor_approach2_predict.py +++ /dev/null @@ -1,7 +0,0 @@ -from pdr_backend.predictoor.approach2.predict import ( # pylint: disable=unused-import - predict_function, -) - - -def test_predictoor_approach2_predict_function(): - pass diff --git a/pdr_backend/predictoor/approach2/test/test_predictoor_main2.py b/pdr_backend/predictoor/approach2/test/test_predictoor_main2.py deleted file mode 100644 index 000857a97..000000000 --- a/pdr_backend/predictoor/approach2/test/test_predictoor_main2.py +++ /dev/null @@ -1,6 +0,0 @@ -# don't import this for now, since main.py needs to import an OceanModel from somewhere -# from pdr_backend.predictoor.approach2.main import process_block, log_loop, main - - -def test_predictoor_approach2_main(): - pass diff --git a/pdr_backend/predictoor/approach3/predictoor_agent3.py b/pdr_backend/predictoor/approach3/predictoor_agent3.py index 59bde6f73..2c776606d 100644 --- a/pdr_backend/predictoor/approach3/predictoor_agent3.py +++ b/pdr_backend/predictoor/approach3/predictoor_agent3.py @@ -2,58 +2,37 @@ from enforce_typing import enforce_types -from pdr_backend.data_eng.data_factory import DataFactory -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.model_eng.model_factory import ModelFactory - +from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory +from pdr_backend.aimodel.aimodel_factory import AimodelFactory +from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory from pdr_backend.predictoor.base_predictoor_agent import BasePredictoorAgent -from pdr_backend.predictoor.approach3.predictoor_config3 import PredictoorConfig3 @enforce_types class PredictoorAgent3(BasePredictoorAgent): - predictoor_config_class = PredictoorConfig3 - - def __init__(self, config: PredictoorConfig3): - super().__init__(config) - self.config: PredictoorConfig3 = config - def get_prediction( - self, addr: str, timestamp: int # pylint: disable=unused-argument + self, timestamp: int # pylint: disable=unused-argument ) -> Tuple[bool, float]: """ @description - Given a feed, let's predict for a given timestamp. + Predict for a given timestamp. @arguments - addr -- str -- address of the trading pair. Info in self.feeds[addr] timestamp -- int -- when to make prediction for (unix time) @return predval -- bool -- if True, it's predicting 'up'. If False, 'down' stake -- int -- amount to stake, in units of Eth """ - feed = self.feeds[addr] - - # user-uncontrollable params, at data-eng level - data_pp = DataPP( - feed.timeframe, # eg "5m" - f"{feed.source} c {feed.base}/{feed.quote}", - N_test=1, # N/A for this context - ) - - # user-controllable params, at data-eng level - data_ss = self.config.data_ss.copy_with_yval(data_pp) - - # user-controllable params, at model-eng level - model_ss = self.config.model_ss + # Compute aimodel_ss + lake_ss = self.ppss.lake_ss - # do work... - data_factory = DataFactory(data_pp, data_ss) + # From lake_ss, build X/y + pq_data_factory = OhlcvDataFactory(lake_ss) + mergedohlcv_df = pq_data_factory.get_mergedohlcv_df() - # Compute X/y - hist_df = data_factory.get_hist_df() - X, y, _ = data_factory.create_xy(hist_df, testshift=0) + model_data_factory = AimodelDataFactory(self.ppss.predictoor_ss) + X, y, _ = model_data_factory.create_xy(mergedohlcv_df, testshift=0) # Split X/y into train & test data st, fin = 0, X.shape[0] - 1 @@ -61,15 +40,15 @@ def get_prediction( y_train, _ = y[st:fin], y[fin : fin + 1] # Compute the model from train data - model_factory = ModelFactory(model_ss) - model = model_factory.build(X_train, y_train) + aimodel_factory = AimodelFactory(self.ppss.predictoor_ss.aimodel_ss) + model = aimodel_factory.build(X_train, y_train) # Predict from test data predprice = model.predict(X_test)[0] curprice = y_train[-1] predval = predprice > curprice - # Stake what was set via envvar STAKE_AMOUNT - stake = self.config.stake_amount + # Stake amount + stake = self.ppss.predictoor_ss.stake_amount return (bool(predval), stake) diff --git a/pdr_backend/predictoor/approach3/predictoor_config3.py b/pdr_backend/predictoor/approach3/predictoor_config3.py deleted file mode 100644 index 6b12c8f70..000000000 --- a/pdr_backend/predictoor/approach3/predictoor_config3.py +++ /dev/null @@ -1,36 +0,0 @@ -import os -from enforce_typing import enforce_types - -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.model_eng.model_ss import ModelSS -from pdr_backend.predictoor.base_predictoor_config import BasePredictoorConfig - -# To try different strategies, simply change any of the arguments to any -# of the constructors below. -# -# - It does *not* use envvars PAIR_FILTER, TIMEFRAME_FILTER, or SOURCE_FILTER. -# Why: to avoid ambiguity. Eg is PAIR_FILTER for yval_coin, or input data? - - -@enforce_types -class PredictoorConfig3(BasePredictoorConfig): - def __init__(self): - super().__init__() - - # **Note: the values below are magic numbers. That's ok for now, - # this is how config files work right now. (Will change with ppss.yaml) - self.model_ss = ModelSS("LIN") # LIN, GPR, SVR, NuSVR, LinearSVR - - self.data_ss = DataSS( # user-controllable params, at data-eng level - ["binanceus c BTC/USDT,ETH/USDT"], - csv_dir=os.path.abspath("csvs"), # eg "csvs". abs or rel loc'n of csvs dir - st_timestr="2023-01-31", # eg "2019-09-13_04:00" (earliest), "2019-09-13" - fin_timestr="now", # eg "now", "2023-09-23_17:55", "2023-09-23" - max_n_train=5000, # eg 50000. # if inf, only limited by data available - autoregressive_n=10, # eg 10. model inputs ar_n past pts z[t-1], .., z[t-ar_n] - ) - - # Note: Inside PredictoorAgent3::get_prediction(), - # it's given a yval to predict with {signal, coin, exchange_str}. - # If that yval isn't in data_ss input vars {signals, coins, exchanges} - # then it will update {signals, coins, exchanges} to include it diff --git a/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py b/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py index f1e918691..b9b4af3da 100644 --- a/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py +++ b/pdr_backend/predictoor/approach3/test/test_predictoor_agent3.py @@ -1,171 +1,6 @@ -import os -import random -from typing import List -from unittest.mock import Mock - -from enforce_typing import enforce_types - -from pdr_backend.predictoor.approach3.predictoor_config3 import PredictoorConfig3 from pdr_backend.predictoor.approach3.predictoor_agent3 import PredictoorAgent3 -from pdr_backend.util.constants import S_PER_MIN, S_PER_DAY - -PRIV_KEY = os.getenv("PRIVATE_KEY") - -ADDR = "0xe8933f2950aec1080efad1ca160a6bb641ad245d" - -SOURCE = "binanceus" -PAIR = "BTC-USDT" -TIMEFRAME, S_PER_EPOCH = "5m", 5 * S_PER_MIN # must change both at once -SECONDS_TILL_EPOCH_END = 60 # how soon to start making predictions? -FEED_S = f"{PAIR}|{SOURCE}|{TIMEFRAME}" -S_PER_SUBSCRIPTION = 1 * S_PER_DAY -FEED_DICT = { # info inside a predictoor contract - "name": f"Feed of {FEED_S}", - "address": ADDR, - "symbol": f"FEED:{FEED_S}", - "seconds_per_epoch": S_PER_EPOCH, - "seconds_per_subscription": S_PER_SUBSCRIPTION, - "trueval_submit_timeout": 15, - "owner": "0xowner", - "pair": PAIR, - "timeframe": TIMEFRAME, - "source": SOURCE, -} -INIT_TIMESTAMP = 107 -INIT_BLOCK_NUMBER = 13 - - -@enforce_types -def toEpochStart(timestamp: int) -> int: - return timestamp // S_PER_EPOCH * S_PER_EPOCH - - -@enforce_types -class MockEth: - def __init__(self): - self.timestamp = INIT_TIMESTAMP - self.block_number = INIT_BLOCK_NUMBER - self._timestamps_seen: List[int] = [INIT_TIMESTAMP] - - def get_block( - self, block_number: int, full_transactions: bool = False - ): # pylint: disable=unused-argument - mock_block = {"timestamp": self.timestamp} - return mock_block - - -@enforce_types -class MockContract: - def __init__(self, w3): - self._w3 = w3 - self.contract_address: str = ADDR - self._prediction_slots: List[int] = [] - - def get_current_epoch(self) -> int: # returns an epoch number - return self.get_current_epoch_ts() // S_PER_EPOCH - - def get_current_epoch_ts(self) -> int: # returns a timestamp - curEpoch_ts = toEpochStart(self._w3.eth.timestamp) - return curEpoch_ts - - def get_secondsPerEpoch(self) -> int: - return S_PER_EPOCH - - def submit_prediction( - self, predval: bool, stake: float, timestamp: int, wait: bool = True - ): # pylint: disable=unused-argument - assert stake <= 3 - if timestamp in self._prediction_slots: - print(f" (Replace prev pred at time slot {timestamp})") - self._prediction_slots.append(timestamp) - - -@enforce_types -def test_predictoor_agent3(monkeypatch): - _setenvs(monkeypatch) - - # mock query_feed_contracts() - def mock_query_feed_contracts(*args, **kwargs): # pylint: disable=unused-argument - feed_dicts = {ADDR: FEED_DICT} - return feed_dicts - - monkeypatch.setattr( - "pdr_backend.models.base_config.query_feed_contracts", - mock_query_feed_contracts, - ) - - # mock w3.eth.block_number, w3.eth.get_block() - - mock_w3 = Mock() # pylint: disable=not-callable - mock_w3.eth = MockEth() - - # mock PredictoorContract - mock_contract = MockContract(mock_w3) - - def mock_contract_func(*args, **kwargs): # pylint: disable=unused-argument - return mock_contract - - monkeypatch.setattr( - "pdr_backend.models.base_config.PredictoorContract", mock_contract_func - ) - - # mock time.sleep() - def advance_func(*args, **kwargs): # pylint: disable=unused-argument - do_advance_block = random.random() < 0.40 - if do_advance_block: - mock_w3.eth.timestamp += random.randint(3, 12) - mock_w3.eth.block_number += 1 - mock_w3.eth._timestamps_seen.append(mock_w3.eth.timestamp) - - monkeypatch.setattr("time.sleep", advance_func) - - # now we're done the mocking, time for the real work!! - - # real work: initialize - c = PredictoorConfig3() - agent = PredictoorAgent3(c) - - # last bit of mocking - agent.config.web3_config.w3 = mock_w3 - - # real work: main iterations - for _ in range(1000): - agent.take_step() - - # log some final results for debubbing / inspection - print("\n" + "=" * 80) - print("Done iterations") - print( - f"init block_number = {INIT_BLOCK_NUMBER}" - f", final = {mock_w3.eth.block_number}" - ) - print() - print(f"init timestamp = {INIT_TIMESTAMP}, final = {mock_w3.eth.timestamp}") - print(f"all timestamps seen = {mock_w3.eth._timestamps_seen}") - print() - print( - "unique prediction_slots = " f"{sorted(set(mock_contract._prediction_slots))}" - ) - print(f"all prediction_slots = {mock_contract._prediction_slots}") - - # relatively basic sanity tests - assert mock_contract._prediction_slots - assert (mock_w3.eth.timestamp + 2 * S_PER_EPOCH) >= max( - mock_contract._prediction_slots - ) - - -def _setenvs(monkeypatch): - # envvars handled by PredictoorConfig3 - monkeypatch.setenv("SECONDS_TILL_EPOCH_END", "60") - monkeypatch.setenv("STAKE_AMOUNT", "1") +from pdr_backend.predictoor.test.predictoor_agent_runner import run_agent_test - # envvars handled by BaseConfig - monkeypatch.setenv("RPC_URL", "http://foo") - monkeypatch.setenv("SUBGRAPH_URL", "http://bar") - monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) - monkeypatch.setenv("PAIR_FILTER", PAIR.replace("-", "/")) - monkeypatch.setenv("TIMEFRAME_FILTER", TIMEFRAME) - monkeypatch.setenv("SOURCE_FILTER", SOURCE) - monkeypatch.setenv("OWNER_ADDRS", FEED_DICT["owner"]) +def test_predictoor_agent3(tmpdir, monkeypatch): + run_agent_test(str(tmpdir), monkeypatch, PredictoorAgent3) diff --git a/pdr_backend/predictoor/approach3/test/test_predictoor_config3.py b/pdr_backend/predictoor/approach3/test/test_predictoor_config3.py deleted file mode 100644 index 8c36f326e..000000000 --- a/pdr_backend/predictoor/approach3/test/test_predictoor_config3.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -from enforce_typing import enforce_types - -from pdr_backend.predictoor.approach3.predictoor_config3 import PredictoorConfig3 - -ADDR = "0xe8933f2950aec1080efad1ca160a6bb641ad245d" # predictoor contract addr -PRIV_KEY = os.getenv("PRIVATE_KEY") - - -@enforce_types -def test_predictoor_config_basic(monkeypatch): - _setenvs(monkeypatch) - c = PredictoorConfig3() - - # values handled by PredictoorConfig3 - assert c.s_until_epoch_end == 60 - assert c.stake_amount == 30000 - - # values handled by BaseConfig - assert c.rpc_url == "http://foo" - assert c.subgraph_url == "http://bar" - assert c.private_key == PRIV_KEY - - assert c.pair_filters == ["BTC/USDT", "ETH/USDT"] - assert c.timeframe_filter == ["5m", "15m"] - assert c.source_filter == ["binance", "kraken"] - assert c.owner_addresses == ["0x123", "0x124"] - - assert c.web3_config is not None - - -def _setenvs(monkeypatch): - # envvars handled by PredictoorConfig3 - monkeypatch.setenv("SECONDS_TILL_EPOCH_END", "60") - monkeypatch.setenv("STAKE_AMOUNT", "30000") - - # envvars handled by BaseConfig - monkeypatch.setenv("RPC_URL", "http://foo") - monkeypatch.setenv("SUBGRAPH_URL", "http://bar") - monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) - - monkeypatch.setenv("PAIR_FILTER", "BTC/USDT,ETH/USDT") - monkeypatch.setenv("TIMEFRAME_FILTER", "5m,15m") - monkeypatch.setenv("SOURCE_FILTER", "binance,kraken") - monkeypatch.setenv("OWNER_ADDRS", "0x123,0x124") diff --git a/pdr_backend/predictoor/base_predictoor_agent.py b/pdr_backend/predictoor/base_predictoor_agent.py index 99ad85245..ebcfa308b 100644 --- a/pdr_backend/predictoor/base_predictoor_agent.py +++ b/pdr_backend/predictoor/base_predictoor_agent.py @@ -1,128 +1,161 @@ -from abc import ABC, abstractmethod -import sys +import os import time -from typing import Dict, List, Tuple +from abc import ABC, abstractmethod +from typing import List, Tuple from enforce_typing import enforce_types -from pdr_backend.models.feed import Feed -from pdr_backend.predictoor.base_predictoor_config import BasePredictoorConfig +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_feed import print_feeds +from pdr_backend.util.mathutil import sole_value -@enforce_types class BasePredictoorAgent(ABC): """ What it does - Fetches Predictoor contracts from subgraph, and filters them - Monitors each contract for epoch changes. - - When a value can be predicted, call predict.py::predict_function() + - When a value can be predicted, call get_prediction() """ - def __init__(self, config: BasePredictoorConfig): - self.config = config + @enforce_types + def __init__(self, ppss: PPSS): + # ppss + self.ppss = ppss + print("\n" + "-" * 180) + print(self.ppss) + + # set self.feeds + cand_feeds = ppss.web3_pp.query_feed_contracts() + print_feeds(cand_feeds, f"cand feeds, owner={ppss.web3_pp.owner_addrs}") - self.feeds: Dict[str, Feed] = self.config.get_feeds() # [addr] : Feed + feed = ppss.predictoor_ss.get_feed_from_candidates(cand_feeds) + if not feed: + raise ValueError("No feeds found.") - if not self.feeds: - print("No feeds found. Exiting") - sys.exit() + print_feeds({feed.address: feed}, "filtered feed") + self.feed = feed - feed_addrs = list(self.feeds.keys()) - self.contracts = self.config.get_contracts(feed_addrs) # [addr] : contract + contracts = ppss.web3_pp.get_contracts([feed.address]) + self.feed_contract = sole_value(contracts) + # set attribs to track block self.prev_block_timestamp: int = 0 self.prev_block_number: int = 0 - self.prev_submit_epochs_per_feed: Dict[str, List[int]] = { - addr: [] for addr in self.feeds - } - - print("\n" + "-" * 80) - print("Config:") - print(self.config) - - print("\n" + "." * 80) - print("Feeds (detailed):") - for feed in self.feeds.values(): - print(f" {feed.longstr()}") - - print("\n" + "." * 80) - print("Feeds (succinct):") - for addr, feed in self.feeds.items(): - print(f" {feed}, {feed.seconds_per_epoch} s/epoch, addr={addr}") + self.prev_submit_epochs: List[int] = [] + @enforce_types def run(self): - print("Starting main loop...") + print("Starting main loop.") + print(self.status_str()) + print("Waiting...", end="") while True: self.take_step() + if os.getenv("TEST") == "true": + break + @enforce_types def take_step(self): - w3 = self.config.web3_config.w3 - print("\n" + "-" * 80) - print("Take_step() begin.") - - # new block? - block_number = w3.eth.block_number - print(f" block_number={block_number}, prev={self.prev_block_number}") - if block_number <= self.prev_block_number: - print(" Done step: block_number hasn't advanced yet. So sleep.") + # at new block number yet? + if self.cur_block_number <= self.prev_block_number: + print(".", end="", flush=True) time.sleep(1) return - block = self.config.web3_config.get_block(block_number, full_transactions=False) - if not block: - print(" Done step: block not ready yet") - return - self.prev_block_number = block_number - self.prev_block_timestamp = block["timestamp"] - - # do work at new block - print(f" Got new block. Timestamp={block['timestamp']}") - for addr in self.feeds: - self._process_block_at_feed(addr, block["timestamp"]) - - def _process_block_at_feed(self, addr: str, timestamp: int) -> tuple: - """Returns (predval, stake, submitted)""" - # base data - feed, contract = self.feeds[addr], self.contracts[addr] - epoch = contract.get_current_epoch() - s_per_epoch = feed.seconds_per_epoch - epoch_s_left = epoch * s_per_epoch + s_per_epoch - timestamp - # print status - print(f" Process {feed} at epoch={epoch}") + # is new block ready yet? + if not self.cur_block: + return + self.prev_block_number = self.cur_block_number + self.prev_block_timestamp = self.cur_timestamp # within the time window to predict? - print( - f" {epoch_s_left} s left in epoch" - f" (predict if <= {self.config.s_until_epoch_end} s left)" - ) - too_early = epoch_s_left > self.config.s_until_epoch_end - if too_early: - print(" Done feed: too early to predict") - return (None, None, False) + if self.cur_epoch_s_left > self.epoch_s_thr: + return + + print() + print(self.status_str()) # compute prediction; exit if no good - target_time = (epoch + 2) * s_per_epoch - print(f" Predict for time slot = {target_time}...") + submit_epoch, target_slot = self.cur_epoch, self.target_slot + print(f"Predict for time slot = {self.target_slot}...") - predval, stake = self.get_prediction(addr, target_time) - print(f" -> Predict result: predval={predval}, stake={stake}") + predval, stake = self.get_prediction(target_slot) + print(f"-> Predict result: predval={predval}, stake={stake}") if predval is None or stake <= 0: - print(" Done feed: can't use predval/stake") - return (None, None, False) + print("Done: can't use predval/stake") + return # submit prediction to chain - print(" Submit predict tx chain...") - contract.submit_prediction(predval, stake, target_time, True) - self.prev_submit_epochs_per_feed[addr].append(epoch) - print(" " + "=" * 80) - print(" -> Submit predict tx result: success.") - print(" " + "=" * 80) - print(" Done feed: success.") - return (predval, stake, True) + print("Submit predict tx to chain...") + self.feed_contract.submit_prediction( + predval, + stake, + target_slot, + wait_for_receipt=True, + ) + self.prev_submit_epochs.append(submit_epoch) + print("-> Submit predict tx result: success.") + print("" + "=" * 180) + + # start printing for next round + print(self.status_str()) + print("Waiting...", end="") + + @property + def cur_epoch(self) -> int: + return self.feed_contract.get_current_epoch() + + @property + def cur_block(self): + return self.ppss.web3_pp.web3_config.get_block( + self.cur_block_number, full_transactions=False + ) + + @property + def cur_block_number(self) -> int: + return self.ppss.web3_pp.w3.eth.block_number + + @property + def cur_timestamp(self) -> int: + return self.cur_block["timestamp"] + + @property + def epoch_s_thr(self): + """Start predicting if there's > this time left""" + return self.ppss.predictoor_ss.s_until_epoch_end + + @property + def s_per_epoch(self) -> int: + return self.feed.seconds_per_epoch + + @property + def next_slot(self) -> int: # a timestamp + return (self.cur_epoch + 1) * self.s_per_epoch + + @property + def target_slot(self) -> int: # a timestamp + return (self.cur_epoch + 2) * self.s_per_epoch + + @property + def cur_epoch_s_left(self) -> int: + return self.next_slot - self.cur_timestamp + + def status_str(self) -> str: + s = "" + s += f"cur_epoch={self.cur_epoch}" + s += f", cur_block_number={self.cur_block_number}" + s += f", cur_timestamp={self.cur_timestamp}" + s += f", next_slot={self.next_slot}" + s += f", target_slot={self.target_slot}" + s += f". {self.cur_epoch_s_left} s left in epoch" + s += f" (predict if <= {self.epoch_s_thr} s left)" + s += f". s_per_epoch={self.s_per_epoch}" + return s @abstractmethod def get_prediction( - self, addr: str, timestamp: int # pylint: disable=unused-argument + self, + timestamp: int, # pylint: disable=unused-argument ) -> Tuple[bool, float]: pass diff --git a/pdr_backend/predictoor/base_predictoor_config.py b/pdr_backend/predictoor/base_predictoor_config.py deleted file mode 100644 index ed116b08c..000000000 --- a/pdr_backend/predictoor/base_predictoor_config.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -## About SECONDS_TILL_EPOCH_END - -If we want to predict the value for epoch E, we need to do it in epoch E - 2 -(latest. Though we could predict for a distant future epoch if desired) - -And to do so, our tx needs to be confirmed in the last block of epoch -(otherwise, it's going to be part of next epoch and our prediction tx - will revert) - -But, for every prediction, there are several steps. Each takes time: -- time to compute prediction (e.g. run model inference) -- time to generate the tx -- time until your pending tx in mempool is picked by miner -- time until your tx is confirmed in a block - -To help, you can set envvar `SECONDS_TILL_EPOCH_END`. It controls how many -seconds in advance of the epoch ending you want the prediction process to -start. A predictoor can submit multiple predictions. However, only the final -submission made before the deadline is considered valid. - -To clarify further: if this value is set to 60, the predictoor will be asked -to predict in every block during the last 60 seconds before the epoch -concludes. -""" - -from abc import ABC -from os import getenv - -from enforce_typing import enforce_types - -from pdr_backend.models.base_config import BaseConfig -from pdr_backend.util.strutil import StrMixin - - -@enforce_types -class BasePredictoorConfig(BaseConfig, ABC, StrMixin): - def __init__(self): - super().__init__() - self.s_until_epoch_end = int(getenv("SECONDS_TILL_EPOCH_END", "60")) - - # For approach 1 stake amount is randomly determined this has no effect. - # For approach 2 stake amount is determined by: - # `STAKE_AMOUNT * confidence` where confidence is between 0 and 1. - # For approach 3 this is the stake amount. - self.stake_amount = float(getenv("STAKE_AMOUNT", "1")) # stake amount in eth diff --git a/pdr_backend/predictoor/main.py b/pdr_backend/predictoor/main.py deleted file mode 100644 index 7791a5f9c..000000000 --- a/pdr_backend/predictoor/main.py +++ /dev/null @@ -1,68 +0,0 @@ -import importlib -import sys - -from enforce_typing import enforce_types - -HELP = """Predictoor runner. - -Usage: python pdr_backend/predictoor/main.py APPROACH - - where APPROACH=1 - does random predictions - APPROACH=2 - uses a static model to predict. Needs MODELDIR specified. - APPROACH=3 - uses a dynamic model to predict - APPROACH=payout - claim all unclaimed payouts. - APPROACH=roseclaim - claim ROSE DF rewards. -""" - - -@enforce_types -def do_help(): - print(HELP) - sys.exit() - - -@enforce_types -def do_main(): - if len(sys.argv) <= 1: - do_help() - - arg1 = sys.argv[1] - if arg1 in ["1", "3"]: # approach1, approach3 - agent_module = importlib.import_module( - f"pdr_backend.predictoor.approach{arg1}.predictoor_agent{arg1}" - ) - agent_class = getattr(agent_module, f"PredictoorAgent{arg1}") - config_class = agent_class.predictoor_config_class - config = config_class() - agent = agent_class(config) - agent.run() - - elif arg1 == "2": # approach2 - # To be integrated similar to "1" - from pdr_backend.predictoor.approach2.main2 import ( # pylint: disable=import-outside-toplevel,line-too-long - do_main2, - ) - - do_main2() - - elif arg1 == "payout": - # pylint: disable=import-outside-toplevel - from pdr_backend.predictoor.payout import do_payout - - do_payout() - - elif arg1 == "roseclaim": - # pylint: disable=import-outside-toplevel - from pdr_backend.predictoor.payout import do_rose_payout - - do_rose_payout() - - elif arg1 == "help": - do_help() - - else: - do_help() - - -if __name__ == "__main__": - do_main() diff --git a/pdr_backend/predictoor/payout.py b/pdr_backend/predictoor/payout.py deleted file mode 100644 index 0cf1894f1..000000000 --- a/pdr_backend/predictoor/payout.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import time -from typing import Any, List -from enforce_typing import enforce_types -from pdr_backend.models.base_config import BaseConfig - -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.wrapped_token import WrappedToken -from pdr_backend.util.subgraph import query_pending_payouts, wait_until_subgraph_syncs -from pdr_backend.models.dfrewards import DFRewards - - -@enforce_types -def batchify(data: List[Any], batch_size: int): - return [data[i : i + batch_size] for i in range(0, len(data), batch_size)] - - -@enforce_types -def request_payout_batches( - predictoor_contract: PredictoorContract, batch_size: int, timestamps: List[int] -): - batches = batchify(timestamps, batch_size) - for batch in batches: - retries = 0 - success = False - - while retries < 5 and not success: - try: - predictoor_contract.payout_multiple(batch, True) - print(".", end="", flush=True) - success = True - except Exception as e: - retries += 1 - print(f"Error: {e}. Retrying... {retries}/5", flush=True) - time.sleep(1) - - if not success: - print("\nFailed after 5 attempts. Moving to next batch.", flush=True) - - print("\nBatch completed") - - -def do_payout(): - config = BaseConfig() - owner = config.web3_config.owner - BATCH_SIZE = int(os.getenv("BATCH_SIZE", "250")) - print("Starting payout") - wait_until_subgraph_syncs(config.web3_config, config.subgraph_url) - print("Finding pending payouts") - pending_payouts = query_pending_payouts(config.subgraph_url, owner) - total_timestamps = sum(len(timestamps) for timestamps in pending_payouts.values()) - print(f"Found {total_timestamps} slots") - - for contract_address in pending_payouts: - print(f"Claiming payouts for {contract_address}") - contract = PredictoorContract(config.web3_config, contract_address) - request_payout_batches(contract, BATCH_SIZE, pending_payouts[contract_address]) - - -def do_rose_payout(): - address = "0xc37F8341Ac6e4a94538302bCd4d49Cf0852D30C0" - wrapped_rose = "0x8Bc2B030b299964eEfb5e1e0b36991352E56D2D3" - config = BaseConfig() - owner = config.web3_config.owner - if config.web3_config.w3.eth.chain_id != 23294: - raise Exception("Unsupported network") - contract = DFRewards(config.web3_config, address) - claimable_rewards = contract.get_claimable_rewards(owner, wrapped_rose) - print(f"Found {claimable_rewards} wROSE available to claim") - - if claimable_rewards > 0: - print("Claiming wROSE rewards...") - contract.claim_rewards(owner, wrapped_rose) - else: - print("No rewards available to claim") - - print("Converting wROSE to ROSE") - time.sleep(10) - wrose = WrappedToken(config.web3_config, wrapped_rose) - wrose_balance = wrose.balanceOf(config.web3_config.owner) - if wrose_balance == 0: - print("wROSE balance is 0") - else: - print(f"Found {wrose_balance/1e18} wROSE, converting to ROSE...") - wrose.withdraw(wrose_balance) - - print("ROSE reward claim done") diff --git a/pdr_backend/predictoor/test/predictoor_agent_runner.py b/pdr_backend/predictoor/test/predictoor_agent_runner.py new file mode 100644 index 000000000..3ef6d150d --- /dev/null +++ b/pdr_backend/predictoor/test/predictoor_agent_runner.py @@ -0,0 +1,67 @@ +""" +This file exposes run_agent_test() +which is used by test_predictoor_agent{1,3}.py +""" +import os + +from enforce_typing import enforce_types + +from pdr_backend.ppss.ppss import mock_feed_ppss +from pdr_backend.ppss.web3_pp import ( + inplace_mock_query_feed_contracts, + inplace_mock_w3_and_contract_with_tracking, +) + +PRIV_KEY = os.getenv("PRIVATE_KEY") +OWNER_ADDR = "0xowner" +INIT_TIMESTAMP = 107 +INIT_BLOCK_NUMBER = 13 + + +@enforce_types +def run_agent_test(tmpdir: str, monkeypatch, predictoor_agent_class): + monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) + feed, ppss = mock_feed_ppss("5m", "binanceus", "BTC/USDT", tmpdir=tmpdir) + inplace_mock_query_feed_contracts(ppss.web3_pp, feed) + + _mock_pdr_contract = inplace_mock_w3_and_contract_with_tracking( + ppss.web3_pp, + INIT_TIMESTAMP, + INIT_BLOCK_NUMBER, + ppss.predictoor_ss.timeframe_s, + feed.address, + monkeypatch, + ) + + # now we're done the mocking, time for the real work!! + + # real work: initialize + agent = predictoor_agent_class(ppss) + + # real work: main iterations + for _ in range(500): + agent.take_step() + + # log some final results for debubbing / inspection + mock_w3 = ppss.web3_pp.web3_config.w3 + print("\n" + "/" * 160) + print("Done iterations") + print( + f"init block_number = {INIT_BLOCK_NUMBER}" + f", final = {mock_w3.eth.block_number}" + ) + print() + print(f"init timestamp = {INIT_TIMESTAMP}, final = {mock_w3.eth.timestamp}") + print(f"all timestamps seen = {mock_w3.eth._timestamps_seen}") + print() + print( + "unique prediction_slots = " + f"{sorted(set(_mock_pdr_contract._prediction_slots))}" + ) + print(f"all prediction_slots = {_mock_pdr_contract._prediction_slots}") + + # relatively basic sanity tests + assert _mock_pdr_contract._prediction_slots + assert (mock_w3.eth.timestamp + 2 * ppss.predictoor_ss.timeframe_s) >= max( + _mock_pdr_contract._prediction_slots + ) diff --git a/pdr_backend/predictoor/test/test_predictoor_agent.py b/pdr_backend/predictoor/test/test_predictoor_agent.py deleted file mode 100644 index 200a2ed5d..000000000 --- a/pdr_backend/predictoor/test/test_predictoor_agent.py +++ /dev/null @@ -1,189 +0,0 @@ -import os -import random -from typing import List -from unittest.mock import Mock - -from enforce_typing import enforce_types - -from pdr_backend.predictoor.approach1.predictoor_config1 import PredictoorConfig1 -from pdr_backend.predictoor.approach1.predictoor_agent1 import PredictoorAgent1 -from pdr_backend.util.constants import S_PER_MIN, S_PER_DAY - -PRIV_KEY = os.getenv("PRIVATE_KEY") - -ADDR = "0xe8933f2950aec1080efad1ca160a6bb641ad245d" - -SOURCE = "kraken" -PAIR = "BTC-USDT" -TIMEFRAME, S_PER_EPOCH = "5m", 5 * S_PER_MIN # must change both at once -SECONDS_TILL_EPOCH_END = 60 # how soon to start making predictions? -FEED_S = f"{PAIR}|{SOURCE}|{TIMEFRAME}" -S_PER_SUBSCRIPTION = 1 * S_PER_DAY -FEED_DICT = { # info inside a predictoor contract - "name": f"Feed of {FEED_S}", - "address": ADDR, - "symbol": f"FEED:{FEED_S}", - "seconds_per_epoch": S_PER_EPOCH, - "seconds_per_subscription": S_PER_SUBSCRIPTION, - "trueval_submit_timeout": 15, - "owner": "0xowner", - "pair": PAIR, - "timeframe": TIMEFRAME, - "source": SOURCE, -} -INIT_TIMESTAMP = 107 -INIT_BLOCK_NUMBER = 13 - - -@enforce_types -def toEpochStart(timestamp: int) -> int: - return timestamp // S_PER_EPOCH * S_PER_EPOCH - - -@enforce_types -class MockEth: - def __init__(self): - self.timestamp = INIT_TIMESTAMP - self.block_number = INIT_BLOCK_NUMBER - self._timestamps_seen: List[int] = [INIT_TIMESTAMP] - - def get_block( - self, block_number: int, full_transactions: bool = False - ): # pylint: disable=unused-argument - mock_block = {"timestamp": self.timestamp} - return mock_block - - -@enforce_types -class MockContract: - def __init__(self, w3): - self._w3 = w3 - self.contract_address: str = ADDR - self._prediction_slots: List[int] = [] - - def get_current_epoch(self) -> int: # returns an epoch number - return self.get_current_epoch_ts() // S_PER_EPOCH - - def get_current_epoch_ts(self) -> int: # returns a timestamp - curEpoch_ts = toEpochStart(self._w3.eth.timestamp) - return curEpoch_ts - - def get_secondsPerEpoch(self) -> int: - return S_PER_EPOCH - - def submit_prediction( - self, predval: bool, stake: float, timestamp: int, wait: bool = True - ): # pylint: disable=unused-argument - assert stake <= 3 - if timestamp in self._prediction_slots: - print(f" (Replace prev pred at time slot {timestamp})") - self._prediction_slots.append(timestamp) - - -@enforce_types -class MockStack: - def __init__(self, monkeypatch) -> None: - self.mock_w3 = Mock() # pylint: disable=not-callable - self.mock_w3.eth = MockEth() - self.mock_contract = MockContract(self.mock_w3) - - # mock query_feed_contracts() - def mock_query_feed_contracts( - *args, **kwargs - ): # pylint: disable=unused-argument - feed_dicts = {ADDR: FEED_DICT} - return feed_dicts - - def mock_contract_func(*args, **kwargs): # pylint: disable=unused-argument - return self.mock_contract - - monkeypatch.setattr( - "pdr_backend.models.base_config.query_feed_contracts", - mock_query_feed_contracts, - ) - - monkeypatch.setattr( - "pdr_backend.models.base_config.PredictoorContract", mock_contract_func - ) - - # mock time.sleep() - def advance_func(*args, **kwargs): # pylint: disable=unused-argument - do_advance_block = random.random() < 0.40 - if do_advance_block: - self.mock_w3.eth.timestamp += random.randint(3, 12) - self.mock_w3.eth.block_number += 1 - self.mock_w3.eth._timestamps_seen.append(self.mock_w3.eth.timestamp) - - monkeypatch.setattr("time.sleep", advance_func) - - def run_tests(self): - # Initialize the Agent - # Take steps with the Agent - # Log final results for debugging / inspection - pass - - -@enforce_types -def test_predictoor_base_agent(monkeypatch): - _setenvs(monkeypatch) - - @enforce_types - class MockBaseAgent(MockStack): - def run_tests(self): - # real work: initialize - c = PredictoorConfig1() - agent = PredictoorAgent1(c) - - # last bit of mocking - agent.config.web3_config.w3 = self.mock_w3 - - # real work: main iterations - for _ in range(1000): - agent.take_step() - - # log some final results for debubbing / inspection - print("\n" + "=" * 80) - print("Done iterations") - print( - f"init block_number = {INIT_BLOCK_NUMBER}" - f", final = {self.mock_w3.eth.block_number}" - ) - print() - print( - f"init timestamp = {INIT_TIMESTAMP}, final = {self.mock_w3.eth.timestamp}" - ) - print(f"all timestamps seen = {self.mock_w3.eth._timestamps_seen}") - print() - print( - "No unique prediction_slots = " - f"{sorted(set(self.mock_contract._prediction_slots))}" - ) - print(f"No prediction_slots = {self.mock_contract._prediction_slots}") - - agent = MockBaseAgent(monkeypatch) - agent.run_tests() - - # relatively basic sanity tests - assert agent.mock_contract._prediction_slots - print(agent.mock_contract) - print(agent.mock_w3.eth.timestamp) - - assert (agent.mock_w3.eth.timestamp + 2 * S_PER_EPOCH) >= max( - agent.mock_contract._prediction_slots - ) - - -def _setenvs(monkeypatch): - # envvars handled by PredictoorConfig1 - monkeypatch.setenv("SECONDS_TILL_EPOCH_END", "60") - monkeypatch.setenv("STAKE_AMOUNT", "30000") - - # envvars handled by BaseConfig - monkeypatch.setenv("RPC_URL", "http://foo") - monkeypatch.setenv("SUBGRAPH_URL", "http://bar") - monkeypatch.setenv("PRIVATE_KEY", PRIV_KEY) - - monkeypatch.setenv("PAIR_FILTER", PAIR.replace("-", "/")) - monkeypatch.setenv("TIMEFRAME_FILTER", TIMEFRAME) - monkeypatch.setenv("SOURCE_FILTER", SOURCE) - monkeypatch.setenv("OWNER_ADDRS", FEED_DICT["owner"]) diff --git a/pdr_backend/publisher/main.py b/pdr_backend/publisher/main.py deleted file mode 100644 index 174a5af14..000000000 --- a/pdr_backend/publisher/main.py +++ /dev/null @@ -1,149 +0,0 @@ -from pdr_backend.models.token import Token -from pdr_backend.publisher.publish import publish, fund_dev_accounts -from pdr_backend.util.contract import get_address -from pdr_backend.util.env import getenv_or_exit -from pdr_backend.util.web3_config import Web3Config - - -def main(): - rpc_url = getenv_or_exit("RPC_URL") - private_key = getenv_or_exit("PRIVATE_KEY") - - # pairs to deploy on testnet and mainnet - pair_list = ["BTC", "ETH", "BNB", "XRP", "ADA", "DOGE", "SOL", "LTC", "TRX", "DOT"] - - # token price - rate = 3 / (1 + 0.2 + 0.001) - - web3_config = Web3Config(rpc_url, private_key) - - if web3_config.w3.eth.chain_id == 8996: - print("Funding dev accounts and publishing pairs on local network...") - ocean_address = get_address(web3_config.w3.eth.chain_id, "Ocean") - OCEAN = Token(web3_config, ocean_address) - accounts_to_fund = [ - # account_key_env, OCEAN_to_send - ("PREDICTOOR_PRIVATE_KEY", 2000.0), - ("PREDICTOOR2_PRIVATE_KEY", 2000.0), - ("PREDICTOOR3_PRIVATE_KEY", 2000.0), - ("TRADER_PRIVATE_KEY", 2000.0), - ("DFBUYER_PRIVATE_KEY", 10000.0), - ("PDR_WEBSOCKET_KEY", 10000.0), - ("PDR_MM_USER", 10000.0), - ] - - fund_dev_accounts(accounts_to_fund, web3_config.owner, OCEAN) - - publish( - s_per_epoch=300, - s_per_subscription=60 * 60 * 24, - base="ETH", - quote="USDT", - source="binance", - timeframe="5m", - trueval_submitter_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", # on barge - feeCollector_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - - publish( - s_per_epoch=300, - s_per_subscription=60 * 60 * 24, - base="BTC", - quote="USDT", - source="binance", - timeframe="5m", - trueval_submitter_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - feeCollector_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - - publish( - s_per_epoch=300, - s_per_subscription=60 * 60 * 24, - base="XRP", - quote="USDT", - source="binance", - timeframe="5m", - trueval_submitter_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - feeCollector_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - print("Publish done") - - if web3_config.w3.eth.chain_id == 23295: - print("Publishing pairs on testnet") - helper_contract = get_address(web3_config.w3.eth.chain_id, "PredictoorHelper") - fee_collector = getenv_or_exit("FEE_COLLECTOR") - for pair in pair_list: - publish( - s_per_epoch=300, - s_per_subscription=60 * 60 * 24, - base=pair, - quote="USDT", - source="binance", - timeframe="5m", - trueval_submitter_addr=helper_contract, - feeCollector_addr=fee_collector, - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - publish( - s_per_epoch=3600, - s_per_subscription=60 * 60 * 24, - base=pair, - quote="USDT", - source="binance", - timeframe="1h", - trueval_submitter_addr=helper_contract, - feeCollector_addr=fee_collector, - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - print("Publish done") - - if web3_config.w3.eth.chain_id == 23294: - print("Publishing pairs on mainnet") - helper_contract = get_address(web3_config.w3.eth.chain_id, "PredictoorHelper") - fee_collector = getenv_or_exit("FEE_COLLECTOR") - for pair in pair_list: - publish( - s_per_epoch=300, - s_per_subscription=60 * 60 * 24, - base=pair, - quote="USDT", - source="binance", - timeframe="5m", - trueval_submitter_addr=helper_contract, - feeCollector_addr=fee_collector, - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - publish( - s_per_epoch=3600, - s_per_subscription=60 * 60 * 24, - base=pair, - quote="USDT", - source="binance", - timeframe="1h", - trueval_submitter_addr=helper_contract, - feeCollector_addr=fee_collector, - rate=rate, - cut=0.2, - web3_config=web3_config, - ) - print("Publish done") - - -if __name__ == "__main__": - print("Publisher start") - main() diff --git a/pdr_backend/publisher/publish.py b/pdr_backend/publisher/publish_asset.py similarity index 64% rename from pdr_backend/publisher/publish.py rename to pdr_backend/publisher/publish_asset.py index 7374f006e..79e0b9093 100644 --- a/pdr_backend/publisher/publish.py +++ b/pdr_backend/publisher/publish_asset.py @@ -1,34 +1,18 @@ -import os -from typing import List, Union +from typing import Union from enforce_typing import enforce_types -from eth_account import Account -from pdr_backend.models.data_nft import DataNft -from pdr_backend.models.erc721_factory import ERC721Factory -from pdr_backend.models.token import Token +from pdr_backend.contract.data_nft import DataNft +from pdr_backend.contract.erc721_factory import Erc721Factory +from pdr_backend.ppss.web3_pp import Web3PP from pdr_backend.util.contract import get_address +from pdr_backend.util.mathutil import to_wei MAX_UINT256 = 2**256 - 1 @enforce_types -def fund_dev_accounts(accounts_to_fund: List[tuple], owner: str, token: Token): - for private_key_name, amount in accounts_to_fund: - if private_key_name in os.environ: - private_key = os.getenv(private_key_name) - account = Account.from_key( # pylint: disable=no-value-for-parameter - private_key - ) - print( - f"Sending OCEAN to account defined by envvar {private_key_name}" - f", with address {account.address}" - ) - token.transfer(account.address, amount * 1e18, owner) - - -@enforce_types -def publish( +def publish_asset( s_per_epoch: int, s_per_subscription: int, base: str, @@ -39,20 +23,22 @@ def publish( feeCollector_addr: str, rate: Union[int, float], cut: Union[int, float], - web3_config, + web3_pp: Web3PP, ): + """Publish one specific asset to chain.""" + web3_config = web3_pp.web3_config pair = base + "/" + quote trueval_timeout = 60 * 60 * 24 * 3 owner = web3_config.owner - ocean_address = get_address(web3_config.w3.eth.chain_id, "Ocean") - fre_address = get_address(web3_config.w3.eth.chain_id, "FixedPrice") - factory = ERC721Factory(web3_config) + ocean_address = get_address(web3_pp, "Ocean") + fre_address = get_address(web3_pp, "FixedPrice") + factory = Erc721Factory(web3_pp) feeCollector = web3_config.w3.to_checksum_address(feeCollector_addr) trueval_submiter = web3_config.w3.to_checksum_address(trueval_submitter_addr) - rate_wei: int = web3_config.w3.to_wei(rate, "ether") - cut_wei: int = web3_config.w3.to_wei(cut, "ether") + rate_wei = to_wei(rate) + cut_wei = to_wei(cut) nft_name: str = base + "-" + quote + "-" + source + "-" + timeframe nft_symbol: str = pair @@ -85,7 +71,7 @@ def publish( data_nft_address: str = logs_nft["newTokenAddress"] print(f"Deployed NFT: {data_nft_address}") - data_nft = DataNft(web3_config, data_nft_address) + data_nft = DataNft(web3_pp, data_nft_address) tx = data_nft.set_data("pair", pair) print(f"Pair set to {pair} in {tx.hex()}") diff --git a/pdr_backend/publisher/publish_assets.py b/pdr_backend/publisher/publish_assets.py new file mode 100644 index 000000000..20f07a0bd --- /dev/null +++ b/pdr_backend/publisher/publish_assets.py @@ -0,0 +1,46 @@ +from enforce_typing import enforce_types + +from pdr_backend.ppss.publisher_ss import PublisherSS +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.publisher.publish_asset import publish_asset +from pdr_backend.util.contract import get_address + +_CUT = 0.2 +_RATE = 3 / (1 + _CUT + 0.001) # token price +_S_PER_SUBSCRIPTION = 60 * 60 * 24 + + +@enforce_types +def publish_assets(web3_pp: Web3PP, publisher_ss: PublisherSS): + """ + Publish assets, with opinions on % cut, token price, subscription length, + timeframe, and choices of feeds. + Meant to be used from CLI. + """ + print(f"Publish on network = {web3_pp.network}") + if web3_pp.network == "development" or "barge" in web3_pp.network: + trueval_submitter_addr = "0xe2DD09d719Da89e5a3D0F2549c7E24566e947260" + fee_collector_addr = "0xe2DD09d719Da89e5a3D0F2549c7E24566e947260" + elif "sapphire" in web3_pp.network: + trueval_submitter_addr = get_address(web3_pp, "PredictoorHelper") + fee_collector_addr = publisher_ss.fee_collector_address + else: + raise ValueError(web3_pp.network) + + for feed in publisher_ss.feeds: + publish_asset( + # timeframe is already asserted in PublisherSS + s_per_epoch=feed.timeframe.s, # type: ignore[union-attr] + s_per_subscription=_S_PER_SUBSCRIPTION, + base=feed.pair.base_str, + quote=feed.pair.quote_str, + source=feed.exchange, + timeframe=str(feed.timeframe), + trueval_submitter_addr=trueval_submitter_addr, + feeCollector_addr=fee_collector_addr, + rate=_RATE, + cut=_CUT, + web3_pp=web3_pp, + ) + + print("Done publishing.") diff --git a/pdr_backend/publisher/test/test_fund_dev_accounts.py b/pdr_backend/publisher/test/test_fund_dev_accounts.py deleted file mode 100644 index 1fce1b77a..000000000 --- a/pdr_backend/publisher/test/test_fund_dev_accounts.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -from unittest.mock import Mock, call - -from eth_account import Account - -from pdr_backend.models.token import Token -from pdr_backend.publisher.publish import fund_dev_accounts - - -def test_fund_dev_accounts(monkeypatch): - pk = os.getenv("PRIVATE_KEY") - monkeypatch.setenv("PREDICTOOR_PRIVATE_KEY", pk) - monkeypatch.setenv("PREDICTOOR2_PRIVATE_KEY", pk) - - mock_token = Mock(spec=Token) - mock_account = Mock(spec=str) - - accounts_to_fund = [ - ("PREDICTOOR_PRIVATE_KEY", 2000), - ("PREDICTOOR2_PRIVATE_KEY", 3000), - ] - - fund_dev_accounts(accounts_to_fund, mock_account, mock_token) - - a = Account.from_key(private_key=pk) # pylint: disable=no-value-for-parameter - mock_token.transfer.assert_has_calls( - [ - call(a.address, 2e21, mock_account), - call(a.address, 3e21, mock_account), - ] - ) diff --git a/pdr_backend/publisher/test/test_publish.py b/pdr_backend/publisher/test/test_publish_asset.py similarity index 55% rename from pdr_backend/publisher/test/test_publish.py rename to pdr_backend/publisher/test/test_publish_asset.py index 53d81ff53..8eb87b164 100644 --- a/pdr_backend/publisher/test/test_publish.py +++ b/pdr_backend/publisher/test/test_publish_asset.py @@ -1,45 +1,40 @@ -# comment out until more fleshed out -# from pdr_backend.publisher.publish import fund_dev_accounts, publish - - -import os - +from enforce_typing import enforce_types from pytest import approx -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.publisher.publish import publish + +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.publisher.publish_asset import publish_asset from pdr_backend.util.contract import get_address -from pdr_backend.util.web3_config import Web3Config -def test_publisher_publish(): - config = Web3Config(os.getenv("RPC_URL"), os.getenv("PRIVATE_KEY")) +@enforce_types +def test_publish_asset(web3_pp, web3_config): base = "ETH" quote = "USDT" source = "kraken" timeframe = "5m" seconds_per_epoch = 300 seconds_per_subscription = 60 * 60 * 24 - nft_data, _, _, _, logs_erc = publish( + nft_data, _, _, _, logs_erc = publish_asset( s_per_epoch=seconds_per_epoch, s_per_subscription=seconds_per_subscription, base=base, quote=quote, source=source, timeframe=timeframe, - trueval_submitter_addr=config.owner, - feeCollector_addr=config.owner, + trueval_submitter_addr=web3_config.owner, + feeCollector_addr=web3_config.owner, rate=3, cut=0.2, - web3_config=config, + web3_pp=web3_pp, ) nft_name = base + "-" + quote + "-" + source + "-" + timeframe nft_symbol = base + "/" + quote - assert nft_data == (nft_name, nft_symbol, 1, "", True, config.owner) + assert nft_data == (nft_name, nft_symbol, 1, "", True, web3_config.owner) dt_addr = logs_erc["newTokenAddress"] - assert config.w3.is_address(dt_addr) + assert web3_config.w3.is_address(dt_addr) - contract = PredictoorContract(config, dt_addr) + contract = PredictoorContract(web3_pp, dt_addr) assert contract.get_secondsPerEpoch() == seconds_per_epoch assert ( @@ -48,7 +43,7 @@ def test_publisher_publish(): ) assert contract.get_price() / 1e18 == approx(3 * (1.201)) - ocean_address = get_address(config.w3.eth.chain_id, "Ocean") + ocean_address = get_address(web3_pp, "Ocean") assert contract.get_stake_token() == ocean_address assert contract.get_trueValSubmitTimeout() == 3 * 24 * 60 * 60 diff --git a/pdr_backend/publisher/test/test_publish_assets.py b/pdr_backend/publisher/test/test_publish_assets.py new file mode 100644 index 000000000..55ee49f70 --- /dev/null +++ b/pdr_backend/publisher/test/test_publish_assets.py @@ -0,0 +1,74 @@ +from unittest.mock import Mock + +from enforce_typing import enforce_types + +from pdr_backend.ppss.publisher_ss import mock_publisher_ss +from pdr_backend.ppss.web3_pp import mock_web3_pp +from pdr_backend.publisher.publish_assets import publish_assets + +_PATH = "pdr_backend.publisher.publish_assets" + + +def test_publish_assets_development(monkeypatch): + _test_barge("development", monkeypatch) + + +def test_publish_assets_barge_pytest(monkeypatch): + _test_barge("barge-pytest", monkeypatch) + + +def test_publish_assets_barge_pdr_bot(monkeypatch): + _test_barge("barge-predictoor-bot", monkeypatch) + + +@enforce_types +def _test_barge(network, monkeypatch): + mock_publish_asset, web3_pp = _setup_and_publish(network, monkeypatch) + + n_calls = len(mock_publish_asset.call_args_list) + assert n_calls == 1 * 3 + + mock_publish_asset.assert_any_call( + s_per_epoch=300, + s_per_subscription=60 * 60 * 24, + base="ETH", + quote="USDT", + source="binance", + timeframe="5m", + trueval_submitter_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", + feeCollector_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", + rate=3 / (1 + 0.2 + 0.001), + cut=0.2, + web3_pp=web3_pp, + ) + + +def test_publish_assets_sapphire_testnet(monkeypatch): + _test_sapphire("sapphire-testnet", monkeypatch) + + +def test_publish_assets_sapphire_mainnet(monkeypatch): + _test_sapphire("sapphire-mainnet", monkeypatch) + + +@enforce_types +def _test_sapphire(network, monkeypatch): + mock_publish_asset, _ = _setup_and_publish(network, monkeypatch) + + n_calls = len(mock_publish_asset.call_args_list) + assert n_calls == 2 * 10 + + +def _setup_and_publish(network, monkeypatch): + web3_pp = mock_web3_pp(network) + publisher_ss = mock_publisher_ss(network) + + monkeypatch.setattr(f"{_PATH}.get_address", Mock()) + + mock_publish_asset = Mock() + monkeypatch.setattr(f"{_PATH}.publish_asset", mock_publish_asset) + + # main call + publish_assets(web3_pp, publisher_ss) + + return mock_publish_asset, web3_pp diff --git a/pdr_backend/publisher/test/test_publisher_main.py b/pdr_backend/publisher/test/test_publisher_main.py deleted file mode 100644 index cd323875c..000000000 --- a/pdr_backend/publisher/test/test_publisher_main.py +++ /dev/null @@ -1,88 +0,0 @@ -from unittest.mock import MagicMock, Mock, patch -import pytest -from pdr_backend.publisher.main import main -from pdr_backend.util.web3_config import Web3Config - - -@pytest.fixture -def mock_getenv_or_exit(): - with patch("pdr_backend.publisher.main.getenv_or_exit") as mock: - yield mock - - -@pytest.fixture -def mock_web3_config(): - with patch("pdr_backend.publisher.main.Web3Config") as mock: - mock_instance = MagicMock(spec=Web3Config) - mock_instance.w3 = Mock() - mock_instance.owner = "0x1" - mock.return_value = mock_instance - yield mock_instance - - -@pytest.fixture -def mock_token(): - with patch("pdr_backend.publisher.main.Token") as mock: - mock_instance = MagicMock() - mock.return_value = mock_instance - yield mock_instance - - -@pytest.fixture -def mock_get_address(): - with patch("pdr_backend.publisher.main.get_address") as mock: - mock_instance = MagicMock() - mock.return_value = mock_instance - yield mock - - -@pytest.fixture -def mock_fund_dev_accounts(): - with patch("pdr_backend.publisher.main.fund_dev_accounts") as mock: - yield mock - - -@pytest.fixture -def mock_publish(): - with patch("pdr_backend.publisher.main.publish") as mock: - yield mock - - -# pylint: disable=redefined-outer-name -def test_main( - mock_getenv_or_exit, - mock_web3_config, - mock_token, - mock_get_address, - mock_fund_dev_accounts, - mock_publish, -): - mock_getenv_or_exit.side_effect = [ - "mock_rpc_url", - "mock_private_key", - "mock_fee_collector", - ] - mock_web3_config.w3.eth.chain_id = 8996 - mock_get_address.return_value = "mock_ocean_address" - mock_token_instance = MagicMock() - mock_token.return_value = mock_token_instance - - main() - - mock_getenv_or_exit.assert_any_call("RPC_URL") - mock_getenv_or_exit.assert_any_call("PRIVATE_KEY") - mock_get_address.assert_called_once_with(8996, "Ocean") - mock_fund_dev_accounts.assert_called_once() - mock_publish.assert_any_call( - s_per_epoch=300, - s_per_subscription=60 * 60 * 24, - base="ETH", - quote="USDT", - source="binance", - timeframe="5m", - trueval_submitter_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - feeCollector_addr="0xe2DD09d719Da89e5a3D0F2549c7E24566e947260", - rate=3 / (1 + 0.2 + 0.001), - cut=0.2, - web3_config=mock_web3_config, - ) diff --git a/pdr_backend/simulation/trade_engine.py b/pdr_backend/sim/sim_engine.py similarity index 73% rename from pdr_backend/simulation/trade_engine.py rename to pdr_backend/sim/sim_engine.py index 21643e13f..7e919a77c 100644 --- a/pdr_backend/simulation/trade_engine.py +++ b/pdr_backend/sim/sim_engine.py @@ -1,22 +1,19 @@ +import copy import os from typing import List -from enforce_typing import enforce_types import matplotlib.pyplot as plt import numpy as np -import pandas as pd +import polars as pl +from enforce_typing import enforce_types from statsmodels.stats.proportion import proportion_confint -from pdr_backend.data_eng.data_factory import DataFactory -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.model_eng.model_factory import ModelFactory -from pdr_backend.model_eng.model_ss import ModelSS -from pdr_backend.simulation.sim_ss import SimSS -from pdr_backend.simulation.trade_ss import TradeSS -from pdr_backend.simulation.trade_pp import TradePP +from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory +from pdr_backend.aimodel.aimodel_factory import AimodelFactory +from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory +from pdr_backend.ppss.ppss import PPSS from pdr_backend.util.mathutil import nmse -from pdr_backend.util.timeutil import current_ut, pretty_timestr +from pdr_backend.util.timeutil import current_ut_ms, pretty_timestr FONTSIZE = 12 @@ -30,39 +27,23 @@ def __init__(self): # pylint: disable=too-many-instance-attributes -class TradeEngine: +class SimEngine: @enforce_types - def __init__( - self, - data_pp: DataPP, - data_ss: DataSS, - model_ss: ModelSS, - trade_pp: TradePP, - trade_ss: TradeSS, - sim_ss: SimSS, - ): - """ - @arguments - data_pp -- user-uncontrollable params, at data level - data_ss -- user-controllable params, at data level - model_ss -- user-controllable params, at model level - trade_pp -- user-uncontrollable params, at trading level - trade_ss -- user-controllable params, at trading level - sim_ss -- user-controllable params, at sim level - """ - # ensure training data has the target yval - assert data_pp.predict_feed_tup in data_ss.input_feed_tups + def __init__(self, ppss: PPSS): + # preconditions + predict_feed = ppss.predictoor_ss.feed + + # timeframe doesn't need to match + assert ( + str(predict_feed.exchange), + str(predict_feed.pair), + ) in ppss.predictoor_ss.aimodel_ss.exchange_pair_tups # pp & ss values - self.data_pp = data_pp - self.data_ss = data_ss - self.model_ss = model_ss - self.trade_pp = trade_pp - self.trade_ss = trade_ss - self.sim_ss = sim_ss + self.ppss = ppss # state - self.holdings = self.trade_pp.init_holdings + self.holdings = copy.copy(self.ppss.trader_ss.init_holdings) self.tot_profit_usd = 0.0 self.nmses_train: List[float] = [] self.ys_test: List[float] = [] @@ -71,28 +52,26 @@ def __init__( self.profit_usds: List[float] = [] self.tot_profit_usds: List[float] = [] - self.data_factory = DataFactory(self.data_pp, self.data_ss) - self.logfile = "" self.plot_state = None - if self.sim_ss.do_plot: + if self.ppss.sim_ss.do_plot: self.plot_state = PlotState() @property def tokcoin(self) -> str: """Return e.g. 'ETH'""" - return self.data_pp.base_str + return self.ppss.predictoor_ss.base_str @property def usdcoin(self) -> str: """Return e.g. 'USDT'""" - return self.data_pp.quote_str + return self.ppss.predictoor_ss.quote_str @enforce_types def _init_loop_attributes(self): - filebase = f"out_{current_ut()}.txt" - self.logfile = os.path.join(self.sim_ss.logpath, filebase) + filebase = f"out_{current_ut_ms()}.txt" + self.logfile = os.path.join(self.ppss.sim_ss.log_dir, filebase) with open(self.logfile, "w") as f: f.write("\n") @@ -105,11 +84,13 @@ def run(self): self._init_loop_attributes() log = self._log log("Start run") + # main loop! - hist_df = self.data_factory.get_hist_df() - for test_i in range(self.data_pp.N_test): - self.run_one_iter(test_i, hist_df) - self._plot(test_i, self.data_pp.N_test) + pq_data_factory = OhlcvDataFactory(self.ppss.lake_ss) + mergedohlcv_df: pl.DataFrame = pq_data_factory.get_mergedohlcv_df() + for test_i in range(self.ppss.sim_ss.test_n): + self.run_one_iter(test_i, mergedohlcv_df) + self._plot(test_i, self.ppss.sim_ss.test_n) log("Done all iters.") @@ -118,17 +99,18 @@ def run(self): log(f"Final nmse_train={nmse_train:.5f}, nmse_test={nmse_test:.5f}") @enforce_types - def run_one_iter(self, test_i: int, hist_df: pd.DataFrame): + def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame): log = self._log - testshift = self.data_pp.N_test - test_i - 1 # eg [99, 98, .., 2, 1, 0] - X, y, _ = self.data_factory.create_xy(hist_df, testshift) + testshift = self.ppss.sim_ss.test_n - test_i - 1 # eg [99, 98, .., 2, 1, 0] + model_data_factory = AimodelDataFactory(self.ppss.predictoor_ss) + X, y, _ = model_data_factory.create_xy(mergedohlcv_df, testshift) st, fin = 0, X.shape[0] - 1 X_train, X_test = X[st:fin, :], X[fin : fin + 1] y_train, y_test = y[st:fin], y[fin : fin + 1] - model_factory = ModelFactory(self.model_ss) - model = model_factory.build(X_train, y_train) + aimodel_factory = AimodelFactory(self.ppss.predictoor_ss.aimodel_ss) + model = aimodel_factory.build(X_train, y_train) y_trainhat = model.predict(X_train) # eg yhat=zhat[y-5] @@ -136,7 +118,8 @@ def run_one_iter(self, test_i: int, hist_df: pd.DataFrame): self.nmses_train.append(nmse_train) # current time - ut = int(hist_df.index.values[-1]) - testshift * self.data_pp.timeframe_ms + recent_ut = int(mergedohlcv_df["timestamp"].to_list()[-1]) + ut = recent_ut - testshift * self.ppss.predictoor_ss.timeframe_ms # current price curprice = y_train[-1] @@ -148,7 +131,7 @@ def run_one_iter(self, test_i: int, hist_df: pd.DataFrame): # simulate buy. Buy 'amt_usd' worth of TOK if we think price going up usdcoin_holdings_before = self.holdings[self.usdcoin] if self._do_buy(predprice, curprice): - self._buy(curprice, self.trade_ss.buy_amt_usd) + self._buy(curprice, self.ppss.trader_ss.buy_amt_usd) # observe true price trueprice = y_test[0] @@ -174,7 +157,7 @@ def run_one_iter(self, test_i: int, hist_df: pd.DataFrame): self.corrects.append(correct) acc = float(sum(self.corrects)) / len(self.corrects) * 100 log( - f"Iter #{test_i+1:3}/{self.data_pp.N_test}: " + f"Iter #{test_i+1:3}/{self.ppss.sim_ss.test_n}: " f" ut{pretty_timestr(ut)[9:][:-9]}" # f". Predval|true|err {predprice:.2f}|{trueprice:.2f}|{err:6.2f}" f". Preddir|true|correct = {pred_dir}|{true_dir}|{correct_s}" @@ -209,7 +192,7 @@ def _buy(self, price: float, usdcoin_amt_spend: float): usdcoin_amt_sent = min(usdcoin_amt_spend, self.holdings[self.usdcoin]) self.holdings[self.usdcoin] -= usdcoin_amt_sent - p = self.trade_pp.fee_percent + p = self.ppss.trader_ss.fee_percent usdcoin_amt_fee = p * usdcoin_amt_sent tokcoin_amt_recd = (1 - p) * usdcoin_amt_sent / price self.holdings[self.tokcoin] += tokcoin_amt_recd @@ -232,7 +215,7 @@ def _sell(self, price: float, tokcoin_amt_sell: float): tokcoin_amt_sent = tokcoin_amt_sell self.holdings[self.tokcoin] -= tokcoin_amt_sent - p = self.trade_pp.fee_percent + p = self.ppss.trader_ss.fee_percent usdcoin_amt_fee = p * tokcoin_amt_sent * price usdcoin_amt_recd = (1 - p) * tokcoin_amt_sent * price self.holdings[self.usdcoin] += usdcoin_amt_recd @@ -245,7 +228,7 @@ def _sell(self, price: float, tokcoin_amt_sell: float): @enforce_types def _plot(self, i, N): - if not self.sim_ss.do_plot: + if not self.ppss.sim_ss.do_plot: return # don't plot first 5 iters -> not interesting @@ -261,7 +244,11 @@ def _plot(self, i, N): N = len(y0) x = list(range(0, N)) ax0.plot(x, y0, "g-") - ax0.set_title("Trading profit vs time", fontsize=FONTSIZE, fontweight="bold") + ax0.set_title( + f"Trading profit vs time. Current: ${y0[-1]:.2f}", + fontsize=FONTSIZE, + fontweight="bold", + ) ax0.set_xlabel("time", fontsize=FONTSIZE) ax0.set_ylabel("trading profit (USD)", fontsize=FONTSIZE) @@ -279,7 +266,9 @@ def _plot(self, i, N): ax1.fill_between(x, y1_l, y1_u, color="b", alpha=0.15) now_s = f"{y1_est[-1]:.2f}% [{y1_l[-1]:.2f}%, {y1_u[-1]:.2f}%]" ax1.set_title( - f"% correct vs time. {now_s}", fontsize=FONTSIZE, fontweight="bold" + f"% correct vs time. Current: {now_s}", + fontsize=FONTSIZE, + fontweight="bold", ) ax1.set_xlabel("time", fontsize=FONTSIZE) ax1.set_ylabel("% correct", fontsize=FONTSIZE) diff --git a/pdr_backend/data_eng/test/conftest.py b/pdr_backend/sim/test/conftest.py similarity index 100% rename from pdr_backend/data_eng/test/conftest.py rename to pdr_backend/sim/test/conftest.py diff --git a/pdr_backend/sim/test/test_sim_engine.py b/pdr_backend/sim/test/test_sim_engine.py new file mode 100644 index 000000000..ba1a0873a --- /dev/null +++ b/pdr_backend/sim/test/test_sim_engine.py @@ -0,0 +1,52 @@ +import os +from unittest import mock + +from enforce_typing import enforce_types + +from pdr_backend.ppss.lake_ss import LakeSS +from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str +from pdr_backend.ppss.predictoor_ss import PredictoorSS +from pdr_backend.ppss.sim_ss import SimSS +from pdr_backend.sim.sim_engine import SimEngine + + +@enforce_types +def test_sim_engine(tmpdir): + s = fast_test_yaml_str(tmpdir) + ppss = PPSS(yaml_str=s, network="development") + + ppss.predictoor_ss = PredictoorSS( + { + "predict_feed": "binanceus BTC/USDT c 5m", + "bot_only": {"s_until_epoch_end": 60, "stake_amount": 1}, + "aimodel_ss": { + "input_feeds": ["binanceus BTC/USDT ETH/USDT oc"], + "max_n_train": 100, + "autoregressive_n": 2, + "approach": "LIN", + }, + } + ) + + ppss.lake_ss = LakeSS( + { + "feeds": ["binanceus BTC/USDT ETH/USDT oc 5m"], + "parquet_dir": os.path.join(tmpdir, "parquet_data"), + "st_timestr": "2023-06-18", + "fin_timestr": "2023-06-30", + "timeframe": "5m", + } + ) + + assert hasattr(ppss, "sim_ss") + ppss.sim_ss = SimSS( + { + "do_plot": True, + "log_dir": os.path.join(tmpdir, "logs"), + "test_n": 10, + } + ) + + with mock.patch("pdr_backend.sim.sim_engine.plt.show"): + sim_engine = SimEngine(ppss) + sim_engine.run() diff --git a/pdr_backend/simulation/runtrade.py b/pdr_backend/simulation/runtrade.py deleted file mode 100755 index 6963a91a0..000000000 --- a/pdr_backend/simulation/runtrade.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python -import os - -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.model_eng.model_ss import ModelSS -from pdr_backend.simulation.sim_ss import SimSS -from pdr_backend.simulation.trade_engine import TradeEngine -from pdr_backend.simulation.trade_pp import TradePP -from pdr_backend.simulation.trade_ss import TradeSS - -# To play with simulation, simply change any of the arguments to any -# of the constructors below. -# -# - It does *not* use envvars PAIR_FILTER, TIMEFRAME_FILTER, or SOURCE_FILTER. -# Why: to avoid ambiguity. Eg is PAIR_FILTER for yval_coin, or input data? - -data_pp = DataPP( # user-uncontrollable params, at data-eng level - "1h", # "5m" or "1h" - "binance c BTC/USDT", # c = "close" - N_test=200, # 50000 . num points to test on, 1 at a time (online) -) - -data_ss = DataSS( # user-controllable params, at data-eng level - ["binance c BTC/USDT,ETH/USDT"], - csv_dir=os.path.abspath("csvs"), # eg "csvs". abs or rel loc'n of csvs dir - st_timestr="2022-06-30", # eg "2019-09-13_04:00" (earliest), "2019-09-13" - fin_timestr="now", # eg "now", "2023-09-23_17:55", "2023-09-23" - max_n_train=5000, # eg 50000. # if inf, only limited by data available - autoregressive_n=20, # eg 10. model inputs past pts z[t-1], .., z[t-ar_n] -) - -model_ss = ModelSS( # user-controllable params, at model-eng level - "LIN" # eg "LIN", "GPR", "SVR", "NuSVR", or "LinearSVR" -) - -trade_pp = TradePP( # user-uncontrollable params, at trading level - fee_percent=0.00, # Eg 0.001 is 0.1%. Trading fee (simulated) - init_holdings={data_pp.base_str: 0.0, data_pp.quote_str: 100000.0}, -) - -trade_ss = TradeSS( # user-controllable params, at trading level - buy_amt_usd=100000.00, # How much to buy at a time. In USD -) - -sim_ss = SimSS( # user-controllable params, at sim level - do_plot=True, # plot at end? - logpath=os.path.abspath("./"), # where to save logs to -) - -# ================================================================== -# print setup -print(f"data_pp={data_pp}") -print(f"data_ss={data_ss}") -print(f"model_ss={model_ss}") -print(f"trade_pp={trade_pp}") -print(f"trade_ss={trade_ss}") -print(f"sim_ss={sim_ss}") - -# ================================================================== -# do work -trade_engine = TradeEngine(data_pp, data_ss, model_ss, trade_pp, trade_ss, sim_ss) - -trade_engine.run() diff --git a/pdr_backend/simulation/sim_ss.py b/pdr_backend/simulation/sim_ss.py deleted file mode 100644 index 91ffaa293..000000000 --- a/pdr_backend/simulation/sim_ss.py +++ /dev/null @@ -1,16 +0,0 @@ -import os - -from enforce_typing import enforce_types - -from pdr_backend.util.strutil import StrMixin - - -@enforce_types -class SimSS(StrMixin): - """User-controllable strategy params related to the simulation itself""" - - def __init__(self, do_plot: bool, logpath: str): - assert os.path.exists(logpath) - - self.do_plot = do_plot - self.logpath = logpath # directory, not file diff --git a/pdr_backend/simulation/test/conftest.py b/pdr_backend/simulation/test/conftest.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pdr_backend/simulation/test/test_sim_ss.py b/pdr_backend/simulation/test/test_sim_ss.py deleted file mode 100644 index cbeda1a3e..000000000 --- a/pdr_backend/simulation/test/test_sim_ss.py +++ /dev/null @@ -1,14 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.simulation.sim_ss import SimSS - - -@enforce_types -def test_sim_ss(tmpdir): - ss = SimSS( - do_plot=False, - logpath=str(tmpdir), - ) - assert not ss.do_plot - assert ss.logpath == str(tmpdir) - assert "SimSS" in str(ss) diff --git a/pdr_backend/simulation/test/test_trade_engine.py b/pdr_backend/simulation/test/test_trade_engine.py deleted file mode 100644 index 0e8779184..000000000 --- a/pdr_backend/simulation/test/test_trade_engine.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -from enforce_typing import enforce_types - -from pdr_backend.data_eng.data_pp import DataPP -from pdr_backend.data_eng.data_ss import DataSS -from pdr_backend.model_eng.model_ss import ModelSS -from pdr_backend.simulation.trade_engine import TradeEngine -from pdr_backend.simulation.sim_ss import SimSS -from pdr_backend.simulation.trade_pp import TradePP -from pdr_backend.simulation.trade_ss import TradeSS - - -@enforce_types -def test_TradeEngine(tmpdir): - logpath = str(tmpdir) - - data_pp = DataPP( # user-uncontrollable params, at data level - "5m", - "binanceus c BTC/USDT", - N_test=100, - ) - - data_ss = DataSS( # user-controllable params, at data level - ["binanceus oc ETH/USDT,BTC/USDT"], - csv_dir=os.path.abspath("csvs"), # use the usual data (worksforme) - st_timestr="2023-06-22", - fin_timestr="2023-06-24", - max_n_train=500, - autoregressive_n=2, - ) - - model_ss = ModelSS( # user-controllable params, at model level - "LIN", - ) - - trade_pp = TradePP( # user-uncontrollable params, at trading level - fee_percent=0.0, # Eg 0.001 is 0.1%. Trading fee (simulated) - init_holdings={"USDT": 100000.0, "BTC": 0.0}, - ) - - trade_ss = TradeSS( # user-controllable params, at trading level - buy_amt_usd=100000.00, # How much to buy at a time. In USD - ) - - sim_ss = SimSS( # user-controllable params, at sim level - do_plot=False, # plot at end? - logpath=logpath, # where to save logs to - ) - - # ================================================================== - # print setup - print(f"data_pp={data_pp}") - print(f"data_ss={data_ss}") - print(f"model_ss={model_ss}") - print(f"trade_pp={trade_pp}") - print(f"trade_ss={trade_ss}") - print(f"sim_ss={sim_ss}") - - # ================================================================== - # do work - trade_engine = TradeEngine(data_pp, data_ss, model_ss, trade_pp, trade_ss, sim_ss) - trade_engine.run() diff --git a/pdr_backend/simulation/test/test_trade_pp.py b/pdr_backend/simulation/test/test_trade_pp.py deleted file mode 100644 index 77b72862a..000000000 --- a/pdr_backend/simulation/test/test_trade_pp.py +++ /dev/null @@ -1,15 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.simulation.trade_pp import TradePP - - -@enforce_types -def test_trade_pp(): - pp = TradePP( - fee_percent=0.01, - init_holdings={"USDT": 10000.0, "BTC": 0.0}, - ) - assert pp.fee_percent == 0.01 - assert pp.init_holdings["USDT"] == 10000.0 - assert "TradePP" in str(pp) - assert "fee_percent" in str(pp) diff --git a/pdr_backend/simulation/test/test_trade_ss.py b/pdr_backend/simulation/test/test_trade_ss.py deleted file mode 100644 index 17c5ed336..000000000 --- a/pdr_backend/simulation/test/test_trade_ss.py +++ /dev/null @@ -1,10 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.simulation.trade_ss import TradeSS - - -@enforce_types -def test_trade_ss(): - ss = TradeSS(buy_amt_usd=100.0) - assert ss.buy_amt_usd == 100.0 - assert "TradeSS" in str(ss) diff --git a/pdr_backend/simulation/trade_pp.py b/pdr_backend/simulation/trade_pp.py deleted file mode 100644 index d4800da3a..000000000 --- a/pdr_backend/simulation/trade_pp.py +++ /dev/null @@ -1,16 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.util.strutil import StrMixin - - -@enforce_types -class TradePP(StrMixin): - """User-uncontrollable parameters, at trading level""" - - def __init__( - self, - fee_percent: float, # Eg 0.001 is 0.1%. Trading fee (simulated) - init_holdings: dict, # Eg {"USDT": 100000.00} - ): - self.fee_percent = fee_percent - self.init_holdings = init_holdings diff --git a/pdr_backend/simulation/trade_ss.py b/pdr_backend/simulation/trade_ss.py deleted file mode 100644 index a098de9e2..000000000 --- a/pdr_backend/simulation/trade_ss.py +++ /dev/null @@ -1,11 +0,0 @@ -from enforce_typing import enforce_types - -from pdr_backend.util.strutil import StrMixin - - -@enforce_types -class TradeSS(StrMixin): - """User-controllable parameters, at trading level""" - - def __init__(self, buy_amt_usd: float): - self.buy_amt_usd = buy_amt_usd diff --git a/pdr_backend/subgraph/core_subgraph.py b/pdr_backend/subgraph/core_subgraph.py new file mode 100644 index 000000000..1ead768c3 --- /dev/null +++ b/pdr_backend/subgraph/core_subgraph.py @@ -0,0 +1,35 @@ +import time +from typing import Dict + +import requests +from enforce_typing import enforce_types + +from pdr_backend.util.constants import SUBGRAPH_MAX_TRIES + + +@enforce_types +def query_subgraph( + subgraph_url: str, query: str, tries: int = 3, timeout: float = 30.0 +) -> Dict[str, dict]: + """ + @arguments + subgraph_url -- e.g. http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph # pylint: disable=line-too-long + query -- e.g. in docstring above + + @return + result -- e.g. {"data" : {"predictContracts": ..}} + """ + response = requests.post(subgraph_url, "", json={"query": query}, timeout=timeout) + if response.status_code != 200: + # pylint: disable=broad-exception-raised + if tries < SUBGRAPH_MAX_TRIES: + time.sleep(((tries + 1) / 2) ** (2) * 10) + return query_subgraph(subgraph_url, query, tries + 1) + + raise Exception( + f"Query failed. Url: {subgraph_url}. Return code is {response.status_code}\n{query}" + ) + + result = response.json() + + return result diff --git a/pdr_backend/subgraph/info725.py b/pdr_backend/subgraph/info725.py new file mode 100644 index 000000000..eb10f23cd --- /dev/null +++ b/pdr_backend/subgraph/info725.py @@ -0,0 +1,100 @@ +from typing import Dict, Optional, Union + +from enforce_typing import enforce_types +from web3 import Web3 + + +@enforce_types +def key_to_key725(key: str): + key725 = Web3.keccak(key.encode("utf-8")).hex() + return key725 + + +@enforce_types +def value_to_value725(value: Union[str, None]): + if value is None: + value725 = None + else: + value725 = Web3.to_hex(text=value) + return value725 + + +@enforce_types +def value725_to_value(value725) -> Union[str, None]: + if value725 is None: + value = None + else: + value = Web3.to_text(hexstr=value725) + return value + + +@enforce_types +def info_to_info725(info: Dict[str, Union[str, None]]) -> list: + """ + @arguments + info -- eg { + "pair": "ETH/USDT", + "timeframe": "5m", + "source": None, + "extra1" : "extra1_value", + "extra2" : None, + } + where info may/may not have keys for "pair", "timeframe", source" + and may have extra keys + + @return + info725 -- eg [ + {"key":encoded("pair"), "value":encoded("ETH/USDT")}, + {"key":encoded("timeframe"), "value":encoded("5m") }, + ... + ] + Where info725 may or may not have each of these keys: + "pair", "timeframe", "source" + """ + keys = ["pair", "timeframe", "source"] + info_keys = list(info.keys()) + for info_key in info_keys: + if info_key not in keys: + keys.append(info_key) + + info725 = [] + for key in keys: + if key in info_keys: + value = info[key] + else: + value = None + key725 = key_to_key725(key) + value725 = value_to_value725(value) + info725.append({"key": key725, "value": value725}) + + return info725 + + +@enforce_types +def info725_to_info(info725: list) -> Dict[str, Optional[str]]: + """ + @arguments + info725 -- eg [{"key":encoded("pair"), "value":encoded("ETH/USDT")}, + {"key":encoded("timeframe"), "value":encoded("5m") }, + ... ] + where info725 may/may not have keys for "pair", "timeframe", source" + and may have extra keys + + @return + info -- e.g. {"pair": "ETH/USDT", + "timeframe": "5m", + "source": None} + where info always has keys "pair", "timeframe", "source" + """ + info: Dict[str, Optional[str]] = {} + target_keys = ["pair", "timeframe", "source"] + for key in target_keys: + info[key] = None + for item725 in info725: + key725, value725 = item725["key"], item725["value"] + if key725 == key_to_key725(key): + value = value725_to_value(value725) + info[key] = value + break + + return info diff --git a/pdr_backend/subgraph/prediction.py b/pdr_backend/subgraph/prediction.py new file mode 100644 index 000000000..ed1cccf24 --- /dev/null +++ b/pdr_backend/subgraph/prediction.py @@ -0,0 +1,268 @@ +from typing import List, Union + +from enforce_typing import enforce_types + + +@enforce_types +class Prediction: + # pylint: disable=too-many-instance-attributes + def __init__( + self, + ID: str, + pair: str, + timeframe: str, + prediction: Union[bool, None], # prediction = subgraph.predicted_value + stake: Union[float, None], + trueval: Union[bool, None], + timestamp: int, # timestamp == prediction submitted timestamp + source: str, + payout: Union[float, None], + slot: int, # slot/epoch timestamp + user: str, + ) -> None: + self.ID = ID + self.pair = pair + self.timeframe = timeframe + self.prediction = prediction + self.stake = stake + self.trueval = trueval + self.timestamp = timestamp + self.source = source + self.payout = payout + self.slot = slot + self.user = user + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_prediction(prediction_tuple: tuple) -> Prediction: + ( + pair_str, + timeframe_str, + prediction, + stake, + trueval, + timestamp, + source, + payout, + slot, + user, + ) = prediction_tuple + + ID = f"{pair_str}-{timeframe_str}-{slot}-{user}" + return Prediction( + ID=ID, + pair=pair_str, + timeframe=timeframe_str, + prediction=prediction, + stake=stake, + trueval=trueval, + timestamp=timestamp, + source=source, + payout=payout, + slot=slot, + user=user, + ) + + +@enforce_types +def mock_first_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) for prediction_tuple in _FIRST_PREDICTION_TUPS + ] + + +@enforce_types +def mock_second_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) + for prediction_tuple in _SECOND_PREDICTION_TUPS + ] + + +@enforce_types +def mock_daily_predictions() -> List[Prediction]: + return [ + mock_prediction(prediction_tuple) for prediction_tuple in _DAILY_PREDICTION_TUPS + ] + + +_FIRST_PREDICTION_TUPS = [ + ( + "ADA/USDT", + "5m", + True, + 0.0500, + False, + 1701503000, + "binance", + 0.0, + 1701503100, + "0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BTC/USDT", + "5m", + True, + 0.0500, + True, + 1701589400, + "binance", + 0.0, + 1701589500, + "0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd", + ), +] + +_SECOND_PREDICTION_TUPS = [ + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1701675800, + "binance", + 0.0500, + 1701675900, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BTC/USDT", + "1h", + True, + 0.0500, + False, + 1701503100, + "binance", + 0.0, + 1701503000, + "0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ADA/USDT", + "5m", + True, + 0.0500, + True, + 1701589400, + "binance", + 0.0500, + 1701589500, + "0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BNB/USDT", + "1h", + True, + 0.0500, + True, + 1701675800, + "kraken", + 0.0500, + 1701675900, + "0xbbbb4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "1h", + True, + None, + False, + 1701589400, + "binance", + 0.0, + 1701589500, + "0xcccc4cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1701675800, + "binance", + 0.0500, + 1701675900, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), +] + +_DAILY_PREDICTION_TUPS = [ + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1698865200, # Nov 01 2023 19:00:00 GMT + "binance", + 0.0500, + 1698865200, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BTC/USDT", + "1h", + True, + 0.0500, + False, + 1698951600, # Nov 02 2023 19:00:00 GMT + "binance", + 0.0, + 1698951600, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ADA/USDT", + "5m", + True, + 0.0500, + True, + 1699038000, # Nov 03 2023 19:00:00 GMT + "binance", + 0.0500, + 1699038000, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "BNB/USDT", + "1h", + True, + 0.0500, + True, + 1699124400, # Nov 04 2023 19:00:00 GMT + "kraken", + 0.0500, + 1699124400, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "1h", + True, + None, + False, + 1699214400, # Nov 05 2023 19:00:00 GMT + "binance", + 0.0, + 1701589500, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), + ( + "ETH/USDT", + "5m", + True, + 0.0500, + True, + 1699300800, # Nov 06 2023 19:00:00 GMT + "binance", + 0.0500, + 1699300800, + "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", + ), +] diff --git a/pdr_backend/subgraph/subgraph_consume_so_far.py b/pdr_backend/subgraph/subgraph_consume_so_far.py new file mode 100644 index 000000000..090e86941 --- /dev/null +++ b/pdr_backend/subgraph/subgraph_consume_so_far.py @@ -0,0 +1,81 @@ +from collections import defaultdict +from typing import Dict, List + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.core_subgraph import query_subgraph + + +@enforce_types +def get_consume_so_far_per_contract( + subgraph_url: str, + user_address: str, + since_timestamp: int, + contract_addresses: List[str], +) -> Dict[str, float]: + chunk_size = 1000 # max for subgraph = 1000 + offset = 0 + consume_so_far: Dict[str, float] = defaultdict(float) + print("Getting consume so far...") + while True: # pylint: disable=too-many-nested-blocks + query = """ + { + predictContracts(first:1000, where: {id_in: %s}){ + id + token{ + id + name + symbol + nft { + owner { + id + } + nftData { + key + value + } + } + orders(where: {createdTimestamp_gt:%s, consumer_in:["%s"]}, first: %s, skip: %s){ + createdTimestamp + consumer { + id + } + lastPriceValue + } + } + secondsPerEpoch + secondsPerSubscription + truevalSubmitTimeout + } + } + """ % ( + str(contract_addresses).replace("'", '"'), + since_timestamp, + user_address.lower(), + chunk_size, + offset, + ) + offset += chunk_size + result = query_subgraph(subgraph_url, query, 3, 30.0) + if "data" not in result or "predictContracts" not in result["data"]: + break + contracts = result["data"]["predictContracts"] + if contracts == []: + break + no_of_zeroes = 0 + for contract in contracts: + contract_address = contract["id"] + if contract_address not in contract_addresses: + no_of_zeroes += 1 + continue + order_count = len(contract["token"]["orders"]) + if order_count == 0: + no_of_zeroes += 1 + for buy in contract["token"]["orders"]: + # 1.2 20% fee + # 0.001 0.01% community swap fee + consume_amt = float(buy["lastPriceValue"]) * 1.201 + consume_so_far[contract_address] += consume_amt + if no_of_zeroes == len(contracts): + break + return consume_so_far diff --git a/pdr_backend/dfbuyer/subgraph.py b/pdr_backend/subgraph/subgraph_dfbuyer.py similarity index 96% rename from pdr_backend/dfbuyer/subgraph.py rename to pdr_backend/subgraph/subgraph_dfbuyer.py index 1d1cde1b5..4e3672a3c 100644 --- a/pdr_backend/dfbuyer/subgraph.py +++ b/pdr_backend/subgraph/subgraph_dfbuyer.py @@ -1,4 +1,4 @@ -from pdr_backend.util.subgraph import query_subgraph +from pdr_backend.subgraph.core_subgraph import query_subgraph def get_consume_so_far( diff --git a/pdr_backend/subgraph/subgraph_feed.py b/pdr_backend/subgraph/subgraph_feed.py new file mode 100644 index 000000000..96c884d18 --- /dev/null +++ b/pdr_backend/subgraph/subgraph_feed.py @@ -0,0 +1,93 @@ +import random +from typing import Dict, Optional + +from enforce_typing import enforce_types + +from pdr_backend.cli.arg_pair import ArgPair +from pdr_backend.cli.timeframe import Timeframe +from pdr_backend.util.strutil import StrMixin + + +class SubgraphFeed(StrMixin): # pylint: disable=too-many-instance-attributes + @enforce_types + def __init__( + self, + name: str, # eg 'Feed: binance | BTC/USDT | 5m' + address: str, # eg '0x123' + symbol: str, # eg 'binance-BTC/USDT-5m' + seconds_per_subscription: int, # eg 60 * 60 * 24 + trueval_submit_timeout: int, # eg 60 + owner: str, # eg '0x456' + pair: str, # eg 'BTC/USDT' + timeframe: str, # eg '5m' + source: str, # eg 'binance' + ): + self.name: str = name + self.address: str = address + self.symbol: str = symbol + self.seconds_per_subscription: int = seconds_per_subscription + self.trueval_submit_timeout: int = trueval_submit_timeout + self.owner: str = owner + self.pair: str = pair.replace("-", "/") + self.timeframe: str = timeframe + self.source: str = source + + @property + def seconds_per_epoch(self): + return Timeframe(self.timeframe).s + + @property + def base(self): + return ArgPair(self.pair).base_str + + @property + def quote(self): + return ArgPair(self.pair).quote_str + + @enforce_types + def shortstr(self) -> str: + return f"Feed: {self.timeframe} {self.source} {self.pair} {self.address}" + + @enforce_types + def __str__(self) -> str: + return self.shortstr() + + +@enforce_types +def print_feeds(feeds: Dict[str, SubgraphFeed], label: Optional[str] = None): + label = label or "feeds" + print(f"{len(feeds)} {label}:") + if not feeds: + print(" ") + return + for feed in feeds.values(): + print(f" {feed}") + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def _rnd_eth_addr() -> str: + """Generate a random address with Ethereum format.""" + addr = "0x" + "".join([str(random.randint(0, 9)) for i in range(40)]) + return addr + + +@enforce_types +def mock_feed(timeframe_str: str, exchange_str: str, pair_str: str) -> SubgraphFeed: + addr = _rnd_eth_addr() + name = f"Feed {addr} {pair_str}|{exchange_str}|{timeframe_str}" + feed = SubgraphFeed( + name=name, + address=addr, + symbol=f"SYM: {addr}", + seconds_per_subscription=86400, + trueval_submit_timeout=60, + owner="0xowner", + pair=pair_str, + timeframe=timeframe_str, + source=exchange_str, + ) + return feed diff --git a/pdr_backend/subgraph/subgraph_feed_contracts.py b/pdr_backend/subgraph/subgraph_feed_contracts.py new file mode 100644 index 000000000..2a0af348f --- /dev/null +++ b/pdr_backend/subgraph/subgraph_feed_contracts.py @@ -0,0 +1,121 @@ +from typing import Dict, Optional + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.subgraph.info725 import info725_to_info +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + +_N_ERRORS = {} # exception_str : num_occurrences +_N_THR = 3 + + +@enforce_types +def query_feed_contracts( + subgraph_url: str, + owners_string: Optional[str] = None, +) -> Dict[str, SubgraphFeed]: + """ + @description + Query the chain for prediction feed contracts. + + @arguments + subgraph_url -- e.g. + owners -- E.g. filter to "0x123,0x124". If None or "", allow all + + @return + feeds -- dict of [feed_addr] : SubgraphFeed + """ + owners = None + if owners_string: + owners = owners_string.lower().split(",") + + chunk_size = 1000 # max for subgraph = 1000 + offset = 0 + feeds: Dict[str, SubgraphFeed] = {} + + while True: + query = """ + { + predictContracts(skip:%s, first:%s){ + id + token { + id + name + symbol + nft { + owner { + id + } + nftData { + key + value + } + } + } + secondsPerEpoch + secondsPerSubscription + truevalSubmitTimeout + } + } + """ % ( + offset, + chunk_size, + ) + offset += chunk_size + try: + result = query_subgraph(subgraph_url, query) + contract_list = result["data"]["predictContracts"] + if not contract_list: + break + for contract in contract_list: + info725 = contract["token"]["nft"]["nftData"] + info = info725_to_info(info725) # {"pair": "ETH/USDT", } + + pair = info["pair"] + timeframe = info["timeframe"] + source = info["source"] + if None in (pair, timeframe, source): + continue + + # filter out unwanted + owner_id = contract["token"]["nft"]["owner"]["id"] + if owners and (owner_id not in owners): + continue + + # ok, add this one + feed = SubgraphFeed( + name=contract["token"]["name"], + address=contract["id"], + symbol=contract["token"]["symbol"], + seconds_per_subscription=int(contract["secondsPerSubscription"]), + trueval_submit_timeout=int(contract["truevalSubmitTimeout"]), + owner=owner_id, + pair=pair, + timeframe=timeframe, + source=source, + ) + feeds[feed.address] = feed + + except Exception as e: + e_str = str(e) + e_key = e_str + if "Connection object" in e_str: + i = e_str.find("Connection object") + len("Connection object") + e_key = e_key[:i] + + if e_key not in _N_ERRORS: + _N_ERRORS[e_key] = 0 + _N_ERRORS[e_key] += 1 + + if _N_ERRORS[e_key] <= _N_THR: + print(e_str) + if _N_ERRORS[e_key] == _N_THR: + print("Future errors like this will be hidden") + return {} + + # postconditions + for feed in feeds.values(): + assert isinstance(feed, SubgraphFeed) + + return feeds diff --git a/pdr_backend/subgraph/subgraph_pending_payouts.py b/pdr_backend/subgraph/subgraph_pending_payouts.py new file mode 100644 index 000000000..36569e0f5 --- /dev/null +++ b/pdr_backend/subgraph/subgraph_pending_payouts.py @@ -0,0 +1,56 @@ +from typing import Dict, List + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.core_subgraph import query_subgraph + + +@enforce_types +def query_pending_payouts(subgraph_url: str, addr: str) -> Dict[str, List[int]]: + chunk_size = 1000 + offset = 0 + pending_slots: Dict[str, List[int]] = {} + addr = addr.lower() + + while True: + query = """ + { + predictPredictions( + where: {user: "%s", payout: null, slot_: {status: "Paying"} }, first: %s, skip: %s + ) { + id + timestamp + slot { + id + slot + predictContract { + id + } + } + } + } + """ % ( + addr, + chunk_size, + offset, + ) + offset += chunk_size + print(".", end="", flush=True) + try: + result = query_subgraph(subgraph_url, query) + if "data" not in result or not result["data"]: + print("No data in result") + break + predict_predictions = result["data"].get("predictPredictions", []) + if not predict_predictions: + break + for prediction in predict_predictions: + contract_address = prediction["slot"]["predictContract"]["id"] + timestamp = prediction["slot"]["slot"] + pending_slots.setdefault(contract_address, []).append(timestamp) + except Exception as e: + print("An error occured", e) + break + + print() # print new line + return pending_slots diff --git a/pdr_backend/subgraph/subgraph_pending_slots.py b/pdr_backend/subgraph/subgraph_pending_slots.py new file mode 100644 index 000000000..fe6e1f5b8 --- /dev/null +++ b/pdr_backend/subgraph/subgraph_pending_slots.py @@ -0,0 +1,120 @@ +import time +from typing import List, Optional + +from pdr_backend.cli.arg_feeds import ArgFeeds +from pdr_backend.contract.slot import Slot +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.subgraph.info725 import info725_to_info +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + + +# don't use @enforce_types here, it causes issues +def get_pending_slots( + subgraph_url: str, + timestamp: int, + owner_addresses: Optional[List[str]], + allowed_feeds: Optional[ArgFeeds] = None, +): + chunk_size = 1000 + offset = 0 + owners: Optional[List[str]] = owner_addresses + + slots: List[Slot] = [] + + now_ts = time.time() + # rounds older than 3 days are canceled + 10 min buffer + three_days_ago = int(now_ts - 60 * 60 * 24 * 3 + 10 * 60) + + while True: + query = """ + { + predictSlots(where: {slot_gt: %s, slot_lte: %s, status: "Pending"}, skip:%s, first:%s){ + id + slot + status + trueValues { + id + } + predictContract { + id + token { + id + name + symbol + nft { + owner { + id + } + nftData { + key + value + } + } + } + secondsPerEpoch + secondsPerSubscription + truevalSubmitTimeout + } + } + } + """ % ( + three_days_ago, + timestamp, + offset, + chunk_size, + ) + + offset += chunk_size + try: + result = query_subgraph(subgraph_url, query) + if not "data" in result: + print("No data in result") + break + slot_list = result["data"]["predictSlots"] + if slot_list == []: + break + for slot in slot_list: + if slot["trueValues"] != []: + continue + + contract = slot["predictContract"] + info725 = contract["token"]["nft"]["nftData"] + info = info725_to_info(info725) + + pair = info["pair"] + timeframe = info["timeframe"] + source = info["source"] + assert pair, "need a pair" + assert timeframe, "need a timeframe" + assert source, "need a source" + + owner_id = contract["token"]["nft"]["owner"]["id"] + if owners and (owner_id not in owners): + continue + + if allowed_feeds and not allowed_feeds.contains_combination( + source, pair, timeframe + ): + continue + + feed = SubgraphFeed( + name=contract["token"]["name"], + address=contract["id"], + symbol=contract["token"]["symbol"], + seconds_per_subscription=int(contract["secondsPerSubscription"]), + trueval_submit_timeout=int(contract["truevalSubmitTimeout"]), + owner=contract["token"]["nft"]["owner"]["id"], + pair=pair, + timeframe=timeframe, + source=source, + ) + + slot_number = int(slot["slot"]) + slot = Slot(slot_number, feed) + slots.append(slot) + + except Exception as e: + print(e) + break + + return slots diff --git a/pdr_backend/util/subgraph_predictions.py b/pdr_backend/subgraph/subgraph_predictions.py similarity index 64% rename from pdr_backend/util/subgraph_predictions.py rename to pdr_backend/subgraph/subgraph_predictions.py index 70a935ba8..9ddef8474 100644 --- a/pdr_backend/util/subgraph_predictions.py +++ b/pdr_backend/subgraph/subgraph_predictions.py @@ -1,22 +1,26 @@ -from typing import List, TypedDict import json from enum import Enum +from typing import List, TypedDict + from enforce_typing import enforce_types -from pdr_backend.util.subgraph import query_subgraph, info_from_725 -from pdr_backend.models.prediction import Prediction +from pdr_backend.subgraph.prediction import Prediction +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.subgraph.info725 import info725_to_info from pdr_backend.util.networkutil import get_subgraph_url class ContractIdAndSPE(TypedDict): - id: str + ID: str seconds_per_epoch: int name: str class FilterMode(Enum): + NONE = 0 CONTRACT = 1 PREDICTOOR = 2 + CONTRACT_TS = 3 @enforce_types @@ -26,6 +30,8 @@ def fetch_filtered_predictions( filters: List[str], network: str, filter_mode: FilterMode, + payout_only: bool = True, + trueval_only: bool = True, ) -> List[Prediction]: """ Fetches predictions from a subgraph within a specified time range @@ -63,16 +69,21 @@ def fetch_filtered_predictions( filters = [f.lower() for f in filters] # pylint: disable=line-too-long - if filter_mode == FilterMode.CONTRACT: - where_clause = f"where: {{slot_: {{predictContract_in: {json.dumps(filters)}, slot_gt: {start_ts}, slot_lt: {end_ts}}}}}" + if filter_mode == FilterMode.NONE: + where_clause = f", where: {{timestamp_gt: {start_ts}, timestamp_lt: {end_ts}}}" + elif filter_mode == FilterMode.CONTRACT_TS: + where_clause = f", where: {{timestamp_gt: {start_ts}, timestamp_lt: {end_ts}, slot_: {{predictContract_in: {json.dumps(filters)}}}}}" + elif filter_mode == FilterMode.CONTRACT: + where_clause = f", where: {{slot_: {{predictContract_in: {json.dumps(filters)}, slot_gt: {start_ts}, slot_lt: {end_ts}}}}}" elif filter_mode == FilterMode.PREDICTOOR: - where_clause = f"where: {{user_: {{id_in: {json.dumps(filters)}}}, slot_: {{slot_gt: {start_ts}, slot_lt: {end_ts}}}}}" + where_clause = f", where: {{user_: {{id_in: {json.dumps(filters)}}}, slot_: {{slot_gt: {start_ts}, slot_lt: {end_ts}}}}}" while True: query = f""" {{ - predictPredictions(skip: {offset}, first: {chunk_size}, {where_clause}) {{ + predictPredictions(skip: {offset}, first: {chunk_size} {where_clause}) {{ id + timestamp user {{ id }} @@ -110,46 +121,60 @@ def fetch_filtered_predictions( offset += chunk_size - if not "data" in result: + if "data" not in result or not result["data"]: break - data = result["data"]["predictPredictions"] + data = result["data"].get("predictPredictions", []) if len(data) == 0: break - for prediction in data: - info725 = prediction["slot"]["predictContract"]["token"]["nft"]["nftData"] - info = info_from_725(info725) - pair_name = info["pair"] + for prediction_sg_dict in data: + info725 = prediction_sg_dict["slot"]["predictContract"]["token"]["nft"][ + "nftData" + ] + info = info725_to_info(info725) + pair = info["pair"] timeframe = info["timeframe"] source = info["source"] - timestamp = prediction["slot"]["slot"] + timestamp = prediction_sg_dict["timestamp"] + slot = prediction_sg_dict["slot"]["slot"] + user = prediction_sg_dict["user"]["id"] + + trueval = None + payout = None + predicted_value = None + stake = None - if prediction["payout"] is None: + if payout_only is True and prediction_sg_dict["payout"] is None: continue - trueval = prediction["payout"]["trueValue"] - payout = float(prediction["payout"]["payout"]) + if not prediction_sg_dict["payout"] is None: + stake = float(prediction_sg_dict["stake"]) + trueval = prediction_sg_dict["payout"]["trueValue"] + predicted_value = prediction_sg_dict["payout"]["predictedValue"] + payout = float(prediction_sg_dict["payout"]["payout"]) - if trueval is None: + if trueval_only is True and trueval is None: continue - predictedValue = prediction["payout"]["predictedValue"] - stake = float(prediction["stake"]) - predictoor_user = prediction["user"]["id"] - - prediction_obj = Prediction( - pair_name, - timeframe, - predictedValue, - stake, - trueval, - timestamp, - source, - payout, - predictoor_user, + prediction = Prediction( + ID=prediction_sg_dict["id"], + pair=pair, + timeframe=timeframe, + prediction=predicted_value, + stake=stake, + trueval=trueval, + timestamp=timestamp, + source=source, + payout=payout, + slot=slot, + user=user, ) - predictions.append(prediction_obj) + predictions.append(prediction) + + # avoids doing next fetch if we've reached the end + if len(data) < chunk_size: + break return predictions @@ -213,21 +238,17 @@ def fetch_contract_id_and_spe( contract_addresses: List[str], network: str ) -> List[ContractIdAndSPE]: """ - This function queries a GraphQL endpoint to retrieve contract details such as - the contract ID and seconds per epoch for each contract address provided. - It supports querying both mainnet and testnet networks. - - Args: - contract_addresses (List[str]): A list of contract addresses to query. - network (str): The blockchain network to query ('mainnet' or 'testnet'). + @description + Query a GraphQL endpoint to retrieve details of contracts, like + contract ID and seconds per epoch. - Raises: - Exception: If the network is not 'mainnet' or 'testnet', or if no data is returned. + @arguments + contract_addresses - contract addresses to query + network - where to query. Eg 'mainnet' or 'testnet' - Returns: - List[ContractDetail]: A list of dictionaries containing contract details. + @return + contracts_list - where each item has contract details """ - if network not in ("mainnet", "testnet"): raise Exception("Invalid network, pick mainnet or testnet") @@ -254,15 +275,15 @@ def fetch_contract_id_and_spe( if "data" not in result: raise Exception("Error fetching contracts: No data returned") - # Parse the results and construct ContractDetail objects - contract_data = result["data"]["predictContracts"] - contracts: List[ContractIdAndSPE] = [ - { - "id": contract["id"], - "seconds_per_epoch": contract["secondsPerEpoch"], - "name": contract["token"]["name"], + contracts_sg_dict = result["data"]["predictContracts"] + + contracts_list: List[ContractIdAndSPE] = [] + for contract_sg_dict in contracts_sg_dict: + contract_item: ContractIdAndSPE = { + "ID": contract_sg_dict["id"], + "seconds_per_epoch": contract_sg_dict["secondsPerEpoch"], + "name": contract_sg_dict["token"]["name"], } - for contract in contract_data - ] + contracts_list.append(contract_item) - return contracts + return contracts_list diff --git a/pdr_backend/subgraph/subgraph_slot.py b/pdr_backend/subgraph/subgraph_slot.py new file mode 100644 index 000000000..02e6e9e98 --- /dev/null +++ b/pdr_backend/subgraph/subgraph_slot.py @@ -0,0 +1,175 @@ +from dataclasses import dataclass +from typing import Any, Dict, List + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.util.networkutil import get_subgraph_url + + +@dataclass +class PredictSlot: + ID: str + slot: str + trueValues: List[Dict[str, Any]] + roundSumStakesUp: float + roundSumStakes: float + + +@enforce_types +def get_predict_slots_query( + asset_ids: List[str], initial_slot: int, last_slot: int, first: int, skip: int +) -> str: + """ + Constructs a GraphQL query string to fetch prediction slot data for + specified assets within a slot range. + + Args: + asset_ids: A list of asset identifiers to include in the query. + initial_slot: The starting slot number for the query range. + last_slot: The ending slot number for the query range. + first: The number of records to fetch per query (pagination limit). + skip: The number of records to skip (pagination offset). + + Returns: + A string representing the GraphQL query. + """ + asset_ids_str = str(asset_ids).replace("[", "[").replace("]", "]").replace("'", '"') + + return """ + query { + predictSlots ( + first: %s + skip: %s + where: { + slot_lte: %s + slot_gte: %s + predictContract_in: %s + } + ) { + id + slot + trueValues { + id + trueValue + } + roundSumStakesUp + roundSumStakes + } + } + """ % ( + first, + skip, + initial_slot, + last_slot, + asset_ids_str, + ) + + +@enforce_types +def get_slots( + addresses: List[str], + end_ts_param: int, + start_ts_param: int, + skip: int, + slots: List[PredictSlot], + network: str = "mainnet", +) -> List[PredictSlot]: + """ + Retrieves slots information for given addresses and a specified time range from a subgraph. + + Args: + addresses: A list of contract addresses to query. + end_ts_param: The Unix timestamp representing the end of the time range. + start_ts_param: The Unix timestamp representing the start of the time range. + skip: The number of records to skip for pagination. + slots: An existing list of slots to which new data will be appended. + network: The blockchain network to query ('mainnet' or 'testnet'). + + Returns: + A list of PredictSlot TypedDicts with the queried slot information. + """ + + slots = slots or [] + + records_per_page = 1000 + + query = get_predict_slots_query( + addresses, + end_ts_param, + start_ts_param, + records_per_page, + skip, + ) + + result = query_subgraph( + get_subgraph_url(network), + query, + timeout=20.0, + ) + + new_slots = result["data"]["predictSlots"] or [] + + # Convert the list of dicts to a list of PredictSlot objects + # by passing the dict as keyword arguments + # convert roundSumStakesUp and roundSumStakes to float + new_slots = [ + PredictSlot( + **{ + "ID": slot["id"], + "slot": slot["slot"], + "trueValues": slot["trueValues"], + "roundSumStakesUp": float(slot["roundSumStakesUp"]), + "roundSumStakes": float(slot["roundSumStakes"]), + } + ) + for slot in new_slots + ] + + slots.extend(new_slots) + if len(new_slots) == records_per_page: + return get_slots( + addresses, + end_ts_param, + start_ts_param, + skip + records_per_page, + slots, + network, + ) + return slots + + +@enforce_types +def fetch_slots_for_all_assets( + asset_ids: List[str], + start_ts_param: int, + end_ts_param: int, + network: str = "mainnet", +) -> Dict[str, List[PredictSlot]]: + """ + Fetches slots for all provided asset IDs within a given time range and organizes them by asset. + + Args: + asset_ids: A list of asset identifiers for which slots will be fetched. + start_ts_param: The Unix timestamp marking the beginning of the desired time range. + end_ts_param: The Unix timestamp marking the end of the desired time range. + network: The blockchain network to query ('mainnet' or 'testnet'). + + Returns: + A dictionary mapping asset IDs to lists of PredictSlot dataclass + containing slot information. + """ + + all_slots = get_slots(asset_ids, end_ts_param, start_ts_param, 0, [], network) + + slots_by_asset: Dict[str, List[PredictSlot]] = {} + for slot in all_slots: + slot_id = slot.ID + # split the id to get the asset id + asset_id = slot_id.split("-")[0] + if asset_id not in slots_by_asset: + slots_by_asset[asset_id] = [] + + slots_by_asset[asset_id].append(slot) + + return slots_by_asset diff --git a/pdr_backend/subgraph/subgraph_subscriptions.py b/pdr_backend/subgraph/subgraph_subscriptions.py new file mode 100644 index 000000000..2199f727c --- /dev/null +++ b/pdr_backend/subgraph/subgraph_subscriptions.py @@ -0,0 +1,133 @@ +import json +from typing import List + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subscription import Subscription +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.subgraph.info725 import info725_to_info +from pdr_backend.util.networkutil import get_subgraph_url + + +@enforce_types +def fetch_filtered_subscriptions( + start_ts: int, + end_ts: int, + contracts: List[str], + network: str, +) -> List[Subscription]: + """ + Fetches subscriptions from predictoor subgraph within a specified time range + and according to given contracts. + + This function supports querying subscriptions based on contract + addresses. It iteratively queries the subgraph in chunks to retrieve all relevant + subscriptions and returns a dataframe as a result. + + Args: + start_ts: The starting Unix timestamp for the query range. + end_ts: The ending Unix timestamp for the query range. + contracts: A list of strings representing the filter + values (contract addresses). + network: A string indicating the blockchain network to query ('mainnet' or 'testnet'). + + Returns: + A dataframe of predictSubscriptions objects that match the filter criteria + + Raises: + Exception: If the specified network is neither 'mainnet' nor 'testnet'. + """ + + if network not in ["mainnet", "testnet"]: + raise Exception("Invalid network, pick mainnet or testnet") + + chunk_size = 1000 + offset = 0 + subscriptions: List[Subscription] = [] + + # Convert contracts to lowercase + contracts = [f.lower() for f in contracts] + + # pylint: disable=line-too-long + if len(contracts) > 0: + where_clause = f", where: {{predictContract_: {{id_in: {json.dumps(contracts)}}}, timestamp_gt: {start_ts}, timestamp_lt: {end_ts}}}" + else: + where_clause = f", where: {{timestamp_gt: {start_ts}, timestamp_lt: {end_ts}}}" + + # pylint: disable=line-too-long + while True: + query = f""" + {{ + predictSubscriptions(skip: {offset}, first: {chunk_size} {where_clause}) {{ + id + txId + timestamp + user {{ + id + }} + predictContract {{ + id + token {{ + id + name + lastPriceValue + nft{{ + nftData {{ + key + value + }} + }} + }} + }} + }} + }}""" + + print("Querying subgraph...", query) + result = query_subgraph( + get_subgraph_url(network), + query, + timeout=20.0, + ) + + offset += chunk_size + + if "data" not in result or not result["data"]: + break + + data = result["data"].get("predictSubscriptions", []) + if len(data) == 0: + break + + for subscription_sg_dict in data: + info725 = subscription_sg_dict["predictContract"]["token"]["nft"]["nftData"] + info = info725_to_info(info725) + pair = info["pair"] + timeframe = info["timeframe"] + source = info["source"] + timestamp = subscription_sg_dict["timestamp"] + tx_id = subscription_sg_dict["txId"] + last_price_value = ( + float( + subscription_sg_dict["predictContract"]["token"]["lastPriceValue"] + ) + * 1.201 + ) + user = subscription_sg_dict["user"]["id"] + + subscription = Subscription( + ID=subscription_sg_dict["id"], + pair=pair, + timeframe=timeframe, + source=source, + timestamp=timestamp, + tx_id=tx_id, + last_price_value=last_price_value, + user=user, + ) + subscriptions.append(subscription) + + # avoids doing next fetch if we've reached the end + if len(data) < chunk_size: + break + + return subscriptions diff --git a/pdr_backend/subgraph/subgraph_sync.py b/pdr_backend/subgraph/subgraph_sync.py new file mode 100644 index 000000000..dc026a392 --- /dev/null +++ b/pdr_backend/subgraph/subgraph_sync.py @@ -0,0 +1,33 @@ +import time + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.util.web3_config import Web3Config + + +@enforce_types +def block_number_is_synced(subgraph_url: str, block_number: int) -> bool: + query = """ + { + predictContracts(block:{number:%s}){ + id + } + } + """ % ( + block_number + ) + try: + result = query_subgraph(subgraph_url, query) + except Exception: + return False + + return "errors" not in result + + +@enforce_types +def wait_until_subgraph_syncs(web3_config: Web3Config, subgraph_url: str): + block_number = web3_config.w3.eth.block_number + while block_number_is_synced(subgraph_url, block_number) is not True: + print("Subgraph is out of sync, trying again in 5 seconds") + time.sleep(5) diff --git a/pdr_backend/subgraph/subscription.py b/pdr_backend/subgraph/subscription.py new file mode 100644 index 000000000..544e57310 --- /dev/null +++ b/pdr_backend/subgraph/subscription.py @@ -0,0 +1,149 @@ +from typing import List + +from enforce_typing import enforce_types + + +@enforce_types +class Subscription: + # pylint: disable=too-many-instance-attributes + def __init__( + self, + ID: str, + pair: str, + timeframe: str, + source: str, + timestamp: int, # timestamp == subscription purchased timestamp + tx_id: str, + last_price_value: float, + user: str, + ) -> None: + self.ID = ID + self.pair = pair + self.timeframe = timeframe + self.source = source + self.timestamp = timestamp + self.tx_id = tx_id + self.last_price_value = last_price_value + self.user = user + + +# ========================================================================= +# utilities for testing + + +@enforce_types +def mock_subscription(subscription_tuple: tuple) -> Subscription: + ( + pair_str, + timeframe_str, + source, + timestamp, + tx_id, + last_price_value, + event_index, + user, + ) = subscription_tuple + + ID = f"{pair_str}-{tx_id}-{event_index}" + return Subscription( + ID=ID, + pair=pair_str, + timeframe=timeframe_str, + source=source, + timestamp=timestamp, + tx_id=tx_id, + last_price_value=float(last_price_value) * 1.201, + user=user, + ) + + +@enforce_types +def mock_subscriptions() -> List[Subscription]: + return [ + mock_subscription(subscription_tuple) + for subscription_tuple in _SUBSCRIPTION_TUPS + ] + + +_SUBSCRIPTION_TUPS = [ + ( + "ETH/USDT", + "5m", + "binance", + 1698850800, # Nov 01 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809592", + "2.4979184013322233", + 98, + "0x2433e002ed10b5d6a3d8d1e0c5d2083be9e37f1d", + ), + ( + "BTC/USDT", + "5m", + "kraken", + 1698937200, # Nov 02 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809593", + "2.4979184013322233", + 99, + "0xabcdef0123456789abcdef0123456789abcdef01", + ), + ( + "LTC/USDT", + "1h", + "kraken", + 1699110000, # Nov 04 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809594", + "2.4979184013322233", + 100, + "0x123456789abcdef0123456789abcdef01234567", + ), + ( + "XRP/USDT", + "5m", + "binance", + 1699110000, # Nov 04 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809595", + "2.4979184013322233", + 101, + "0xabcdef0123456789abcdef0123456789abcdef02", + ), + ( + "DOGE/USDT", + "5m", + "kraken", + 1699110000, # Nov 04 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809596", + "2.4979184013322233", + 102, + "0xabcdef0123456789abcdef0123456789abcdef03", + ), + ( + "ADA/USDT", + "1h", + "kraken", + 1699200000, # Nov 05 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809597", + "2.4979184013322233", + 103, + "0xabcdef0123456789abcdef0123456789abcdef04", + ), + ( + "DOT/USDT", + "5m", + "binance", + 1699200000, # Nov 05 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809598", + "2.4979184013322233", + 104, + "0xabcdef0123456789abcdef0123456789abcdef05", + ), + ( + "LINK/USDT", + "1h", + "kraken", + 1699286400, # Nov 06 2023 15:00 GMT/UTC + "0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809599", + "2.4979184013322233", + 105, + "0xabcdef0123456789abcdef0123456789abcdef06", + ), +] diff --git a/pdr_backend/subgraph/test/resources.py b/pdr_backend/subgraph/test/resources.py new file mode 100644 index 000000000..955a4e4dc --- /dev/null +++ b/pdr_backend/subgraph/test/resources.py @@ -0,0 +1,24 @@ +from enforce_typing import enforce_types + + +@enforce_types +class MockResponse: + def __init__(self, contract_list: list, status_code: int): + self.contract_list = contract_list + self.status_code = status_code + self.num_queries = 0 + + def json(self) -> dict: + self.num_queries += 1 + if self.num_queries > 1: + self.contract_list = [] + return {"data": {"predictContracts": self.contract_list}} + + +@enforce_types +class MockPost: + def __init__(self, contract_list: list = [], status_code: int = 200): + self.response = MockResponse(contract_list, status_code) + + def __call__(self, *args, **kwargs): + return self.response diff --git a/pdr_backend/subgraph/test/test_core_subgraph.py b/pdr_backend/subgraph/test/test_core_subgraph.py new file mode 100644 index 000000000..4256e29f5 --- /dev/null +++ b/pdr_backend/subgraph/test/test_core_subgraph.py @@ -0,0 +1,20 @@ +import pytest +import requests +from enforce_typing import enforce_types + +from pdr_backend.subgraph.core_subgraph import query_subgraph +from pdr_backend.subgraph.test.resources import MockPost + + +@enforce_types +def test_query_subgraph_happypath(monkeypatch): + monkeypatch.setattr(requests, "post", MockPost(status_code=200)) + result = query_subgraph(subgraph_url="foo", query="bar") + assert result == {"data": {"predictContracts": []}} + + +@enforce_types +def test_query_subgraph_badpath(monkeypatch): + monkeypatch.setattr(requests, "post", MockPost(status_code=400)) + with pytest.raises(Exception): + query_subgraph(subgraph_url="foo", query="bar") diff --git a/pdr_backend/subgraph/test/test_info725.py b/pdr_backend/subgraph/test/test_info725.py new file mode 100644 index 000000000..567c9e4eb --- /dev/null +++ b/pdr_backend/subgraph/test/test_info725.py @@ -0,0 +1,85 @@ +from enforce_typing import enforce_types +from web3 import Web3 + +from pdr_backend.subgraph.info725 import ( + info725_to_info, + info_to_info725, + key_to_key725, + value725_to_value, + value_to_value725, +) + + +@enforce_types +def test_key(): + key = "name" + key725 = key_to_key725(key) + assert key725 == Web3.keccak(key.encode("utf-8")).hex() + + +@enforce_types +def test_value(): + value = "ETH/USDT" + value725 = value_to_value725(value) + value_again = value725_to_value(value725) + + assert value == value_again + assert value == Web3.to_text(hexstr=value725) + + +@enforce_types +def test_value_None(): + assert value_to_value725(None) is None + assert value725_to_value(None) is None + + +@enforce_types +def test_info_to_info725_and_back(): + info = {"pair": "BTC/USDT", "timeframe": "5m", "source": "binance"} + info725 = [ + {"key": key_to_key725("pair"), "value": value_to_value725("BTC/USDT")}, + {"key": key_to_key725("timeframe"), "value": value_to_value725("5m")}, + {"key": key_to_key725("source"), "value": value_to_value725("binance")}, + ] + assert info_to_info725(info) == info725 + assert info725_to_info(info725) == info + + +@enforce_types +def test_info_to_info725_and_back__some_None(): + info = {"pair": "BTC/USDT", "timeframe": "5m", "source": None} + info725 = [ + {"key": key_to_key725("pair"), "value": value_to_value725("BTC/USDT")}, + {"key": key_to_key725("timeframe"), "value": value_to_value725("5m")}, + {"key": key_to_key725("source"), "value": None}, + ] + assert info_to_info725(info) == info725 + assert info725_to_info(info725) == info + + +@enforce_types +def test_info_to_info725__extraval(): + info = { + "pair": "BTC/USDT", + "timeframe": "5m", + "source": "binance", + "extrakey": "extraval", + } + info725 = [ + {"key": key_to_key725("pair"), "value": value_to_value725("BTC/USDT")}, + {"key": key_to_key725("timeframe"), "value": value_to_value725("5m")}, + {"key": key_to_key725("source"), "value": value_to_value725("binance")}, + {"key": key_to_key725("extrakey"), "value": value_to_value725("extraval")}, + ] + assert info_to_info725(info) == info725 + + +@enforce_types +def test_info_to_info725__missingkey(): + info = {"pair": "BTC/USDT", "timeframe": "5m"} # no "source" + info725 = [ + {"key": key_to_key725("pair"), "value": value_to_value725("BTC/USDT")}, + {"key": key_to_key725("timeframe"), "value": value_to_value725("5m")}, + {"key": key_to_key725("source"), "value": None}, + ] + assert info_to_info725(info) == info725 diff --git a/pdr_backend/subgraph/test/test_prediction.py b/pdr_backend/subgraph/test/test_prediction.py new file mode 100644 index 000000000..a73e584aa --- /dev/null +++ b/pdr_backend/subgraph/test/test_prediction.py @@ -0,0 +1,20 @@ +from enforce_typing import enforce_types + +from pdr_backend.subgraph.prediction import Prediction, mock_first_predictions + + +@enforce_types +def test_predictions(): + predictions = mock_first_predictions() + + assert len(predictions) == 2 + assert isinstance(predictions[0], Prediction) + assert isinstance(predictions[1], Prediction) + assert ( + predictions[0].ID + == "ADA/USDT-5m-1701503100-0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd" + ) + assert ( + predictions[1].ID + == "BTC/USDT-5m-1701589500-0xaaaa4cb4ff2584bad80ff5f109034a891c3d88dd" + ) diff --git a/pdr_backend/subgraph/test/test_subgraph_consume_so_far.py b/pdr_backend/subgraph/test/test_subgraph_consume_so_far.py new file mode 100644 index 000000000..b5a24abdb --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_consume_so_far.py @@ -0,0 +1,129 @@ +from unittest.mock import patch + +from enforce_typing import enforce_types +from pytest import approx + +from pdr_backend.subgraph.info725 import key_to_key725, value_to_value725 +from pdr_backend.subgraph.subgraph_consume_so_far import get_consume_so_far_per_contract + +SAMPLE_CONTRACT_DATA = [ + { + "id": "contract1", + "token": { + "id": "token1", + "name": "ether", + "symbol": "ETH", + "orders": [ + { + "createdTimestamp": 1695288424, + "consumer": {"id": "0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a"}, + "lastPriceValue": "2.4979184013322233", + }, + { + "createdTimestamp": 1695288724, + "consumer": {"id": "0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a"}, + "lastPriceValue": "2.4979184013322233", + }, + ], + "nft": { + "owner": {"id": "0xowner1"}, + "nftData": [ + { + "key": key_to_key725("pair"), + "value": value_to_value725("ETH/USDT"), + }, + { + "key": key_to_key725("timeframe"), + "value": value_to_value725("5m"), + }, + { + "key": key_to_key725("source"), + "value": value_to_value725("binance"), + }, + ], + }, + }, + "secondsPerEpoch": 7, + "secondsPerSubscription": 700, + "truevalSubmitTimeout": 5, + } +] + + +@enforce_types +def test_get_consume_so_far_per_contract(): + call_count = 0 + + def mock_query_subgraph( + subgraph_url, query, tries, timeout + ): # pylint:disable=unused-argument + nonlocal call_count + slot_data = SAMPLE_CONTRACT_DATA + + if call_count > 0: + slot_data[0]["token"]["orders"] = [] + + call_count += 1 + return {"data": {"predictContracts": slot_data}} + + PATH = "pdr_backend.subgraph.subgraph_consume_so_far" + with patch(f"{PATH}.query_subgraph", mock_query_subgraph): + consumes = get_consume_so_far_per_contract( + subgraph_url="foo", + user_address="0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a", + since_timestamp=2000, + contract_addresses=["contract1"], + ) + + assert consumes["contract1"] == approx(6, 0.001) + + +@enforce_types +def test_get_consume_so_far_per_contract_empty_data(): + def mock_query_subgraph( + subgraph_url, query, tries, timeout + ): # pylint:disable=unused-argument + return {} + + PATH = "pdr_backend.subgraph.subgraph_consume_so_far" + with patch( + f"{PATH}.query_subgraph", mock_query_subgraph + ): # pylint:disable=unused-argument + consumes = get_consume_so_far_per_contract( + subgraph_url="foo", + user_address="0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a", + since_timestamp=2000, + contract_addresses=["contract1"], + ) + + assert consumes == {} + + def mock_query_subgraph_2( + subgraph_url, query, tries, timeout + ): # pylint:disable=unused-argument + return {"data": {"predictContracts": []}} + + with patch(f"{PATH}.query_subgraph", mock_query_subgraph_2): + consumes = get_consume_so_far_per_contract( + subgraph_url="foo", + user_address="0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a", + since_timestamp=2000, + contract_addresses=["contract1"], + ) + + assert consumes == {} + + def mock_query_subgraph_3( + subgraph_url, query, tries, timeout + ): # pylint:disable=unused-argument + return {"data": {"predictContracts": [{"id": "contract2"}]}} + + with patch(f"{PATH}.query_subgraph", mock_query_subgraph_3): + consumes = get_consume_so_far_per_contract( + subgraph_url="foo", + user_address="0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a", + since_timestamp=2000, + contract_addresses=["contract1"], + ) + + assert consumes == {} diff --git a/pdr_backend/subgraph/test/test_subgraph_feed.py b/pdr_backend/subgraph/test/test_subgraph_feed.py new file mode 100644 index 000000000..eda39f14e --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_feed.py @@ -0,0 +1,82 @@ +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed, mock_feed, print_feeds + + +@enforce_types +def test_feed(): + feed = SubgraphFeed( + "Contract Name", + "0x12345", + "SYM:TEST", + 60, + 15, + "0xowner", + "BTC/USDT", + "5m", + "binance", + ) + + assert feed.name == "Contract Name" + assert feed.address == "0x12345" + assert feed.symbol == "SYM:TEST" + assert feed.seconds_per_subscription == 60 + assert feed.trueval_submit_timeout == 15 + assert feed.owner == "0xowner" + assert feed.pair == "BTC/USDT" + assert feed.timeframe == "5m" + assert feed.source == "binance" + + assert feed.seconds_per_epoch == 5 * 60 + assert feed.quote == "USDT" + assert feed.base == "BTC" + + assert str(feed) == feed.shortstr() + s = str(feed) + for want_s in ["Feed", "5m", "BTC/USDT", "binance", feed.address]: + assert want_s in s + + +@enforce_types +def test_mock_feed(): + feed = mock_feed("5m", "binance", "BTC/USDT") + assert feed.timeframe == "5m" + assert feed.source == "binance" + assert feed.pair == "BTC/USDT" + assert feed.address[:2] == "0x" + assert len(feed.address) == 42 # ethereum sized address + + +@enforce_types +def test_feed__seconds_per_epoch(): + # 5m + feed = mock_feed("5m", "binance", "BTC/USDT") + assert feed.timeframe == "5m" + assert feed.seconds_per_epoch == 5 * 60 + + # 1h + feed = mock_feed("1h", "binance", "BTC/USDT") + assert feed.timeframe == "1h" + assert feed.seconds_per_epoch == 60 * 60 + + +@enforce_types +def test_feed__convert_pair(): + # start with '/', no convert needed + feed = mock_feed("5m", "binance", "BTC/USDT") + assert feed.pair == "BTC/USDT" + + # start with '-', convert to '/' + feed = mock_feed("5m", "binance", "BTC-USDT") + assert feed.pair == "BTC/USDT" + + +@enforce_types +def test_print_feeds(): + f1 = mock_feed("5m", "binance", "BTC/USDT") + f2 = mock_feed("1h", "kraken", "BTC/USDT") + feeds = {f1.address: f1, f2.address: f2} + + print_feeds(feeds) + print_feeds(feeds, label=None) + print_feeds(feeds, "my feeds") diff --git a/pdr_backend/subgraph/test/test_subgraph_feed_contracts.py b/pdr_backend/subgraph/test/test_subgraph_feed_contracts.py new file mode 100644 index 000000000..b999394f8 --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_feed_contracts.py @@ -0,0 +1,113 @@ +import pytest +import requests +from enforce_typing import enforce_types + +from pdr_backend.cli.timeframe import Timeframe +from pdr_backend.subgraph.info725 import info_to_info725 +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.subgraph.subgraph_feed_contracts import query_feed_contracts +from pdr_backend.subgraph.test.resources import MockPost + + +@enforce_types +def mock_contract(info: dict, symbol: str) -> dict: + info725 = info_to_info725(info) + + contract = { + "id": "0xNFT1", + "token": { + "id": "0xDT1", + "name": f"Name:{symbol}", + "symbol": symbol, + "nft": { + "owner": { + "id": "0xowner1", + }, + "nftData": info725, + }, + }, + "secondsPerEpoch": Timeframe(info["timeframe"]).s, + "secondsPerSubscription": 700, + "truevalSubmitTimeout": 5, + } + return contract + + +@enforce_types +def test_query_feed_contracts__emptychain(monkeypatch): + contract_list = [] + monkeypatch.setattr(requests, "post", MockPost(contract_list)) + contracts = query_feed_contracts(subgraph_url="foo") + assert contracts == {} + + +@enforce_types +def test_query_feed_contracts__fullchain(monkeypatch): + # This test is a simple-as-possible happy path. Start here. + # Then follow up with test_filter() below, which is complex but thorough + + info = {"pair": "BTC/USDT", "timeframe": "5m", "source": "binance"} + contract = mock_contract(info, "contract1") + contract_addr = contract["id"] + + contract_list = [contract] + monkeypatch.setattr(requests, "post", MockPost(contract_list)) + + feeds = query_feed_contracts(subgraph_url="foo") + + assert len(feeds) == 1 + assert contract_addr in feeds + feed = feeds[contract_addr] + assert isinstance(feed, SubgraphFeed) + + assert feed.name == "Name:contract1" + assert feed.address == "0xNFT1" + assert feed.symbol == "contract1" + assert feed.seconds_per_subscription == 700 + assert feed.trueval_submit_timeout == 5 + assert feed.owner == "0xowner1" + assert feed.pair == "BTC/USDT" + assert feed.timeframe == "5m" + assert feed.source == "binance" + assert feed.seconds_per_epoch == 5 * 60 + + +@enforce_types +@pytest.mark.parametrize( + "expect_result, owners", + [ + (True, None), + (True, ""), + (True, "0xowner1"), + (False, "0xowner2"), + (True, "0xowner1,0xowner2"), + (False, "0xowner2,0xowner3"), + ], +) +def test_query_feed_contracts__filter(monkeypatch, expect_result, owners): + info = {"pair": "BTC/USDT", "timeframe": "5m", "source": "binance"} + info725 = info_to_info725(info) + + contract1 = { + "id": "contract1", + "token": { + "id": "token1", + "name": "ether", + "symbol": "ETH", + "nft": { + "owner": { + "id": "0xowner1", + }, + "nftData": info725, + }, + }, + "secondsPerEpoch": 7, + "secondsPerSubscription": 700, + "truevalSubmitTimeout": 5, + } + contract_list = [contract1] + + monkeypatch.setattr(requests, "post", MockPost(contract_list)) + feed_dicts = query_feed_contracts("foo", owners) + + assert bool(feed_dicts) == bool(expect_result) diff --git a/pdr_backend/subgraph/test/test_subgraph_pending_payouts.py b/pdr_backend/subgraph/test/test_subgraph_pending_payouts.py new file mode 100644 index 000000000..6c926837f --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_pending_payouts.py @@ -0,0 +1,61 @@ +from unittest.mock import patch + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_pending_payouts import query_pending_payouts + +SAMPLE_PENDING_DATA = [ + { + "id": "slot1", + "timestamp": 1000, + "slot": { + "id": "slot1", + "slot": 2000, + "predictContract": { + "id": "contract1", + }, + }, + } +] + + +@enforce_types +def test_query_pending_payouts(): + call_count = 0 + + def mock_query_subgraph(subgraph_url, query): # pylint:disable=unused-argument + nonlocal call_count + pending_payout_data = SAMPLE_PENDING_DATA if call_count < 1 else [] + call_count += 1 + return {"data": {"predictPredictions": pending_payout_data}} + + PATH = "pdr_backend.subgraph.subgraph_pending_payouts" + with patch(f"{PATH}.query_subgraph", mock_query_subgraph): + pending_payouts = query_pending_payouts( + subgraph_url="foo", + addr="0x123", + ) + + assert pending_payouts == {"contract1": [2000]} + + +@enforce_types +def test_query_pending_payouts_edge_cases(): + def mock_query_subgraph(subgraph_url, query): # pylint:disable=unused-argument + return {"data": {}} + + PATH = "pdr_backend.subgraph.subgraph_pending_payouts" + with patch(f"{PATH}.query_subgraph", mock_query_subgraph): + query_pending_payouts( + subgraph_url="foo", + addr="0x123", + ) + + def mock_query_subgraph_2(subgraph_url, query): # pylint:disable=unused-argument + return {"data": {"predictPredictions": []}} + + with patch(f"{PATH}.query_subgraph", mock_query_subgraph_2): + query_pending_payouts( + subgraph_url="foo", + addr="0x123", + ) diff --git a/pdr_backend/subgraph/test/test_subgraph_pending_slots.py b/pdr_backend/subgraph/test/test_subgraph_pending_slots.py new file mode 100644 index 000000000..448206be2 --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_pending_slots.py @@ -0,0 +1,69 @@ +from unittest.mock import patch + +from enforce_typing import enforce_types + +from pdr_backend.contract.slot import Slot +from pdr_backend.subgraph.info725 import key_to_key725, value_to_value725 +from pdr_backend.subgraph.subgraph_pending_slots import get_pending_slots + +SAMPLE_SLOT_DATA = [ + { + "id": "slot1", + "slot": 1000, + "trueValues": [], + "predictContract": { + "id": "contract1", + "token": { + "id": "token1", + "name": "ether", + "symbol": "ETH", + "nft": { + "owner": {"id": "0xowner1"}, + "nftData": [ + { + "key": key_to_key725("pair"), + "value": value_to_value725("ETH/USDT"), + }, + { + "key": key_to_key725("timeframe"), + "value": value_to_value725("5m"), + }, + { + "key": key_to_key725("source"), + "value": value_to_value725("binance"), + }, + ], + }, + }, + "secondsPerEpoch": 7, + "secondsPerSubscription": 700, + "truevalSubmitTimeout": 5, + }, + } +] + + +@enforce_types +def test_get_pending_slots(): + call_count = 0 + + def mock_query_subgraph(subgraph_url, query): # pylint:disable=unused-argument + nonlocal call_count + slot_data = SAMPLE_SLOT_DATA if call_count <= 1 else [] + call_count += 1 + return {"data": {"predictSlots": slot_data}} + + PATH = "pdr_backend.subgraph.subgraph_pending_slots" + with patch(f"{PATH}.query_subgraph", mock_query_subgraph): + slots = get_pending_slots( + subgraph_url="foo", + timestamp=2000, + owner_addresses=None, + allowed_feeds=None, + ) + + assert len(slots) == 2 + slot0 = slots[0] + assert isinstance(slot0, Slot) + assert slot0.slot_number == 1000 + assert slot0.feed.name == "ether" diff --git a/pdr_backend/util/test_ganache/test_subgraph_predictions.py b/pdr_backend/subgraph/test/test_subgraph_predictions.py similarity index 60% rename from pdr_backend/util/test_ganache/test_subgraph_predictions.py rename to pdr_backend/subgraph/test/test_subgraph_predictions.py index 4c5d78d84..d42053400 100644 --- a/pdr_backend/util/test_ganache/test_subgraph_predictions.py +++ b/pdr_backend/subgraph/test/test_subgraph_predictions.py @@ -1,23 +1,29 @@ from typing import Dict from unittest.mock import patch + +import pytest from enforce_typing import enforce_types -from pdr_backend.util.subgraph_predictions import ( - fetch_filtered_predictions, - get_all_contract_ids_by_owner, - fetch_contract_id_and_spe, + +from pdr_backend.subgraph.subgraph_predictions import ( FilterMode, Prediction, + fetch_contract_id_and_spe, + fetch_filtered_predictions, + get_all_contract_ids_by_owner, ) SAMPLE_PREDICTION = Prediction( + # pylint: disable=line-too-long + ID="0x18f54cc21b7a2fdd011bea06bba7801b280e3151-1698527100-0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", pair="ADA/USDT", timeframe="5m", prediction=True, stake=0.050051425480971974, trueval=False, - timestamp=1698527100, + timestamp=1698527000, source="binance", payout=0.0, + slot=1698527100, user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", ) @@ -29,6 +35,7 @@ "id": "0x18f54cc21b7a2fdd011bea06bba7801b280e3151-1698527100-0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", "user": {"id": "0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd"}, "stake": "0.050051425480971974", + "timestamp": 1698527000, "payout": {"payout": "0", "trueValue": False, "predictedValue": True}, "slot": { "slot": 1698527100, @@ -91,7 +98,7 @@ @enforce_types -@patch("pdr_backend.util.subgraph_predictions.query_subgraph") +@patch("pdr_backend.subgraph.subgraph_predictions.query_subgraph") def test_fetch_filtered_predictions(mock_query_subgraph): mock_query_subgraph.side_effect = [ MOCK_PREDICTIONS_RESPONSE_FIRST_CALL, @@ -111,12 +118,51 @@ def test_fetch_filtered_predictions(mock_query_subgraph): assert predictions[0].pair == "ADA/USDT" assert predictions[0].trueval is False assert predictions[0].prediction is True - assert mock_query_subgraph.call_count == 2 + assert mock_query_subgraph.call_count == 1 + + +@enforce_types +def test_fetch_filtered_predictions_no_data(): + # network not supported + with pytest.raises(Exception): + fetch_filtered_predictions( + start_ts=1622547000, + end_ts=1622548800, + filters=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="xyz", + filter_mode=FilterMode.PREDICTOOR, + ) + + with patch( + "pdr_backend.subgraph.subgraph_predictions.query_subgraph" + ) as mock_query_subgraph: + mock_query_subgraph.return_value = {"data": {}} + predictions = fetch_filtered_predictions( + start_ts=1622547000, + end_ts=1622548800, + filters=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="mainnet", + filter_mode=FilterMode.PREDICTOOR, + ) + assert len(predictions) == 0 + + with patch( + "pdr_backend.subgraph.subgraph_predictions.query_subgraph" + ) as mock_query_subgraph: + mock_query_subgraph.return_value = {"data": {"predictPredictions": []}} + predictions = fetch_filtered_predictions( + start_ts=1622547000, + end_ts=1622548800, + filters=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="mainnet", + filter_mode=FilterMode.PREDICTOOR, + ) + assert len(predictions) == 0 @enforce_types @patch( - "pdr_backend.util.subgraph_predictions.query_subgraph", + "pdr_backend.subgraph.subgraph_predictions.query_subgraph", return_value=MOCK_CONTRACTS_RESPONSE, ) def test_get_all_contract_ids_by_owner( @@ -131,24 +177,53 @@ def test_get_all_contract_ids_by_owner( assert "token2" in contract_ids mock_query_subgraph.assert_called_once() + with patch( + "pdr_backend.subgraph.subgraph_predictions.query_subgraph", + return_value={"data": {}}, + ): + with pytest.raises(Exception): + get_all_contract_ids_by_owner(owner_address="0xOwner", network="mainnet") + + # network not supported + with pytest.raises(Exception): + get_all_contract_ids_by_owner(owner_address="0xOwner", network="xyz") + @enforce_types @patch( - "pdr_backend.util.subgraph_predictions.query_subgraph", + "pdr_backend.subgraph.subgraph_predictions.query_subgraph", return_value=MOCK_CONTRACT_DETAILS_RESPONSE, ) def test_fetch_contract_id_and_spe( mock_query_subgraph, ): # pylint: disable=unused-argument - contract_details = fetch_contract_id_and_spe( + contracts_list = fetch_contract_id_and_spe( contract_addresses=["contract1", "contract2"], network="mainnet" ) - assert len(contract_details) == 2 - assert contract_details[0]["id"] == "contract1" - assert contract_details[0]["seconds_per_epoch"] == 300 - assert contract_details[0]["name"] == "token1" - assert contract_details[1]["id"] == "contract2" - assert contract_details[1]["seconds_per_epoch"] == 600 - assert contract_details[1]["name"] == "token2" + assert len(contracts_list) == 2 + + c0, c1 = contracts_list # pylint: disable=unbalanced-tuple-unpacking + assert c0["ID"] == "contract1" + assert c0["seconds_per_epoch"] == 300 + assert c0["name"] == "token1" + assert c1["ID"] == "contract2" + assert c1["seconds_per_epoch"] == 600 + assert c1["name"] == "token2" + mock_query_subgraph.assert_called_once() + + with patch( + "pdr_backend.subgraph.subgraph_predictions.query_subgraph", + return_value={"data": {}}, + ): + with pytest.raises(Exception): + fetch_contract_id_and_spe( + contract_addresses=["contract1", "contract2"], network="mainnet" + ) + + # network not supported + with pytest.raises(Exception): + fetch_contract_id_and_spe( + contract_addresses=["contract1", "contract2"], network="xyz" + ) diff --git a/pdr_backend/subgraph/test/test_subgraph_slot.py b/pdr_backend/subgraph/test/test_subgraph_slot.py new file mode 100644 index 000000000..3c84073bd --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_slot.py @@ -0,0 +1,99 @@ +from typing import Dict +from unittest.mock import patch + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_slot import ( + PredictSlot, + fetch_slots_for_all_assets, + get_predict_slots_query, + get_slots, +) + + +@enforce_types +def test_get_predict_slots_query(): + # Test the get_predict_slots_query function with expected inputs and outputs + query = get_predict_slots_query( + asset_ids=["0xAsset"], initial_slot=1000, last_slot=2000, first=10, skip=0 + ) + assert "predictSlots" in query + assert "0xAsset" in query + assert "1000" in query + assert "2000" in query + + +# Sample data for tests +SAMPLE_PREDICT_QUERY_RESULT_ITEM = { + "id": "0xAsset-12345", + "slot": "12345", + "trueValues": [{"ID": "1", "trueValue": True}], + "roundSumStakesUp": 150.0, + "roundSumStakes": 100.0, +} + + +MOCK_QUERY_RESPONSE = {"data": {"predictSlots": [SAMPLE_PREDICT_QUERY_RESULT_ITEM]}} + +MOCK_QUERY_RESPONSE_FIRST_CALL = { + "data": { + "predictSlots": [SAMPLE_PREDICT_QUERY_RESULT_ITEM] + * 1000 # Simulate a full page of results + } +} + +MOCK_QUERY_RESPONSE_SECOND_CALL: Dict[str, Dict[str, list]] = { + "data": {"predictSlots": []} # Simulate no further results, stopping the recursion +} + + +@enforce_types +@patch("pdr_backend.subgraph.subgraph_slot.query_subgraph") +def test_get_slots(mock_query_subgraph): + # Configure the mock to return a full page of results on the first call, + # and no results on the second call + mock_query_subgraph.side_effect = [ + MOCK_QUERY_RESPONSE_FIRST_CALL, + MOCK_QUERY_RESPONSE_SECOND_CALL, + ] + + result_slots = get_slots( + addresses=["0xAsset"], + end_ts_param=2000, + start_ts_param=1000, + skip=0, + slots=[], + network="mainnet", + ) + + print("test_get_slots", result_slots) + + # Verify that the mock was called twice (once for the initial call, once for the recursive call) + assert mock_query_subgraph.call_count == 2 + # Verify that the result contains the expected number of slots + assert len(result_slots) == 1000 + # Verify that the slots contain instances of PredictSlot + assert isinstance(result_slots[0], PredictSlot) + # Verify the first slot's data matches the sample + assert result_slots[0].ID == "0xAsset-12345" + + +@enforce_types +@patch( + "pdr_backend.subgraph.subgraph_slot.query_subgraph", + return_value=MOCK_QUERY_RESPONSE, +) +def test_fetch_slots_for_all_assets(mock_query_subgraph): + # Test the fetch_slots_for_all_assets function + result = fetch_slots_for_all_assets( + asset_ids=["0xAsset"], start_ts_param=1000, end_ts_param=2000, network="mainnet" + ) + + print("test_fetch_slots_for_all_assets", result) + # Verify that the result is structured correctly + assert "0xAsset" in result + assert all(isinstance(slot, PredictSlot) for slot in result["0xAsset"]) + assert len(result["0xAsset"]) == 1 + assert result["0xAsset"][0].ID == "0xAsset-12345" + # Verify that the mock was called + mock_query_subgraph.assert_called() diff --git a/pdr_backend/subgraph/test/test_subgraph_subscriptions.py b/pdr_backend/subgraph/test/test_subgraph_subscriptions.py new file mode 100644 index 000000000..bd9c23ccd --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_subscriptions.py @@ -0,0 +1,133 @@ +from typing import Dict +from unittest.mock import patch + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_subscriptions import ( + Subscription, + fetch_filtered_subscriptions, +) + +SAMPLE_PREDICTION = Subscription( + # pylint: disable=line-too-long + ID="0x18f54cc21b7a2fdd011bea06bba7801b280e3151-0x00d1e4950e0de743fe88956f02f44b16d22a1827f8c29ff561b69716dbcc2677-14", + pair="ADA/USDT", + timeframe="5m", + source="binance", + timestamp=1701129777, + tx_id="0x00d1e4950e0de743fe88956f02f44b16d22a1827f8c29ff561b69716dbcc2677", + last_price_value=float("2.4979184013322233") * 1.201, + user="0x2433e002ed10b5d6a3d8d1e0c5d2083be9e37f1d", +) + +# pylint: disable=line-too-long +MOCK_SUBSCRIPTIONS_RESPONSE_FIRST_CALL = { + "data": { + "predictSubscriptions": [ + { + "id": "0x18f54cc21b7a2fdd011bea06bba7801b280e3151-0x00d1e4950e0de743fe88956f02f44b16d22a1827f8c29ff561b69716dbcc2677-14", + "predictContract": { + "id": "0x18f54cc21b7a2fdd011bea06bba7801b280e3151", + "token": { + "id": "0x18f54cc21b7a2fdd011bea06bba7801b280e3151", + "name": "ADA/USDT", + "nft": { + "nftData": [ + { + "key": "0x238ad53218834f943da60c8bafd36c36692dcb35e6d76bdd93202f5c04c0baff", + "value": "0x55534454", + }, + { + "key": "0x2cef5778d97683b4f64607f72e862fc0c92376e44cc61195ef72a634c0b1793e", + "value": "0x4144412f55534454", + }, + { + "key": "0x49435d2ff85f9f3594e40e887943d562765d026d50b7383e76891f8190bff4c9", + "value": "0x356d", + }, + { + "key": "0xf1f3eb40f5bc1ad1344716ced8b8a0431d840b5783aea1fd01786bc26f35ac0f", + "value": "0x414441", + }, + { + "key": "0xf7e3126f87228afb82c9b18537eed25aaeb8171a78814781c26ed2cfeff27e69", + "value": "0x62696e616e6365", + }, + ] + }, + "lastPriceValue": "2.4979184013322233", + }, + "secondsPerSubscription": "86400", + "secondsPerEpoch": "300", + }, + "user": {"id": "0x2433e002ed10b5d6a3d8d1e0c5d2083be9e37f1d"}, + "expireTime": "1701216000", + "eventIndex": 14, + "block": 1342747, + "timestamp": 1701129777, + "txId": "0x00d1e4950e0de743fe88956f02f44b16d22a1827f8c29ff561b69716dbcc2677", + } + ] + } +} + +MOCK_SUBSCRIPTIONS_RESPONSE_SECOND_CALL: Dict[str, dict] = {} + + +@enforce_types +@patch("pdr_backend.subgraph.subgraph_subscriptions.query_subgraph") +def test_fetch_filtered_subscriptions(mock_query_subgraph): + mock_query_subgraph.side_effect = [ + MOCK_SUBSCRIPTIONS_RESPONSE_FIRST_CALL, + MOCK_SUBSCRIPTIONS_RESPONSE_SECOND_CALL, + ] + subscriptions = fetch_filtered_subscriptions( + start_ts=1701129700, + end_ts=1701129800, + contracts=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="mainnet", + ) + + assert len(subscriptions) == 1 + assert isinstance(subscriptions[0], Subscription) + assert subscriptions[0].user == "0x2433e002ed10b5d6a3d8d1e0c5d2083be9e37f1d" + assert subscriptions[0].pair == "ADA/USDT" + assert mock_query_subgraph.call_count == 1 + + +@enforce_types +def test_fetch_filtered_subscriptions_no_data(): + # network not supported + with pytest.raises(Exception): + fetch_filtered_subscriptions( + start_ts=1701129700, + end_ts=1701129800, + contracts=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="xyz", + ) + + with patch( + "pdr_backend.subgraph.subgraph_subscriptions.query_subgraph" + ) as mock_query_subgraph: + mock_query_subgraph.return_value = {"data": {}} + subscriptions = fetch_filtered_subscriptions( + start_ts=1701129700, + end_ts=1701129800, + contracts=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="mainnet", + ) + assert len(subscriptions) == 0 + + with patch( + "pdr_backend.subgraph.subgraph_subscriptions.query_subgraph" + ) as mock_query_subgraph: + mock_query_subgraph.return_value = {"data": {"predictPredictions": []}} + subscriptions = fetch_filtered_subscriptions( + start_ts=1701129700, + end_ts=1701129800, + contracts=["0x18f54cc21b7a2fdd011bea06bba7801b280e3151"], + network="mainnet", + ) + + assert len(subscriptions) == 0 diff --git a/pdr_backend/subgraph/test/test_subgraph_sync.py b/pdr_backend/subgraph/test/test_subgraph_sync.py new file mode 100644 index 000000000..df1efbf47 --- /dev/null +++ b/pdr_backend/subgraph/test/test_subgraph_sync.py @@ -0,0 +1,47 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_sync import ( + block_number_is_synced, + wait_until_subgraph_syncs, +) +from pdr_backend.util.web3_config import Web3Config + + +@enforce_types +def test_block_number_is_synced(): + def mock_response(url: str, query: str): # pylint:disable=unused-argument + if "number:50" in query: + return { + "errors": [ + { + # pylint: disable=line-too-long + "message": "Failed to decode `block.number` value: `subgraph QmaGAi4jQw5L8J2xjnofAJb1PX5LLqRvGjsWbVehBELAUx only has data starting at block number 499 and data for block number 500 is therefore not available`" + } + ] + } + + return {"data": {"predictContracts": [{"id": "sample_id"}]}} + + with patch( + "pdr_backend.subgraph.subgraph_sync.query_subgraph", + side_effect=mock_response, + ): + assert block_number_is_synced("foo", 499) is True + assert block_number_is_synced("foo", 500) is False + assert block_number_is_synced("foo", 501) is False + + +@enforce_types +def test_wait_until_subgraph_syncs(): + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth = Mock() + mock_web3_config.w3.eth.block_number = 500 + + with patch( + "pdr_backend.subgraph.subgraph_sync.block_number_is_synced", + side_effect=[False, True], + ): + wait_until_subgraph_syncs(mock_web3_config, "foo") diff --git a/pdr_backend/subgraph/test/test_subscriptions.py b/pdr_backend/subgraph/test/test_subscriptions.py new file mode 100644 index 000000000..742e442ff --- /dev/null +++ b/pdr_backend/subgraph/test/test_subscriptions.py @@ -0,0 +1,20 @@ +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subscription import Subscription, mock_subscriptions + + +@enforce_types +def test_subscriptions(): + subscriptions = mock_subscriptions() + + assert len(subscriptions) == 8 + assert isinstance(subscriptions[0], Subscription) + assert isinstance(subscriptions[1], Subscription) + assert ( + subscriptions[0].ID + == "ETH/USDT-0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809592-98" + ) + assert ( + subscriptions[1].ID + == "BTC/USDT-0x01d3285e0e3b83a4c029142477c0573c3be5317ff68223703696093b27809593-99" + ) diff --git a/pdr_backend/trader/README.md b/pdr_backend/trader/README.md deleted file mode 100644 index bb6eacc7b..000000000 --- a/pdr_backend/trader/README.md +++ /dev/null @@ -1 +0,0 @@ -See [READMEs/trader.md](../../READMEs/trader.md). diff --git a/pdr_backend/trader/approach1/main.py b/pdr_backend/trader/approach1/main.py deleted file mode 100644 index fb57ea879..000000000 --- a/pdr_backend/trader/approach1/main.py +++ /dev/null @@ -1,12 +0,0 @@ -from pdr_backend.trader.approach1.trader_agent1 import TraderAgent1 -from pdr_backend.trader.approach1.trader_config1 import TraderConfig1 - - -def main(testing=False): - config = TraderConfig1() - t = TraderAgent1(config) - t.run(testing) - - -if __name__ == "__main__": - main() diff --git a/pdr_backend/trader/approach1/test/test_trader_agent1.py b/pdr_backend/trader/approach1/test/test_trader_agent1.py index ded735676..e3ba3bf85 100644 --- a/pdr_backend/trader/approach1/test/test_trader_agent1.py +++ b/pdr_backend/trader/approach1/test/test_trader_agent1.py @@ -1,78 +1,47 @@ -from unittest.mock import Mock, patch +from unittest.mock import patch -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types -from pdr_backend.models.feed import Feed from pdr_backend.trader.approach1.trader_agent1 import TraderAgent1 -from pdr_backend.trader.approach1.trader_config1 import TraderConfig1 +from pdr_backend.trader.test.trader_agent_runner import ( + do_constructor, + do_run, + setup_trade, +) @enforce_types -def mock_feed(): - feed = Mock(spec=Feed) - feed.name = "test feed" - feed.seconds_per_epoch = 60 - return feed +@patch.object(TraderAgent1, "check_subscriptions_and_subscribe") +def test_trader_agent1_constructor(check_subscriptions_and_subscribe_mock): + do_constructor(TraderAgent1, check_subscriptions_and_subscribe_mock) @enforce_types @patch.object(TraderAgent1, "check_subscriptions_and_subscribe") -def test_new_agent(check_subscriptions_and_subscribe_mock, predictoor_contract): - trader_config = Mock(spec=TraderConfig1) - trader_config.exchange_str = "mexc" - trader_config.exchange_pair = "BTC/USDT" - trader_config.timeframe = "5m" - trader_config.size = 10.0 - trader_config.get_feeds = Mock() - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - agent = TraderAgent1(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - no_feeds_config = Mock(spec=TraderConfig1) - no_feeds_config.get_feeds.return_value = {} - no_feeds_config.max_tries = 10 - - with pytest.raises(SystemExit): - TraderAgent1(no_feeds_config) +def test_trader_agent1_run(check_subscriptions_and_subscribe_mock): + do_run(TraderAgent1, check_subscriptions_and_subscribe_mock) @enforce_types @pytest.mark.asyncio @patch.object(TraderAgent1, "check_subscriptions_and_subscribe") -async def test_do_trade( - check_subscriptions_and_subscribe_mock, - predictoor_contract, - web3_config, -): - trader_config = Mock(spec=TraderConfig1) - trader_config.exchange_str = "mexc" - trader_config.exchange_pair = "BTC/USDT" - trader_config.timeframe = "5m" - trader_config.size = 10.0 - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - trader_config.max_tries = 10 - trader_config.web3_config = web3_config +async def test_trader_agent1_do_trade(check_subscriptions_and_subscribe_mock, capfd): + agent, feed = setup_trade( + TraderAgent1, + check_subscriptions_and_subscribe_mock, + ) + + await agent._do_trade(feed, (1.0, 1.0)) + assert agent.exchange.create_market_buy_order.call_count == 1 - agent = TraderAgent1(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() + await agent._do_trade(feed, (1.0, 0)) + out, _ = capfd.readouterr() + assert "There's no stake on this" in out - agent.exchange = Mock() - agent.exchange.create_market_buy_order.return_value = {"info": {"origQty": 1}} + agent.order = {} + with patch.object(agent.exchange, "create_market_sell_order", return_value="mock"): + await agent._do_trade(feed, (1.0, 1.0)) - await agent._do_trade(mock_feed(), (1.0, 1.0)) - assert agent.exchange.create_market_buy_order.call_count == 1 + out, _ = capfd.readouterr() + assert "Closing Order" in out diff --git a/pdr_backend/trader/approach1/trader_agent1.py b/pdr_backend/trader/approach1/trader_agent1.py index 94fe26d3e..2e0cd62ea 100644 --- a/pdr_backend/trader/approach1/trader_agent1.py +++ b/pdr_backend/trader/approach1/trader_agent1.py @@ -1,19 +1,19 @@ from os import getenv -from typing import Any, Dict, Tuple, Optional +from typing import Any, Dict, Optional, Tuple import ccxt from enforce_typing import enforce_types -from pdr_backend.models.feed import Feed -from pdr_backend.trader.approach1.trader_config1 import TraderConfig1 -from pdr_backend.trader.trader_agent import TraderAgent +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.trader.base_trader_agent import BaseTraderAgent @enforce_types -class TraderAgent1(TraderAgent): +class TraderAgent1(BaseTraderAgent): """ @description - TraderAgent Naive CCXT + Naive trader agent. - Market order buy-only - Doesn't save client state or manage pending open trades - Only works with MEXC. How to improve: @@ -25,10 +25,6 @@ class TraderAgent1(TraderAgent): 1. If existing open position, close it. 2. If new long prediction meets criteria, open long. - You can use the ENV_VARS to: - 1. Configure your strategy: pair, timeframe, etc.. - 2. Configure your exchange: api_key + secret_key - You can improve this by: 1. Improving how to enter/exit trade w/ orders 2. Improving when to buy @@ -36,12 +32,11 @@ class TraderAgent1(TraderAgent): 4. Using SL and TP """ - def __init__(self, config: TraderConfig1): - super().__init__(config) - self.config: TraderConfig1 = config + def __init__(self, ppss: PPSS): + super().__init__(ppss) - # Generic exchange clss - exchange_class = getattr(ccxt, self.config.exchange_str) + # Generic exchange class + exchange_class = self.ppss.trader_ss.exchange_class self.exchange: ccxt.Exchange = exchange_class( { "apiKey": getenv("EXCHANGE_API_KEY"), @@ -60,7 +55,7 @@ def __init__(self, config: TraderConfig1): self.order: Optional[Dict[str, Any]] = None assert self.exchange is not None, "Exchange cannot be None" - async def do_trade(self, feed: Feed, prediction: Tuple[float, float]): + async def do_trade(self, feed: SubgraphFeed, prediction: Tuple[float, float]): """ @description Logic: @@ -77,31 +72,34 @@ async def do_trade(self, feed: Feed, prediction: Tuple[float, float]): if self.order is not None and isinstance(self.order, dict): # get existing long position amount = 0.0 - if self.config.exchange_str in ("mexc"): + if self.ppss.trader_ss.exchange_str == "mexc": amount = float(self.order["info"]["origQty"]) # close it order = self.exchange.create_market_sell_order( - self.config.exchange_pair, amount + self.ppss.trader_ss.pair_str, amount ) print(f" [Trade Closed] {self.exchange}") print(f" [Previous Order] {self.order}") print(f" [Closing Order] {order}") - # TO DO - Calculate PNL (self.order - order) self.order = None ### Create new order if prediction meets our criteria pred_nom, pred_denom = prediction print(f" {feed} has a new prediction: {pred_nom} / {pred_denom}.") + if pred_denom == 0: + print(" There's no stake on this, one way or the other. Exiting.") + return + pred_properties = self.get_pred_properties(pred_nom, pred_denom) print(f" prediction properties are: {pred_properties}") if pred_properties["dir"] == 1 and pred_properties["confidence"] > 0.5: order = self.exchange.create_market_buy_order( - self.config.exchange_pair, self.config.size + self.ppss.trader_ss.pair_str, self.ppss.trader_ss.position_size ) # If order is successful, we log the order so we can close it diff --git a/pdr_backend/trader/approach1/trader_config1.py b/pdr_backend/trader/approach1/trader_config1.py deleted file mode 100644 index e762c69ae..000000000 --- a/pdr_backend/trader/approach1/trader_config1.py +++ /dev/null @@ -1,48 +0,0 @@ -from os import getenv - -from enforce_typing import enforce_types - -from pdr_backend.trader.trader_config import TraderConfig - -CAND_EXCHANGE = ["mexc"] -CAND_PAIR = [ - "BTC/USDT", - "ETH/USDT", - "ADA/USDT", - "BNB/USDT", - "SOL/USDT", - "XRP/USDT", - "DOT/USDT", - "LTC/USDT", - "DOGE/USDT", - "TRX/USDT", -] -CAND_TIMEFRAME = ["5m", "1h"] - - -# Mexc does not support -@enforce_types -class TraderConfig1(TraderConfig): - def __init__(self): - super().__init__() - - self.exchange_str = getenv("EXCHANGE_FILTER") - self.pair = getenv("PAIR_FILTER") - self.timeframe = getenv("TIMEFRAME_FILTER") - - ## Exchange Parameters - self.exchange_pair = ( - getenv("EXCHANGE_PAIR_FILTER") - if getenv("EXCHANGE_PAIR_FILTER") - else self.pair - ) - - ## Position Parameters - self.size = getenv("POSITION_SIZE") - - assert self.exchange_str in CAND_EXCHANGE, "Exchange must be valid" - assert self.pair in CAND_PAIR, "Pair must be valid" - assert self.timeframe in CAND_TIMEFRAME, "Timeframe must be valid" - assert ( - self.size is not None and self.size > 0.0 - ), "Position size must be greater than 0.0" diff --git a/pdr_backend/trader/approach2/main.py b/pdr_backend/trader/approach2/main.py deleted file mode 100644 index 74c66cee9..000000000 --- a/pdr_backend/trader/approach2/main.py +++ /dev/null @@ -1,12 +0,0 @@ -from pdr_backend.trader.approach2.trader_agent2 import TraderAgent2 -from pdr_backend.trader.approach2.trader_config2 import TraderConfig2 - - -def main(testing=False): - config = TraderConfig2() - t = TraderAgent2(config) - t.run(testing) - - -if __name__ == "__main__": - main() diff --git a/pdr_backend/trader/approach2/portfolio.py b/pdr_backend/trader/approach2/portfolio.py index a2f2c6e30..dd5e072bb 100644 --- a/pdr_backend/trader/approach2/portfolio.py +++ b/pdr_backend/trader/approach2/portfolio.py @@ -1,7 +1,6 @@ from enum import Enum from typing import Dict, List, Optional - -import ccxt +from enforce_typing import enforce_types class OrderState(Enum): @@ -36,7 +35,7 @@ def timestamp(self): return None -class MEXCOrder(Order): +class MexcOrder(Order): def __init__(self, order: Dict): # pylint: disable=useless-parent-delegation super().__init__(order) @@ -54,22 +53,18 @@ def timestamp(self): return self.order["timestamp"] -def create_order(order: Dict, exchange: ccxt.Exchange) -> Order: - if exchange in ("mexc"): - return MEXCOrder(order) - return Order(order) +@enforce_types +def create_order(order: Dict, exchange_str: str) -> Order: + return MexcOrder(order) if exchange_str == "mexc" else Order(order) class Position: """ @description Has an open and and a close order minimum - TO DO - Support many buy/sell orders, balance, etc... """ def __init__(self, order: Order): - # TO DO - Have N open_orders, have N close_orders - # TO DO - Move from __init__(order) to open(order) self.open_order: Order = order self.close_order: Optional[Order] = None self.state: OrderState = OrderState.OPEN @@ -79,7 +74,6 @@ def __init__(self, order: Order): def __str__(self): return f"<{self.open_order}, {self.close_order}, {self.__class__}>" - # TO DO - Only callable by portfolio def close(self, order: Order): self.close_order = order self.state = OrderState.CLOSED @@ -107,13 +101,15 @@ def open_position(self, open_order: Order) -> Position: return position def close_position(self, close_order: Order) -> Optional[Position]: - position = self.open_positions.pop() - if position: - position.close(close_order) - self.closed_positions.append(position) - print(" [Position closed in Sheet]") - return position - return None + position = self.open_positions.pop() if self.open_positions else None + + if not position: + return None + + position.close(close_order) + self.closed_positions.append(position) + print(" [Position closed in Sheet]") + return position class Portfolio: @@ -131,14 +127,8 @@ def get_sheet(self, addr: str) -> Optional[Sheet]: def open_position(self, addr: str, order: Order) -> Optional[Position]: sheet = self.get_sheet(addr) - if sheet: - return sheet.open_position(order) - - return None + return sheet.open_position(order) if sheet else None def close_position(self, addr: str, order: Order) -> Optional[Position]: sheet = self.get_sheet(addr) - if sheet: - return sheet.close_position(order) - - return None + return sheet.close_position(order) if sheet else None diff --git a/pdr_backend/trader/approach2/test/test_portfolio.py b/pdr_backend/trader/approach2/test/test_portfolio.py new file mode 100644 index 000000000..ae702b6a5 --- /dev/null +++ b/pdr_backend/trader/approach2/test/test_portfolio.py @@ -0,0 +1,67 @@ +from pdr_backend.trader.approach2.portfolio import ( + MexcOrder, + Order, + OrderState, + Portfolio, + Position, + Sheet, + create_order, +) + + +def test_order_classes(): + order_dict = {"id": 1, "info": {"origQty": 2}, "timestamp": 3} + + order = Order(order_dict) + assert order.id is None + assert order.amount is None + assert order.timestamp is None + + mexc_order = MexcOrder(order_dict) + assert mexc_order.id == 1 + assert mexc_order.amount == 2 + assert mexc_order.timestamp == 3 + + assert isinstance(create_order(order_dict, "mexc"), MexcOrder) + assert isinstance(create_order(order_dict, "other"), Order) + + +def test_position(): + position = Position(Order({})) + assert position.state == OrderState.OPEN + assert position.close_order is None + + position.close(Order({})) + assert position.state == OrderState.CLOSED + assert isinstance(position.close_order, Order) + + +def test_sheet(): + sheet = Sheet("0x123") + assert sheet.asset == "0x123" + assert sheet.open_positions == [] + assert sheet.closed_positions == [] + + # no effect + sheet.close_position(Order({})) + assert sheet.open_positions == [] + assert sheet.closed_positions == [] + + # open and close + sheet.open_position(Order({})) + assert isinstance(sheet.open_positions[0], Position) + + sheet.close_position(Order({})) + assert isinstance(sheet.closed_positions[0], Position) + assert sheet.open_positions == [] + + +def test_portfolio(): + portfolio = Portfolio(["0x123", "0x456"]) + assert portfolio.sheets.keys() == {"0x123", "0x456"} + + assert portfolio.get_sheet("0x123").asset == "0x123" + assert portfolio.get_sheet("xxxx") is None + + assert isinstance(portfolio.open_position("0x123", Order({})), Position) + assert isinstance(portfolio.close_position("0x123", Order({})), Position) diff --git a/pdr_backend/trader/approach2/test/test_trader_agent2.py b/pdr_backend/trader/approach2/test/test_trader_agent2.py index 9656f66b0..2de828019 100644 --- a/pdr_backend/trader/approach2/test/test_trader_agent2.py +++ b/pdr_backend/trader/approach2/test/test_trader_agent2.py @@ -1,144 +1,57 @@ from datetime import datetime from unittest.mock import Mock, patch -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types -from pdr_backend.models.feed import Feed +from pdr_backend.ppss.ppss import mock_feed_ppss +from pdr_backend.ppss.web3_pp import inplace_mock_feedgetters from pdr_backend.trader.approach2.trader_agent2 import TraderAgent2 -from pdr_backend.trader.approach2.trader_config2 import TraderConfig2 +from pdr_backend.trader.test.trader_agent_runner import ( + do_constructor, + do_run, + setup_trade, +) @enforce_types -def mock_feed(): - feed = Mock(spec=Feed) - feed.name = "test feed" - feed.address = "0xtestfeed" - feed.seconds_per_epoch = 60 - return feed +@patch.object(TraderAgent2, "check_subscriptions_and_subscribe") +def test_trader_agent2_constructor(check_subscriptions_and_subscribe_mock): + do_constructor(TraderAgent2, check_subscriptions_and_subscribe_mock) @enforce_types @patch.object(TraderAgent2, "check_subscriptions_and_subscribe") -def test_new_agent(check_subscriptions_and_subscribe_mock, predictoor_contract): - # Setting up the mock trader configuration - trader_config = Mock(spec=TraderConfig2) - trader_config.exchange_str = "mexc" - trader_config.exchange_pair = "BTC/USDT" - trader_config.timeframe = "5m" - trader_config.size = 10.0 - trader_config.get_feeds = Mock() - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - - # Creating a new agent and asserting the configuration - agent = TraderAgent2(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - # Setting up a configuration with no feeds and testing for SystemExit - no_feeds_config = Mock(spec=TraderConfig2) - no_feeds_config.get_feeds.return_value = {} - no_feeds_config.max_tries = 10 - - with pytest.raises(SystemExit): - TraderAgent2(no_feeds_config) +def test_trader_agent2_run(check_subscriptions_and_subscribe_mock): + do_run(TraderAgent2, check_subscriptions_and_subscribe_mock) @enforce_types @pytest.mark.asyncio @patch.object(TraderAgent2, "check_subscriptions_and_subscribe") -async def test_do_trade( - check_subscriptions_and_subscribe_mock, - predictoor_contract, - web3_config, -): - # Mocking the trader configuration - trader_config = Mock(spec=TraderConfig2) - trader_config.exchange_str = "mexc" - trader_config.exchange_pair = "BTC/USDT" - trader_config.timeframe = "5m" - trader_config.size = 10.0 - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - trader_config.max_tries = 10 - trader_config.web3_config = web3_config - - # Creating a new agent and setting up the mock objects - agent = TraderAgent2(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - # Creating mock objects and functions - agent.exchange = Mock() - agent.exchange.create_market_buy_order.return_value = {"info": {"origQty": 1}} - - agent.portfolio = Mock() - agent.update_positions = Mock() - agent.update_cache = Mock() - - agent.get_pred_properties = Mock() - agent.get_pred_properties.return_value = { - "confidence": 100.0, - "dir": 1, - "stake": 1, - } - - # Performing a trade and checking the call counts of the methods - await agent._do_trade(mock_feed(), (1.0, 1.0)) +async def test_trader_agent2_do_trade(check_subscriptions_and_subscribe_mock): + agent, feed = setup_trade( + TraderAgent2, + check_subscriptions_and_subscribe_mock, + ) - assert agent.get_pred_properties.call_count == 1 + await agent._do_trade(feed, (1.0, 1.0)) assert agent.exchange.create_market_buy_order.call_count == 1 - assert agent.update_positions.call_count == 1 - assert agent.portfolio.open_position.call_count == 1 - assert agent.update_cache.call_count == 1 -# Test for TraderAgent2.update_positions @enforce_types @patch.object(TraderAgent2, "check_subscriptions_and_subscribe") -def test_update_positions( +def test_trader_agent2_update_positions( # pylint: disable=unused-argument check_subscriptions_and_subscribe_mock, - predictoor_contract, - web3_config, ): - trader_config = Mock(spec=TraderConfig2) - trader_config.exchange_str = "mexc" - trader_config.exchange_pair = "BTC/USDT" - trader_config.timeframe = "5m" - trader_config.size = 10.0 - trader_config.get_feeds = Mock() - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - trader_config.max_tries = 10 - trader_config.web3_config = web3_config - - # Creating a new agent and setting up the mock objects - agent = TraderAgent2(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - # Creating mock objects and functions + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + agent = TraderAgent2(ppss) + agent.exchange = Mock() agent.exchange.create_market_sell_order.return_value = {"info": {"origQty": 1}} - agent.feeds = Mock() - agent.feeds.keys.return_value = ["0x0000000000000000000000000000000000000000"] agent.portfolio = Mock() mock_sheet = Mock() mock_sheet.open_positions = [Mock(), Mock()] @@ -149,59 +62,66 @@ def test_update_positions( agent.close_position = Mock() agent.update_cache = Mock() - # Update agent positions agent.update_positions() assert agent.portfolio.get_sheet.call_count == 1 assert agent.exchange.create_market_sell_order.call_count == 2 assert agent.portfolio.close_position.call_count == 2 assert agent.portfolio.close_position.call_args == ( - ("0x0000000000000000000000000000000000000000", {"info": {"origQty": 1}}), + (feed.address, {"info": {"origQty": 1}}), ) assert agent.update_cache.call_count == 2 + original_call_count = agent.update_cache.call_count + + # does nothing without sheet + agent.portfolio = Mock() + mock_sheet = None + agent.portfolio.get_sheet.return_value = mock_sheet + agent.update_positions() + assert agent.update_cache.call_count == original_call_count + + # does nothing without a portfolio + agent.portfolio = None + agent.update_positions() + assert agent.update_cache.call_count == original_call_count + -# Test for TraderAgent2.should_close @enforce_types @patch.object(TraderAgent2, "check_subscriptions_and_subscribe") -def test_should_close( +def test_trader_agent2_should_close( # pylint: disable=unused-argument check_subscriptions_and_subscribe_mock, - predictoor_contract, - web3_config, ): - trader_config = Mock(spec=TraderConfig2) - trader_config.exchange_str = "mexc" - trader_config.exchange_pair = "BTC/USDT" - trader_config.timeframe = "5m" - trader_config.size = 10.0 - trader_config.get_feeds = Mock() - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - trader_config.max_tries = 10 - trader_config.web3_config = web3_config - - # TraderConfig2.timedelta is a property, so we need to mock it - trader_config.timedelta = 300 - - # Creating a new agent and setting up the mock objects - agent = TraderAgent2(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - # Test 1 - Creating mock objects and functions to handle should_close + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + agent = TraderAgent2(ppss) + + # test 1 - creating mock objects and functions to handle should_close mock_order = Mock() mock_order.timestamp = 1 result = agent.should_close(mock_order) assert result - # Test 2 - Make more order recent, now it should not close + # test 2 - ensure more order recent, now it should not close mock_order.timestamp = datetime.now().timestamp() * 1000 result = agent.should_close(mock_order) assert not result + + +@enforce_types +@pytest.mark.asyncio +@patch.object(TraderAgent2, "check_subscriptions_and_subscribe") +async def test_trader_agent2_do_trade_edges( + check_subscriptions_and_subscribe_mock, capfd +): + agent, feed = setup_trade( + TraderAgent2, + check_subscriptions_and_subscribe_mock, + ) + + await agent._do_trade(feed, (1.0, 0)) + out, _ = capfd.readouterr() + assert "There's no stake on this" in out diff --git a/pdr_backend/trader/approach2/trader_agent2.py b/pdr_backend/trader/approach2/trader_agent2.py index cfb4522d1..a66b97ae2 100644 --- a/pdr_backend/trader/approach2/trader_agent2.py +++ b/pdr_backend/trader/approach2/trader_agent2.py @@ -5,52 +5,33 @@ import ccxt from enforce_typing import enforce_types -from pdr_backend.models.feed import Feed -from pdr_backend.trader.approach2.portfolio import ( - Portfolio, - Order, - create_order, -) -from pdr_backend.trader.approach2.trader_config2 import TraderConfig2 -from pdr_backend.trader.trader_agent import TraderAgent +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.trader.approach2.portfolio import Order, Portfolio, create_order +from pdr_backend.trader.base_trader_agent import BaseTraderAgent @enforce_types -class TraderAgent2(TraderAgent): +class TraderAgent2(BaseTraderAgent): """ @description - TraderAgent Naive CCXT - - This is a naive algorithm. It will simply: - 1. If open position, close it - 2. If new prediction up, open long - 3. If new prediction down, open short - - You can use the ENV_VARS to: - 1. Configure your strategy: pair, timeframe, etc.. - 2. Configure your exchange: api_key + secret_key - - You can improve this by: - 1. Improving the type of method to buy/exit (i.e. limit) - 2. Improving the buy Conditional statement - 3. Enabling buying and shorting - 4. Using SL and TP + Trader agent that's slightly less naive than agent 1. """ - def __init__(self, config: TraderConfig2): - # Initialize cache params + def __init__(self, ppss: PPSS): + # Initialize cache params. Must be *before* calling parent constructor! self.portfolio = None self.reset_cache = False - super().__init__(config) - self.config: TraderConfig2 = config + # + super().__init__(ppss) # If cache params are empty, instantiate if self.portfolio is None: - self.portfolio = Portfolio(list(self.feeds.keys())) + self.portfolio = Portfolio([self.feed.address]) # Generic exchange clss - exchange_class = getattr(ccxt, self.config.exchange_str) + exchange_class = self.ppss.trader_ss.exchange_class self.exchange: ccxt.Exchange = exchange_class( { "apiKey": getenv("EXCHANGE_API_KEY"), @@ -65,7 +46,7 @@ def __init__(self, config: TraderConfig2): } ) - self.update_positions(list(self.feeds.keys())) + self.update_positions([self.feed.address]) def update_cache(self): super().update_cache() @@ -83,11 +64,11 @@ def load_cache(self): def should_close(self, order: Order): """ @description - Check if order has lapsed in time relative to config.timeframe + Check if order has lapsed in time relative to trader_ss.timeframe """ now_ts = int(datetime.now().timestamp() * 1000) tx_ts = int(order.timestamp) - order_lapsed = now_ts - tx_ts > self.config.timedelta * 1000 + order_lapsed = now_ts - tx_ts > self.ppss.trader_ss.timeframe_ms return order_lapsed def update_positions(self, feeds: Optional[List[str]] = None): @@ -95,7 +76,7 @@ def update_positions(self, feeds: Optional[List[str]] = None): @description Cycle through open positions and asses them """ - feeds = list(self.feeds.keys()) if feeds is None or feeds == [] else feeds + feeds = [self.feed.address] if feeds is None or feeds == [] else feeds if not feeds: return if not self.portfolio: @@ -117,13 +98,13 @@ def update_positions(self, feeds: Optional[List[str]] = None): print(" [Close Position] Requirements met") order = self.exchange.create_market_sell_order( - self.config.exchange_pair, + self.ppss.trader_ss.exchange_str, position.open_order.amount, ) self.portfolio.close_position(addr, order) self.update_cache() - async def do_trade(self, feed: Feed, prediction: Tuple[float, float]): + async def do_trade(self, feed: SubgraphFeed, prediction: Tuple[float, float]): """ @description Logic: @@ -143,16 +124,21 @@ async def do_trade(self, feed: Feed, prediction: Tuple[float, float]): pred_nom, pred_denom = prediction print(f" {feed.address} has a new prediction: {pred_nom} / {pred_denom}.") + if pred_denom == 0: + print(" There's no stake on this, one way or the other. Exiting.") + return + pred_properties = self.get_pred_properties(pred_nom, pred_denom) print(f" prediction properties are: {pred_properties}") if pred_properties["dir"] == 1 and pred_properties["confidence"] > 0.5: print(" [Open Position] Requirements met") order = self.exchange.create_market_buy_order( - symbol=self.config.exchange_pair, amount=self.config.size + symbol=self.ppss.trader_ss.exchange_str, + amount=self.ppss.trader_ss.position_size, ) if order and self.portfolio: - order = create_order(order, self.config.exchange_str) + order = create_order(order, self.ppss.trader_ss.exchange_str) self.portfolio.open_position(feed.address, order) self.update_cache() else: diff --git a/pdr_backend/trader/approach2/trader_config2.py b/pdr_backend/trader/approach2/trader_config2.py deleted file mode 100644 index 05c2039ec..000000000 --- a/pdr_backend/trader/approach2/trader_config2.py +++ /dev/null @@ -1,53 +0,0 @@ -from os import getenv - -from enforce_typing import enforce_types - -from pdr_backend.trader.trader_config import TraderConfig - -CAND_EXCHANGE = ["mexc"] -CAND_PAIR = [ - "BTC/USDT", - "ETH/USDT", - "ADA/USDT", - "BNB/USDT", - "SOL/USDT", - "XRP/USDT", - "DOT/USDT", - "LTC/USDT", - "DOGE/USDT", - "TRX/USDT", -] -CAND_TIMEFRAME = ["5m", "1h"] - - -# Mexc does not support -class TraderConfig2(TraderConfig): - @enforce_types - def __init__(self): - super().__init__() - - self.exchange_str = getenv("EXCHANGE_FILTER") - self.pair = getenv("PAIR_FILTER") - self.timeframe = getenv("TIMEFRAME_FILTER") - - ## Exchange Parameters - self.exchange_pair = ( - getenv("EXCHANGE_PAIR_FILTER") - if getenv("EXCHANGE_PAIR_FILTER") - else self.pair - ) - - ## Position Parameters - self.size = getenv("POSITION_SIZE") - - assert self.exchange_str in CAND_EXCHANGE, "Exchange must be valid" - assert self.pair in CAND_PAIR, "Pair must be valid" - assert self.timeframe in CAND_TIMEFRAME, "Timeframe must be valid" - assert ( - self.size is not None and self.size > 0.0 - ), "Position size must be greater than 0.0" - - @property - def timedelta(self): - delta = {"5m": 300, "1h": 3600} - return delta[self.timeframe] diff --git a/pdr_backend/trader/trader_agent.py b/pdr_backend/trader/base_trader_agent.py similarity index 62% rename from pdr_backend/trader/trader_agent.py rename to pdr_backend/trader/base_trader_agent.py index 590000b9e..f97840dd8 100644 --- a/pdr_backend/trader/trader_agent.py +++ b/pdr_backend/trader/base_trader_agent.py @@ -1,83 +1,74 @@ -import sys -import time import asyncio +import time from typing import Any, Callable, Dict, List, Optional, Tuple -from enforce_typing import enforce_types - -from pdr_backend.models.feed import Feed -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.trader.trader_config import TraderConfig +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed, print_feeds from pdr_backend.util.cache import Cache +from pdr_backend.util.mathutil import sole_value # pylint: disable=too-many-instance-attributes -class TraderAgent: +class BaseTraderAgent: def __init__( self, - trader_config: TraderConfig, - _do_trade: Optional[Callable[[Feed, Tuple], Any]] = None, + ppss: PPSS, + _do_trade: Optional[Callable[[SubgraphFeed, Tuple], Any]] = None, cache_dir=".cache", ): - self.config = trader_config - self._do_trade = _do_trade if _do_trade else self.do_trade + # ppss + self.ppss = ppss + print("\n" + "-" * 80) + print(self.ppss) - self.feeds: Dict[str, Feed] = self.config.get_feeds() # [addr] : Feed + # _do_trade + self._do_trade = _do_trade or self.do_trade - if not self.feeds: - print("No feeds found. Exiting") - sys.exit() + # set self.feeds + cand_feeds = ppss.web3_pp.query_feed_contracts() + print_feeds(cand_feeds, f"cand feeds, owner={ppss.web3_pp.owner_addrs}") - feed_addrs = list(self.feeds.keys()) - self.contracts = self.config.get_contracts(feed_addrs) # [addr] : contract + feed = ppss.trader_ss.get_feed_from_candidates(cand_feeds) + print_feeds({feed.address: feed}, "filtered feeds") + if not feed: + raise ValueError("No feeds found.") + self.feed = feed + contracts = ppss.web3_pp.get_contracts([feed.address]) + self.contract = sole_value(contracts) + + # set attribs to track block self.prev_block_timestamp: int = 0 self.prev_block_number: int = 0 - self.prev_traded_epochs_per_feed: Dict[str, List[int]] = { - addr: [] for addr in self.feeds - } + self.prev_traded_epochs: List[int] = [] self.cache = Cache(cache_dir=cache_dir) self.load_cache() - print("-" * 80) - print("Config:") - print(self.config) - - print("\n" + "." * 80) - print("Feeds (detailed):") - for feed in self.feeds.values(): - print(f" {feed.longstr()}") - - print("\n" + "." * 80) - print("Feeds (succinct):") - for addr, feed in self.feeds.items(): - print(f" {feed}, {feed.seconds_per_epoch} s/epoch, addr={addr}") - self.check_subscriptions_and_subscribe() def check_subscriptions_and_subscribe(self): - for addr, feed in self.feeds.items(): - contract = PredictoorContract(self.config.web3_config, addr) - if not contract.is_valid_subscription(): - print(f"Purchasing new subscription for feed: {feed}") - contract.buy_and_start_subscription(None, True) + if not self.contract.is_valid_subscription(): + print(f"Purchase subscription for feed {self.feed}: begin") + self.contract.buy_and_start_subscription( + gasLimit=None, + wait_for_receipt=True, + ) + print(f"Purchase new subscription for feed {self.feed}: done") time.sleep(1) def update_cache(self): - for feed, epochs in self.prev_traded_epochs_per_feed.items(): - if epochs: - last_epoch = epochs[-1] - self.cache.save(f"trader_last_trade_{feed}", last_epoch) + epochs = self.prev_traded_epochs + if epochs: + last_epoch = epochs[-1] + self.cache.save(f"trader_last_trade_{self.feed.address}", last_epoch) def load_cache(self): - for feed in self.feeds: - last_epoch = self.cache.load(f"trader_last_trade_{feed}") - if last_epoch is not None: - self.prev_traded_epochs_per_feed[feed].append(last_epoch) + last_epoch = self.cache.load(f"trader_last_trade_{self.feed.address}") + if last_epoch is not None: + self.prev_traded_epochs.append(last_epoch) - @enforce_types def run(self, testing: bool = False): while True: asyncio.run(self.take_step()) @@ -85,7 +76,8 @@ def run(self, testing: bool = False): break async def take_step(self): - w3 = self.config.web3_config.w3 + web3_config = self.ppss.web3_pp.web3_config + w3 = web3_config.w3 # at new block number yet? block_number = w3.eth.block_number @@ -95,16 +87,14 @@ async def take_step(self): self.prev_block_number = block_number # is new block ready yet? - block = self.config.web3_config.get_block(block_number, full_transactions=False) + block = web3_config.get_block(block_number, full_transactions=False) if not block: return self.prev_block_number = block_number self.prev_block_timestamp = block["timestamp"] print("before:", time.time()) - tasks = [ - self._process_block_at_feed(addr, block["timestamp"]) for addr in self.feeds - ] + tasks = [self._process_block(block["timestamp"])] s_till_epoch_ends, log_list = zip(*await asyncio.gather(*tasks)) for logs in log_list: @@ -122,12 +112,11 @@ async def take_step(self): time.sleep(sleep_time) - async def _process_block_at_feed( - self, addr: str, timestamp: int, tries: int = 0 + async def _process_block( + self, timestamp: int, tries: int = 0 ) -> Tuple[int, List[str]]: """ @param: - addr - contract address of the feed timestamp - timestamp/epoch to process [tries] - number of attempts made in case of an error, 0 by default @return: @@ -135,21 +124,18 @@ async def _process_block_at_feed( logs - list of strings of function logs """ logs = [] - feed, predictoor_contract = self.feeds[addr], self.contracts[addr] - s_per_epoch = feed.seconds_per_epoch + predictoor_contract = self.contract + s_per_epoch = self.feed.seconds_per_epoch epoch = int(timestamp / s_per_epoch) epoch_s_left = epoch * s_per_epoch + s_per_epoch - timestamp - logs.append(f"{'-'*40} Processing {feed} {'-'*40}\nEpoch {epoch}") - logs.append("Seconds remaining in epoch: {epoch_s_left}") + logs.append(f"{'-'*40} Processing {self.feed} {'-'*40}\nEpoch {epoch}") + logs.append(f"Seconds remaining in epoch: {epoch_s_left}") - if ( - self.prev_traded_epochs_per_feed.get(addr) - and epoch == self.prev_traded_epochs_per_feed[addr][-1] - ): + if self.prev_traded_epochs and epoch == self.prev_traded_epochs[-1]: logs.append(" Done feed: already traded this epoch") return epoch_s_left, logs - if epoch_s_left < self.config.trader_min_buffer: + if epoch_s_left < self.ppss.trader_ss.min_buffer: logs.append(" Done feed: not enough time left in epoch") return epoch_s_left, logs @@ -159,7 +145,7 @@ async def _process_block_at_feed( None, predictoor_contract.get_agg_predval, (epoch + 1) * s_per_epoch ) except Exception as e: - if tries < self.config.max_tries: + if tries < self.ppss.trader_ss.max_tries: logs.append(e.args[0]["message"]) if ( len(e.args) > 0 @@ -174,14 +160,14 @@ async def _process_block_at_feed( ) # -1 means the subscription has expired for this pair logs.append(" Could not get aggpredval, trying again in a second") await asyncio.sleep(1) - return await self._process_block_at_feed(addr, timestamp, tries + 1) + return await self._process_block(timestamp, tries + 1) logs.append( f" Done feed: aggpredval not available, an error occured: {e}" ) return epoch_s_left, logs - await self._do_trade(feed, prediction) - self.prev_traded_epochs_per_feed[addr].append(epoch) + await self._do_trade(self.feed, prediction) + self.prev_traded_epochs.append(epoch) self.update_cache() return epoch_s_left, logs @@ -199,10 +185,7 @@ def get_pred_properties( """ confidence: float = pred_nom / pred_denom direction: float = 1 if confidence >= 0.5 else 0 - if confidence > 0.5: - confidence -= 0.5 - else: - confidence = 0.5 - confidence + confidence = abs(confidence - 0.5) confidence = (confidence / 0.5) * 100 return { @@ -211,7 +194,7 @@ def get_pred_properties( "stake": pred_denom, } - async def do_trade(self, feed: Feed, prediction: Tuple[float, float]): + async def do_trade(self, feed: SubgraphFeed, prediction: Tuple[float, float]): """ @description This function is called each time there's a new prediction available. diff --git a/pdr_backend/trader/main.py b/pdr_backend/trader/main.py deleted file mode 100644 index 89f70c998..000000000 --- a/pdr_backend/trader/main.py +++ /dev/null @@ -1,12 +0,0 @@ -from pdr_backend.trader.trader_agent import TraderAgent -from pdr_backend.trader.trader_config import TraderConfig - - -def main(testing=False): - config = TraderConfig() - t = TraderAgent(config) - t.run(testing) - - -if __name__ == "__main__": - main() diff --git a/pdr_backend/trader/test/test_base_trader_agent.py b/pdr_backend/trader/test/test_base_trader_agent.py new file mode 100644 index 000000000..2c5f41c9d --- /dev/null +++ b/pdr_backend/trader/test/test_base_trader_agent.py @@ -0,0 +1,157 @@ +import os +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.ppss.ppss import mock_feed_ppss +from pdr_backend.ppss.web3_pp import ( + inplace_mock_feedgetters, + inplace_mock_w3_and_contract_with_tracking, +) +from pdr_backend.trader.base_trader_agent import BaseTraderAgent +from pdr_backend.trader.test.trader_agent_runner import ( + INIT_BLOCK_NUMBER, + INIT_TIMESTAMP, + do_constructor, + do_run, + setup_take_step, +) + + +@enforce_types +@patch.object(BaseTraderAgent, "check_subscriptions_and_subscribe") +def test_trader_agent_constructor(check_subscriptions_and_subscribe_mock): + do_constructor(BaseTraderAgent, check_subscriptions_and_subscribe_mock) + + +@enforce_types +@patch.object(BaseTraderAgent, "check_subscriptions_and_subscribe") +def test_trader_agent_run(check_subscriptions_and_subscribe_mock): + do_run(BaseTraderAgent, check_subscriptions_and_subscribe_mock) + + +@enforce_types +@pytest.mark.asyncio +@patch.object(BaseTraderAgent, "check_subscriptions_and_subscribe") +async def test_trader_agent_take_step( + check_subscriptions_and_subscribe_mock, + monkeypatch, +): + agent = setup_take_step( + BaseTraderAgent, + check_subscriptions_and_subscribe_mock, + monkeypatch, + ) + + await agent.take_step() + + assert check_subscriptions_and_subscribe_mock.call_count == 2 + assert agent._process_block.call_count == 1 + + +@enforce_types +def custom_do_trade(feed, prediction): + return (feed, prediction) + + +@pytest.mark.asyncio +@patch.object(BaseTraderAgent, "check_subscriptions_and_subscribe") +async def test_process_block( # pylint: disable=unused-argument + check_subscriptions_and_subscribe_mock, + monkeypatch, +): + feed, ppss = mock_feed_ppss("1m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + _mock_pdr_contract = inplace_mock_w3_and_contract_with_tracking( + ppss.web3_pp, + INIT_TIMESTAMP, + INIT_BLOCK_NUMBER, + ppss.trader_ss.timeframe_s, + feed.address, + monkeypatch, + ) + + agent = BaseTraderAgent(ppss, custom_do_trade) + + agent.prev_traded_epochs = [] + + async def _do_trade(feed, prediction): # pylint: disable=unused-argument + pass + + agent._do_trade = Mock(side_effect=_do_trade) + + # mock feed seconds per epoch is 60 + # test agent config min buffer is 30 + # so it should trade if there's more than 30 seconds left in the epoch + + # epoch_s_left = 60 - 55 = 5, so we should not trade + # because it's too close to the epoch end + s_till_epoch_end, _ = await agent._process_block(55) + assert len(agent.prev_traded_epochs) == 0 + assert s_till_epoch_end == 5 + + # epoch_s_left = 60 + 60 - 80 = 40, so we should trade + s_till_epoch_end, _ = await agent._process_block(80) + assert len(agent.prev_traded_epochs) == 1 + assert s_till_epoch_end == 40 + + # but not again, because we've already traded this epoch + s_till_epoch_end, _ = await agent._process_block(80) + assert len(agent.prev_traded_epochs) == 1 + assert s_till_epoch_end == 40 + + # but we should trade again in the next epoch + _mock_pdr_contract.get_current_epoch = Mock() + _mock_pdr_contract.get_current_epoch.return_value = 2 + s_till_epoch_end, _ = await agent._process_block(140) + assert len(agent.prev_traded_epochs) == 2 + assert s_till_epoch_end == 40 + + +@enforce_types +@patch.object(BaseTraderAgent, "check_subscriptions_and_subscribe") +def test_save_and_load_cache( + check_subscriptions_and_subscribe_mock, +): # pylint: disable=unused-argument + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + agent = BaseTraderAgent(ppss, custom_do_trade, cache_dir=".test_cache") + + agent.prev_traded_epochs = [1, 2, 3] + + agent.update_cache() + + agent_new = BaseTraderAgent(ppss, custom_do_trade, cache_dir=".test_cache") + assert agent_new.prev_traded_epochs == [3] + cache_dir_path = ( + Path(os.path.dirname(os.path.abspath(__file__))).parent.parent + / "util/.test_cache" + ) + for item in cache_dir_path.iterdir(): + item.unlink() + cache_dir_path.rmdir() + + +@pytest.mark.asyncio +@patch.object(BaseTraderAgent, "check_subscriptions_and_subscribe") +async def test_get_pred_properties( + check_subscriptions_and_subscribe_mock, +): # pylint: disable=unused-argument + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + agent = BaseTraderAgent(ppss) + check_subscriptions_and_subscribe_mock.assert_called_once() + + agent.get_pred_properties = Mock() + agent.get_pred_properties.return_value = { + "confidence": 100.0, + "dir": 1, + "stake": 1, + } + + await agent._do_trade(feed, (1.0, 1.0)) + assert agent.get_pred_properties.call_count == 1 diff --git a/pdr_backend/trader/test/test_trader_agent.py b/pdr_backend/trader/test/test_trader_agent.py deleted file mode 100644 index 3b8a2c7ca..000000000 --- a/pdr_backend/trader/test/test_trader_agent.py +++ /dev/null @@ -1,221 +0,0 @@ -import os -from pathlib import Path -from unittest.mock import Mock, patch - -import pytest - -from pdr_backend.models.feed import Feed -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.trader.trader_agent import TraderAgent -from pdr_backend.trader.trader_config import TraderConfig - - -def mock_feed(): - feed = Mock(spec=Feed) - feed.name = "test feed" - feed.seconds_per_epoch = 60 - return feed - - -@patch.object(TraderAgent, "check_subscriptions_and_subscribe") -def test_new_agent( - check_subscriptions_and_subscribe_mock, predictoor_contract -): # pylint: disable=unused-argument - trader_config = TraderConfig() - trader_config.get_feeds = Mock() - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.get_contracts = Mock() - trader_config.get_contracts.return_value = { - "0x0000000000000000000000000000000000000000": predictoor_contract - } - agent = TraderAgent(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - no_feeds_config = Mock(spec=TraderConfig) - no_feeds_config.get_feeds.return_value = {} - no_feeds_config.max_tries = 10 - - with pytest.raises(SystemExit): - TraderAgent(no_feeds_config) - - -@patch.object(TraderAgent, "check_subscriptions_and_subscribe") -def test_run(check_subscriptions_and_subscribe_mock): # pylint: disable=unused-argument - trader_config = Mock(spec=TraderConfig) - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.max_tries = 10 - agent = TraderAgent(trader_config) - - with patch.object(agent, "take_step") as ts_mock: - agent.run(True) - - ts_mock.assert_called_once() - - -@pytest.mark.asyncio -@patch.object(TraderAgent, "check_subscriptions_and_subscribe") -async def test_take_step( - check_subscriptions_and_subscribe_mock, web3_config -): # pylint: disable=unused-argument - trader_config = Mock(spec=TraderConfig) - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.max_tries = 10 - trader_config.web3_config = web3_config - agent = TraderAgent(trader_config) - - # Create async mock fn so we can await asyncio.gather(*tasks) - async def _process_block_at_feed( - addr, timestamp - ): # pylint: disable=unused-argument - return (-1, []) - - agent._process_block_at_feed = Mock(side_effect=_process_block_at_feed) - - await agent.take_step() - - assert check_subscriptions_and_subscribe_mock.call_count == 2 - assert agent._process_block_at_feed.call_count == 1 - - -def custom_trader(feed, prediction): - return (feed, prediction) - - -@pytest.mark.asyncio -@patch.object(TraderAgent, "check_subscriptions_and_subscribe") -async def test_process_block_at_feed( - check_subscriptions_and_subscribe_mock, -): # pylint: disable=unused-argument - trader_config = Mock(spec=TraderConfig) - trader_config.max_tries = 10 - trader_config.trader_min_buffer = 20 - feed = mock_feed() - predictoor_contract = Mock(spec=PredictoorContract) - predictoor_contract.get_agg_predval.return_value = (1, 2) - - trader_config.get_feeds.return_value = {"0x123": feed} - trader_config.get_contracts.return_value = {"0x123": predictoor_contract} - - agent = TraderAgent(trader_config, custom_trader) - agent.prev_traded_epochs_per_feed.clear() - agent.prev_traded_epochs_per_feed["0x123"] = [] - - async def _do_trade(feed, prediction): # pylint: disable=unused-argument - pass - - agent._do_trade = Mock(side_effect=_do_trade) - - # epoch_s_left = 60 - 55 = 5, so we should not trade - # because it's too close to the epoch end - s_till_epoch_end, _ = await agent._process_block_at_feed("0x123", 55) - assert len(agent.prev_traded_epochs_per_feed["0x123"]) == 0 - assert s_till_epoch_end == 5 - - # epoch_s_left = 60 + 60 - 80 = 40, so we should not trade - s_till_epoch_end, _ = await agent._process_block_at_feed("0x123", 80) - assert len(agent.prev_traded_epochs_per_feed["0x123"]) == 1 - assert s_till_epoch_end == 40 - - # but not again, because we've already traded this epoch - s_till_epoch_end, _ = await agent._process_block_at_feed("0x123", 80) - assert len(agent.prev_traded_epochs_per_feed["0x123"]) == 1 - assert s_till_epoch_end == 40 - - # but we should trade again in the next epoch - predictoor_contract.get_current_epoch.return_value = 2 - s_till_epoch_end, _ = await agent._process_block_at_feed("0x123", 140) - assert len(agent.prev_traded_epochs_per_feed["0x123"]) == 2 - assert s_till_epoch_end == 40 - - # prediction is empty, so no trading - predictoor_contract.get_current_epoch.return_value = 3 - predictoor_contract.get_agg_predval.side_effect = Exception( - {"message": "An error occurred while getting agg_predval."} - ) - s_till_epoch_end, _ = await agent._process_block_at_feed("0x123", 20) - assert len(agent.prev_traded_epochs_per_feed["0x123"]) == 2 - assert s_till_epoch_end == 40 - - # default trader - agent = TraderAgent(trader_config) - agent.prev_traded_epochs_per_feed.clear() - agent.prev_traded_epochs_per_feed["0x123"] = [] - predictoor_contract.get_agg_predval.return_value = (1, 3) - predictoor_contract.get_agg_predval.side_effect = None - s_till_epoch_end, _ = await agent._process_block_at_feed("0x123", 20) - assert len(agent.prev_traded_epochs_per_feed["0x123"]) == 1 - assert s_till_epoch_end == 40 - - -@patch.object(TraderAgent, "check_subscriptions_and_subscribe") -def test_save_and_load_cache( - check_subscriptions_and_subscribe_mock, -): # pylint: disable=unused-argument - trader_config = Mock(spec=TraderConfig) - trader_config.max_tries = 10 - trader_config.trader_min_buffer = 20 - feed = mock_feed() - predictoor_contract = Mock(spec=PredictoorContract) - predictoor_contract.get_agg_predval.return_value = (1, 2) - - trader_config.get_feeds.return_value = {"0x1": feed, "0x2": feed, "0x3": feed} - trader_config.get_contracts.return_value = { - "0x1": predictoor_contract, - "0x2": predictoor_contract, - "0x3": predictoor_contract, - } - - agent = TraderAgent(trader_config, custom_trader, cache_dir=".test_cache") - - agent.prev_traded_epochs_per_feed = { - "0x1": [1, 2, 3], - "0x2": [4, 5, 6], - "0x3": [1, 24, 66], - } - - agent.update_cache() - - agent_new = TraderAgent(trader_config, custom_trader, cache_dir=".test_cache") - assert agent_new.prev_traded_epochs_per_feed["0x1"] == [3] - assert agent_new.prev_traded_epochs_per_feed["0x2"] == [6] - assert agent_new.prev_traded_epochs_per_feed["0x3"] == [66] - cache_dir_path = ( - Path(os.path.dirname(os.path.abspath(__file__))).parent.parent - / "util/.test_cache" - ) - for item in cache_dir_path.iterdir(): - item.unlink() - cache_dir_path.rmdir() - - -@pytest.mark.asyncio -@patch.object(TraderAgent, "check_subscriptions_and_subscribe") -async def test_get_pred_properties( - check_subscriptions_and_subscribe_mock, web3_config -): # pylint: disable=unused-argument - trader_config = Mock(spec=TraderConfig) - trader_config.get_feeds.return_value = { - "0x0000000000000000000000000000000000000000": mock_feed() - } - trader_config.max_tries = 10 - trader_config.web3_config = web3_config - agent = TraderAgent(trader_config) - assert agent.config == trader_config - check_subscriptions_and_subscribe_mock.assert_called_once() - - agent.get_pred_properties = Mock() - agent.get_pred_properties.return_value = { - "confidence": 100.0, - "dir": 1, - "stake": 1, - } - - await agent._do_trade(mock_feed(), (1.0, 1.0)) - assert agent.get_pred_properties.call_count == 1 diff --git a/pdr_backend/trader/test/test_trader_config.py b/pdr_backend/trader/test/test_trader_config.py deleted file mode 100644 index 1ce247d6d..000000000 --- a/pdr_backend/trader/test/test_trader_config.py +++ /dev/null @@ -1,24 +0,0 @@ -import os - -from pdr_backend.trader.trader_config import TraderConfig - - -def test_trader_config(monkeypatch): - monkeypatch.setenv("PAIR_FILTER", "BTC/USDT,ETH/USDT") - monkeypatch.setenv("TIMEFRAME_FILTER", "5m,15m") - monkeypatch.setenv("SOURCE_FILTER", "binance,kraken") - monkeypatch.setenv("OWNER_ADDRS", "0x123,0x124") - - config = TraderConfig() - - # values handled by BaseConfig - assert config.rpc_url == os.getenv("RPC_URL") - assert config.subgraph_url == os.getenv("SUBGRAPH_URL") - assert config.private_key == os.getenv("PRIVATE_KEY") - - assert config.pair_filters == ["BTC/USDT", "ETH/USDT"] - assert config.timeframe_filter == ["5m", "15m"] - assert config.source_filter == ["binance", "kraken"] - assert config.owner_addresses == ["0x123", "0x124"] - - assert config.web3_config is not None diff --git a/pdr_backend/trader/test/trader_agent_runner.py b/pdr_backend/trader/test/trader_agent_runner.py new file mode 100644 index 000000000..ee6afc403 --- /dev/null +++ b/pdr_backend/trader/test/trader_agent_runner.py @@ -0,0 +1,83 @@ +from unittest.mock import Mock, patch + +from enforce_typing import enforce_types + +from pdr_backend.ppss.ppss import mock_feed_ppss +from pdr_backend.ppss.web3_pp import ( + inplace_mock_feedgetters, + inplace_mock_w3_and_contract_with_tracking, +) + +INIT_TIMESTAMP = 107 +INIT_BLOCK_NUMBER = 13 + + +@enforce_types +def do_constructor(agent_class, check_subscriptions_and_subscribe_mock): + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + # 1 predict feed + assert ppss.trader_ss.feed + agent = agent_class(ppss) + assert agent.ppss == ppss + assert agent.feed + check_subscriptions_and_subscribe_mock.assert_called_once() + + +@enforce_types +def do_run( # pylint: disable=unused-argument + agent_class, + check_subscriptions_and_subscribe_mock, +): + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + agent = agent_class(ppss) + + with patch.object(agent, "take_step") as mock_stake_step: + agent.run(True) + mock_stake_step.assert_called_once() + + +@enforce_types +def setup_take_step( # pylint: disable=unused-argument + agent_class, + check_subscriptions_and_subscribe_mock, + monkeypatch, +): + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + _mock_pdr_contract = inplace_mock_w3_and_contract_with_tracking( + ppss.web3_pp, + INIT_TIMESTAMP, + INIT_BLOCK_NUMBER, + ppss.trader_ss.timeframe_s, + feed.address, + monkeypatch, + ) + + agent = agent_class(ppss) + + # Create async mock fn so we can await asyncio.gather(*tasks) + async def _process_block(timestamp): # pylint: disable=unused-argument + return (-1, []) + + agent._process_block = Mock(side_effect=_process_block) + + return agent + + +@enforce_types +def setup_trade( # pylint: disable=unused-argument + agent_class, check_subscriptions_and_subscribe_mock +): + feed, ppss = mock_feed_ppss("5m", "binance", "BTC/USDT") + inplace_mock_feedgetters(ppss.web3_pp, feed) # mock publishing feeds + + agent = agent_class(ppss) + + agent.exchange = Mock() + agent.exchange.create_market_buy_order.return_value = {"info": {"origQty": 1}} + + return agent, feed diff --git a/pdr_backend/trader/trader_config.py b/pdr_backend/trader/trader_config.py deleted file mode 100644 index 406f38d3d..000000000 --- a/pdr_backend/trader/trader_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from os import getenv - -from enforce_typing import enforce_types - -from pdr_backend.models.base_config import BaseConfig - - -@enforce_types -class TraderConfig(BaseConfig): - def __init__(self): - super().__init__() - - # Sets a threshold (in seconds) for trade decisions. - # For example, if set to 180 and there's 179 seconds left, no trade. If 181, then trade. - self.trader_min_buffer = int(getenv("TRADER_MIN_BUFFER", "60")) - - # Maximum attempts to process a feed - self.max_tries = 10 diff --git a/pdr_backend/trueval/get_trueval.py b/pdr_backend/trueval/get_trueval.py new file mode 100644 index 000000000..733ba7253 --- /dev/null +++ b/pdr_backend/trueval/get_trueval.py @@ -0,0 +1,66 @@ +from typing import Tuple + +import ccxt +from enforce_typing import enforce_types + +from pdr_backend.lake.fetch_ohlcv import safe_fetch_ohlcv +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + + +@enforce_types +def get_trueval( + feed: SubgraphFeed, init_timestamp: int, end_timestamp: int +) -> Tuple[bool, bool]: + """ + @description + Checks if the price has risen between two given timestamps. + If the round should be canceled, the second value in the returned tuple is set to True. + + @arguments + feed -- SubgraphFeed -- The feed object containing pair details + init_timestamp -- int -- The starting timestamp. + end_timestamp -- int -- The ending timestamp. + + @return + trueval -- did price rise y/n? + cancel_round -- should we cancel the round y/n? + """ + symbol = feed.pair + symbol = symbol.replace("-", "/") + symbol = symbol.upper() + + # since we will get close price + # we need to go back 1 candle + init_timestamp -= feed.seconds_per_epoch + end_timestamp -= feed.seconds_per_epoch + + # convert seconds to ms + init_timestamp = int(init_timestamp * 1000) + end_timestamp = int(end_timestamp * 1000) + + exchange_class = getattr(ccxt, feed.source) + exchange = exchange_class() + tohlcvs = safe_fetch_ohlcv( + exchange, symbol, feed.timeframe, since=init_timestamp, limit=2 + ) + assert len(tohlcvs) == 2, f"expected exactly 2 tochlv tuples. {tohlcvs}" + init_tohlcv, end_tohlcv = tohlcvs[0], tohlcvs[1] + + if init_tohlcv[0] != init_timestamp: + raise Exception( + f"exchange's init_tohlcv[0]={init_tohlcv[0]} should have matched" + f" target init_timestamp={init_timestamp}" + ) + if end_tohlcv[0] != end_timestamp: + raise Exception( + f"exchange's end_tohlcv[0]={end_tohlcv[0]} should have matched" + f" target end_timestamp={end_timestamp}" + ) + + init_c, end_c = init_tohlcv[4], end_tohlcv[4] # c = closing price + if end_c == init_c: + return False, True + + trueval = end_c > init_c + cancel_round = False + return trueval, cancel_round diff --git a/pdr_backend/trueval/main.py b/pdr_backend/trueval/main.py deleted file mode 100644 index 7450ba219..000000000 --- a/pdr_backend/trueval/main.py +++ /dev/null @@ -1,51 +0,0 @@ -import sys - -from enforce_typing import enforce_types - -from pdr_backend.trueval.trueval_agent_base import get_trueval -from pdr_backend.trueval.trueval_agent_batch import TruevalAgentBatch -from pdr_backend.trueval.trueval_agent_single import TruevalAgentSingle -from pdr_backend.trueval.trueval_config import TruevalConfig -from pdr_backend.util.contract import get_address - - -HELP = """Trueval runner. - -Usage: python pdr_backend/trueval/main.py APPROACH - - where APPROACH=1 submits truevals one by one - APPROACH=2 submits truevals in a batch -""" - - -def do_help(): - print(HELP) - sys.exit() - - -@enforce_types -def main(testing=False): - if len(sys.argv) <= 1: - do_help() - arg1 = sys.argv[1] - config = TruevalConfig() - - if arg1 == "1": - t = TruevalAgentSingle(config, get_trueval) - t.run(testing) - - elif arg1 == "2": - predictoor_batcher_addr = get_address( - config.web3_config.w3.eth.chain_id, "PredictoorHelper" - ) - t = TruevalAgentBatch(config, get_trueval, predictoor_batcher_addr) - t.run(testing) - - elif arg1 == "help": - do_help() - else: - do_help() - - -if __name__ == "__main__": - main() diff --git a/pdr_backend/trueval/test/conftest.py b/pdr_backend/trueval/test/conftest.py index c75b40027..469922498 100644 --- a/pdr_backend/trueval/test/conftest.py +++ b/pdr_backend/trueval/test/conftest.py @@ -1,61 +1,36 @@ -import os from unittest.mock import Mock, patch from pdr_backend.conftest_ganache import * # pylint: disable=wildcard-import -from pdr_backend.models.feed import Feed -from pdr_backend.models.slot import Slot -from pdr_backend.trueval.trueval_config import TruevalConfig +from pdr_backend.contract.slot import Slot +from pdr_backend.ppss.ppss import PPSS, fast_test_yaml_str +from pdr_backend.subgraph.subgraph_feed import mock_feed @pytest.fixture() def slot(): - feed = Feed( - name="ETH-USDT", - address="0xBE5449a6A97aD46c8558A3356267Ee5D2731ab5e", - symbol="ETH-USDT", - seconds_per_epoch=60, - seconds_per_subscription=500, - pair="eth-usdt", - source="kraken", - timeframe="5m", - trueval_submit_timeout=100, - owner="0xowner", - ) - + feed = mock_feed("5m", "kraken", "ETH/USDT") return Slot( feed=feed, slot_number=1692943200, ) -@pytest.fixture(autouse=True) -def set_env_vars(): - original_value = os.environ.get("OWNER_ADDRS", None) - os.environ["OWNER_ADDRS"] = "0xBE5449a6A97aD46c8558A3356267Ee5D2731ab5e" - yield - if original_value is not None: - os.environ["OWNER_ADDRS"] = original_value - else: - os.environ.pop("OWNER_ADDRS", None) - - @pytest.fixture() -def trueval_config(): - return TruevalConfig() +def mock_ppss(tmpdir): + return PPSS(yaml_str=fast_test_yaml_str(tmpdir), network="development") @pytest.fixture() def predictoor_contract_mock(): + def mock_contract(*args, **kwarg): # pylint: disable=unused-argument + m = Mock() + m.get_secondsPerEpoch.return_value = 60 + m.submit_trueval.return_value = {"tx": "0x123"} + m.contract_address = "0x1" + return m + with patch( - "pdr_backend.trueval.trueval_agent_base.PredictoorContract", + "pdr_backend.trueval.trueval_agent.PredictoorContract", return_value=mock_contract(), ) as mock_predictoor_contract_mock: yield mock_predictoor_contract_mock - - -def mock_contract(*args, **kwarg): # pylint: disable=unused-argument - m = Mock() - m.get_secondsPerEpoch.return_value = 60 - m.submit_trueval.return_value = {"tx": "0x123"} - m.contract_address = "0x1" - return m diff --git a/pdr_backend/trueval/test/test_get_trueval.py b/pdr_backend/trueval/test/test_get_trueval.py new file mode 100644 index 000000000..36ee87b50 --- /dev/null +++ b/pdr_backend/trueval/test/test_get_trueval.py @@ -0,0 +1,41 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.subgraph.subgraph_feed import mock_feed +from pdr_backend.trueval.get_trueval import get_trueval + +_PATH = "pdr_backend.trueval.get_trueval" + + +@enforce_types +def test_get_trueval_success(monkeypatch): + def mock_fetch_ohlcv(*args, **kwargs): # pylint: disable=unused-argument + since = kwargs.get("since") + if since == 0: + return [[0, 0, 0, 0, 100], [300000, 0, 0, 0, 200]] + raise ValueError(f"Invalid timestamp: since={since}") + + monkeypatch.setattr(f"{_PATH}.safe_fetch_ohlcv", mock_fetch_ohlcv) + + feed = mock_feed("5m", "kraken", "ETH/USDT") + + init_ts = feed.seconds_per_epoch + end_ts = init_ts + feed.seconds_per_epoch + result = get_trueval(feed, init_ts, end_ts) + assert result == (True, False) + + +@enforce_types +def test_get_trueval_fail(monkeypatch): + def mock_fetch_ohlcv_fail(*args, **kwargs): # pylint: disable=unused-argument + return [[0, 0, 0, 0, 0], [300000, 0, 0, 0, 200]] + + monkeypatch.setattr(f"{_PATH}.safe_fetch_ohlcv", mock_fetch_ohlcv_fail) + + feed = mock_feed("5m", "kraken", "eth-usdt") + + init_ts = feed.seconds_per_epoch + end_ts = init_ts + feed.seconds_per_epoch + with pytest.raises(Exception): + result = get_trueval(feed, init_ts, end_ts) + assert result == (False, True) # 2nd True because failed diff --git a/pdr_backend/trueval/test/test_trueval.py b/pdr_backend/trueval/test/test_trueval.py deleted file mode 100644 index 4df63cf3d..000000000 --- a/pdr_backend/trueval/test/test_trueval.py +++ /dev/null @@ -1,98 +0,0 @@ -from enforce_typing import enforce_types - -import pytest - -from pdr_backend.trueval.main import get_trueval -from pdr_backend.models.feed import Feed - - -def mock_fetch_ohlcv(*args, **kwargs): # pylint: disable=unused-argument - since = kwargs.get("since") - if since == 0: - return [[0, 0, 0, 0, 100], [60000, 0, 0, 0, 200]] - raise ValueError("Invalid timestamp") - - -def mock_fetch_ohlcv_fail(*args, **kwargs): # pylint: disable=unused-argument - return [[0, 0, 0, 0, 0]] - - -@enforce_types -def test_get_trueval_success(monkeypatch): - feed = Feed( - name="ETH-USDT", - address="0x1", - symbol="ETH-USDT", - seconds_per_epoch=60, - seconds_per_subscription=500, - pair="eth-usdt", - source="kraken", - timeframe="5m", - trueval_submit_timeout=100, - owner="0xowner", - ) - - monkeypatch.setattr("ccxt.kraken.fetch_ohlcv", mock_fetch_ohlcv) - - result = get_trueval(feed, 60, 120) - assert result == (True, False) - - -@enforce_types -def test_get_trueval_live_lowercase_slash_5m(): - feed = Feed( - name="ETH-USDT", - address="0x1", - symbol="ETH-USDT", - seconds_per_epoch=300, - seconds_per_subscription=500, - pair="btc/usdt", - source="kucoin", - timeframe="5m", - trueval_submit_timeout=100, - owner="0xowner", - ) - - result = get_trueval(feed, 1692943200, 1692943200 + 5 * 60) - assert result == (False, False) - - -@enforce_types -def test_get_trueval_live_lowercase_dash_1h(): - feed = Feed( - name="ETH-USDT", - address="0x1", - symbol="ETH-USDT", - seconds_per_epoch=3600, - seconds_per_subscription=500, - pair="btc-usdt", - source="kucoin", - timeframe="1h", - trueval_submit_timeout=100, - owner="0xowner", - ) - - result = get_trueval(feed, 1692943200, 1692943200 + 1 * 60 * 60) - assert result == (False, False) - - -@enforce_types -def test_get_trueval_fail(monkeypatch): - feed = Feed( - name="ETH-USDT", - address="0x1", - symbol="ETH-USDT", - seconds_per_epoch=60, - seconds_per_subscription=500, - pair="eth-usdt", - source="kraken", - timeframe="5m", - trueval_submit_timeout=100, - owner="0xowner", - ) - - monkeypatch.setattr("ccxt.kraken.fetch_ohlcv", mock_fetch_ohlcv_fail) - - with pytest.raises(Exception): - result = get_trueval(feed, 1, 2) - assert result == (False, True) # 2nd True because failed diff --git a/pdr_backend/trueval/test/test_trueval_agent.py b/pdr_backend/trueval/test/test_trueval_agent.py index aa337ec78..74a52dc93 100644 --- a/pdr_backend/trueval/test/test_trueval_agent.py +++ b/pdr_backend/trueval/test/test_trueval_agent.py @@ -1,30 +1,20 @@ -from unittest.mock import patch, MagicMock +from copy import deepcopy +from unittest.mock import MagicMock, Mock, patch -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types -from pdr_backend.trueval.trueval_agent_base import get_trueval -from pdr_backend.trueval.trueval_agent_single import TruevalAgentSingle -from pdr_backend.trueval.trueval_config import TruevalConfig -from pdr_backend.util.web3_config import Web3Config - +from pdr_backend.trueval.trueval_agent import TruevalAgent, TruevalSlot +from pdr_backend.util.constants import ZERO_ADDRESS -@enforce_types -def test_new_agent(trueval_config): - agent_ = TruevalAgentSingle(trueval_config, get_trueval) - assert agent_.config == trueval_config +PATH = "pdr_backend.trueval.trueval_agent" @enforce_types -def test_process_slot( - agent, slot, predictoor_contract_mock -): # pylint: disable=unused-argument - with patch.object( - agent, "get_and_submit_trueval", return_value={"tx": "0x123"} - ) as mock_submit: - result = agent.process_slot(slot) - assert result == {"tx": "0x123"} - mock_submit.assert_called() +def test_trueval_agent_constructor(mock_ppss): + agent_ = TruevalAgent(mock_ppss, ZERO_ADDRESS) + assert agent_.ppss == mock_ppss + assert agent_.predictoor_batcher.contract_address == ZERO_ADDRESS @enforce_types @@ -32,148 +22,149 @@ def test_get_contract_info_caching(agent, predictoor_contract_mock): agent.get_contract_info("0x1") agent.get_contract_info("0x1") assert predictoor_contract_mock.call_count == 1 - predictoor_contract_mock.assert_called_once_with(agent.config.web3_config, "0x1") - - -@enforce_types -def test_submit_trueval_mocked_price_down(agent, slot, predictoor_contract_mock): - with patch.object(agent, "get_trueval", return_value=(False, False)): - result = agent.get_and_submit_trueval( - slot, predictoor_contract_mock.return_value - ) - assert result == {"tx": "0x123"} - predictoor_contract_mock.return_value.submit_trueval.assert_called_once_with( - False, 1692943200, False, True - ) - - -@enforce_types -def test_submit_trueval_mocked_price_up(agent, slot, predictoor_contract_mock): - with patch.object(agent, "get_trueval", return_value=(True, False)): - result = agent.get_and_submit_trueval( - slot, predictoor_contract_mock.return_value - ) - assert result == {"tx": "0x123"} - predictoor_contract_mock.return_value.submit_trueval.assert_called_once_with( - True, 1692943200, False, True - ) + predictoor_contract_mock.assert_called_once_with(agent.ppss.web3_pp, "0x1") @enforce_types -def test_submit_trueval_mocked_cancel(agent, slot, predictoor_contract_mock): - with patch.object(agent, "get_trueval", return_value=(True, True)): - result = agent.get_and_submit_trueval( - slot, predictoor_contract_mock.return_value - ) - assert result == {"tx": "0x123"} - predictoor_contract_mock.return_value.submit_trueval.assert_called_once_with( - True, 1692943200, True, True - ) +def test_get_trueval_slot( + agent, slot, predictoor_contract_mock +): # pylint: disable=unused-argument + for trueval, cancel in [ + (True, True), # up + (False, True), # down + (True, False), # cancel + ]: + with patch(f"{PATH}.get_trueval", Mock(return_value=(trueval, cancel))): + result = agent.get_trueval_slot(slot) + assert result == (trueval, cancel) @enforce_types -def test_get_trueval_slot_up( +def test_get_trueval_slot_too_many_requests_retry( agent, slot, predictoor_contract_mock ): # pylint: disable=unused-argument - with patch.object(agent, "get_trueval", return_value=(True, True)): + mock_get_trueval = MagicMock( + side_effect=[Exception("Too many requests"), (True, True)] + ) + with patch(f"{PATH}.get_trueval", mock_get_trueval), patch( + "time.sleep", return_value=None + ) as mock_sleep: result = agent.get_trueval_slot(slot) + mock_sleep.assert_called_once_with(60) assert result == (True, True) + assert mock_get_trueval.call_count == 2 @enforce_types -def test_get_trueval_slot_down( - agent, slot, predictoor_contract_mock -): # pylint: disable=unused-argument - with patch.object(agent, "get_trueval", return_value=(False, True)): - result = agent.get_trueval_slot(slot) - assert result == (False, True) +def test_trueval_agent_run(agent): + mock_take_step = Mock() + with patch.object(agent, "take_step", mock_take_step): + agent.run(testing=True) + + mock_take_step.assert_called_once() @enforce_types -def test_get_trueval_slot_cancel( - agent, slot, predictoor_contract_mock -): # pylint: disable=unused-argument - with patch.object(agent, "get_trueval", return_value=(True, False)): - result = agent.get_trueval_slot(slot) - assert result == (True, False) +def test_trueval_agent_get_init_and_ts(agent): + ts = 2000 + seconds_per_epoch = 300 + + (initial_ts, end_ts) = agent.get_init_and_ts(ts, seconds_per_epoch) + assert initial_ts == ts - 300 + assert end_ts == ts @enforce_types -def test_get_trueval_slot_too_many_requests_retry( +def test_process_trueval_slot( agent, slot, predictoor_contract_mock ): # pylint: disable=unused-argument - mock_get_trueval = MagicMock( - side_effect=[Exception("Too many requests"), (True, True)] - ) - with patch.object(agent, "get_trueval", mock_get_trueval), patch( - "time.sleep", return_value=None - ) as mock_sleep: - result = agent.get_trueval_slot(slot) - mock_sleep.assert_called_once_with(60) - assert result == (True, True) - assert mock_get_trueval.call_count == 2 + for trueval, cancel in [ + (True, True), # up + (False, True), # down + (True, False), # cancel + ]: + with patch(f"{PATH}.get_trueval", Mock(return_value=(trueval, cancel))): + slot = TruevalSlot(slot_number=slot.slot_number, feed=slot.feed) + agent.process_trueval_slot(slot) + + assert slot.trueval == trueval + assert slot.cancel == cancel @enforce_types -def test_take_step(slot, agent): - mocked_env = { - "SLEEP_TIME": "1", - "BATCH_SIZE": "1", - } - - mocked_web3_config = MagicMock(spec=Web3Config) - - with patch.dict("os.environ", mocked_env), patch.object( - agent.config, "web3_config", new=mocked_web3_config - ), patch( - "pdr_backend.trueval.trueval_agent_single.wait_until_subgraph_syncs" - ), patch.object( - TruevalConfig, "get_pending_slots", return_value=[slot] - ), patch( - "time.sleep" - ), patch.object( - TruevalConfig, "get_pending_slots", return_value=[slot] - ), patch.object( - TruevalAgentSingle, "process_slot" - ) as ps_mock: - agent.take_step() +def test_batch_submit_truevals(agent, slot): + times = 3 + slot.feed.address = "0x0000000000000000000000000000000000c0ffee" + trueval_slots = [ + TruevalSlot(feed=slot.feed, slot_number=i) for i in range(0, times) + ] + for i in trueval_slots: + i.set_trueval(True) + i.set_cancel(False) + + slot2 = deepcopy(slot) + slot2.feed.address = "0x0000000000000000000000000000000000badbad" + trueval_slots_2 = [ + TruevalSlot(feed=slot2.feed, slot_number=i) for i in range(0, times) + ] + for i in trueval_slots_2: + i.set_trueval(True) + i.set_cancel(False) + + trueval_slots.extend(trueval_slots_2) + + contract_addrs = [ + "0x0000000000000000000000000000000000C0FFEE", + "0x0000000000000000000000000000000000baDbad", + ] # checksum + epoch_starts = [list(range(0, times))] * 2 + truevals = [[True] * times, [True] * times] + cancels = [[False] * times, [False] * times] - ps_mock.assert_called_once_with(slot) + with patch.object( + agent.predictoor_batcher, + "submit_truevals_contracts", + return_value={"transactionHash": bytes.fromhex("badc0ffeee")}, + ) as mock: + tx = agent.batch_submit_truevals(trueval_slots) + assert tx == "badc0ffeee" + mock.assert_called_with(contract_addrs, epoch_starts, truevals, cancels, True) @enforce_types -def test_run(slot, agent): - mocked_env = { - "SLEEP_TIME": "1", - "BATCH_SIZE": "1", - } - - mocked_web3_config = MagicMock(spec=Web3Config) - - with patch.dict("os.environ", mocked_env), patch.object( - agent.config, "web3_config", new=mocked_web3_config - ), patch( - "pdr_backend.trueval.trueval_agent_single.wait_until_subgraph_syncs" - ), patch( +def test_trueval_agent_take_step(agent, slot): + with patch(f"{PATH}.wait_until_subgraph_syncs"), patch.object( + agent, "get_batch", return_value=[slot] + ) as mock_get_batch, patch.object( + agent, "process_trueval_slot" + ) as mock_process_trueval_slot, patch( "time.sleep" ), patch.object( - TruevalConfig, "get_pending_slots", return_value=[slot] - ), patch.object( - TruevalAgentSingle, "process_slot" - ) as ps_mock: - agent.run(True) + agent, "batch_submit_truevals" + ) as mock_batch_submit_truevals: + agent.take_step() + + mock_get_batch.assert_called_once() + call_args = mock_process_trueval_slot.call_args[0][0] + assert call_args.slot_number == slot.slot_number + assert call_args.feed == slot.feed - ps_mock.assert_called_once_with(slot) + call_args = mock_batch_submit_truevals.call_args[0][0] + assert call_args[0].slot_number == slot.slot_number + assert call_args[0].feed == slot.feed @enforce_types -def test_get_init_and_ts(agent): - ts = 2000 - seconds_per_epoch = 300 +def test_trueval_agent_get_batch(agent, slot): + with patch.object(agent.ppss.web3_pp, "get_pending_slots", return_value=[slot]): + batch = agent.get_batch() - (initial_ts, end_ts) = agent.get_init_and_ts(ts, seconds_per_epoch) - assert initial_ts == ts - 300 - assert end_ts == ts + assert batch == [slot] + + with patch.object(agent.ppss.web3_pp, "get_pending_slots", return_value=[]): + batch = agent.get_batch() + + assert batch == [] # ---------------------------------------------- @@ -181,5 +172,5 @@ def test_get_init_and_ts(agent): @pytest.fixture(name="agent") -def agent_fixture(trueval_config): - return TruevalAgentSingle(trueval_config, get_trueval) +def agent_fixture(mock_ppss): + return TruevalAgent(mock_ppss, ZERO_ADDRESS) diff --git a/pdr_backend/trueval/test/test_trueval_agent_batch.py b/pdr_backend/trueval/test/test_trueval_agent_batch.py deleted file mode 100644 index 25c7c9457..000000000 --- a/pdr_backend/trueval/test/test_trueval_agent_batch.py +++ /dev/null @@ -1,119 +0,0 @@ -from copy import deepcopy -from unittest.mock import patch - -import pytest - -from pdr_backend.trueval.trueval_agent_base import get_trueval -from pdr_backend.trueval.trueval_agent_batch import TruevalAgentBatch, TruevalSlot -from pdr_backend.util.constants import ZERO_ADDRESS - - -def test_new_agent(trueval_config): - agent_ = TruevalAgentBatch(trueval_config, get_trueval, ZERO_ADDRESS) - assert agent_.config == trueval_config - assert agent_.predictoor_batcher.contract_address == ZERO_ADDRESS - - -def test_process_trueval_slot_up( - agent, slot, predictoor_contract_mock -): # pylint: disable=unused-argument - with patch.object(agent, "get_trueval", return_value=(True, False)): - slot = TruevalSlot(slot_number=slot.slot_number, feed=slot.feed) - agent.process_trueval_slot(slot) - - assert not slot.cancel - assert slot.trueval - - -def test_process_trueval_slot_down( - agent, slot, predictoor_contract_mock -): # pylint: disable=unused-argument - with patch.object(agent, "get_trueval", return_value=(False, False)): - slot = TruevalSlot(slot_number=slot.slot_number, feed=slot.feed) - agent.process_trueval_slot(slot) - - assert not slot.cancel - assert not slot.trueval - - -def test_process_trueval_slot_cancel( - agent, slot, predictoor_contract_mock -): # pylint: disable=unused-argument - with patch.object(agent, "get_trueval", return_value=(False, True)): - slot = TruevalSlot(slot_number=slot.slot_number, feed=slot.feed) - agent.process_trueval_slot(slot) - - assert slot.cancel - assert not slot.trueval - - -def test_batch_submit_truevals(agent, slot): - times = 3 - slot.feed.address = "0x0000000000000000000000000000000000c0ffee" - trueval_slots = [ - TruevalSlot(feed=slot.feed, slot_number=i) for i in range(0, times) - ] - for i in trueval_slots: - i.set_trueval(True) - i.set_cancel(False) - - slot2 = deepcopy(slot) - slot2.feed.address = "0x0000000000000000000000000000000000badbad" - trueval_slots_2 = [ - TruevalSlot(feed=slot2.feed, slot_number=i) for i in range(0, times) - ] - for i in trueval_slots_2: - i.set_trueval(True) - i.set_cancel(False) - - trueval_slots.extend(trueval_slots_2) - - contract_addrs = [ - "0x0000000000000000000000000000000000C0FFEE", - "0x0000000000000000000000000000000000baDbad", - ] # checksum - epoch_starts = [list(range(0, times))] * 2 - truevals = [[True] * times, [True] * times] - cancels = [[False] * times, [False] * times] - - with patch.object( - agent.predictoor_batcher, - "submit_truevals_contracts", - return_value={"transactionHash": bytes.fromhex("badc0ffeee")}, - ) as mock: - tx = agent.batch_submit_truevals(trueval_slots) - assert tx == "badc0ffeee" - mock.assert_called_with(contract_addrs, epoch_starts, truevals, cancels, True) - - -def test_take_step(agent, slot): - with patch( - "pdr_backend.trueval.trueval_agent_batch.wait_until_subgraph_syncs" - ), patch.object( - agent, "get_batch", return_value=[slot] - ) as mock_get_batch, patch.object( - agent, "process_trueval_slot" - ) as mock_process_trueval_slot, patch( - "time.sleep" - ), patch.object( - agent, "batch_submit_truevals" - ) as mock_batch_submit_truevals: - agent.take_step() - - mock_get_batch.assert_called_once() - call_args = mock_process_trueval_slot.call_args[0][0] - assert call_args.slot_number == slot.slot_number - assert call_args.feed == slot.feed - - call_args = mock_batch_submit_truevals.call_args[0][0] - assert call_args[0].slot_number == slot.slot_number - assert call_args[0].feed == slot.feed - - -# ---------------------------------------------- -# Fixtures - - -@pytest.fixture(name="agent") -def agent_fixture(trueval_config): - return TruevalAgentBatch(trueval_config, get_trueval, ZERO_ADDRESS) diff --git a/pdr_backend/trueval/test/test_trueval_config.py b/pdr_backend/trueval/test/test_trueval_config.py deleted file mode 100644 index e66e769df..000000000 --- a/pdr_backend/trueval/test/test_trueval_config.py +++ /dev/null @@ -1,18 +0,0 @@ -import os -from pdr_backend.trueval.trueval_config import TruevalConfig -from pdr_backend.util.env import parse_filters - - -def test_trueval_config(): - config = TruevalConfig() - assert config.rpc_url == os.getenv("RPC_URL") - assert config.subgraph_url == os.getenv("SUBGRAPH_URL") - assert config.private_key == os.getenv("PRIVATE_KEY") - assert config.sleep_time == int(os.getenv("SLEEP_TIME", "30")) - assert config.batch_size == int(os.getenv("BATCH_SIZE", "30")) - - (f0, f1, f2, f3) = parse_filters() - assert config.pair_filters == f0 - assert config.timeframe_filter == f1 - assert config.source_filter == f2 - assert config.owner_addresses == f3 diff --git a/pdr_backend/trueval/test/test_trueval_main.py b/pdr_backend/trueval/test/test_trueval_main.py deleted file mode 100644 index 9494d7e3e..000000000 --- a/pdr_backend/trueval/test/test_trueval_main.py +++ /dev/null @@ -1,49 +0,0 @@ -from unittest.mock import MagicMock, Mock, patch -from pdr_backend.trueval.main import main -from pdr_backend.trueval.trueval_agent_batch import TruevalAgentBatch -from pdr_backend.trueval.trueval_agent_single import TruevalAgentSingle -from pdr_backend.trueval.trueval_config import TruevalConfig -from pdr_backend.util.constants import ZERO_ADDRESS -from pdr_backend.util.web3_config import Web3Config - - -def test_trueval_main_1(slot): - mocked_web3_config = Mock(spec=Web3Config) - mocked_web3_config.get_block = Mock() - mocked_web3_config.get_block.return_value = {"timestamp": 0} - mocked_web3_config.w3 = MagicMock() - - with patch( - "pdr_backend.models.base_config.Web3Config", return_value=mocked_web3_config - ), patch( - "pdr_backend.trueval.trueval_agent_single.wait_until_subgraph_syncs" - ), patch( - "time.sleep" - ), patch( - "pdr_backend.trueval.main.sys.argv", [0, "1"] - ), patch.object( - TruevalConfig, "get_pending_slots", return_value=[slot] - ), patch.object( - TruevalAgentSingle, "process_slot" - ) as ps_mock: - main(True) - - ps_mock.assert_called_once_with(slot) - - -def test_trueval_main_2(): - mocked_web3_config = Mock(spec=Web3Config) - mocked_web3_config.get_block = Mock() - mocked_web3_config.get_block.return_value = {"timestamp": 0} - mocked_web3_config.w3 = MagicMock() - - with patch( - "pdr_backend.models.base_config.Web3Config", return_value=mocked_web3_config - ), patch("pdr_backend.trueval.main.get_address", return_value=ZERO_ADDRESS), patch( - "pdr_backend.trueval.main.sys.argv", [0, "2"] - ), patch.object( - TruevalAgentBatch, "take_step" - ) as ts_mock: - main(True) - - ts_mock.assert_called_once() diff --git a/pdr_backend/trueval/trueval_agent.py b/pdr_backend/trueval/trueval_agent.py new file mode 100644 index 000000000..560641bad --- /dev/null +++ b/pdr_backend/trueval/trueval_agent.py @@ -0,0 +1,180 @@ +import os +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +from enforce_typing import enforce_types + +from pdr_backend.contract.predictoor_batcher import PredictoorBatcher +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.contract.slot import Slot +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.subgraph.subgraph_sync import wait_until_subgraph_syncs +from pdr_backend.trueval.get_trueval import get_trueval + + +@enforce_types +class TruevalSlot(Slot): + def __init__(self, slot_number: int, feed: SubgraphFeed): + super().__init__(slot_number, feed) + self.trueval: Optional[bool] = None + self.cancel = False + + def set_trueval(self, trueval: Optional[bool]): + self.trueval = trueval + + def set_cancel(self, cancel: bool): + self.cancel = cancel + + +@enforce_types +class TruevalAgent: + def __init__(self, ppss: PPSS, predictoor_batcher_addr: str): + self.ppss = ppss + self.predictoor_batcher: PredictoorBatcher = PredictoorBatcher( + self.ppss.web3_pp, predictoor_batcher_addr + ) + self.contract_cache: Dict[str, tuple] = {} + + def run(self, testing: bool = False): + while True: + self.take_step() + if testing or os.getenv("TEST") == "true": + break + + def take_step(self): + wait_until_subgraph_syncs( + self.ppss.web3_pp.web3_config, self.ppss.web3_pp.subgraph_url + ) + pending_slots = self.get_batch() + + if len(pending_slots) == 0: + print( + f"No pending slots, sleeping for {self.ppss.trueval_ss.sleep_time} seconds..." + ) + time.sleep(self.ppss.trueval_ss.sleep_time) + return + + # convert slots to TruevalSlot + trueval_slots = [ + TruevalSlot(slot.slot_number, slot.feed) for slot in pending_slots + ] + + # get the trueval for each slot + for slot in trueval_slots: + self.process_trueval_slot(slot) + print(".", end="", flush=True) + print() # new line + + print("Submitting transaction...") + + tx_hash = self.batch_submit_truevals(trueval_slots) + print( + f"Tx sent: {tx_hash}, sleeping for {self.ppss.trueval_ss.sleep_time} seconds..." + ) + + time.sleep(self.ppss.trueval_ss.sleep_time) + + def get_batch(self) -> List[Slot]: + timestamp = self.ppss.web3_pp.web3_config.get_block("latest")["timestamp"] + + pending_slots = self.ppss.web3_pp.get_pending_slots( + timestamp, + allowed_feeds=self.ppss.trueval_ss.feeds, + ) + print( + f"Found {len(pending_slots)} pending slots" + f", processing {self.ppss.trueval_ss.batch_size}" + ) + pending_slots = pending_slots[: self.ppss.trueval_ss.batch_size] + return pending_slots + + def get_contract_info( + self, contract_address: str + ) -> Tuple[PredictoorContract, int]: + if contract_address in self.contract_cache: + predictoor_contract, seconds_per_epoch = self.contract_cache[ + contract_address + ] + else: + predictoor_contract = PredictoorContract( + self.ppss.web3_pp, contract_address + ) + seconds_per_epoch = predictoor_contract.get_secondsPerEpoch() + self.contract_cache[contract_address] = ( + predictoor_contract, + seconds_per_epoch, + ) + return (predictoor_contract, seconds_per_epoch) + + def get_init_and_ts(self, slot: int, seconds_per_epoch: int) -> Tuple[int, int]: + initial_ts = slot - seconds_per_epoch + end_ts = slot + return initial_ts, end_ts + + def get_trueval_slot(self, slot: Slot): + """ + @description + Get trueval at the specified slot + + @arguments + slot + + @return + trueval: bool + cancel_round: bool + """ + _, s_per_epoch = self.get_contract_info(slot.feed.address) + init_ts, end_ts = self.get_init_and_ts(slot.slot_number, s_per_epoch) + + print( + f"Get trueval slot: begin. For slot_number {slot.slot_number}" + f" of {slot.feed}" + ) + try: + # calls to get_trueval() func below, via Callable attribute on self + (trueval, cancel_round) = get_trueval(slot.feed, init_ts, end_ts) + except Exception as e: + if "Too many requests" in str(e): + print("Get trueval slot: too many requests, wait for a minute") + time.sleep(60) + return self.get_trueval_slot(slot) + + # pylint: disable=line-too-long + raise Exception(f"An error occured: {e}") from e + + print(f"Get trueval slot: done. trueval={trueval}, cancel_round={cancel_round}") + return (trueval, cancel_round) + + def batch_submit_truevals(self, slots: List[TruevalSlot]) -> str: + contracts: dict = defaultdict( + lambda: {"epoch_starts": [], "trueVals": [], "cancelRounds": []} + ) + + w3 = self.ppss.web3_pp.web3_config.w3 + for slot in slots: + if slot.trueval is None: # We only want slots with non-None truevals + continue + data = contracts[w3.to_checksum_address(slot.feed.address)] + data["epoch_starts"].append(slot.slot_number) + data["trueVals"].append(slot.trueval) + data["cancelRounds"].append(slot.cancel) + + contract_addrs = list(contracts.keys()) + epoch_starts = [data["epoch_starts"] for data in contracts.values()] + trueVals = [data["trueVals"] for data in contracts.values()] + cancelRounds = [data["cancelRounds"] for data in contracts.values()] + + tx = self.predictoor_batcher.submit_truevals_contracts( + contract_addrs, epoch_starts, trueVals, cancelRounds, True + ) + return tx["transactionHash"].hex() + + def process_trueval_slot(self, slot: TruevalSlot): + # (don't wrap with try/except because the called func already does) + (trueval, cancel_round) = self.get_trueval_slot(slot) + + slot.set_trueval(trueval) + if cancel_round: + slot.set_cancel(True) diff --git a/pdr_backend/trueval/trueval_agent_base.py b/pdr_backend/trueval/trueval_agent_base.py deleted file mode 100644 index 5b2e78bdc..000000000 --- a/pdr_backend/trueval/trueval_agent_base.py +++ /dev/null @@ -1,125 +0,0 @@ -import time -from abc import ABC -from typing import Dict, List, Tuple, Callable - -import ccxt -from enforce_typing import enforce_types - -from pdr_backend.models.slot import Slot -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.feed import Feed -from pdr_backend.trueval.trueval_config import TruevalConfig - - -@enforce_types -class TruevalAgentBase(ABC): - def __init__( - self, - trueval_config: TruevalConfig, - _get_trueval: Callable[[Feed, int, int], Tuple[bool, bool]], - ): - self.config = trueval_config - self.get_trueval = _get_trueval - self.contract_cache: Dict[str, tuple] = {} - - def run(self, testing: bool = False): - while True: - self.take_step() - if testing: - break - - def take_step(self): - raise NotImplementedError("Take step is not implemented") - - def get_batch(self) -> List[Slot]: - timestamp = self.config.web3_config.get_block("latest")["timestamp"] - pending_slots = self.config.get_pending_slots( - timestamp, - ) - print( - f"Found {len(pending_slots)} pending slots, processing {self.config.batch_size}" - ) - pending_slots = pending_slots[: self.config.batch_size] - return pending_slots - - def get_contract_info( - self, contract_address: str - ) -> Tuple[PredictoorContract, int]: - if contract_address in self.contract_cache: - predictoor_contract, seconds_per_epoch = self.contract_cache[ - contract_address - ] - else: - predictoor_contract = PredictoorContract( - self.config.web3_config, contract_address - ) - seconds_per_epoch = predictoor_contract.get_secondsPerEpoch() - self.contract_cache[contract_address] = ( - predictoor_contract, - seconds_per_epoch, - ) - return (predictoor_contract, seconds_per_epoch) - - def get_init_and_ts(self, slot: int, seconds_per_epoch: int) -> Tuple[int, int]: - initial_ts = slot - seconds_per_epoch - end_ts = slot - return initial_ts, end_ts - - def get_trueval_slot(self, slot: Slot): - _, seconds_per_epoch = self.get_contract_info(slot.feed.address) - init_ts, end_ts = self.get_init_and_ts(slot.slot_number, seconds_per_epoch) - try: - (trueval, cancel) = self.get_trueval(slot.feed, init_ts, end_ts) - return trueval, cancel - except Exception as e: - if "Too many requests" in str(e): - print("Too many requests, waiting for a minute") - time.sleep(60) - return self.get_trueval_slot(slot) - - # pylint: disable=line-too-long - raise Exception( - f"An error occured: {e}, while getting trueval for: {slot.feed.address} {slot.feed.pair} {slot.slot_number}" - ) from e - - -@enforce_types -def get_trueval( - feed: Feed, initial_timestamp: int, end_timestamp: int -) -> Tuple[bool, bool]: - """ - @description - Checks if the price has risen between two given timestamps. - If the round should be canceled, the second value in the returned tuple is set to True. - - @arguments - feed -- Feed -- The feed object containing pair details - initial_timestamp -- int -- The starting timestamp. - end_timestamp -- int -- The ending timestamp. - - @return - Tuple[bool, bool] -- The trueval and a boolean indicating if the round should be canceled. - """ - symbol = feed.pair - symbol = symbol.replace("-", "/") - symbol = symbol.upper() - - # since we will get close price - # we need to go back 1 candle - initial_timestamp -= feed.seconds_per_epoch - end_timestamp -= feed.seconds_per_epoch - - # convert seconds to ms - initial_timestamp = int(initial_timestamp * 1000) - end_timestamp = int(end_timestamp * 1000) - - exchange_class = getattr(ccxt, feed.source) - exchange = exchange_class() - price_data = exchange.fetch_ohlcv( - symbol, feed.timeframe, since=initial_timestamp, limit=2 - ) - if price_data[0][0] != initial_timestamp or price_data[1][0] != end_timestamp: - raise Exception("Timestamp mismatch") - if price_data[1][4] == price_data[0][4]: - return (False, True) - return (price_data[1][4] >= price_data[0][4], False) diff --git a/pdr_backend/trueval/trueval_agent_batch.py b/pdr_backend/trueval/trueval_agent_batch.py deleted file mode 100644 index c7f84e777..000000000 --- a/pdr_backend/trueval/trueval_agent_batch.py +++ /dev/null @@ -1,100 +0,0 @@ -from collections import defaultdict -import time -from typing import List, Optional, Tuple, Callable - -from enforce_typing import enforce_types - -from pdr_backend.models.feed import Feed -from pdr_backend.models.predictoor_batcher import PredictoorBatcher -from pdr_backend.models.slot import Slot -from pdr_backend.trueval.trueval_agent_base import TruevalAgentBase -from pdr_backend.trueval.trueval_config import TruevalConfig -from pdr_backend.util.subgraph import wait_until_subgraph_syncs - - -@enforce_types -class TruevalSlot(Slot): - def __init__(self, slot_number: int, feed: Feed): - super().__init__(slot_number, feed) - self.trueval: Optional[bool] = None - self.cancel = False - - def set_trueval(self, trueval: Optional[bool]): - self.trueval = trueval - - def set_cancel(self, cancel: bool): - self.cancel = cancel - - -@enforce_types -class TruevalAgentBatch(TruevalAgentBase): - def __init__( - self, - trueval_config: TruevalConfig, - _get_trueval: Callable[[Feed, int, int], Tuple[bool, bool]], - predictoor_batcher_addr: str, - ): - super().__init__(trueval_config, _get_trueval) - self.predictoor_batcher: PredictoorBatcher = PredictoorBatcher( - self.config.web3_config, predictoor_batcher_addr - ) - - def take_step(self): - wait_until_subgraph_syncs(self.config.web3_config, self.config.subgraph_url) - pending_slots = self.get_batch() - - if len(pending_slots) == 0: - print(f"No pending slots, sleeping for {self.config.sleep_time} seconds...") - time.sleep(self.config.sleep_time) - return - - # convert slots to TruevalSlot - trueval_slots = [ - TruevalSlot(slot.slot_number, slot.feed) for slot in pending_slots - ] - - # get the trueval for each slot - for slot in trueval_slots: - self.process_trueval_slot(slot) - print(".", end="", flush=True) - print() # new line - - print("Submitting transaction...") - - tx_hash = self.batch_submit_truevals(trueval_slots) - print(f"Tx sent: {tx_hash}, sleeping for {self.config.sleep_time} seconds...") - - time.sleep(self.config.sleep_time) - - def batch_submit_truevals(self, slots: List[TruevalSlot]) -> str: - contracts: dict = defaultdict( - lambda: {"epoch_starts": [], "trueVals": [], "cancelRounds": []} - ) - - for slot in slots: - if slot.trueval is not None: # Only consider slots with non-None truevals - data = contracts[ - self.config.web3_config.w3.to_checksum_address(slot.feed.address) - ] - data["epoch_starts"].append(slot.slot_number) - data["trueVals"].append(slot.trueval) - data["cancelRounds"].append(slot.cancel) - - contract_addrs = list(contracts.keys()) - epoch_starts = [data["epoch_starts"] for data in contracts.values()] - trueVals = [data["trueVals"] for data in contracts.values()] - cancelRounds = [data["cancelRounds"] for data in contracts.values()] - - tx = self.predictoor_batcher.submit_truevals_contracts( - contract_addrs, epoch_starts, trueVals, cancelRounds, True - ) - return tx["transactionHash"].hex() - - def process_trueval_slot(self, slot: TruevalSlot): - try: - (trueval, cancel) = self.get_trueval_slot(slot) - slot.set_trueval(trueval) - if cancel: - slot.set_cancel(True) - except Exception as e: - print("An error occured while getting processing slot:", e) diff --git a/pdr_backend/trueval/trueval_agent_single.py b/pdr_backend/trueval/trueval_agent_single.py deleted file mode 100644 index b5461d9fa..000000000 --- a/pdr_backend/trueval/trueval_agent_single.py +++ /dev/null @@ -1,53 +0,0 @@ -import time -from enforce_typing import enforce_types -from pdr_backend.models.predictoor_contract import PredictoorContract -from pdr_backend.models.slot import Slot - -from pdr_backend.trueval.trueval_agent_base import TruevalAgentBase -from pdr_backend.util.subgraph import wait_until_subgraph_syncs - - -@enforce_types -class TruevalAgentSingle(TruevalAgentBase): - def take_step(self): - wait_until_subgraph_syncs(self.config.web3_config, self.config.subgraph_url) - pending_slots = self.get_batch() - - if len(pending_slots) == 0: - print(f"No pending slots, sleeping for {self.config.sleep_time} seconds...") - time.sleep(self.config.sleep_time) - return - - for slot in pending_slots: - print("-" * 30) - print(f"Processing slot {slot.slot_number} for {slot.feed}") - try: - self.process_slot(slot) - except Exception as e: - print("An error occured", e) - print(f"Done processing, sleeping for {self.config.sleep_time} seconds...") - time.sleep(self.config.sleep_time) - - def process_slot(self, slot: Slot) -> dict: - predictoor_contract, _ = self.get_contract_info(slot.feed.address) - return self.get_and_submit_trueval(slot, predictoor_contract) - - def get_and_submit_trueval( - self, - slot: Slot, - predictoor_contract: PredictoorContract, - ) -> dict: - try: - (trueval, cancel) = self.get_trueval_slot(slot) - - # pylint: disable=line-too-long - print( - f"{slot.feed} - Submitting trueval {trueval} and slot:{slot.slot_number}" - ) - tx = predictoor_contract.submit_trueval( - trueval, slot.slot_number, cancel, True - ) - return tx - except Exception as e: - print("Error while getting trueval:", e) - return {} diff --git a/pdr_backend/trueval/trueval_config.py b/pdr_backend/trueval/trueval_config.py deleted file mode 100644 index cd453c350..000000000 --- a/pdr_backend/trueval/trueval_config.py +++ /dev/null @@ -1,17 +0,0 @@ -from os import getenv - -from enforce_typing import enforce_types - -from pdr_backend.models.base_config import BaseConfig - - -@enforce_types -class TruevalConfig(BaseConfig): - def __init__(self): - super().__init__() - - if self.owner_addresses is None: - raise Exception("env var OWNER_ADDRS must be set") - - self.sleep_time = int(getenv("SLEEP_TIME", "30")) - self.batch_size = int(getenv("BATCH_SIZE", "30")) diff --git a/pdr_backend/util/constants.py b/pdr_backend/util/constants.py index 3dd3eebb0..58a8e92d6 100644 --- a/pdr_backend/util/constants.py +++ b/pdr_backend/util/constants.py @@ -1,13 +1,22 @@ ZERO_ADDRESS = "0x0000000000000000000000000000000000000000" MAX_UINT = 2**256 - 1 +DEVELOPMENT_CHAINID = 8996 + SAPPHIRE_TESTNET_RPC = "https://testnet.sapphire.oasis.dev" SAPPHIRE_TESTNET_CHAINID = 23295 + SAPPHIRE_MAINNET_RPC = "https://sapphire.oasis.io" SAPPHIRE_MAINNET_CHAINID = 23294 +OCEAN_TOKEN_ADDRS = { + SAPPHIRE_MAINNET_CHAINID: "0x39d22B78A7651A76Ffbde2aaAB5FD92666Aca520", + SAPPHIRE_TESTNET_CHAINID: "0x973e69303259B0c2543a38665122b773D28405fB", +} + S_PER_MIN = 60 S_PER_DAY = 86400 +S_PER_WEEK = S_PER_DAY * 7 SUBGRAPH_MAX_TRIES = 5 WEB3_MAX_TRIES = 5 diff --git a/scripts/addresses.py b/pdr_backend/util/constants_opf_addrs.py similarity index 92% rename from scripts/addresses.py rename to pdr_backend/util/constants_opf_addrs.py index a6ba360ff..96d1f9137 100644 --- a/scripts/addresses.py +++ b/pdr_backend/util/constants_opf_addrs.py @@ -1,7 +1,12 @@ -def get_opf_addresses(chain_id): - addresses = {} - if chain_id == 23295: - addresses = { +from typing import Dict + +from enforce_typing import enforce_types + + +@enforce_types +def get_opf_addresses(network_name: str) -> Dict[str, str]: + if network_name == "sapphire-testnet": + return { "predictoor1": "0xE02A421dFc549336d47eFEE85699Bd0A3Da7D6FF", "predictoor2": "0x00C4C993e7B0976d63E7c92D874346C3D0A05C1e", "predictoor3": "0x005C414442a892077BD2c1d62B1dE2Fc127E5b9B", @@ -10,8 +15,8 @@ def get_opf_addresses(chain_id): "dfbuyer": "0xeA24C440eC55917fFa030C324535fc49B42c2fD7", } - if chain_id == 23294: - addresses = { + if network_name == "sapphire-mainnet": + return { "predictoor1": "0x35Afee1168D1e1053298F368488F4eE95E891a6e", "predictoor2": "0x1628BeA0Fb859D56Cd2388054c0bA395827e4374", "predictoor3": "0x3f0825d0c0bbfbb86cd13C7E6c9dC778E3Bb44ec", @@ -56,4 +61,5 @@ def get_opf_addresses(chain_id): "websocket": "0x6Cc4Fe9Ba145AbBc43227b3D4860FA31AFD225CB", "dfbuyer": "0x2433e002Ed10B5D6a3d8d1e0C5D2083BE9E37f1D", } - return addresses + + raise ValueError(network_name) diff --git a/pdr_backend/util/contract.py b/pdr_backend/util/contract.py index 5e5678db2..78c282c80 100644 --- a/pdr_backend/util/contract.py +++ b/pdr_backend/util/contract.py @@ -1,3 +1,4 @@ +import copy import json import os from pathlib import Path @@ -9,8 +10,12 @@ @enforce_types -def get_address(chain_id: int, contract_name: str): - network = get_addresses(chain_id) +def get_address(web3_pp, contract_name: str): + """ + Returns address for this contract + in web3_pp.address_file, in web3_pp.network + """ + network = get_addresses(web3_pp) if not network: raise ValueError(f"Cannot figure out {contract_name} address") address = network.get(contract_name) @@ -18,12 +23,16 @@ def get_address(chain_id: int, contract_name: str): @enforce_types -def get_addresses(chain_id: int): - address_filename = os.getenv("ADDRESS_FILE") +def get_addresses(web3_pp) -> Union[dict, None]: + """ + Returns addresses in web3_pp.address_file, in web3_pp.network + """ + address_file = web3_pp.address_file + path = None - if address_filename: - address_filename = os.path.expanduser(address_filename) - path = Path(address_filename) + if address_file: + address_file = os.path.expanduser(address_file) + path = Path(address_file) else: path = Path(str(os.path.dirname(addresses.__file__)) + "/address.json") @@ -31,18 +40,23 @@ def get_addresses(chain_id: int): raise TypeError(f"Cannot find address.json file at {path}") with open(path) as f: - data = json.load(f) - for name in data: - network = data[name] - if network["chainId"] == chain_id: - return network + d = json.load(f) + + d = _condition_sapphire_keys(d) + + if "barge" in web3_pp.network: # eg "barge-pytest" + return d["development"] + + if web3_pp.network in d: # eg "development", "oasis_sapphire" + return d[web3_pp.network] + return None @enforce_types -def get_contract_abi(contract_name): +def get_contract_abi(contract_name: str, address_file: Union[str, None]): """Returns the abi dict for a contract name.""" - path = get_contract_filename(contract_name) + path = get_contract_filename(contract_name, address_file) if not path.exists(): raise TypeError("Contract name does not exist in artifacts.") @@ -53,20 +67,19 @@ def get_contract_abi(contract_name): @enforce_types -def get_contract_filename(contract_name: str): +def get_contract_filename(contract_name: str, address_file: Union[str, None]): """Returns filename for a contract name.""" contract_basename = f"{contract_name}.json" # first, try to find locally - address_filename = os.getenv("ADDRESS_FILE") path: Union[None, str, Path] = None - if address_filename: - address_filename = os.path.expanduser(address_filename) - address_dir = os.path.dirname(address_filename) + if address_file: + address_file = os.path.expanduser(address_file) + address_dir = os.path.dirname(address_file) root_dir = os.path.join(address_dir, "..") paths = list(Path(root_dir).rglob(contract_basename)) if paths: - assert len(paths) == 1, "had duplicates for {contract_basename}" + assert len(paths) == 1, f"had duplicates for {contract_basename}" path = paths[0] path = Path(path).expanduser().resolve() assert ( @@ -77,6 +90,24 @@ def get_contract_filename(contract_name: str): # didn't find locally, so use use artifacts lib path = os.path.join(os.path.dirname(artifacts.__file__), "", contract_basename) path = Path(path).expanduser().resolve() + if not path.exists(): raise TypeError(f"Contract '{contract_name}' not found in artifacts.") + return path + + +@enforce_types +def _condition_sapphire_keys(d: dict) -> dict: + """ + For each sapphire-related key seen from address.json, + transform it to something friendly to pdr-backend (and named better) + """ + d2 = copy.deepcopy(d) + names = list(d.keys()) # eg ["mumbai", "oasis_saphire"] + for name in names: + if name == "oasis_saphire_testnet": + d2["sapphire-testnet"] = d[name] + elif name == "oasis_saphire": + d2["sapphire-mainnet"] = d[name] + return d2 diff --git a/pdr_backend/util/csvs.py b/pdr_backend/util/csvs.py index d9cba25cd..0b8e41218 100644 --- a/pdr_backend/util/csvs.py +++ b/pdr_backend/util/csvs.py @@ -1,36 +1,101 @@ -import os import csv +import os +from typing import Dict, List +from enforce_typing import enforce_types -def write_prediction_csv(all_predictions, csv_output_dir): - if not os.path.exists(csv_output_dir): - os.makedirs(csv_output_dir) +from pdr_backend.subgraph.subgraph_predictions import Prediction + + +@enforce_types +def get_plots_dir(pq_dir: str): + plots_dir = os.path.join(pq_dir, "plots") + + if not os.path.exists(plots_dir): + os.makedirs(plots_dir) - data = {} - for prediction in all_predictions: + return plots_dir + + +@enforce_types +def key_csv_filename_with_dir(csv_output_dir: str, key: str) -> str: + return os.path.join( + csv_output_dir, + key + ".csv", + ) + + +@enforce_types +def generate_prediction_data_structure( + predictions: List[Prediction], +) -> Dict[str, List[Prediction]]: + data: Dict[str, List[Prediction]] = {} + for prediction in predictions: key = ( prediction.pair.replace("/", "-") + prediction.timeframe + prediction.source ) + if key not in data: data[key] = [] data[key].append(prediction) + return data + + +@enforce_types +def _save_prediction_csv( + all_predictions: List[Prediction], + csv_output_dir: str, + headers: List, + attribute_names: List, +): + if not os.path.isdir(csv_output_dir): + os.makedirs(csv_output_dir) + + data = generate_prediction_data_structure(all_predictions) + for key, predictions in data.items(): predictions.sort(key=lambda x: x.timestamp) - filename = os.path.join(csv_output_dir, key + ".csv") + filename = key_csv_filename_with_dir(csv_output_dir, key) with open(filename, "w", newline="") as file: writer = csv.writer(file) - writer.writerow( - ["Predicted Value", "True Value", "Timestamp", "Stake", "Payout"] - ) + + writer.writerow(headers) + for prediction in predictions: writer.writerow( [ - prediction.prediction, - prediction.trueval, - prediction.timestamp, - prediction.stake, - prediction.payout, + getattr(prediction, attribute_name) + for attribute_name in attribute_names ] ) + print(f"CSV file '{filename}' created successfully.") + + +def save_prediction_csv(all_predictions: List[Prediction], csv_output_dir: str): + _save_prediction_csv( + all_predictions, + csv_output_dir, + ["Predicted Value", "True Value", "Timestamp", "Stake", "Payout"], + ["prediction", "trueval", "timestamp", "stake", "payout"], + ) + + +@enforce_types +def save_analysis_csv(all_predictions: List[Prediction], csv_output_dir: str): + _save_prediction_csv( + all_predictions, + csv_output_dir, + [ + "PredictionID", + "Timestamp", + "Slot", + "Stake", + "Wallet", + "Payout", + "True Value", + "Predicted Value", + ], + ["ID", "timestamp", "slot", "stake", "user", "payout", "trueval", "prediction"], + ) diff --git a/pdr_backend/util/env.py b/pdr_backend/util/env.py index 5c825b982..68a3e011f 100644 --- a/pdr_backend/util/env.py +++ b/pdr_backend/util/env.py @@ -1,6 +1,6 @@ -from os import getenv import sys -from typing import List, Tuple, Union +from os import getenv +from typing import Union from enforce_typing import enforce_types @@ -12,38 +12,3 @@ def getenv_or_exit(envvar_name: str) -> Union[None, str]: print(f"You must set {envvar_name} environment variable") sys.exit(1) return value - - -@enforce_types -def parse_filters() -> Tuple[List[str], List[str], List[str], List[str]]: - """ - @description - Grabs envvar values for each of the filters (PAIR_FILTER, etc). - Then, parses each: splits the string into a list. - Returns the list. - - @arguments - - - @return - parsed_pair_filter -- e.g. [] or ["ETH-USDT", "BTC-USDT"] - parsed_timeframe_filter -- e.g. ["5m"] - parsed_source_filter -- e.g. ["binance"] - parsed_owner_addrs -- e.g. ["0x123.."] - - @notes - if envvar is None, the parsed filter is [], *not* None - """ - - def _parse1(envvar) -> List[str]: - envval = getenv(envvar) - if envval is None: - return [] - return envval.split(",") - - return ( - _parse1("PAIR_FILTER"), - _parse1("TIMEFRAME_FILTER"), - _parse1("SOURCE_FILTER"), - _parse1("OWNER_ADDRS"), - ) diff --git a/pdr_backend/util/feedstr.py b/pdr_backend/util/feedstr.py deleted file mode 100644 index d0e6b2ed5..000000000 --- a/pdr_backend/util/feedstr.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import List, Tuple - -import ccxt -from enforce_typing import enforce_types - -from pdr_backend.util.pairstr import ( - unpack_pairs_str, - verify_pair_str, -) -from pdr_backend.util.signalstr import ( - unpack_signalchar_str, - verify_signal_str, -) - -# ========================================================================== -# unpack..() functions - - -@enforce_types -def unpack_feeds_strs( - feeds_strs: List[str], do_verify: bool = True -) -> List[Tuple[str, List[str], List[str]]]: - """ - @description - Unpack *one or more* feeds strs. - - Example: Given [ - 'binance oc ADA/USDT BTC/USDT', - 'kraken o BTC-USDT', - ] - Return [ - ('binance', 'open', 'ADA-USDT'), - ('binance', 'close', 'ADA-USDT'), - ('binance', 'open', 'BTC-USDT'), - ('binance', 'close', 'BTC-USDT'), - ('kraken', 'open', 'BTC-USDT'), - ] - - @arguments - feeds_strs - list of ' ' - do_verify - typically T. Only F to avoid recursion from verify functions - - @return - feed_tups - list of (exchange_str, signal_str, pair_str) - """ - if do_verify: - if not feeds_strs: - raise ValueError(feeds_strs) - - feed_tups = [] - for feeds_str in feeds_strs: - feed_tups += unpack_feeds_str(feeds_str, do_verify=False) - - if do_verify: - for feed_tup in feed_tups: - verify_feed_tup(feed_tup) - return feed_tups - - -@enforce_types -def unpack_feeds_str( - feeds_str: str, do_verify: bool = True -) -> List[Tuple[str, str, str]]: - """ - @description - Unpack a *single* feeds str. It can have >1 feeds of course. - - Example: Given 'binance oc ADA/USDT BTC/USDT' - Return [ - ('binance', 'open', 'ADA-USDT'), - ('binance', 'close', 'ADA-USDT'), - ('binance', 'open', 'BTC-USDT'), - ('binance', 'close', 'BTC-USDT'), - ] - - @arguments - feeds_str - ' ' - do_verify - typically T. Only F to avoid recursion from verify functions - - @return - feed_tups - list of (exchange_str, signal_str, pair_str) - """ - feeds_str = feeds_str.strip() - feeds_str = " ".join(feeds_str.split()) # replace multiple whitespace w/ 1 - exchange_str, signal_char_str, pairs_str = feeds_str.split(" ", maxsplit=2) - signal_str_list = unpack_signalchar_str(signal_char_str) - pair_str_list = unpack_pairs_str(pairs_str) - feed_tups = [ - (exchange_str, signal_str, pair_str) - for signal_str in signal_str_list - for pair_str in pair_str_list - ] - - if do_verify: - for feed_tup in feed_tups: - verify_feed_tup(feed_tup) - return feed_tups - - -@enforce_types -def unpack_feed_str(feed_str: str, do_verify: bool = True) -> Tuple[str, str, str]: - """ - @description - Unpack the string for a *single* feed: 1 exchange, 1 signal, 1 pair - - Example: 'binance o ADA/USDT' - Return ('binance', 'open', 'BTC-USDT') - - @argument - feed_str -- eg 'binance o ADA/USDT'; not eg 'binance oc ADA/USDT BTC/DAI' - do_verify - typically T. Only F to avoid recursion from verify functions - - @return - feed_tup - (exchange_str, signal_str, pair_str) - """ - feeds_str = feed_str - feed_tups = unpack_feeds_str(feeds_str, do_verify=False) - if do_verify: - if len(feed_tups) != 1: - raise ValueError(feed_str) - feed_tup = feed_tups[0] - return feed_tup - - -# ========================================================================== -# verify..() functions - - -@enforce_types -def verify_feeds_strs(feeds_strs: List[str]): - """ - @description - Raise an error if feeds_strs is invalid - - @argument - feeds_strs -- eg ['binance oh ADA/USDT BTC/USDT', 'kraken o ADA/DAI'] - """ - if not feeds_strs: - raise ValueError() - for feeds_str in feeds_strs: - verify_feeds_str(feeds_str) - - -@enforce_types -def verify_feeds_str(feeds_str: str): - """ - @description - Raise an error if feeds_str is invalid - - @argument - feeds_str -- e.g. 'binance oh ADA/USDT BTC/USDT' - """ - feed_tups = unpack_feeds_str(feeds_str, do_verify=False) - if not feed_tups: - raise ValueError(feeds_str) - for feed_tup in feed_tups: - verify_feed_tup(feed_tup) - - -@enforce_types -def verify_feed_str(feed_str: str): - """ - @description - Raise an error if feed_str is invalid - - @argument - feed_str -- e.g. 'binance o ADA/USDT' - """ - feeds_str = feed_str - feed_tups = unpack_feeds_str(feeds_str, do_verify=False) - if not len(feed_tups) == 1: - raise ValueError(feed_str) - verify_feed_tup(feed_tups[0]) - - -@enforce_types -def verify_feed_tup(feed_tup: Tuple[str, str, str]): - """ - @description - Raise an error if feed_tup is invalid. - - @argument - feed_tup -- (exchange_str, signal_str, pair_str) - E.g. ('binance', 'open', 'BTC/USDT') - """ - exchange_str, signal_str, pair_str = feed_tup - verify_exchange_str(exchange_str) - verify_signal_str(signal_str) - verify_pair_str(pair_str) - - -@enforce_types -def verify_exchange_str(exchange_str: str): - """ - @description - Raise an error if exchange is invalid. - - @argument - exchange_str -- e.g. "binance" - """ - # it's valid if ccxt sees it - if not hasattr(ccxt, exchange_str): - raise ValueError(exchange_str) diff --git a/pdr_backend/util/fund_accounts.py b/pdr_backend/util/fund_accounts.py new file mode 100644 index 000000000..a60d1dc54 --- /dev/null +++ b/pdr_backend/util/fund_accounts.py @@ -0,0 +1,49 @@ +import os +from typing import List + +from enforce_typing import enforce_types +from eth_account import Account + +from pdr_backend.contract.token import Token +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.contract import get_address + + +@enforce_types +def fund_accounts_with_OCEAN(web3_pp: Web3PP): + """ + Fund accounts, with opinions: use OCEAN, and choices of amounts. + Meant to be used from CLI. + """ + print(f"Fund accounts with OCEAN, network = {web3_pp.network}") + accounts_to_fund = [ + # account_key_env, OCEAN_to_send + ("PREDICTOOR_PRIVATE_KEY", 2000.0), + ("PREDICTOOR2_PRIVATE_KEY", 2000.0), + ("PREDICTOOR3_PRIVATE_KEY", 2000.0), + ("TRADER_PRIVATE_KEY", 2000.0), + ("DFBUYER_PRIVATE_KEY", 10000.0), + ("PDR_WEBSOCKET_KEY", 10000.0), + ("PDR_MM_USER", 10000.0), + ] + + OCEAN_addr = get_address(web3_pp, "Ocean") + OCEAN = Token(web3_pp, OCEAN_addr) + fund_accounts(accounts_to_fund, web3_pp.web3_config.owner, OCEAN) + print("Done funding.") + + +@enforce_types +def fund_accounts(accounts_to_fund: List[tuple], owner: str, token: Token): + """Worker function to actually fund accounts""" + for private_key_name, amount in accounts_to_fund: + if private_key_name in os.environ: + private_key = os.getenv(private_key_name) + account = Account.from_key( # pylint: disable=no-value-for-parameter + private_key + ) + print( + f"Sending OCEAN to account defined by envvar {private_key_name}" + f", with address {account.address}" + ) + token.transfer(account.address, amount * 1e18, owner) diff --git a/pdr_backend/util/listutil.py b/pdr_backend/util/listutil.py new file mode 100644 index 000000000..824263458 --- /dev/null +++ b/pdr_backend/util/listutil.py @@ -0,0 +1,12 @@ +from enforce_typing import enforce_types + + +@enforce_types +def remove_dups(seq: list): + """Returns all the items of seq, except duplicates. Preserves x's order.""" + + # the implementation below is the fastest according to stackoverflow + # https://stackoverflow.com/questions/480214/how-do-i-remove-duplicates-from-a-list-while-preserving-order + seen = set() # type: ignore[var-annotated] + seen_add = seen.add + return [x for x in seq if not (x in seen or seen_add(x))] diff --git a/pdr_backend/util/mathutil.py b/pdr_backend/util/mathutil.py index 1f237ed1d..7a9ad56d4 100644 --- a/pdr_backend/util/mathutil.py +++ b/pdr_backend/util/mathutil.py @@ -1,11 +1,12 @@ -from math import log10, floor import random import re +from math import floor, log10 from typing import Union -from enforce_typing import enforce_types import numpy as np import pandas as pd +import polars as pl +from enforce_typing import enforce_types from pdr_backend.util.strutil import StrMixin @@ -50,25 +51,78 @@ def round_sig(x: Union[int, float], sig: int) -> Union[int, float]: @enforce_types -def has_nan(x: Union[np.ndarray, pd.DataFrame, pd.Series]) -> bool: - """Returns True if any entry in x has a nan""" - if type(x) == np.ndarray: - return np.isnan(np.min(x)) - if type(x) in [pd.DataFrame, pd.Series]: - return x.isnull().values.any() # type: ignore[union-attr] - raise ValueError(f"Can't handle type {type(x)}") +def all_nan( + x: Union[np.ndarray, pd.DataFrame, pd.Series, pl.DataFrame, pl.Series] +) -> bool: + """Returns True if all entries in x have a nan _or_ a None""" + if isinstance(x, np.ndarray): + x = np.array(x, dtype=float) + return np.isnan(x).all() + + if isinstance(x, pd.Series): + x = x.fillna(value=np.nan, inplace=False) + return x.isnull().all() + + if isinstance(x, pd.DataFrame): + x = x.fillna(value=np.nan) + return x.isnull().all().all() + + # pl.Series or pl.DataFrame + return all_nan(x.to_numpy()) # type: ignore[union-attr] @enforce_types -def fill_nans(df: pd.DataFrame) -> pd.DataFrame: - """Interpolate the nans using Linear method. +def has_nan( + x: Union[np.ndarray, pd.DataFrame, pd.Series, pl.DataFrame, pl.Series] +) -> bool: + """Returns True if any entry in x has a nan _or_ a None""" + if isinstance(x, np.ndarray): + has_None = (x == None).any() # pylint: disable=singleton-comparison + return has_None or np.isnan(np.min(x)) + + if isinstance(x, pl.Series): + has_None = x.has_validity() + return has_None or sum(x.is_nan()) > 0 # type: ignore[union-attr] + + if isinstance(x, pl.DataFrame): + has_None = any(col.has_validity() for col in x) + return has_None or sum(sum(x).is_nan()) > 0 # type: ignore[union-attr] + + # pd.Series or pd.DataFrame + return x.isnull().values.any() # type: ignore[union-attr] + + +@enforce_types +def fill_nans( + df: Union[pd.DataFrame, pl.DataFrame] +) -> Union[pd.DataFrame, pl.DataFrame]: + """Interpolate the nans using Linear method available in pandas. It ignores the index and treat the values as equally spaced. Ref: https://www.geeksforgeeks.org/working-with-missing-data-in-pandas/ """ - df = df.interpolate(method="linear", limit_direction="forward") - df = df.interpolate(method="linear", limit_direction="backward") # row 0 - return df + interpolate_df = pd.DataFrame() + output_type = type(df) + + # polars support + if output_type == pl.DataFrame: + interpolate_df = df.to_pandas() + else: + interpolate_df = df + + # interpolate is a pandas-only feature + interpolate_df = interpolate_df.interpolate( + method="linear", limit_direction="forward" + ) + interpolate_df = interpolate_df.interpolate( + method="linear", limit_direction="backward" + ) # row 0 + + # return polars if input was polars + if type(output_type) == pl.DataFrame: + interpolate_df = pl.from_pandas(interpolate_df) + + return interpolate_df @enforce_types @@ -116,3 +170,34 @@ def nmse(yhat, y, ymin=None, ymax=None) -> float: nmse_result = mse_xy / mse_x return nmse_result + + +@enforce_types +def from_wei(amt_wei: int) -> Union[int, float]: + return float(amt_wei / 1e18) + + +@enforce_types +def to_wei(amt_eth: Union[int, float]) -> int: + return int(amt_eth * 1e18) + + +@enforce_types +def str_with_wei(amt_wei: int) -> str: + return f"{from_wei(amt_wei)} ({amt_wei} wei)" + + +@enforce_types +def string_to_bytes32(data) -> bytes: + if len(data) > 32: + myBytes32 = data[:32] + else: + myBytes32 = data.ljust(32, "0") + return bytes(myBytes32, "utf-8") + + +@enforce_types +def sole_value(d: dict): + if len(d) != 1: + raise ValueError(len(d)) + return list(d.values())[0] diff --git a/pdr_backend/util/networkutil.py b/pdr_backend/util/networkutil.py index 4e2540047..75c68c428 100644 --- a/pdr_backend/util/networkutil.py +++ b/pdr_backend/util/networkutil.py @@ -1,43 +1,14 @@ from enforce_typing import enforce_types -from sapphirepy import wrapper - -from pdr_backend.util.constants import ( - SAPPHIRE_TESTNET_CHAINID, - SAPPHIRE_MAINNET_CHAINID, -) @enforce_types -def is_sapphire_network(chain_id: int) -> bool: - return chain_id in [SAPPHIRE_TESTNET_CHAINID, SAPPHIRE_MAINNET_CHAINID] - +def get_sapphire_postfix(network: str) -> str: + if network == "sapphire-testnet": + return "testnet" + if network == "sapphire-mainnet": + return "mainnet" -@enforce_types -def send_encrypted_tx( - contract_instance, - function_name, - args, - pk, - sender, - receiver, - rpc_url, - value=0, # in wei - gasLimit=10000000, - gasCost=0, # in wei - nonce=0, -) -> tuple: - data = contract_instance.encodeABI(fn_name=function_name, args=args) - return wrapper.send_encrypted_sapphire_tx( - pk, - sender, - receiver, - rpc_url, - value, - gasLimit, - data, - gasCost, - nonce, - ) + raise ValueError(f"'{network}' is not valid name") @enforce_types diff --git a/pdr_backend/util/pairstr.py b/pdr_backend/util/pairstr.py deleted file mode 100644 index b1356f41e..000000000 --- a/pdr_backend/util/pairstr.py +++ /dev/null @@ -1,137 +0,0 @@ -import re -from typing import List, Tuple - -from enforce_typing import enforce_types - -from pdr_backend.util.constants import CAND_USDCOINS - - -# ========================================================================== -# unpack..() functions - - -@enforce_types -def unpack_pairs_str(pairs_str: str, do_verify: bool = True) -> List[str]: - """ - @description - Unpack the string for *one or more* pairs, into list of pair_str - - Example: Given 'ADA-USDT, BTC/USDT, ETH/USDT' - Return ['ADA-USDT', 'BTC-USDT', 'ETH-USDT'] - - @argument - pairs_str - '/' or 'base-quote' - do_verify - typically T. Only F to avoid recursion from verify functions - - @return - pair_str_list -- List[] - """ - pairs_str = pairs_str.strip() - pairs_str = " ".join(pairs_str.split()) # replace multiple whitespace w/ 1 - pairs_str = pairs_str.replace(", ", ",").replace(" ,", ",") - pairs_str = pairs_str.replace(" ", ",") - pairs_str = pairs_str.replace("/", "-") # ETH/USDT -> ETH-USDT. Safer files. - pair_str_list = pairs_str.split(",") - - if do_verify: - if not pair_str_list: - raise ValueError(pairs_str) - - for pair_str in pair_str_list: - verify_pair_str(pair_str) - - return pair_str_list - - -@enforce_types -def unpack_pair_str(pair_str: str, do_verify: bool = True) -> Tuple[str, str]: - """ - @description - Unpack the string for a *single* pair, into base_str and quote_str. - - Example: Given 'BTC/USDT' or 'BTC-USDT' - Return ('BTC', 'USDT') - - @argument - pair_str - '/' or 'base-quote' - do_verify - typically T. Only F to avoid recursion from verify functions - - @return - base_str -- e.g. 'BTC' - quote_str -- e.g. 'USDT' - """ - if do_verify: - verify_pair_str(pair_str) - - pair_str = pair_str.replace("/", "-") - base_str, quote_str = pair_str.split("-") - - if do_verify: - verify_base_str(base_str) - verify_quote_str(quote_str) - - return (base_str, quote_str) - - -# ========================================================================== -# verify..() functions - - -# @enforce_types -def verify_pairs_str(pairs_str: str): - """ - @description - Raise an error if pairs_str is invalid - - @argument - pairs_str -- e.g. 'ADA/USDT BTC/USDT' or 'ADA-USDT, ETH-RAI' - """ - pair_str_list = unpack_pairs_str(pairs_str, do_verify=False) - for pair_str in pair_str_list: - verify_pair_str(pair_str) - - -@enforce_types -def verify_pair_str(pair_str: str): - """ - @description - Raise an error if pair_str is invalid - - @argument - pair_str -- e.g. 'ADA/USDT' or 'ETH-RAI' - """ - pair_str = pair_str.strip() - if not re.match("[A-Z]+[-/][A-Z]+", pair_str): - raise ValueError(pair_str) - - base_str, quote_str = unpack_pair_str(pair_str, do_verify=False) - verify_base_str(base_str) - verify_quote_str(quote_str) - - -@enforce_types -def verify_base_str(base_str: str): - """ - @description - Raise an error if base_str is invalid - - @argument - base_str -- e.g. 'ADA' or ' ETH ' - """ - base_str = base_str.strip() - if not re.match("[A-Z]+$", base_str): - raise ValueError(base_str) - - -@enforce_types -def verify_quote_str(quote_str: str): - """ - @description - Raise an error if quote_str is invalid - - @argument - quote_str -- e.g. 'USDT' or ' RAI ' - """ - quote_str = quote_str.strip() - if quote_str not in CAND_USDCOINS: - raise ValueError(quote_str) diff --git a/pdr_backend/util/predictoor_stats.py b/pdr_backend/util/predictoor_stats.py deleted file mode 100644 index 16c9bdae1..000000000 --- a/pdr_backend/util/predictoor_stats.py +++ /dev/null @@ -1,193 +0,0 @@ -from typing import List, Dict, Tuple, TypedDict, Set -from enforce_typing import enforce_types -from pdr_backend.models.prediction import Prediction - - -class PairTimeframeStat(TypedDict): - pair: str - timeframe: str - accuracy: float - stake: float - payout: float - number_of_predictions: int - - -class PredictoorStat(TypedDict): - predictoor_address: str - accuracy: float - stake: float - payout: float - number_of_predictions: int - details: Set[Tuple[str, str, str]] - - -@enforce_types -def aggregate_prediction_statistics( - all_predictions: List[Prediction], -) -> Tuple[Dict[str, Dict], int]: - """ - Aggregates statistics from a list of prediction objects. It organizes statistics - by currency pair and timeframe and predictor address. For each category, it - tallies the total number of predictions, the number of correct predictions, - and the total stakes and payouts. It also returns the total number of correct - predictions across all categories. - - Args: - all_predictions (List[Prediction]): A list of Prediction objects to aggregate. - - Returns: - Tuple[Dict[str, Dict], int]: A tuple containing a dictionary of aggregated - statistics and the total number of correct predictions. - """ - stats: Dict[str, Dict] = {"pair_timeframe": {}, "predictor": {}} - correct_predictions = 0 - - for prediction in all_predictions: - pair_timeframe_key = (prediction.pair, prediction.timeframe) - predictor_key = prediction.user - source = prediction.source - - is_correct = prediction.prediction == prediction.trueval - - if pair_timeframe_key not in stats["pair_timeframe"]: - stats["pair_timeframe"][pair_timeframe_key] = { - "correct": 0, - "total": 0, - "stake": 0, - "payout": 0, - } - - if predictor_key not in stats["predictor"]: - stats["predictor"][predictor_key] = { - "correct": 0, - "total": 0, - "stake": 0, - "payout": 0, - "details": set(), - } - - if is_correct: - correct_predictions += 1 - stats["pair_timeframe"][pair_timeframe_key]["correct"] += 1 - stats["predictor"][predictor_key]["correct"] += 1 - - stats["pair_timeframe"][pair_timeframe_key]["total"] += 1 - stats["pair_timeframe"][pair_timeframe_key]["stake"] += prediction.stake - stats["pair_timeframe"][pair_timeframe_key]["payout"] += prediction.payout - - stats["predictor"][predictor_key]["total"] += 1 - stats["predictor"][predictor_key]["stake"] += prediction.stake - stats["predictor"][predictor_key]["payout"] += prediction.payout - stats["predictor"][predictor_key]["details"].add( - (prediction.pair, prediction.timeframe, source) - ) - - return stats, correct_predictions - - -@enforce_types -def get_endpoint_statistics( - all_predictions: List[Prediction], -) -> Tuple[float, List[PairTimeframeStat], List[PredictoorStat]]: - """ - Calculates the overall accuracy of predictions, and aggregates detailed prediction - statistics by currency pair and timeframe with predictoor. - - The function first determines the overall accuracy of all given predictions. - It then organizes individual prediction statistics into two separate lists: - one for currency pair and timeframe statistics, and another for predictor statistics. - - Args: - all_predictions (List[Prediction]): A list of Prediction objects to be analyzed. - - Returns: - Tuple[float, List[Dict[str, Any]], List[Dict[str, Any]]]: A tuple containing the - overall accuracy as a float, a list of dictionaries with statistics for each - currency pair and timeframe, and a list of dictionaries with statistics for each - predictor. - """ - total_predictions = len(all_predictions) - stats, correct_predictions = aggregate_prediction_statistics(all_predictions) - - overall_accuracy = ( - correct_predictions / total_predictions * 100 if total_predictions else 0 - ) - - pair_timeframe_stats: List[PairTimeframeStat] = [] - for key, stat_pair_timeframe_item in stats["pair_timeframe"].items(): - pair, timeframe = key - accuracy = ( - stat_pair_timeframe_item["correct"] - / stat_pair_timeframe_item["total"] - * 100 - if stat_pair_timeframe_item["total"] - else 0 - ) - pair_timeframe_stat: PairTimeframeStat = { - "pair": pair, - "timeframe": timeframe, - "accuracy": accuracy, - "stake": stat_pair_timeframe_item["stake"], - "payout": stat_pair_timeframe_item["payout"], - "number_of_predictions": stat_pair_timeframe_item["total"], - } - pair_timeframe_stats.append(pair_timeframe_stat) - - predictoor_stats: List[PredictoorStat] = [] - for predictoor_addr, stat_predictoor_item in stats["predictor"].items(): - accuracy = ( - stat_predictoor_item["correct"] / stat_predictoor_item["total"] * 100 - if stat_predictoor_item["total"] - else 0 - ) - predictoor_stat: PredictoorStat = { - "predictoor_address": predictoor_addr, - "accuracy": accuracy, - "stake": stat_predictoor_item["stake"], - "payout": stat_predictoor_item["payout"], - "number_of_predictions": stat_predictoor_item["total"], - "details": set(stat_predictoor_item["details"]), - } - predictoor_stats.append(predictoor_stat) - - return overall_accuracy, pair_timeframe_stats, predictoor_stats - - -@enforce_types -def get_cli_statistics(all_predictions: List[Prediction]) -> None: - total_predictions = len(all_predictions) - - stats, correct_predictions = aggregate_prediction_statistics(all_predictions) - - if total_predictions == 0: - print("No predictions found.") - return - - if correct_predictions == 0: - print("No correct predictions found.") - return - - print(f"Overall Accuracy: {correct_predictions/total_predictions*100:.2f}%") - - for key, stat_pair_timeframe_item in stats["pair_timeframe"].items(): - pair, timeframe = key - accuracy = ( - stat_pair_timeframe_item["correct"] - / stat_pair_timeframe_item["total"] - * 100 - ) - print(f"Accuracy for Pair: {pair}, Timeframe: {timeframe}: {accuracy:.2f}%") - print(f"Total stake: {stat_pair_timeframe_item['stake']}") - print(f"Total payout: {stat_pair_timeframe_item['payout']}") - print(f"Number of predictions: {stat_pair_timeframe_item['total']}\n") - - for predictoor_addr, stat_predictoor_item in stats["predictor"].items(): - accuracy = stat_predictoor_item["correct"] / stat_predictoor_item["total"] * 100 - print(f"Accuracy for Predictoor Address: {predictoor_addr}: {accuracy:.2f}%") - print(f"Stake: {stat_predictoor_item['stake']}") - print(f"Payout: {stat_predictoor_item['payout']}") - print(f"Number of predictions: {stat_predictoor_item['total']}") - print("Details of Predictions:") - for detail in stat_predictoor_item["details"]: - print(f"Pair: {detail[0]}, Timeframe: {detail[1]}, Source: {detail[2]}") - print("\n") diff --git a/pdr_backend/util/signalstr.py b/pdr_backend/util/signalstr.py index 660c178de..0440312f2 100644 --- a/pdr_backend/util/signalstr.py +++ b/pdr_backend/util/signalstr.py @@ -1,9 +1,58 @@ -from typing import List +from typing import List, Set, Union from enforce_typing import enforce_types from pdr_backend.util.constants import CAND_SIGNALS, CHAR_TO_SIGNAL + +# ========================================================================== +# conversions +@enforce_types +def char_to_signal(char: str) -> str: + """eg given "o", return "open" """ + if char not in CHAR_TO_SIGNAL: + raise ValueError() + return CHAR_TO_SIGNAL[char] + + +@enforce_types +def signal_to_char(signal_str: str) -> str: + """ + Example: Given "open" + Return "o" + """ + for c, s in CHAR_TO_SIGNAL.items(): + if s == signal_str: + return c + + raise ValueError(signal_str) + + +# don't use @enforce_types, it causes problems +def signals_to_chars(signal_strs: Union[List[str], Set[str]]) -> str: + """ + Example: Given {"high", "close", "open"} + Return "ohc" + """ + # preconditions + if not signal_strs: + raise ValueError() + for signal_str in signal_strs: + verify_signal_str(signal_str) + + # main work + chars = "" + for cand_signal in CAND_SIGNALS: + if cand_signal in signal_strs: + c = signal_to_char(cand_signal) + chars += c + + # postconditions + if chars == "": + raise ValueError(signal_strs) + return chars + + # ========================================================================== # unpack..() functions @@ -33,7 +82,7 @@ def unpack_signalchar_str(signalchar_str: str) -> List[str]: @enforce_types -def verify_signalchar_str(signalchar_str: str): +def verify_signalchar_str(signalchar_str: str, graceful: bool = False): """ @description Raise an error if signalchar_str is invalid @@ -47,9 +96,14 @@ def verify_signalchar_str(signalchar_str: str): c_seen = set() for c in signalchar_str: if c not in "ohlcv" or c in c_seen: + if graceful: + return False + raise ValueError(signalchar_str) c_seen.add(c) + return True + @enforce_types def verify_signal_str(signal_str: str): diff --git a/pdr_backend/util/strutil.py b/pdr_backend/util/strutil.py index ae5109192..7c519ebcf 100644 --- a/pdr_backend/util/strutil.py +++ b/pdr_backend/util/strutil.py @@ -20,7 +20,11 @@ def longstr(self) -> str: obj = self short_attrs, long_attrs = [], [] - for attr in dir(obj): + if hasattr(self, "__STR_OBJDIR__"): + obj_dir = self.__STR_OBJDIR__ + else: + obj_dir = dir(obj) + for attr in obj_dir: if "__" in attr: continue attr_obj = getattr(obj, attr) @@ -74,12 +78,10 @@ def asCurrency(amount, decimals: bool = True) -> str: if decimals: if amount >= 0: return f"${amount:,.2f}" - return f"-${-amount:,.2f}".format(-amount) - if amount >= 0: - return f"${amount:,.0f}" + return f"-${-amount:,.2f}".format(-amount) - return f"-${-amount:,.0f}" + return f"${amount:,.0f}" if amount >= 0 else f"-${-amount:,.0f}" def prettyBigNum(amount, remove_zeroes: bool = True) -> str: diff --git a/pdr_backend/util/subgraph.py b/pdr_backend/util/subgraph.py deleted file mode 100644 index fa0a06207..000000000 --- a/pdr_backend/util/subgraph.py +++ /dev/null @@ -1,482 +0,0 @@ -""" -- READMEs/subgraph.md describes usage of Predictoor subgraph, with an example query -- the functions below provide other specific examples, that are used by agents of pdr-backend -""" -import time -from collections import defaultdict -from typing import Optional, Dict, List - -from enforce_typing import enforce_types -import requests -from web3 import Web3 - -from pdr_backend.util.constants import SUBGRAPH_MAX_TRIES -from pdr_backend.models.feed import Feed -from pdr_backend.models.slot import Slot -from pdr_backend.util.web3_config import Web3Config - -_N_ERRORS = {} # exception_str : num_occurrences -_N_THR = 3 - - -@enforce_types -def key_to_725(key: str): - key725 = Web3.keccak(key.encode("utf-8")).hex() - return key725 - - -@enforce_types -def value_to_725(value: str): - value725 = Web3.to_hex(text=value) - return value725 - - -@enforce_types -def value_from_725(value725) -> str: - value = Web3.to_text(hexstr=value725) - return value - - -@enforce_types -def info_from_725(info725_list: list) -> Dict[str, Optional[str]]: - """ - @arguments - info725_list -- eg [{"key":encoded("pair"), "value":encoded("ETH/USDT")}, - {"key":encoded("timeframe"), "value":encoded("5m") }, - ... ] - @return - info_dict -- e.g. {"pair": "ETH/USDT", - "timeframe": "5m", - ... } - """ - target_keys = ["pair", "timeframe", "source", "base", "quote"] - info_dict: Dict[str, Optional[str]] = {} - for key in target_keys: - info_dict[key] = None - for item725 in info725_list: - key725, value725 = item725["key"], item725["value"] - if key725 == key_to_725(key): - value = value_from_725(value725) - info_dict[key] = value - break - - return info_dict - - -@enforce_types -def query_subgraph( - subgraph_url: str, query: str, tries: int = 3, timeout: float = 30.0 -) -> Dict[str, dict]: - """ - @arguments - subgraph_url -- e.g. http://172.15.0.15:8000/subgraphs/name/oceanprotocol/ocean-subgraph # pylint: disable=line-too-long - query -- e.g. in docstring above - - @return - result -- e.g. {"data" : {"predictContracts": ..}} - """ - request = requests.post(subgraph_url, "", json={"query": query}, timeout=timeout) - if request.status_code != 200: - # pylint: disable=broad-exception-raised - if tries < SUBGRAPH_MAX_TRIES: - return query_subgraph(subgraph_url, query, tries + 1) - raise Exception( - f"Query failed. Url: {subgraph_url}. Return code is {request.status_code}\n{query}" - ) - result = request.json() - return result - - -@enforce_types -def query_pending_payouts(subgraph_url: str, addr: str) -> Dict[str, List[int]]: - chunk_size = 1000 - offset = 0 - pending_slots: Dict[str, List[int]] = {} - addr = addr.lower() - - while True: - query = """ - { - predictPredictions( - where: {user: "%s", payout: null, slot_: {status: "Paying"} }, first: %s, skip: %s - ) { - id - timestamp - slot { - id - slot - predictContract { - id - } - } - } - } - """ % ( - addr, - chunk_size, - offset, - ) - offset += chunk_size - print(".", end="", flush=True) - try: - result = query_subgraph(subgraph_url, query) - if not "data" in result or len(result["data"]) == 0: - print("No data in result") - break - predict_predictions = result["data"]["predictPredictions"] - if len(predict_predictions) == 0: - break - for prediction in predict_predictions: - contract_address = prediction["slot"]["predictContract"]["id"] - timestamp = prediction["slot"]["slot"] - pending_slots.setdefault(contract_address, []).append(timestamp) - except Exception as e: - print("An error occured", e) - - print() # print new line - return pending_slots - - -@enforce_types -def query_feed_contracts( # pylint: disable=too-many-statements - subgraph_url: str, - pairs_string: Optional[str] = None, - timeframes_string: Optional[str] = None, - sources_string: Optional[str] = None, - owners_string: Optional[str] = None, -) -> Dict[str, dict]: - """ - @description - Query the chain for prediction feed contracts, then filter down - according to pairs, timeframes, sources, or owners. - - @arguments - subgraph_url -- e.g. - pairs -- E.g. filter to "BTC/USDT,ETH/USDT". If None/"", allow all - timeframes -- E.g. filter to "5m,15m". If None/"", allow all - sources -- E.g. filter to "binance,kraken". If None/"", allow all - owners -- E.g. filter to "0x123,0x124". If None/"", allow all - - @return - feed_dicts -- dict of [contract_id] : feed_dict - where feed_dict is a dict with fields name, address, symbol, .. - """ - pairs = None - timeframes = None - sources = None - owners = None - - if pairs_string: - pairs = pairs_string.split(",") - if timeframes_string: - timeframes = timeframes_string.split(",") - if sources_string: - sources = sources_string.split(",") - if owners_string: - owners = owners_string.lower().split(",") - - chunk_size = 1000 # max for subgraph = 1000 - offset = 0 - feed_dicts = {} - - while True: - query = """ - { - predictContracts(skip:%s, first:%s){ - id - token { - id - name - symbol - nft { - owner { - id - } - nftData { - key - value - } - } - } - secondsPerEpoch - secondsPerSubscription - truevalSubmitTimeout - } - } - """ % ( - offset, - chunk_size, - ) - offset += chunk_size - try: - result = query_subgraph(subgraph_url, query) - contract_list = result["data"]["predictContracts"] - if not contract_list: - break - for contract in contract_list: - info725 = contract["token"]["nft"]["nftData"] - info = info_from_725(info725) # {"pair": "ETH/USDT", "base":...} - - # filter out unwanted - owner_id = contract["token"]["nft"]["owner"]["id"] - if owners and (owner_id not in owners): - continue - - pair = info["pair"] - if pair and pairs and (pair not in pairs): - continue - - timeframe = info["timeframe"] - if timeframe and timeframes and (timeframe not in timeframes): - continue - - source = info["source"] - if source and sources and (source not in sources): - continue - - # ok, add this one - addr = contract["id"] - feed_dict = { - "name": contract["token"]["name"], - "address": contract["id"], - "symbol": contract["token"]["symbol"], - "seconds_per_epoch": int(contract["secondsPerEpoch"]), - "seconds_per_subscription": int(contract["secondsPerSubscription"]), - "trueval_submit_timeout": int(contract["truevalSubmitTimeout"]), - "owner": owner_id, - "last_submited_epoch": 0, - } - feed_dict.update(info) - feed_dicts[addr] = feed_dict - - except Exception as e: - e_str = str(e) - e_key = e_str - if "Connection object" in e_str: - i = e_str.find("Connection object") + len("Connection object") - e_key = e_key[:i] - - if e_key not in _N_ERRORS: - _N_ERRORS[e_key] = 0 - _N_ERRORS[e_key] += 1 - - if _N_ERRORS[e_key] <= _N_THR: - print(e_str) - if _N_ERRORS[e_key] == _N_THR: - print("Future errors like this will be hidden") - return {} - - return feed_dicts - - -def get_pending_slots( - subgraph_url: str, - timestamp: int, - owner_addresses: Optional[List[str]], - pair_filter: Optional[List[str]] = None, - timeframe_filter: Optional[List[str]] = None, - source_filter: Optional[List[str]] = None, -): - chunk_size = 1000 - offset = 0 - owners: Optional[List[str]] = owner_addresses - - slots: List[Slot] = [] - - now_ts = time.time() - # rounds older than 3 days are canceled + 10 min buffer - three_days_ago = int(now_ts - 60 * 60 * 24 * 3 + 10 * 60) - - while True: - query = """ - { - predictSlots(where: {slot_gt: %s, slot_lte: %s, status: "Pending"}, skip:%s, first:%s){ - id - slot - status - trueValues { - id - } - predictContract { - id - token { - id - name - symbol - nft { - owner { - id - } - nftData { - key - value - } - } - } - secondsPerEpoch - secondsPerSubscription - truevalSubmitTimeout - } - } - } - """ % ( - three_days_ago, - timestamp, - offset, - chunk_size, - ) - - offset += chunk_size - try: - result = query_subgraph(subgraph_url, query) - if not "data" in result: - print("No data in result") - break - slot_list = result["data"]["predictSlots"] - if slot_list == []: - break - for slot in slot_list: - if slot["trueValues"] != []: - continue - - contract = slot["predictContract"] - info725 = contract["token"]["nft"]["nftData"] - info = info_from_725(info725) - assert info["pair"], "need a pair" - assert info["timeframe"], "need a timeframe" - assert info["source"], "need a source" - - owner_id = contract["token"]["nft"]["owner"]["id"] - if owners and (owner_id not in owners): - continue - - if pair_filter and (info["pair"] not in pair_filter): - continue - - if timeframe_filter and (info["timeframe"] not in timeframe_filter): - continue - - if source_filter and (info["source"] not in source_filter): - continue - - feed = Feed( - name=contract["token"]["name"], - address=contract["id"], - symbol=contract["token"]["symbol"], - seconds_per_epoch=int(contract["secondsPerEpoch"]), - seconds_per_subscription=int(contract["secondsPerSubscription"]), - trueval_submit_timeout=int(contract["truevalSubmitTimeout"]), - owner=contract["token"]["nft"]["owner"]["id"], - pair=info["pair"], - timeframe=info["timeframe"], - source=info["source"], - ) - - slot_number = int(slot["slot"]) - slot = Slot(slot_number, feed) - slots.append(slot) - - except Exception as e: - print(e) - break - - return slots - - -def get_consume_so_far_per_contract( - subgraph_url: str, - user_address: str, - since_timestamp: int, - contract_addresses: List[str], -) -> Dict[str, float]: - chunk_size = 1000 # max for subgraph = 1000 - offset = 0 - consume_so_far: Dict[str, float] = defaultdict(float) - print("Getting consume so far...") - while True: # pylint: disable=too-many-nested-blocks - query = """ - { - predictContracts(first:1000, where: {id_in: %s}){ - id - token{ - id - name - symbol - nft { - owner { - id - } - nftData { - key - value - } - } - orders(where: {createdTimestamp_gt:%s, consumer_in:["%s"]}, first: %s, skip: %s){ - createdTimestamp - consumer { - id - } - lastPriceValue - } - } - secondsPerEpoch - secondsPerSubscription - truevalSubmitTimeout - } - } - """ % ( - str(contract_addresses).replace("'", '"'), - since_timestamp, - user_address.lower(), - chunk_size, - offset, - ) - offset += chunk_size - result = query_subgraph(subgraph_url, query, 3, 30.0) - contracts = result["data"]["predictContracts"] - if contracts == []: - break - no_of_zeroes = 0 - for contract in contracts: - contract_address = contract["id"] - if contract_address not in contract_addresses: - continue - order_count = len(contract["token"]["orders"]) - if order_count == 0: - no_of_zeroes += 1 - for buy in contract["token"]["orders"]: - # 1.2 20% fee - # 0.001 0.01% community swap fee - consume_amt = float(buy["lastPriceValue"]) * 1.201 - consume_so_far[contract_address] += consume_amt - if no_of_zeroes == len(contracts): - break - return consume_so_far - - -@enforce_types -def block_number_is_synced(subgraph_url: str, block_number: int) -> bool: - query = """ - { - predictContracts(block:{number:%s}){ - id - } - } - """ % ( - block_number - ) - try: - result = query_subgraph(subgraph_url, query) - if "errors" in result: - return False - except Exception as _: - return False - return True - - -@enforce_types -def wait_until_subgraph_syncs(web3_config: Web3Config, subgraph_url: str): - block_number = web3_config.w3.eth.block_number - while block_number_is_synced(subgraph_url, block_number) is not True: - print("Subgraph is out of sync, trying again in 5 seconds") - time.sleep(5) diff --git a/pdr_backend/util/subgraph_slot.py b/pdr_backend/util/subgraph_slot.py deleted file mode 100644 index d2b74fb6c..000000000 --- a/pdr_backend/util/subgraph_slot.py +++ /dev/null @@ -1,360 +0,0 @@ -from dataclasses import dataclass -from typing import List, Dict, Any, Tuple, Optional -from enforce_typing import enforce_types - -from pdr_backend.util.subgraph import query_subgraph -from pdr_backend.util.networkutil import get_subgraph_url -from pdr_backend.util.subgraph_predictions import ContractIdAndSPE - - -@dataclass -class PredictSlot: - id: str - slot: str - trueValues: List[Dict[str, Any]] - roundSumStakesUp: float - roundSumStakes: float - - -@enforce_types -def get_predict_slots_query( - asset_ids: List[str], initial_slot: int, last_slot: int, first: int, skip: int -) -> str: - """ - Constructs a GraphQL query string to fetch prediction slot data for - specified assets within a slot range. - - Args: - asset_ids: A list of asset identifiers to include in the query. - initial_slot: The starting slot number for the query range. - last_slot: The ending slot number for the query range. - first: The number of records to fetch per query (pagination limit). - skip: The number of records to skip (pagination offset). - - Returns: - A string representing the GraphQL query. - """ - asset_ids_str = str(asset_ids).replace("[", "[").replace("]", "]").replace("'", '"') - - return """ - query { - predictSlots ( - first: %s - skip: %s - where: { - slot_lte: %s - slot_gte: %s - predictContract_in: %s - } - ) { - id - slot - trueValues { - id - trueValue - } - roundSumStakesUp - roundSumStakes - } - } - """ % ( - first, - skip, - initial_slot, - last_slot, - asset_ids_str, - ) - - -SECONDS_IN_A_DAY = 86400 - - -@enforce_types -def get_slots( - addresses: List[str], - end_ts_param: int, - start_ts_param: int, - skip: int, - slots: List[PredictSlot], - network: str = "mainnet", -) -> List[PredictSlot]: - """ - Retrieves slots information for given addresses and a specified time range from a subgraph. - - Args: - addresses: A list of contract addresses to query. - end_ts_param: The Unix timestamp representing the end of the time range. - start_ts_param: The Unix timestamp representing the start of the time range. - skip: The number of records to skip for pagination. - slots: An existing list of slots to which new data will be appended. - network: The blockchain network to query ('mainnet' or 'testnet'). - - Returns: - A list of PredictSlot TypedDicts with the queried slot information. - """ - - slots = slots or [] - - records_per_page = 1000 - - query = get_predict_slots_query( - addresses, - end_ts_param, - start_ts_param, - records_per_page, - skip, - ) - - result = query_subgraph( - get_subgraph_url(network), - query, - timeout=20.0, - ) - - new_slots = result["data"]["predictSlots"] or [] - - # Convert the list of dicts to a list of PredictSlot objects - # by passing the dict as keyword arguments - # convert roundSumStakesUp and roundSumStakes to float - new_slots = [ - PredictSlot( - **{ - **slot, - "roundSumStakesUp": float(slot["roundSumStakesUp"]), - "roundSumStakes": float(slot["roundSumStakes"]), - } - ) - for slot in new_slots - ] - - slots.extend(new_slots) - if len(new_slots) == records_per_page: - return get_slots( - addresses, - end_ts_param, - start_ts_param, - skip + records_per_page, - slots, - network, - ) - return slots - - -@enforce_types -def fetch_slots_for_all_assets( - asset_ids: List[str], - start_ts_param: int, - end_ts_param: int, - network: str = "mainnet", -) -> Dict[str, List[PredictSlot]]: - """ - Fetches slots for all provided asset IDs within a given time range and organizes them by asset. - - Args: - asset_ids: A list of asset identifiers for which slots will be fetched. - start_ts_param: The Unix timestamp marking the beginning of the desired time range. - end_ts_param: The Unix timestamp marking the end of the desired time range. - network: The blockchain network to query ('mainnet' or 'testnet'). - - Returns: - A dictionary mapping asset IDs to lists of PredictSlot dataclass - containing slot information. - """ - - all_slots = get_slots(asset_ids, end_ts_param, start_ts_param, 0, [], network) - - slots_by_asset: Dict[str, List[PredictSlot]] = {} - for slot in all_slots: - slot_id = slot.id - # split the id to get the asset id - asset_id = slot_id.split("-")[0] - if asset_id not in slots_by_asset: - slots_by_asset[asset_id] = [] - - slots_by_asset[asset_id].append(slot) - - return slots_by_asset - - -@enforce_types -def calculate_prediction_result( - round_sum_stakes_up: float, round_sum_stakes: float -) -> Optional[bool]: - """ - Calculates the prediction result based on the sum of stakes. - - Args: - round_sum_stakes_up: The summed stakes for the 'up' prediction. - round_sum_stakes: The summed stakes for all prediction. - - Returns: - A boolean indicating the predicted direction. - """ - - # checks for to be sure that the division is not by zero - round_sum_stakes_up_float = float(round_sum_stakes_up) - round_sum_stakes_float = float(round_sum_stakes) - - if round_sum_stakes_float == 0.0: - return None - - if round_sum_stakes_up_float == 0.0: - return False - - return (round_sum_stakes_up_float / round_sum_stakes_float) > 0.5 - - -@enforce_types -def process_single_slot( - slot: PredictSlot, end_of_previous_day_timestamp: int -) -> Optional[Tuple[float, float, int, int]]: - """ - Processes a single slot and calculates the staked amounts for yesterday and today, - as well as the count of correct predictions. - - Args: - slot: A PredictSlot TypedDict containing information about a single prediction slot. - end_of_previous_day_timestamp: The Unix timestamp marking the end of the previous day. - - Returns: - A tuple containing staked amounts for yesterday, today, and the counts of correct - predictions and slots evaluated, or None if no stakes were made today. - """ - - staked_yesterday = staked_today = 0.0 - correct_predictions_count = slots_evaluated = 0 - - if float(slot.roundSumStakes) == 0.0: - return None - - # split the id to get the slot timestamp - timestamp = int(slot.id.split("-")[1]) # Using dot notation for attribute access - - if ( - end_of_previous_day_timestamp - SECONDS_IN_A_DAY - < timestamp - < end_of_previous_day_timestamp - ): - staked_yesterday += float(slot.roundSumStakes) - elif timestamp > end_of_previous_day_timestamp: - staked_today += float(slot.roundSumStakes) - - prediction_result = calculate_prediction_result( - slot.roundSumStakesUp, slot.roundSumStakes - ) - - if prediction_result is None: - print("Prediction result is None for slot: ", slot.id) - return ( - staked_yesterday, - staked_today, - correct_predictions_count, - slots_evaluated, - ) - - true_values: List[Dict[str, Any]] = slot.trueValues or [] - true_value: Optional[bool] = true_values[0]["trueValue"] if true_values else None - - if len(true_values) > 0 and prediction_result == true_value: - correct_predictions_count += 1 - - if len(true_values) > 0 and true_value is not None: - slots_evaluated += 1 - - return staked_yesterday, staked_today, correct_predictions_count, slots_evaluated - - -@enforce_types -def aggregate_statistics( - slots: List[PredictSlot], end_of_previous_day_timestamp: int -) -> Tuple[float, float, int, int]: - """ - Aggregates statistics across all provided slots for an asset. - - Args: - slots: A list of PredictSlot TypedDicts containing information - about multiple prediction slots. - end_of_previous_day_timestamp: The Unix timestamp marking the end of the previous day. - - Returns: - A tuple containing the total staked amounts for yesterday, today, - and the total counts of correct predictions and slots evaluated. - """ - - total_staked_yesterday = ( - total_staked_today - ) = total_correct_predictions = total_slots_evaluated = 0 - for slot in slots: - slot_results = process_single_slot(slot, end_of_previous_day_timestamp) - if slot_results: - ( - staked_yesterday, - staked_today, - correct_predictions_count, - slots_evaluated, - ) = slot_results - total_staked_yesterday += staked_yesterday - total_staked_today += staked_today - total_correct_predictions += correct_predictions_count - total_slots_evaluated += slots_evaluated - return ( - total_staked_yesterday, - total_staked_today, - total_correct_predictions, - total_slots_evaluated, - ) - - -@enforce_types -def calculate_statistics_for_all_assets( - asset_ids: List[str], - contracts: List[ContractIdAndSPE], - start_ts_param: int, - end_ts_param: int, - network: str = "mainnet", -) -> Dict[str, Dict[str, Any]]: - """ - Calculates statistics for all provided assets based on - slot data within a specified time range. - - Args: - asset_ids: A list of asset identifiers for which statistics will be calculated. - start_ts_param: The Unix timestamp for the start of the time range. - end_ts_param: The Unix timestamp for the end of the time range. - network: The blockchain network to query ('mainnet' or 'testnet'). - - Returns: - A dictionary mapping asset IDs to another dictionary with - calculated statistics such as average accuracy and total staked amounts. - """ - slots_by_asset = fetch_slots_for_all_assets( - asset_ids, start_ts_param, end_ts_param, network - ) - - overall_stats = {} - for asset_id, slots in slots_by_asset.items(): - ( - staked_yesterday, - staked_today, - correct_predictions_count, - slots_evaluated, - ) = aggregate_statistics(slots, end_ts_param - SECONDS_IN_A_DAY) - average_accuracy = ( - 0 - if correct_predictions_count == 0 - else (correct_predictions_count / slots_evaluated) * 100 - ) - - # filter contracts to get the contract with the current asset id - contract = next( - (contract for contract in contracts if contract["id"] == asset_id), - None, - ) - - overall_stats[asset_id] = { - "token_name": contract["name"] if contract else None, - "average_accuracy": average_accuracy, - "total_staked_yesterday": staked_yesterday, - "total_staked_today": staked_today, - } - return overall_stats diff --git a/pdr_backend/trader/approach1/test/test_ccxt_exchanges.py b/pdr_backend/util/test_ganache/test_ccxt_exchanges.py similarity index 100% rename from pdr_backend/trader/approach1/test/test_ccxt_exchanges.py rename to pdr_backend/util/test_ganache/test_ccxt_exchanges.py diff --git a/pdr_backend/util/test_ganache/test_contract.py b/pdr_backend/util/test_ganache/test_contract.py deleted file mode 100644 index 5caf80854..000000000 --- a/pdr_backend/util/test_ganache/test_contract.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path - -from enforce_typing import enforce_types - -from pdr_backend.util.contract import ( - get_address, - get_addresses, - get_contract_abi, - get_contract_filename, -) - - -@enforce_types -def test_get_address(chain_id): - result = get_address(chain_id, "Ocean") - assert result is not None - - -@enforce_types -def test_get_addresses(chain_id): - result = get_addresses(chain_id) - assert result is not None - - -@enforce_types -def test_get_contract_abi(): - result = get_contract_abi("ERC20Template3") - assert len(result) > 0 and isinstance(result, list) - - -@enforce_types -def test_get_contract_filename(): - result = get_contract_filename("ERC20Template3") - assert result is not None and isinstance(result, Path) diff --git a/pdr_backend/util/test_ganache/test_fund_accounts.py b/pdr_backend/util/test_ganache/test_fund_accounts.py new file mode 100644 index 000000000..456db3c17 --- /dev/null +++ b/pdr_backend/util/test_ganache/test_fund_accounts.py @@ -0,0 +1,50 @@ +import os +from unittest.mock import Mock, call + +from enforce_typing import enforce_types +from eth_account import Account + +from pdr_backend.contract.token import Token +from pdr_backend.ppss.web3_pp import mock_web3_pp +from pdr_backend.util.fund_accounts import fund_accounts, fund_accounts_with_OCEAN + + +@enforce_types +def test_fund_accounts_with_OCEAN(monkeypatch): + web3_pp = mock_web3_pp("development") + + path = "pdr_backend.util.fund_accounts" + + monkeypatch.setattr(f"{path}.get_address", Mock()) + monkeypatch.setattr(f"{path}.Token", Mock()) + + mock_f = Mock() + monkeypatch.setattr(f"{path}.fund_accounts", mock_f) + + fund_accounts_with_OCEAN(web3_pp) + mock_f.assert_called() + + +@enforce_types +def test_fund_accounts(monkeypatch): + pk = os.getenv("PRIVATE_KEY") + monkeypatch.setenv("PREDICTOOR_PRIVATE_KEY", pk) + monkeypatch.setenv("PREDICTOOR2_PRIVATE_KEY", pk) + + mock_token = Mock(spec=Token) + mock_account = Mock(spec=str) + + accounts_to_fund = [ + ("PREDICTOOR_PRIVATE_KEY", 2000), + ("PREDICTOOR2_PRIVATE_KEY", 3000), + ] + + fund_accounts(accounts_to_fund, mock_account, mock_token) + + a = Account.from_key(private_key=pk) # pylint: disable=no-value-for-parameter + mock_token.transfer.assert_has_calls( + [ + call(a.address, 2e21, mock_account), + call(a.address, 3e21, mock_account), + ] + ) diff --git a/pdr_backend/util/test_ganache/test_networkutil.py b/pdr_backend/util/test_ganache/test_networkutil.py index 51c9376bc..4a3be329c 100644 --- a/pdr_backend/util/test_ganache/test_networkutil.py +++ b/pdr_backend/util/test_ganache/test_networkutil.py @@ -1,78 +1,37 @@ -from unittest.mock import Mock -import os - -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types -from pdr_backend.util.constants import ( - SAPPHIRE_TESTNET_CHAINID, - SAPPHIRE_MAINNET_CHAINID, -) from pdr_backend.util.networkutil import ( - is_sapphire_network, - send_encrypted_tx, + get_sapphire_postfix, + get_subgraph_url, ) @enforce_types -def test_is_sapphire_network(): - assert not is_sapphire_network(0) - assert is_sapphire_network(SAPPHIRE_TESTNET_CHAINID) - assert is_sapphire_network(SAPPHIRE_MAINNET_CHAINID) - - -@enforce_types -def test_send_encrypted_tx( - mock_send_encrypted_sapphire_tx, # pylint: disable=redefined-outer-name - ocean_token, - web3_config, -): - # Set up dummy return value for the mocked function - mock_send_encrypted_sapphire_tx.return_value = ( - 0, - "dummy_tx_hash", - ) - # Sample inputs for send_encrypted_tx - function_name = "transfer" - args = [web3_config.owner, 100] - pk = os.getenv("PRIVATE_KEY") - sender = web3_config.owner - receiver = web3_config.w3.eth.accounts[1] - rpc_url = "http://localhost:8545" - value = 0 - gasLimit = 10000000 - gasCost = 0 - nonce = 0 - tx_hash, encrypted_data = send_encrypted_tx( - ocean_token.contract_instance, - function_name, - args, - pk, - sender, - receiver, - rpc_url, - value, - gasLimit, - gasCost, - nonce, +def test_get_sapphire_postfix(): + assert get_sapphire_postfix("sapphire-testnet"), "testnet" + assert get_sapphire_postfix("sapphire-mainnet"), "mainnet" + + unwanteds = [ + "oasis_saphire_testnet", + "saphire_mainnet", + "barge-pytest", + "barge-predictoor-bot", + "development", + "foo", + "", + ] + for unwanted in unwanteds: + with pytest.raises(ValueError): + assert get_sapphire_postfix(unwanted) + + +def test_get_subgraph(): + expected = ( + "https://v4.subgraph.sapphire-testnet.oceanprotocol.com/" + "subgraphs/name/oceanprotocol/ocean-subgraph" ) - assert tx_hash == 0 - assert encrypted_data == "dummy_tx_hash" - mock_send_encrypted_sapphire_tx.assert_called_once_with( - pk, - sender, - receiver, - rpc_url, - value, - gasLimit, - ocean_token.contract_instance.encodeABI(fn_name=function_name, args=args), - gasCost, - nonce, - ) - + assert get_subgraph_url("testnet") == expected -@pytest.fixture -def mock_send_encrypted_sapphire_tx(monkeypatch): - mock_function = Mock(return_value=(0, "dummy_tx_hash")) - monkeypatch.setattr("sapphirepy.wrapper.send_encrypted_sapphire_tx", mock_function) - return mock_function + with pytest.raises(ValueError): + get_subgraph_url("sapphire-testnet") diff --git a/pdr_backend/util/test_ganache/test_predictoor_stats b/pdr_backend/util/test_ganache/test_predictoor_stats deleted file mode 100644 index 0fa8cd542..000000000 --- a/pdr_backend/util/test_ganache/test_predictoor_stats +++ /dev/null @@ -1,108 +0,0 @@ -from typing import List, Set -from enforce_typing import enforce_types - -from pdr_backend.util.predictoor_stats import ( - aggregate_prediction_statistics, - get_endpoint_statistics, - get_cli_statistics, -) - -from pdr_backend.util.subgraph_predictions import ( - Prediction, -) - -sample_predictions = [ - Prediction( - pair="ADA/USDT", - timeframe="5m", - prediction=True, - stake=0.050051425480971974, - trueval=False, - timestamp=1698527100, - source="binance", - payout=0.0, - user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - pair="ADA/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1698527700, - source="binance", - payout=0.0, - user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", - ), -] - - -@enforce_types -def test_aggregate_prediction_statistics(): - stats, correct_predictions = aggregate_prediction_statistics(sample_predictions) - assert isinstance(stats, dict) - assert "pair_timeframe" in stats - assert "predictor" in stats - assert correct_predictions == 1 # Adjust based on your sample data - - -@enforce_types -def test_get_endpoint_statistics(): - accuracy, pair_timeframe_stats, predictoor_stats = get_endpoint_statistics( - sample_predictions - ) - assert isinstance(accuracy, float) - assert isinstance(pair_timeframe_stats, List) # List[PairTimeframeStat] - assert isinstance(predictoor_stats, List) # List[PredictoorStat] - for pair_timeframe_stat in pair_timeframe_stats: - assert isinstance(pair_timeframe_stat, dict) - assert "pair" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["pair"], str - ) - assert "timeframe" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["timeframe"], str - ) - assert "accuracy" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["accuracy"], float - ) - assert "stake" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["stake"], float - ) - assert "payout" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["payout"], float - ) - assert "number_of_predictions" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["number_of_predictions"], int - ) - - for predictoor_stat in predictoor_stats: - assert isinstance(predictoor_stat, dict) and len(predictoor_stat) == 6 - assert "predictoor_address" in predictoor_stat and isinstance( - predictoor_stat["predictoor_address"], str - ) - assert "accuracy" in predictoor_stat and isinstance( - predictoor_stat["accuracy"], float - ) - assert "stake" in predictoor_stat and isinstance( - predictoor_stat["stake"], float - ) - assert "payout" in predictoor_stat and isinstance( - predictoor_stat["payout"], float - ) - assert "number_of_predictions" in predictoor_stat and isinstance( - predictoor_stat["number_of_predictions"], int - ) - assert "details" in predictoor_stat and isinstance( - predictoor_stat["details"], Set - ) - assert len(predictoor_stat["details"]) == 1 - - -@enforce_types -def test_get_cli_statistics(capsys): - get_cli_statistics(sample_predictions) - captured = capsys.readouterr() - output = captured.out - assert "Overall Accuracy" in output - assert "Accuracy for Pair" in output - assert "Accuracy for Predictoor Address" in output diff --git a/pdr_backend/util/test_ganache/test_predictor_stats.py b/pdr_backend/util/test_ganache/test_predictor_stats.py deleted file mode 100644 index 0fa8cd542..000000000 --- a/pdr_backend/util/test_ganache/test_predictor_stats.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import List, Set -from enforce_typing import enforce_types - -from pdr_backend.util.predictoor_stats import ( - aggregate_prediction_statistics, - get_endpoint_statistics, - get_cli_statistics, -) - -from pdr_backend.util.subgraph_predictions import ( - Prediction, -) - -sample_predictions = [ - Prediction( - pair="ADA/USDT", - timeframe="5m", - prediction=True, - stake=0.050051425480971974, - trueval=False, - timestamp=1698527100, - source="binance", - payout=0.0, - user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", - ), - Prediction( - pair="ADA/USDT", - timeframe="5m", - prediction=True, - stake=0.0500, - trueval=True, - timestamp=1698527700, - source="binance", - payout=0.0, - user="0xd2a24cb4ff2584bad80ff5f109034a891c3d88dd", - ), -] - - -@enforce_types -def test_aggregate_prediction_statistics(): - stats, correct_predictions = aggregate_prediction_statistics(sample_predictions) - assert isinstance(stats, dict) - assert "pair_timeframe" in stats - assert "predictor" in stats - assert correct_predictions == 1 # Adjust based on your sample data - - -@enforce_types -def test_get_endpoint_statistics(): - accuracy, pair_timeframe_stats, predictoor_stats = get_endpoint_statistics( - sample_predictions - ) - assert isinstance(accuracy, float) - assert isinstance(pair_timeframe_stats, List) # List[PairTimeframeStat] - assert isinstance(predictoor_stats, List) # List[PredictoorStat] - for pair_timeframe_stat in pair_timeframe_stats: - assert isinstance(pair_timeframe_stat, dict) - assert "pair" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["pair"], str - ) - assert "timeframe" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["timeframe"], str - ) - assert "accuracy" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["accuracy"], float - ) - assert "stake" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["stake"], float - ) - assert "payout" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["payout"], float - ) - assert "number_of_predictions" in pair_timeframe_stat and isinstance( - pair_timeframe_stat["number_of_predictions"], int - ) - - for predictoor_stat in predictoor_stats: - assert isinstance(predictoor_stat, dict) and len(predictoor_stat) == 6 - assert "predictoor_address" in predictoor_stat and isinstance( - predictoor_stat["predictoor_address"], str - ) - assert "accuracy" in predictoor_stat and isinstance( - predictoor_stat["accuracy"], float - ) - assert "stake" in predictoor_stat and isinstance( - predictoor_stat["stake"], float - ) - assert "payout" in predictoor_stat and isinstance( - predictoor_stat["payout"], float - ) - assert "number_of_predictions" in predictoor_stat and isinstance( - predictoor_stat["number_of_predictions"], int - ) - assert "details" in predictoor_stat and isinstance( - predictoor_stat["details"], Set - ) - assert len(predictoor_stat["details"]) == 1 - - -@enforce_types -def test_get_cli_statistics(capsys): - get_cli_statistics(sample_predictions) - captured = capsys.readouterr() - output = captured.out - assert "Overall Accuracy" in output - assert "Accuracy for Pair" in output - assert "Accuracy for Predictoor Address" in output diff --git a/pdr_backend/util/test_ganache/test_subgraph.py b/pdr_backend/util/test_ganache/test_subgraph.py deleted file mode 100644 index 0ecf6d358..000000000 --- a/pdr_backend/util/test_ganache/test_subgraph.py +++ /dev/null @@ -1,364 +0,0 @@ -from unittest.mock import patch - -from enforce_typing import enforce_types -import pytest -import requests -from web3 import Web3 -from pytest import approx - -from pdr_backend.models.slot import Slot -from pdr_backend.util.subgraph import ( - block_number_is_synced, - key_to_725, - value_to_725, - value_from_725, - info_from_725, - query_subgraph, - query_feed_contracts, - get_pending_slots, - get_consume_so_far_per_contract, -) - - -@enforce_types -def test_key(): - key = "name" - key725 = key_to_725(key) - assert key725 == Web3.keccak(key.encode("utf-8")).hex() - - -@enforce_types -def test_value(): - value = "ETH/USDT" - value725 = value_to_725(value) - value_again = value_from_725(value725) - - assert value == value_again - assert value == Web3.to_text(hexstr=value725) - - -@enforce_types -def test_info_from_725(): - info725_list = [ - {"key": key_to_725("pair"), "value": value_to_725("ETH/USDT")}, - {"key": key_to_725("timeframe"), "value": value_to_725("5m")}, - ] - info_dict = info_from_725(info725_list) - assert info_dict == { - "pair": "ETH/USDT", - "timeframe": "5m", - "base": None, - "quote": None, - "source": None, - } - - -@enforce_types -class MockResponse: - def __init__(self, contract_list: list, status_code: int): - self.contract_list = contract_list - self.status_code = status_code - self.num_queries = 0 - - def json(self) -> dict: - self.num_queries += 1 - if self.num_queries > 1: - self.contract_list = [] - return {"data": {"predictContracts": self.contract_list}} - - -@enforce_types -class MockPost: - def __init__(self, contract_list: list = [], status_code: int = 200): - self.response = MockResponse(contract_list, status_code) - - def __call__(self, *args, **kwargs): - return self.response - - -@enforce_types -def test_query_subgraph_happypath(monkeypatch): - monkeypatch.setattr(requests, "post", MockPost(status_code=200)) - result = query_subgraph(subgraph_url="foo", query="bar") - assert result == {"data": {"predictContracts": []}} - - -@enforce_types -def test_query_subgraph_badpath(monkeypatch): - monkeypatch.setattr(requests, "post", MockPost(status_code=400)) - with pytest.raises(Exception): - query_subgraph(subgraph_url="foo", query="bar") - - -@enforce_types -def test_get_contracts_emptychain(monkeypatch): - contract_list = [] - monkeypatch.setattr(requests, "post", MockPost(contract_list)) - contracts = query_feed_contracts(subgraph_url="foo") - assert contracts == {} - - -@enforce_types -def test_query_feed_contracts_fullchain(monkeypatch): - # This test is a simple-as-possible happy path. Start here. - # Then follow up with test_filter() below, which is complex but thorough - info725_list = [ - {"key": key_to_725("pair"), "value": value_to_725("ETH/USDT")}, - {"key": key_to_725("timeframe"), "value": value_to_725("5m")}, - ] - - contract1 = { - "id": "contract1", - "token": { - "id": "token1", - "name": "ether", - "symbol": "ETH", - "nft": { - "owner": { - "id": "0xowner", - }, - "nftData": info725_list, - }, - }, - "secondsPerEpoch": 7, - "secondsPerSubscription": 700, - "truevalSubmitTimeout": 5, - } - contract_list = [contract1] - monkeypatch.setattr(requests, "post", MockPost(contract_list)) - feed_dicts = query_feed_contracts(subgraph_url="foo") - assert list(feed_dicts.keys()) == ["contract1"] - feed_dict = feed_dicts["contract1"] - assert feed_dict == { - "name": "ether", - "address": "contract1", - "symbol": "ETH", - "seconds_per_epoch": 7, - "seconds_per_subscription": 700, - "trueval_submit_timeout": 5, - "owner": "0xowner", - "last_submited_epoch": 0, - "pair": "ETH/USDT", - "base": None, - "quote": None, - "source": None, - "timeframe": "5m", - } - - -@enforce_types -@pytest.mark.parametrize( - "expect_result, pairs, timeframes, sources, owners", - [ - (True, None, None, None, None), - (True, "ETH/USDT", "5m", "binance", "0xowner1"), - (True, "ETH/USDT,BTC/USDT", "5m,15m", "binance,kraken", "0xowner1,o2"), - (True, "ETH/USDT", None, None, None), - (False, "BTC/USDT", None, None, None), - (True, "ETH/USDT,BTC/USDT", None, None, None), - (True, None, "5m", None, None), - (False, None, "15m", None, None), - (True, None, "5m,15m", None, None), - (True, None, None, "binance", None), - (False, None, None, "kraken", None), - (True, None, None, "binance,kraken", None), - (True, None, None, None, "0xowner1"), - (False, None, None, None, "owner2"), - (True, None, None, None, "0xowner1,owner2"), - (True, None, None, None, ""), - (True, "", "", "", ""), - (True, None, None, "", "0xowner1,owner2"), - ], -) -def test_filter(monkeypatch, expect_result, pairs, timeframes, sources, owners): - info725_list = [ - {"key": key_to_725("pair"), "value": value_to_725("ETH/USDT")}, - {"key": key_to_725("timeframe"), "value": value_to_725("5m")}, - {"key": key_to_725("source"), "value": value_to_725("binance")}, - {"key": key_to_725("base"), "value": value_to_725("USDT")}, - {"key": key_to_725("quote"), "value": value_to_725("1400.1")}, - {"key": key_to_725("extra1"), "value": value_to_725("extra1_value")}, - {"key": key_to_725("extra2"), "value": value_to_725("extra2_value")}, - ] - - contract1 = { - "id": "contract1", - "token": { - "id": "token1", - "name": "ether", - "symbol": "ETH", - "nft": { - "owner": { - "id": "0xowner1", - }, - "nftData": info725_list, - }, - }, - "secondsPerEpoch": 7, - "secondsPerSubscription": 700, - "truevalSubmitTimeout": 5, - } - - contract_list = [contract1] - - monkeypatch.setattr(requests, "post", MockPost(contract_list)) - feed_dicts = query_feed_contracts("foo", pairs, timeframes, sources, owners) - - assert bool(feed_dicts) == bool(expect_result) - - -@enforce_types -def test_get_pending_slots(): - sample_slot_data = [ - { - "id": "slot1", - "slot": 1000, - "trueValues": [], - "predictContract": { - "id": "contract1", - "token": { - "id": "token1", - "name": "ether", - "symbol": "ETH", - "nft": { - "owner": {"id": "0xowner1"}, - "nftData": [ - { - "key": key_to_725("pair"), - "value": value_to_725("ETH/USDT"), - }, - { - "key": key_to_725("timeframe"), - "value": value_to_725("5m"), - }, - { - "key": key_to_725("source"), - "value": value_to_725("binance"), - }, - ], - }, - }, - "secondsPerEpoch": 7, - "secondsPerSubscription": 700, - "truevalSubmitTimeout": 5, - }, - } - ] - - call_count = 0 - - def mock_query_subgraph(subgraph_url, query): # pylint:disable=unused-argument - nonlocal call_count - slot_data = sample_slot_data if call_count <= 1 else [] - call_count += 1 - return {"data": {"predictSlots": slot_data}} - - with patch("pdr_backend.util.subgraph.query_subgraph", mock_query_subgraph): - slots = get_pending_slots( - subgraph_url="foo", - timestamp=2000, - owner_addresses=None, - pair_filter=None, - timeframe_filter=None, - source_filter=None, - ) - - assert len(slots) == 2 - slot0 = slots[0] - assert isinstance(slot0, Slot) - assert slot0.slot_number == 1000 - assert slot0.feed.name == "ether" - - -@enforce_types -def test_get_consume_so_far_per_contract(): - sample_contract_data = [ - { - "id": "contract1", - "token": { - "id": "token1", - "name": "ether", - "symbol": "ETH", - "orders": [ - { - "createdTimestamp": 1695288424, - "consumer": { - "id": "0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a" - }, - "lastPriceValue": "2.4979184013322233", - }, - { - "createdTimestamp": 1695288724, - "consumer": { - "id": "0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a" - }, - "lastPriceValue": "2.4979184013322233", - }, - ], - "nft": { - "owner": {"id": "0xowner1"}, - "nftData": [ - { - "key": key_to_725("pair"), - "value": value_to_725("ETH/USDT"), - }, - { - "key": key_to_725("timeframe"), - "value": value_to_725("5m"), - }, - { - "key": key_to_725("source"), - "value": value_to_725("binance"), - }, - ], - }, - }, - "secondsPerEpoch": 7, - "secondsPerSubscription": 700, - "truevalSubmitTimeout": 5, - } - ] - - call_count = 0 - - def mock_query_subgraph( - subgraph_url, query, tries, timeout - ): # pylint:disable=unused-argument - nonlocal call_count - slot_data = sample_contract_data - - if call_count > 0: - slot_data[0]["token"]["orders"] = [] - - call_count += 1 - return {"data": {"predictContracts": slot_data}} - - with patch("pdr_backend.util.subgraph.query_subgraph", mock_query_subgraph): - consumes = get_consume_so_far_per_contract( - subgraph_url="foo", - user_address="0xff8dcdfc0a76e031c72039b7b1cd698f8da81a0a", - since_timestamp=2000, - contract_addresses=["contract1"], - ) - - assert consumes["contract1"] == approx(6, 0.001) - - -def test_block_number_is_synced(): - def mock_response(url: str, query: str): # pylint:disable=unused-argument - if "number:50" in query: - return { - "errors": [ - { - # pylint: disable=line-too-long - "message": "Failed to decode `block.number` value: `subgraph QmaGAi4jQw5L8J2xjnofAJb1PX5LLqRvGjsWbVehBELAUx only has data starting at block number 499 and data for block number 500 is therefore not available`" - } - ] - } - - return {"data": {"predictContracts": [{"id": "sample_id"}]}} - - with patch("pdr_backend.util.subgraph.query_subgraph", side_effect=mock_response): - assert block_number_is_synced("foo", 499) is True - assert block_number_is_synced("foo", 500) is False - assert block_number_is_synced("foo", 501) is False diff --git a/pdr_backend/util/test_ganache/test_subgraph_slot.py b/pdr_backend/util/test_ganache/test_subgraph_slot.py deleted file mode 100644 index 5f41ecd92..000000000 --- a/pdr_backend/util/test_ganache/test_subgraph_slot.py +++ /dev/null @@ -1,180 +0,0 @@ -from unittest.mock import patch -from dataclasses import asdict -from typing import Dict, List -from enforce_typing import enforce_types - -from pdr_backend.util.subgraph_slot import ( - get_predict_slots_query, - get_slots, - fetch_slots_for_all_assets, - calculate_prediction_result, - process_single_slot, - aggregate_statistics, - calculate_statistics_for_all_assets, - PredictSlot, -) -from pdr_backend.util.subgraph_predictions import ContractIdAndSPE - -# Sample data for tests -SAMPLE_PREDICT_SLOT = PredictSlot( - id="1-12345", - slot="12345", - trueValues=[{"id": "1", "trueValue": True}], - roundSumStakesUp=150.0, - roundSumStakes=100.0, -) - - -@enforce_types -def test_get_predict_slots_query(): - # Test the get_predict_slots_query function with expected inputs and outputs - query = get_predict_slots_query( - asset_ids=["0xAsset"], initial_slot=1000, last_slot=2000, first=10, skip=0 - ) - assert "predictSlots" in query - assert "0xAsset" in query - assert "1000" in query - assert "2000" in query - - -# Sample data for tests -SAMPLE_PREDICT_SLOT = PredictSlot( - id="0xAsset-12345", - slot="12345", - trueValues=[{"id": "1", "trueValue": True}], - roundSumStakesUp=150.0, - roundSumStakes=100.0, -) - - -MOCK_QUERY_RESPONSE = {"data": {"predictSlots": [asdict(SAMPLE_PREDICT_SLOT)]}} - -MOCK_QUERY_RESPONSE_FIRST_CALL = { - "data": { - "predictSlots": [asdict(SAMPLE_PREDICT_SLOT)] - * 1000 # Simulate a full page of results - } -} - -MOCK_QUERY_RESPONSE_SECOND_CALL: Dict[str, Dict[str, list]] = { - "data": {"predictSlots": []} # Simulate no further results, stopping the recursion -} - - -@enforce_types -@patch("pdr_backend.util.subgraph_slot.query_subgraph") -def test_get_slots(mock_query_subgraph): - # Configure the mock to return a full page of results on the first call, - # and no results on the second call - mock_query_subgraph.side_effect = [ - MOCK_QUERY_RESPONSE_FIRST_CALL, - MOCK_QUERY_RESPONSE_SECOND_CALL, - ] - - result_slots = get_slots( - addresses=["0xAsset"], - end_ts_param=2000, - start_ts_param=1000, - skip=0, - slots=[], - network="mainnet", - ) - - print("test_get_slots", result_slots) - - # Verify that the mock was called twice (once for the initial call, once for the recursive call) - assert mock_query_subgraph.call_count == 2 - # Verify that the result contains the expected number of slots - assert len(result_slots) == 1000 - # Verify that the slots contain instances of PredictSlot - assert isinstance(result_slots[0], PredictSlot) - # Verify the first slot's data matches the sample - assert result_slots[0].id == "0xAsset-12345" - - -@enforce_types -def test_calculate_prediction_result(): - # Test the calculate_prediction_prediction_result function with expected inputs - result = calculate_prediction_result(150.0, 200.0) - assert result - - result = calculate_prediction_result(100.0, 250.0) - assert not result - - -@enforce_types -def test_process_single_slot(): - # Test the process_single_slot function - ( - staked_yesterday, - staked_today, - correct_predictions, - slots_evaluated, - ) = process_single_slot( - slot=SAMPLE_PREDICT_SLOT, end_of_previous_day_timestamp=12340 - ) - - assert staked_yesterday == 0.0 - assert staked_today == 100.0 - assert correct_predictions == 1 - assert slots_evaluated == 1 - - -@enforce_types -def test_aggregate_statistics(): - # Test the aggregate_statistics function - ( - total_staked_yesterday, - total_staked_today, - total_correct_predictions, - total_slots_evaluated, - ) = aggregate_statistics( - slots=[SAMPLE_PREDICT_SLOT], end_of_previous_day_timestamp=12340 - ) - assert total_staked_yesterday == 0.0 - assert total_staked_today == 100.0 - assert total_correct_predictions == 1 - assert total_slots_evaluated == 1 - - -@enforce_types -@patch("pdr_backend.util.subgraph_slot.fetch_slots_for_all_assets") -def test_calculate_statistics_for_all_assets(mock_fetch_slots): - # Set up the mock to return a predetermined value - mock_fetch_slots.return_value = {"0xAsset": [SAMPLE_PREDICT_SLOT] * 1000} - # Contracts List - contracts: List[ContractIdAndSPE] = [ - {"id": "0xAsset", "seconds_per_epoch": 300, "name": "TEST/USDT"} - ] - # Test the calculate_statistics_for_all_assets function - statistics = calculate_statistics_for_all_assets( - asset_ids=["0xAsset"], - contracts=contracts, - start_ts_param=1000, - end_ts_param=2000, - network="mainnet", - ) - # Verify that the statistics are calculated as expected - assert statistics["0xAsset"]["average_accuracy"] == 100.0 - # Verify that the mock was called as expected - mock_fetch_slots.assert_called_once_with(["0xAsset"], 1000, 2000, "mainnet") - - -@enforce_types -@patch( - "pdr_backend.util.subgraph_slot.query_subgraph", return_value=MOCK_QUERY_RESPONSE -) -def test_fetch_slots_for_all_assets(mock_query_subgraph): - # Test the fetch_slots_for_all_assets function - result = fetch_slots_for_all_assets( - asset_ids=["0xAsset"], start_ts_param=1000, end_ts_param=2000, network="mainnet" - ) - - print("test_fetch_slots_for_all_assets", result) - # Verify that the result is structured correctly - assert "0xAsset" in result - assert all(isinstance(slot, PredictSlot) for slot in result["0xAsset"]) - assert len(result["0xAsset"]) == 1 - assert result["0xAsset"][0].id == "0xAsset-12345" - # Verify that the mock was called - mock_query_subgraph.assert_called() diff --git a/pdr_backend/util/test_ganache/test_web3_config.py b/pdr_backend/util/test_ganache/test_web3_config.py index 0887cb866..5c4aceb79 100644 --- a/pdr_backend/util/test_ganache/test_web3_config.py +++ b/pdr_backend/util/test_ganache/test_web3_config.py @@ -1,27 +1,27 @@ import os -from enforce_typing import enforce_types + import pytest +from enforce_typing import enforce_types +from pdr_backend.util.constants import MAX_UINT from pdr_backend.util.web3_config import Web3Config @enforce_types def test_Web3Config_bad_rpc(): private_key = os.getenv("PRIVATE_KEY") - with pytest.raises(ValueError): + with pytest.raises(TypeError): Web3Config(rpc_url=None, private_key=private_key) @enforce_types -def test_Web3Config_bad_key(): - rpc_url = os.getenv("RPC_URL") +def test_Web3Config_bad_key(rpc_url): with pytest.raises(ValueError): Web3Config(rpc_url=rpc_url, private_key="foo") @enforce_types -def test_Web3Config_happy_noPrivateKey(): - rpc_url = os.getenv("RPC_URL") +def test_Web3Config_happy_noPrivateKey(rpc_url): c = Web3Config(rpc_url=rpc_url, private_key=None) assert c.w3 is not None @@ -31,9 +31,8 @@ def test_Web3Config_happy_noPrivateKey(): @enforce_types -def test_Web3Config_happy_havePrivateKey_noKeywords(): +def test_Web3Config_happy_havePrivateKey_noKeywords(rpc_url): private_key = os.getenv("PRIVATE_KEY") - rpc_url = os.getenv("RPC_URL") c = Web3Config(rpc_url, private_key) assert c.account assert c.owner == c.account.address @@ -41,9 +40,8 @@ def test_Web3Config_happy_havePrivateKey_noKeywords(): @enforce_types -def test_Web3Config_happy_havePrivateKey_withKeywords(): +def test_Web3Config_happy_havePrivateKey_withKeywords(rpc_url): private_key = os.getenv("PRIVATE_KEY") - rpc_url = os.getenv("RPC_URL") c = Web3Config(rpc_url=rpc_url, private_key=private_key) assert c.account assert c.owner == c.account.address @@ -51,9 +49,8 @@ def test_Web3Config_happy_havePrivateKey_withKeywords(): @enforce_types -def test_Web3Config_get_block_latest(): +def test_Web3Config_get_block_latest(rpc_url): private_key = os.getenv("PRIVATE_KEY") - rpc_url = os.getenv("RPC_URL") c = Web3Config(rpc_url=rpc_url, private_key=private_key) block = c.get_block("latest") assert block @@ -61,10 +58,30 @@ def test_Web3Config_get_block_latest(): @enforce_types -def test_Web3Config_get_block_0(): +def test_Web3Config_get_block_0(rpc_url): private_key = os.getenv("PRIVATE_KEY") - rpc_url = os.getenv("RPC_URL") c = Web3Config(rpc_url=rpc_url, private_key=private_key) block = c.get_block(0) assert block assert block["timestamp"] > 0 + + +@enforce_types +def test_Web3Config_get_auth_signature(rpc_url): + private_key = os.getenv("PRIVATE_KEY") + c = Web3Config(rpc_url=rpc_url, private_key=private_key) + auth = c.get_auth_signature() + + # just a super basic test + assert sorted(auth.keys()) == sorted(["userAddress", "v", "r", "s", "validUntil"]) + + +@enforce_types +def test_get_max_gas(rpc_url): + private_key = os.getenv("PRIVATE_KEY") + web3_config = Web3Config(rpc_url=rpc_url, private_key=private_key) + max_gas = web3_config.get_max_gas() + assert 0 < max_gas < MAX_UINT + + target_max_gas = int(web3_config.get_block("latest").gasLimit * 0.99) + assert max_gas == target_max_gas diff --git a/pdr_backend/util/test_noganache/conftest.py b/pdr_backend/util/test_noganache/conftest.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pdr_backend/util/test_noganache/test_cache.py b/pdr_backend/util/test_noganache/test_cache.py index 694f64d26..cd0393e89 100644 --- a/pdr_backend/util/test_noganache/test_cache.py +++ b/pdr_backend/util/test_noganache/test_cache.py @@ -1,5 +1,6 @@ import os from pathlib import Path + import pytest from pdr_backend.util.cache import Cache diff --git a/pdr_backend/util/test_noganache/test_constants_get_opf_addrs.py b/pdr_backend/util/test_noganache/test_constants_get_opf_addrs.py new file mode 100644 index 000000000..34d7700fe --- /dev/null +++ b/pdr_backend/util/test_noganache/test_constants_get_opf_addrs.py @@ -0,0 +1,36 @@ +import pytest +from enforce_typing import enforce_types + +from pdr_backend.util.constants_opf_addrs import get_opf_addresses + + +@enforce_types +def test_get_opf_addresses_testnet(): + addrs = get_opf_addresses("sapphire-testnet") + assert len(addrs) > 3 + assert "dfbuyer" in addrs + assert "websocket" in addrs + assert "trueval" in addrs + + +@enforce_types +def test_get_opf_addresses_mainnet(): + # sapphire testnet + addrs = get_opf_addresses("sapphire-mainnet") + assert len(addrs) > 3 + assert "dfbuyer" in addrs + assert "websocket" in addrs + assert "trueval" in addrs + + +@enforce_types +def test_get_opf_addresses_other(): + for s in ( + "", + "foo", + "development", + "oasis_saphire_testnet", + "oasis_saphire", + ): + with pytest.raises(ValueError): + get_opf_addresses(s) diff --git a/pdr_backend/util/test_noganache/test_contract.py b/pdr_backend/util/test_noganache/test_contract.py new file mode 100644 index 000000000..d7662e1d7 --- /dev/null +++ b/pdr_backend/util/test_noganache/test_contract.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.util.contract import ( + _condition_sapphire_keys, + get_address, + get_addresses, + get_contract_abi, + get_contract_filename, +) + +_NETWORKS = [ + "sapphire-testnet", + "sapphire-mainnet", + "development", +] + + +@enforce_types +@pytest.mark.parametrize("network", _NETWORKS) +def test_contract_main(network): + # setup + + ppss = mock_ppss(["binance BTC/USDT c 5m"], network) + web3_pp = ppss.web3_pp + assert web3_pp.network == network + + # tests + assert get_address(web3_pp, "Ocean") is not None + + assert get_addresses(web3_pp) is not None + + result = get_contract_abi("ERC20Template3", web3_pp.address_file) + assert len(result) > 0 and isinstance(result, list) + + result = get_contract_filename("ERC20Template3", web3_pp.address_file) + assert result is not None and isinstance(result, Path) + + with pytest.raises(TypeError): + get_contract_abi("xyz", web3_pp.address_file) + + +@enforce_types +def test_condition_sapphire_keys(): + assert _condition_sapphire_keys({}) == {} + + assert _condition_sapphire_keys({"foo": "bar"}) == {"foo": "bar"} + + k1 = {"oasis_saphire_testnet": "test", "oasis_saphire": "main", "foo": "bar"} + k2 = { + "oasis_saphire_testnet": "test", + "oasis_saphire": "main", + "sapphire-testnet": "test", + "sapphire-mainnet": "main", + "foo": "bar", + } + assert _condition_sapphire_keys(k1) == k2 + + k3 = {"sapphire-testnet": "test", "sapphire-mainnet": "main", "foo": "bar"} + assert _condition_sapphire_keys(k3) == k3 diff --git a/pdr_backend/util/test_noganache/test_csvs.py b/pdr_backend/util/test_noganache/test_csvs.py new file mode 100644 index 000000000..35a5f91e0 --- /dev/null +++ b/pdr_backend/util/test_noganache/test_csvs.py @@ -0,0 +1,58 @@ +import csv +import os + +from pdr_backend.subgraph.prediction import mock_daily_predictions +from pdr_backend.util.csvs import save_analysis_csv, save_prediction_csv + + +def test_save_analysis_csv(tmpdir): + predictions = mock_daily_predictions() + key = ( + predictions[0].pair.replace("/", "-") + + predictions[0].timeframe + + predictions[0].source + ) + save_analysis_csv(predictions, str(tmpdir)) + + with open(os.path.join(str(tmpdir), key + ".csv")) as f: + data = csv.DictReader(f) + data_rows = list(data) + + assert data_rows[0]["Predicted Value"] == str(predictions[0].prediction) + assert data_rows[0]["True Value"] == str(predictions[0].trueval) + assert data_rows[0]["Timestamp"] == str(predictions[0].timestamp) + assert list(data_rows[0].keys()) == [ + "PredictionID", + "Timestamp", + "Slot", + "Stake", + "Wallet", + "Payout", + "True Value", + "Predicted Value", + ] + + +def test_save_prediction_csv(tmpdir): + predictions = mock_daily_predictions() + key = ( + predictions[0].pair.replace("/", "-") + + predictions[0].timeframe + + predictions[0].source + ) + save_prediction_csv(predictions, str(tmpdir)) + + with open(os.path.join(str(tmpdir), key + ".csv")) as f: + data = csv.DictReader(f) + data_rows = list(row for row in data) + + assert data_rows[0]["Predicted Value"] == str(predictions[0].prediction) + assert data_rows[0]["True Value"] == str(predictions[0].trueval) + assert data_rows[0]["Timestamp"] == str(predictions[0].timestamp) + assert list(data_rows[0].keys()) == [ + "Predicted Value", + "True Value", + "Timestamp", + "Stake", + "Payout", + ] diff --git a/pdr_backend/util/test_noganache/test_env.py b/pdr_backend/util/test_noganache/test_env.py index c8eb50973..d8442718a 100644 --- a/pdr_backend/util/test_noganache/test_env.py +++ b/pdr_backend/util/test_noganache/test_env.py @@ -1,41 +1,14 @@ -from unittest.mock import patch - -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types -from pdr_backend.util.env import getenv_or_exit, parse_filters +from pdr_backend.util.env import getenv_or_exit @enforce_types def test_getenv_or_exit(monkeypatch): - monkeypatch.delenv("RPC_URL", raising=False) + monkeypatch.delenv("MY_VAR", raising=False) with pytest.raises(SystemExit): - getenv_or_exit("RPC_URL") - - monkeypatch.setenv("RPC_URL", "http://test.url") - assert getenv_or_exit("RPC_URL") == "http://test.url" - - -@enforce_types -def test_parse_filters(): - mock_values = { - "PAIR_FILTER": "BTC-USDT,ETH-USDT", - "TIMEFRAME_FILTER": "1D,1H", - "SOURCE_FILTER": None, - "OWNER_ADDRS": "0x1234,0x5678", - } - - def mock_getenv(key, default=None): - return mock_values.get(key, default) - - with patch("pdr_backend.util.env.getenv", mock_getenv): - result = parse_filters() - - expected = ( - ["BTC-USDT", "ETH-USDT"], # pair - ["1D", "1H"], # timeframe - [], # source - ["0x1234", "0x5678"], # owner_addrs - ) + getenv_or_exit("MY_VAR") - assert result == expected, f"Expected {expected}, but got {result}" + monkeypatch.setenv("MY_VAR", "http://test.url") + assert getenv_or_exit("MY_VAR") == "http://test.url" diff --git a/pdr_backend/util/test_noganache/test_feedstr.py b/pdr_backend/util/test_noganache/test_feedstr.py deleted file mode 100644 index d64017c28..000000000 --- a/pdr_backend/util/test_noganache/test_feedstr.py +++ /dev/null @@ -1,286 +0,0 @@ -from typing import Set - -from enforce_typing import enforce_types -import pytest - -from pdr_backend.util.feedstr import ( - unpack_feeds_strs, - unpack_feeds_str, - unpack_feed_str, - verify_feeds_strs, - verify_feeds_str, - verify_feed_str, - verify_feed_tup, - verify_exchange_str, -) - - -# ========================================================================== -# unpack..() functions - - -@enforce_types -def test_unpack_feeds_strs(): - # 1 str w 1 feed, 1 feed total - feed_tups = unpack_feeds_strs(["binance o ADA/USDT"]) - assert feed_tups == [("binance", "open", "ADA-USDT")] - - # 1 str w 2 feeds, 2 feeds total - feed_tups = unpack_feeds_strs(["binance oh ADA/USDT"]) - assert feed_tups == [ - ("binance", "open", "ADA-USDT"), - ("binance", "high", "ADA-USDT"), - ] - - # 2 strs each w 1 feed, 2 feeds total - feed_tups = unpack_feeds_strs( - [ - "binance o ADA/USDT", - "kraken h ADA/RAI", - ] - ) - assert feed_tups == [ - ("binance", "open", "ADA-USDT"), - ("kraken", "high", "ADA-RAI"), - ] - - # first str has 4 feeds and second has 1 feed; 5 feeds total - feed_tups = unpack_feeds_strs( - [ - "binance oc ADA/USDT BTC/USDT", - "kraken h ADA/RAI", - ] - ) - assert sorted(feed_tups) == [ - ("binance", "close", "ADA-USDT"), - ("binance", "close", "BTC-USDT"), - ("binance", "open", "ADA-USDT"), - ("binance", "open", "BTC-USDT"), - ("kraken", "high", "ADA-RAI"), - ] - - # unhappy paths. Note: verify section has way more - lists = [ - [], - ["xyz o ADA/USDT"], - ["binance ox ADA/USDT"], - ["binance o ADA/X"], - ] - for feeds_strs in lists: - with pytest.raises(ValueError): - unpack_feeds_strs(feeds_strs) - - -@enforce_types -def test_unpack_feeds_str(): - # 1 feed - feed_tups = unpack_feeds_str("binance o ADA/USDT") - assert feed_tups == [("binance", "open", "ADA-USDT")] - - # >1 signal, so >1 feed - feed_tups = unpack_feeds_str("binance oc ADA/USDT") - assert feed_tups == [ - ("binance", "open", "ADA-USDT"), - ("binance", "close", "ADA-USDT"), - ] - - # >1 pair, so >1 feed - feed_tups = unpack_feeds_str("binance o ADA/USDT ETH/RAI") - assert feed_tups == [ - ("binance", "open", "ADA-USDT"), - ("binance", "open", "ETH-RAI"), - ] - - # >1 signal and >1 pair, so >1 feed - feed_tups = unpack_feeds_str("binance oc ADA/USDT,BTC/USDT") - assert len(feed_tups) == 4 - assert sorted(feed_tups) == [ - ("binance", "close", "ADA-USDT"), - ("binance", "close", "BTC-USDT"), - ("binance", "open", "ADA-USDT"), - ("binance", "open", "BTC-USDT"), - ] - - # unhappy paths. Note: verify section has way more - strs = [ - "xyz o ADA/USDT", - "binance ox ADA/USDT", - "binance o ADA/X", - ] - for feeds_str in strs: - with pytest.raises(ValueError): - unpack_feeds_str(feeds_str) - - # test separators between pairs: space, comma, both or a mix - # Note: verify section has way more - def _pairs(feed_tups) -> Set[str]: - return set(pair for (_, _, pair) in feed_tups) - - pairs = _pairs(unpack_feeds_str("binance o ADA/USDT BTC/USDT")) - assert pairs == set(["ADA-USDT", "BTC-USDT"]) - - pairs = _pairs(unpack_feeds_str("binance oc ADA/USDT,BTC/USDT")) - assert _pairs(feed_tups) == set(["ADA-USDT", "BTC-USDT"]) - - pairs = _pairs( - unpack_feeds_str("binance oc ADA/USDT BTC/USDT ,ETH/USDC, DOT/DAI") - ) - assert pairs == set(["ADA-USDT", "BTC-USDT", "ETH-USDC", "DOT-DAI"]) - - -@enforce_types -def test_unpack_feed_str(): - feed_tup = unpack_feed_str("binance c BTC/USDT") - exchange_str, signal, pair = feed_tup - assert exchange_str == "binance" - assert signal == "close" - assert pair == "BTC-USDT" - - -# ========================================================================== -# verify..() functions - - -@enforce_types -def test_verify_feeds_strs(): - # ok for verify_feeds_strs - lists = [ - ["binance o ADA/USDT"], - ["binance oc ADA/USDT BTC/USDT", "kraken h ADA/RAI"], - ] - for feeds_strs in lists: - verify_feeds_strs(feeds_strs) - - # not ok for verify_feeds_strs - lists = [ - [], - [""], - ["binance xoc ADA/USDT BTC/USDT", "kraken h ADA/RAI"], - ["", "kraken h ADA/RAI"], - ] - for feeds_strs in lists: - with pytest.raises(ValueError): - verify_feeds_strs(feeds_strs) - - -@enforce_types -def test_verify_feeds_str__and__verify_feed_str(): - # ok for verify_feeds_str, ok for verify_feed_str - # (well-formed 1 signal and 1 pair) - strs = [ - "binance o ADA/USDT", - "binance o ADA-USDT", - " binance o ADA/USDT", - "binance o ADA/USDT", - " binance o ADA/USDT ", - " binance o ADA/USDT ", - ] - for feed_str in strs: - verify_feed_str(feed_str) - for feeds_str in strs: - verify_feeds_str(feeds_str) - - # not ok for verify_feed_str, ok for verify_feeds_str - # (well-formed >1 signal or >1 pair) - strs = [ - "binance oh ADA/USDT", - " binance oh ADA/USDT", - "binance o ADA/USDT BTC/USDT", - " binance o ADA/USDT BTC/USDT ", - "binance o ADA/USDT, BTC/USDT ,ETH/USDC , DOT/DAI", - " binance o ADA/USDT, BTC/USDT ,ETH/USDC , DOT/DAI ", - ] - for feed_str in strs: - with pytest.raises(ValueError): - verify_feed_str(feed_str) - for feeds_str in strs: - verify_feeds_str(feeds_str) - - # not ok for verify_feed_str, not ok for verify_feeds_str - # (poorly formed) - strs = [ - "", - " ", - ",", - " , ", - " , ,", - " xyz ", - " xyz abc ", - "binance o", - "binance o ", - "binance o ,", - "o ADA/USDT", - "binance ADA/USDT", - "binance,ADA/USDT", - "binance , ADA/USDT", - "xyz o ADA/USDT", # catch non-exchanges! - "binancexyz o ADA/USDT", - "binance ohz ADA/USDT", - "binance z ADA/USDT", - "binance , o ADA/USDT", - "binance o , ADA/USDT", - "binance , o , ADA/USDT", - "binance,o,ADA/USDT", - "binance o XYZ", - "binance o USDT", - "binance o ADA/", - "binance o /USDT", - "binance o ADA:USDT", - "binance o ADA::USDT", - "binance o ADA,USDT", - "binance o ADA&USDT", - "binance o ADA/USDT XYZ", - ] - for feed_str in strs: - with pytest.raises(ValueError): - verify_feed_str(feed_str) - - for feeds_str in strs: - with pytest.raises(ValueError): - verify_feeds_str(feeds_str) - - -@enforce_types -def test_verify_feed_tup(): - # ok - tups = [ - ("binance", "open", "BTC/USDT"), - ("kraken", "close", "BTC/DAI"), - ] - for feed_tup in tups: - verify_feed_tup(feed_tup) - - # not ok - tups = [ - (), - ("binance", "open"), - ("binance", "open", ""), - ("xyz", "open", "BTC/USDT"), - ("binance", "xyz", "BTC/USDT"), - ("binance", "open", "BTC/XYZ"), - ("binance", "open", "BTC/USDT", ""), - ] - for feed_tup in tups: - with pytest.raises(ValueError): - verify_feed_tup(feed_tup) - - -@enforce_types -def test_verify_exchange_str(): - # ok - strs = [ - "binance", - "kraken", - ] - for exchange_str in strs: - verify_exchange_str(exchange_str) - - # not ok - strs = [ - "", - " ", - "xyz", - ] - for exchange_str in strs: - with pytest.raises(ValueError): - verify_exchange_str(exchange_str) diff --git a/pdr_backend/util/test_noganache/test_listutil.py b/pdr_backend/util/test_noganache/test_listutil.py new file mode 100644 index 000000000..eecabdd18 --- /dev/null +++ b/pdr_backend/util/test_noganache/test_listutil.py @@ -0,0 +1,9 @@ +from pdr_backend.util.listutil import remove_dups + + +def test_remove_dups(): + assert remove_dups([]) == [] + assert remove_dups([3]) == [3] + assert remove_dups(["foo"]) == ["foo"] + assert remove_dups([3, 3]) == [3] + assert remove_dups([3, "foo", "foo", 3, 4, 10, 4, 9]) == [3, "foo", 4, 10, 9] diff --git a/pdr_backend/util/test_noganache/test_mathutil.py b/pdr_backend/util/test_noganache/test_mathutil.py index e1a8ba984..5dec4b626 100644 --- a/pdr_backend/util/test_noganache/test_mathutil.py +++ b/pdr_backend/util/test_noganache/test_mathutil.py @@ -1,17 +1,24 @@ -from enforce_typing import enforce_types import numpy as np import pandas as pd +import polars as pl import pytest +from enforce_typing import enforce_types from pdr_backend.util.mathutil import ( - isNumber, - intInStr, Range, - randunif, - round_sig, - has_nan, + all_nan, fill_nans, + from_wei, + has_nan, + intInStr, + isNumber, nmse, + randunif, + round_sig, + sole_value, + str_with_wei, + string_to_bytes32, + to_wei, ) @@ -142,43 +149,130 @@ def test_round_sig(): @enforce_types -def test_has_nan(): +def test_all_nan__or_None(): + # 1d array + assert not all_nan(np.array([1.0, 2.0, 3.0, 4.0])) + assert not all_nan(np.array([1.0, None, 3.0, 4.0])) + assert not all_nan(np.array([1.0, 2.0, np.nan, 4.0])) + assert not all_nan(np.array([1.0, None, np.nan, 4.0])) + assert all_nan(np.array([None, None, None, None])) + assert all_nan(np.array([np.nan, np.nan, np.nan, None])) + assert all_nan(np.array([np.nan, np.nan, np.nan, np.nan])) + + # 2d array + assert not all_nan(np.array([[1.0, 2.0], [3.0, 4.0]])) + assert not all_nan(np.array([[1.0, None], [3.0, 4.0]])) + assert not all_nan(np.array([[1.0, 2.0], [np.nan, 4.0]])) + assert not all_nan(np.array([[1.0, None], [np.nan, 4.0]])) + assert all_nan(np.array([[None, None], [None, None]])) + assert all_nan(np.array([[np.nan, np.nan], [np.nan, None]])) + assert all_nan(np.array([[np.nan, np.nan], [np.nan, np.nan]])) + + # pd Series + assert not all_nan(pd.Series([1.0, 2.0, 3.0, 4.0])) + assert not all_nan(pd.Series([1.0, None, 3.0, 4.0])) + assert not all_nan(pd.Series([1.0, 2.0, np.nan, 4.0])) + assert not all_nan(pd.Series([1.0, None, np.nan, 4.0])) + assert all_nan(pd.Series([None, None, None, None])) + assert all_nan(pd.Series([np.nan, np.nan, np.nan, None])) + assert all_nan(pd.Series([np.nan, np.nan, np.nan, np.nan])) + + # pd DataFrame + assert not all_nan(pd.DataFrame({"A": [1.0, 2.0], "B": [3.0, 4.0]})) + assert not all_nan(pd.DataFrame({"A": [1.0, None], "B": [3.0, 4.0]})) + assert not all_nan(pd.DataFrame({"A": [1.0, 2.0], "B": [np.nan, 4.0]})) + assert not all_nan(pd.DataFrame({"A": [1.0, None], "B": [np.nan, 4.0]})) + assert all_nan(pd.DataFrame({"A": [None, None], "B": [None, None]})) + assert all_nan(pd.DataFrame({"A": [np.nan, np.nan], "B": [np.nan, None]})) + assert all_nan(pd.DataFrame({"A": [np.nan, np.nan], "B": [np.nan, np.nan]})) + + # pl Series + assert not all_nan(pl.Series([1.0, 2.0, 3.0, 4.0])) + assert not all_nan(pl.Series([1.0, None, 3.0, 4.0])) + assert not all_nan(pl.Series([1.0, 2.0, np.nan, 4.0])) + assert not all_nan(pl.Series([1.0, None, np.nan, 4.0])) + assert all_nan(pl.Series([None, None, None, None])) + assert all_nan(pl.Series([np.nan, np.nan, np.nan, None])) + assert all_nan(pl.Series([np.nan, np.nan, np.nan, np.nan])) + + # pl DataFrame + assert not all_nan(pl.DataFrame({"A": [1.0, 2.0], "B": [3.0, 4.0]})) + assert not all_nan(pl.DataFrame({"A": [1.0, None], "B": [3.0, 4.0]})) + assert not all_nan(pl.DataFrame({"A": [1.0, 2.0], "B": [np.nan, 4.0]})) + assert not all_nan(pl.DataFrame({"A": [1.0, None], "B": [np.nan, 4.0]})) + assert all_nan(pl.DataFrame({"A": [None, None], "B": [None, None]})) + assert all_nan(pl.DataFrame({"A": [np.nan, np.nan], "B": [np.nan, None]})) + assert all_nan(pl.DataFrame({"A": [np.nan, np.nan], "B": [np.nan, np.nan]})) + + +@enforce_types +def test_has_nan__or_None(): # 1d array assert not has_nan(np.array([1.0, 2.0, 3.0, 4.0])) assert has_nan(np.array([1.0, 2.0, np.nan, 4.0])) + assert has_nan(np.array([1.0, None, 3.0, 4.0])) + assert has_nan(np.array([1.0, None, np.nan, 4.0])) # 2d array assert not has_nan(np.array([[1.0, 2.0], [3.0, 4.0]])) assert has_nan(np.array([[1.0, 2.0], [np.nan, 4.0]])) + assert has_nan(np.array([[1.0, None], [3.0, 4.0]])) + assert has_nan(np.array([[1.0, None], [np.nan, 4.0]])) # pd Series assert not has_nan(pd.Series([1.0, 2.0, 3.0, 4.0])) assert has_nan(pd.Series([1.0, 2.0, np.nan, 4.0])) + assert has_nan(pd.Series([1.0, None, 3.0, 4.0])) + assert has_nan(pd.Series([1.0, None, np.nan, 4.0])) # pd DataFrame assert not has_nan(pd.DataFrame({"A": [1.0, 2.0], "B": [3.0, 4.0]})) assert has_nan(pd.DataFrame({"A": [1.0, 2.0], "B": [np.nan, 4.0]})) + assert has_nan(pd.DataFrame({"A": [1.0, None], "B": [3.0, 4.0]})) + assert has_nan(pd.DataFrame({"A": [1.0, None], "B": [np.nan, 4.0]})) + + # pl Series + assert not has_nan(pl.Series([1.0, 2.0, 3.0, 4.0])) + assert has_nan(pl.Series([1.0, 2.0, np.nan, 4.0])) + assert has_nan(pl.Series([1.0, None, 3.0, 4.0])) + assert has_nan(pl.Series([1.0, None, np.nan, 4.0])) + + # pl DataFrame + assert not has_nan(pl.DataFrame({"A": [1.0, 2.0], "B": [3.0, 4.0]})) + assert has_nan(pl.DataFrame({"A": [1.0, 2.0], "B": [np.nan, 4.0]})) + assert has_nan(pl.DataFrame({"A": [1.0, None], "B": [3.0, 4.0]})) + assert has_nan(pl.DataFrame({"A": [1.0, None], "B": [np.nan, 4.0]})) + + +@enforce_types +def test_fill_nans_pd(): + _test_fill_nans(pd) @enforce_types -def test_fill_nans(): +def test_fill_nans_pl(): + _test_fill_nans(pl) + + +@enforce_types +def _test_fill_nans(pdl): # nan at front - df1 = pd.DataFrame({"A": [np.nan, 1.0, 2.0, 3.0, 4.0, 5.0]}) + df1 = pdl.DataFrame({"A": [np.nan, 1.0, 2.0, 3.0, 4.0, 5.0]}) df2 = fill_nans(df1) assert not has_nan(df2) # nan in middle - df1 = pd.DataFrame({"A": [1.0, 2.0, np.nan, 3.0, 4.0]}) + df1 = pdl.DataFrame({"A": [1.0, 2.0, np.nan, 3.0, 4.0]}) df2 = fill_nans(df1) assert not has_nan(df2) # nan at end - df1 = pd.DataFrame({"A": [1.0, 2.0, 3.0, 4.0, np.nan]}) + df1 = pdl.DataFrame({"A": [1.0, 2.0, 3.0, 4.0, np.nan]}) df2 = fill_nans(df1) assert not has_nan(df2) # nan at front, middle, end - df1 = pd.DataFrame({"A": [np.nan, 1.0, 2.0, np.nan, 3.0, 4.0, np.nan]}) + df1 = pdl.DataFrame({"A": [np.nan, 1.0, 2.0, np.nan, 3.0, 4.0, np.nan]}) df2 = fill_nans(df1) assert not has_nan(df2) @@ -190,3 +284,59 @@ def test_nmse(): ymin, ymax = 10.0, 20.0 e = nmse(yhat, y, ymin, ymax) assert 0.035 <= e <= 0.036 + + +@enforce_types +def test_wei(): + assert from_wei(int(1234 * 1e18)) == 1234 + assert from_wei(int(12.34 * 1e18)) == 12.34 + assert from_wei(int(0.1234 * 1e18)) == 0.1234 + + assert to_wei(1234) == 1234 * 1e18 and type(to_wei(1234)) == int + assert to_wei(12.34) == 12.34 * 1e18 and type(to_wei(12.34)) == int + assert to_wei(0.1234) == 0.1234 * 1e18 and type(to_wei(0.1234)) == int + + assert str_with_wei(int(12.34 * 1e18)) == "12.34 (12340000000000000000 wei)" + + +@enforce_types +def test_string_to_bytes32_1_short(): + data = "hello" + data_bytes32 = string_to_bytes32(data) + assert data_bytes32 == b"hello000000000000000000000000000" + + +@enforce_types +def test_string_to_bytes32_2_long(): + data = "hello" + "a" * 50 + data_bytes32 = string_to_bytes32(data) + assert data_bytes32 == b"helloaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + +@enforce_types +@pytest.mark.parametrize( + "input_data,expected_output", + [ + ("short", b"short" + b"0" * 27), + ("this is exactly 32 chars", b"this is exactly 32 chars00000000"), + ( + "this is a very long string which is more than 32 chars", + b"this is a very long string which", + ), + ], +) +def test_string_to_bytes32_3(input_data, expected_output): + result = string_to_bytes32(input_data) + assert ( + result == expected_output + ), f"For {input_data}, expected {expected_output}, but got {result}" + + +@enforce_types +def test_sole_value(): + assert sole_value({"b": 3}) == 3 + assert sole_value({5: "foo"}) == "foo" + with pytest.raises(ValueError): + sole_value({}) + with pytest.raises(ValueError): + sole_value({"a": 1, "b": 2}) diff --git a/pdr_backend/util/test_noganache/test_signalstr.py b/pdr_backend/util/test_noganache/test_signalstr.py index fb61dd08f..2c17a37fe 100644 --- a/pdr_backend/util/test_noganache/test_signalstr.py +++ b/pdr_backend/util/test_noganache/test_signalstr.py @@ -1,12 +1,57 @@ -from enforce_typing import enforce_types import pytest +from enforce_typing import enforce_types from pdr_backend.util.signalstr import ( + char_to_signal, + signal_to_char, + signals_to_chars, unpack_signalchar_str, - verify_signalchar_str, verify_signal_str, + verify_signalchar_str, ) +# ========================================================================== +# conversions + + +@enforce_types +def test_single_conversions(): + tups = [ + ("o", "open"), + ("h", "high"), + ("l", "low"), + ("c", "close"), + ("v", "volume"), + ] + for char, signal in tups: + assert char_to_signal(char) == signal + assert signal_to_char(signal) == char + + with pytest.raises(ValueError): + signal_to_char("xyz") + + with pytest.raises(ValueError): + char_to_signal("x") + + +@enforce_types +def test_multi_conversions(): + assert signals_to_chars(["open"]) == "o" + assert signals_to_chars(["close", "open"]) == "oc" + assert signals_to_chars({"close", "open"}) == "oc" + + for bad_input in [ + None, + "", + "foo", + "open", + "foo", + ["open", "foo"], + {"open", "foo"}, + ]: + with pytest.raises(ValueError): + signals_to_chars(bad_input) + # ========================================================================== # unpack..() functions diff --git a/pdr_backend/util/test_noganache/test_strutil.py b/pdr_backend/util/test_noganache/test_strutil.py index 456ecf327..6e1d14b7a 100644 --- a/pdr_backend/util/test_noganache/test_strutil.py +++ b/pdr_backend/util/test_noganache/test_strutil.py @@ -1,10 +1,10 @@ import random from pdr_backend.util import mathutil -from pdr_backend.util.strutil import StrMixin, dictStr, prettyBigNum, asCurrency +from pdr_backend.util.strutil import StrMixin, asCurrency, dictStr, prettyBigNum -def testStrMixin(): +def testStrMixin1(): class Foo(StrMixin): def __init__(self): self.x = 1 @@ -33,6 +33,29 @@ def ignoreMethod(self): s3 = f.longstr() assert s3 == s + f.__class__.__STR_GIVES_NEWLINE__ = True + s4 = f.longstr() + assert "\n" in s4 + f.__class__.__STR_GIVES_NEWLINE__ = False + + +def testStrMixin2(): + class Foo(StrMixin): + __STR_OBJDIR__ = ["x", "y"] + + def __init__(self): + self.x = 1 + self.y = 2 + self.z = 3 + + f = Foo() + s = str(f) + s2 = s.replace(" ", "") + assert "Foo={" in s + assert "x=1" in s2 + assert "y=2" in s2 + assert "z=3" not in s2 + def testDictStr(): d = {"a": 3, "b": 4} @@ -43,6 +66,9 @@ def testDictStr(): assert "'b':4" in s2 assert "dict}" in s + s = dictStr(d, True) + assert "\n" in s + def testEmptyDictStr(): d = {} @@ -67,6 +93,9 @@ def testAsCurrency(): assert asCurrency(2e6, False) == "$2,000,000" assert asCurrency(2e6 + 0.03, False) == "$2,000,000" + assert asCurrency(-0.03, True) == "-$0.03" + assert asCurrency(-0.03, False) == "-$0" + def testPrettyBigNum1_DoRemoveZeros_decimalsNeeded(): assert prettyBigNum(1.23456e13) == "1.23e13" diff --git a/pdr_backend/util/test_noganache/test_timeutil.py b/pdr_backend/util/test_noganache/test_timeutil.py index 87fed9f7a..63fb1352e 100644 --- a/pdr_backend/util/test_noganache/test_timeutil.py +++ b/pdr_backend/util/test_noganache/test_timeutil.py @@ -1,16 +1,17 @@ import datetime from datetime import timezone +import pytest from enforce_typing import enforce_types from pdr_backend.util.timeutil import ( - pretty_timestr, - current_ut, + current_ut_ms, dt_to_ut, - ut_to_dt, + ms_to_seconds, + pretty_timestr, timestr_to_ut, + ut_to_dt, ut_to_timestr, - ms_to_seconds, ) @@ -24,8 +25,8 @@ def test_pretty_timestr(): @enforce_types -def test_current_ut(): - ut = current_ut() +def test_current_ut_ms(): + ut = current_ut_ms() assert isinstance(ut, int) assert ut > 1648576500000 @@ -59,6 +60,10 @@ def test_timestr_to_ut(): assert timestr_to_ut("2022-03-29_17:55") == 1648576500000 assert timestr_to_ut("2022-03-29_17:55:12.345") == 1648576512345 + # test error + with pytest.raises(ValueError): + timestr_to_ut("::::::::") + @enforce_types def test_ut_to_timestr(): @@ -94,6 +99,9 @@ def test_dt_to_ut_and_back(): dt2 = ut_to_dt(ut) assert dt2 == dt + with pytest.raises(AssertionError): + ut_to_dt(-1) + @enforce_types def test_ms_to_seconds(): diff --git a/pdr_backend/util/test_noganache/test_topup.py b/pdr_backend/util/test_noganache/test_topup.py new file mode 100644 index 000000000..24e7eeac7 --- /dev/null +++ b/pdr_backend/util/test_noganache/test_topup.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock, patch + +import pytest +from enforce_typing import enforce_types + +from pdr_backend.contract.token import NativeToken, Token +from pdr_backend.ppss.ppss import mock_ppss +from pdr_backend.util.mathutil import to_wei +from pdr_backend.util.topup import topup_main + + +@pytest.fixture(name="mock_token_") +def mock_token(): + token = MagicMock(spec=Token) + token.balanceOf.side_effect = [to_wei(500), 0, 0] # Mock balance in wei + token.transfer.return_value = None + return token + + +@pytest.fixture(name="mock_native_token_") +def mock_native_token(): + native_token = MagicMock(spec=NativeToken) + native_token.balanceOf.side_effect = [to_wei(500), 0, 0] # Mock balance in wei + native_token.transfer.return_value = None + return native_token + + +@pytest.fixture(name="mock_get_opf_addresses_") +def mock_get_opf_addresses(): + return MagicMock( + return_value={ + "predictoor1": "0x1", + "predictoor2": "0x2", + } + ) + + +@enforce_types +def test_topup_main(mock_token_, mock_native_token_, mock_get_opf_addresses_, tmpdir): + ppss = mock_ppss(["binance BTC/USDT c 5m"], "sapphire-mainnet", str(tmpdir)) + + PATH = "pdr_backend.util.topup" + with patch(f"{PATH}.Token", return_value=mock_token_), patch( + f"{PATH}.NativeToken", return_value=mock_native_token_ + ), patch(f"{PATH}.get_opf_addresses", mock_get_opf_addresses_), patch( + f"{PATH}.sys.exit" + ) as mock_exit: + topup_main(ppss) + + mock_exit.assert_called_with(0) + + assert mock_token_.transfer.called + assert mock_native_token_.transfer.called + + ppss.web3_pp.network = "foo" + with pytest.raises(SystemExit): + topup_main(ppss) diff --git a/pdr_backend/util/test_noganache/test_util_constants.py b/pdr_backend/util/test_noganache/test_util_constants.py index e5c695e5e..e91200858 100644 --- a/pdr_backend/util/test_noganache/test_util_constants.py +++ b/pdr_backend/util/test_noganache/test_util_constants.py @@ -1,20 +1,20 @@ -from enforce_typing import enforce_types import numpy as np +from enforce_typing import enforce_types from pdr_backend.util.constants import ( - ZERO_ADDRESS, - SAPPHIRE_TESTNET_RPC, - SAPPHIRE_TESTNET_CHAINID, - SAPPHIRE_MAINNET_RPC, - SAPPHIRE_MAINNET_CHAINID, - S_PER_MIN, + CAND_SIGNALS, + CAND_TIMEFRAMES, + CAND_USDCOINS, + CHAR_TO_SIGNAL, S_PER_DAY, + S_PER_MIN, + SAPPHIRE_MAINNET_CHAINID, + SAPPHIRE_MAINNET_RPC, + SAPPHIRE_TESTNET_CHAINID, + SAPPHIRE_TESTNET_RPC, SUBGRAPH_MAX_TRIES, WEB3_MAX_TRIES, - CAND_USDCOINS, - CAND_TIMEFRAMES, - CAND_SIGNALS, - CHAR_TO_SIGNAL, + ZERO_ADDRESS, ) diff --git a/pdr_backend/util/timeutil.py b/pdr_backend/util/timeutil.py index e51420451..1375c35c5 100644 --- a/pdr_backend/util/timeutil.py +++ b/pdr_backend/util/timeutil.py @@ -1,3 +1,4 @@ +import time import datetime from datetime import timezone @@ -11,16 +12,21 @@ def pretty_timestr(ut: int) -> str: @enforce_types -def current_ut() -> int: +def current_ut_ms() -> int: """Return the current date/time as a unix time (int in # ms)""" dt = datetime.datetime.now(timezone.utc) return dt_to_ut(dt) +def current_ut_s() -> int: + """Returns the current UTC unix time in seconds""" + return int(time.time()) + + @enforce_types def timestr_to_ut(timestr: str) -> int: """ - Convert a datetime string to unix time (in #ms) + Convert a datetime string to ut: unix time, in ms, in UTC time zone Needs a date; time for a given date is optional. Examples: @@ -33,7 +39,7 @@ def timestr_to_ut(timestr: str) -> int: Does not use local time, rather always uses UTC """ if timestr.lower() == "now": - return current_ut() + return current_ut_ms() ncolon = timestr.count(":") if ncolon == 0: @@ -79,6 +85,10 @@ def dt_to_ut(dt: datetime.datetime) -> int: @enforce_types def ut_to_dt(ut: int) -> datetime.datetime: """Convert unix time (in # ms) to datetime format""" + # precondition + assert ut >= 0, ut + + # main work dt = datetime.datetime.utcfromtimestamp(ut / 1000) dt = dt.replace(tzinfo=timezone.utc) # tack on timezone diff --git a/pdr_backend/util/topup.py b/pdr_backend/util/topup.py new file mode 100644 index 000000000..68dd07b76 --- /dev/null +++ b/pdr_backend/util/topup.py @@ -0,0 +1,81 @@ +import sys +from typing import Dict + +from enforce_typing import enforce_types + +from pdr_backend.contract.token import NativeToken, Token +from pdr_backend.ppss.ppss import PPSS +from pdr_backend.util.constants_opf_addrs import get_opf_addresses +from pdr_backend.util.contract import get_address +from pdr_backend.util.mathutil import from_wei, to_wei + + +@enforce_types +def topup_main(ppss: PPSS): + # if there is not enough balance, exit 1 so we know that script failed + failed = False + + web3_pp = ppss.web3_pp + owner = web3_pp.web3_config.owner + if web3_pp.network not in ["sapphire-testnet", "sapphire-mainnet"]: + print("Unknown network") + sys.exit(1) + + OCEAN_addr = get_address(ppss.web3_pp, "Ocean") + OCEAN = Token(ppss.web3_pp, OCEAN_addr) + ROSE = NativeToken(ppss.web3_pp) + + owner_OCEAN_bal = from_wei(OCEAN.balanceOf(owner)) + owner_ROSE_bal = from_wei(ROSE.balanceOf(owner)) + print( + f"Topup address ({owner}) has " + + f"{owner_OCEAN_bal:.2f} OCEAN and {owner_ROSE_bal:.2f} ROSE\n\n" + ) + + addresses: Dict[str, str] = get_opf_addresses(web3_pp.network) + for addr_label, address in addresses.items(): + OCEAN_bal = from_wei(OCEAN.balanceOf(address)) + ROSE_bal = from_wei(ROSE.balanceOf(address)) + + min_OCEAN_bal, topup_OCEAN_bal = ( + (0, 0) if addr_label in ["trueval", "dfbuyer"] else (20, 20) + ) + min_ROSE_bal, topup_ROSE_bal = ( + (250, 250) if addr_label == "dfbuyer" else (30, 30) + ) + + print(f"{addr_label}: {OCEAN_bal:.2f} OCEAN, {ROSE_bal:.2f} ROSE") + + # check if we need to transfer + if min_OCEAN_bal > 0 and OCEAN_bal < min_OCEAN_bal: + print(f"\t Transferring {topup_OCEAN_bal} OCEAN to {address}...") + if owner_OCEAN_bal > topup_OCEAN_bal: + OCEAN.transfer( + address, + to_wei(topup_OCEAN_bal), + owner, + True, + ) + owner_OCEAN_bal = owner_OCEAN_bal - topup_OCEAN_bal + else: + failed = True + print("Not enough OCEAN :(") + + if min_ROSE_bal > 0 and ROSE_bal < min_ROSE_bal: + print(f"\t Transferring {topup_ROSE_bal} ROSE to {address}...") + if owner_ROSE_bal > topup_ROSE_bal: + ROSE.transfer( + address, + to_wei(topup_ROSE_bal), + owner, + True, + ) + owner_ROSE_bal = owner_ROSE_bal - topup_ROSE_bal + else: + failed = True + print("Not enough ROSE :(") + + if failed: + sys.exit(1) + + sys.exit(0) diff --git a/pdr_backend/util/web3_config.py b/pdr_backend/util/web3_config.py index c05e1cc03..6e7e4e6eb 100644 --- a/pdr_backend/util/web3_config.py +++ b/pdr_backend/util/web3_config.py @@ -4,6 +4,8 @@ from enforce_typing import enforce_types from eth_account.signers.local import LocalAccount +from eth_keys import KeyAPI +from eth_keys.backends import NativeECCBackend from eth_typing import BlockIdentifier from web3 import Web3 from web3.middleware import ( @@ -13,18 +15,18 @@ from web3.types import BlockData from pdr_backend.util.constants import WEB3_MAX_TRIES +from pdr_backend.util.constants import ( + SAPPHIRE_MAINNET_CHAINID, + SAPPHIRE_TESTNET_CHAINID, +) + +_KEYS = KeyAPI(NativeECCBackend) @enforce_types class Web3Config: - def __init__( - self, rpc_url: Optional[str] = None, private_key: Optional[str] = None - ): - self.rpc_url = rpc_url - - if rpc_url is None: - raise ValueError("You must set RPC_URL variable") - + def __init__(self, rpc_url: str, private_key: Optional[str] = None): + self.rpc_url: str = rpc_url self.w3 = Web3(Web3.HTTPProvider(rpc_url)) if private_key is not None: @@ -51,3 +53,53 @@ def get_block( time.sleep(((tries + 1) / 2) ** (2) * 10) return self.get_block(block, full_transactions, tries + 1) raise Exception("Couldn't get block") from e + + def get_auth_signature(self): + """ + @description + Digitally sign + + @return + auth -- dict with keys "userAddress", "v", "r", "s", "validUntil" + """ + valid_until = self.get_block("latest").timestamp + 3600 + message_hash = self.w3.solidity_keccak( + ["address", "uint256"], + [self.owner, valid_until], + ) + pk = _KEYS.PrivateKey(self.account.key) + prefix = "\x19Ethereum Signed Message:\n32" + signable_hash = self.w3.solidity_keccak( + ["bytes", "bytes"], + [ + self.w3.to_bytes(text=prefix), + self.w3.to_bytes(message_hash), + ], + ) + signed = _KEYS.ecdsa_sign(message_hash=signable_hash, private_key=pk) + auth = { + "userAddress": self.owner, + "v": (signed.v + 27) if signed.v <= 1 else signed.v, + "r": self.w3.to_hex(self.w3.to_bytes(signed.r).rjust(32, b"\0")), + "s": self.w3.to_hex(self.w3.to_bytes(signed.s).rjust(32, b"\0")), + "validUntil": valid_until, + } + return auth + + @property + def is_sapphire(self): + return self.w3.eth.chain_id in [ + SAPPHIRE_TESTNET_CHAINID, + SAPPHIRE_MAINNET_CHAINID, + ] + + @enforce_types + def get_max_gas(self) -> int: + """Returns max block gas""" + block = self.get_block(self.w3.eth.block_number, full_transactions=False) + return int(block["gasLimit"] * 0.99) + + @enforce_types + def get_current_timestamp(self): + """Returns latest block""" + return self.get_block("latest")["timestamp"] diff --git a/ppss.yaml b/ppss.yaml index 7b8e3aeb2..54d590f88 100644 --- a/ppss.yaml +++ b/ppss.yaml @@ -1,29 +1,118 @@ -data_pp: - timeframe: 1h - predict_feed: binance c BTC/USDT - test_n : 200 # sim only, not bots - -data_ss: - st_timestr: 2019-09-13_04:00 - fin_timestr: now - max_n_train: 5000 - autoregressive_n : 10 - input_feeds : - - binance ohlcv BTC/USDT - - binance c ETH/USDT,TRX/USDT,ADA/USDT -# - binance c DOT/USDT - - kraken ohlcv BTC/USDT - -model_ss: - approach: LIN - -trade_pp: # sim only, not bots - fee_percent: 0.0 - init_holdings: - - 100000 USDT - -trade_ss: - buy_amt: 10 USDT - -sim_ss: +# (web3_pp / network settings is at bottom, because it's long) +lake_ss: + parquet_dir: parquet_data + feeds: + - binance BTC/USDT 1h +# - binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT 1h +# - kraken BTC/USDT 1h + st_timestr: 2023-06-01_00:00 # starting date for data + fin_timestr: now # ending date for data + +predictoor_ss: + predict_feed: binance BTC/USDT c 1h + bot_only: + s_until_epoch_end: 60 # in s. Start predicting if there's > this time left + stake_amount: 1 # stake this amount with each prediction. In OCEAN + approach3: + aimodel_ss: + input_feeds: + - binance BTC/USDT c 1h +# - binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT ohlcv 1h +# - kraken BTC/USDT 1h + max_n_train: 5000 # no. epochs to train model on + autoregressive_n : 10 # no. epochs that model looks back, to predict next + approach: LIN + +trader_ss: + feed: binance BTC/USDT c 1h + sim_only: + buy_amt: 10 USD # buy this amount in each epoch + fee_percent: 0.0 # simulated % fee + init_holdings: + - 100000 USDT + - 0 BTC + bot_only: + min_buffer: 30 # in s. only trade if there's > this time left + max_tries: 10 # max no. attempts to process a feed + position_size: 3 # buy/sell this amount in each epoch + +sim_ss: # sim only do_plot: True + log_dir: logs + test_n : 200 # number of epochs to simulate + +# ------------------------------------------------------------------ +# Bots run by OPF + +publisher_ss: + sapphire-mainnet: + fee_collector_address: 0x0000000000000000000000000000000000000000 + feeds: + - binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT c 5m,1h + sapphire-testnet: + fee_collector_address: 0x0000000000000000000000000000000000000000 + feeds: + - binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT c 5m,1h + development: + fee_collector_address: 0x0000000000000000000000000000000000000000 + feeds: + - binance BTC/USDT ETH/USDT XRP/USDT c 5m + +trueval_ss: + feeds: + - binance BTC/USDT 5m +# - binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT 5m +# - kraken BTC/USDT 5m + batch_size: 30 + sleep_time: 30 + +dfbuyer_ss: + feeds: + - binance BTC/USDT 5m +# - binance BTC/USDT ETH/USDT BNB/USDT XRP/USDT ADA/USDT DOGE/USDT SOL/USDT LTC/USDT TRX/USDT DOT/USDT 5m +# - kraken BTC/USDT 5m + batch_size: 20 + consume_interval_seconds: 86400 + weekly_spending_limit: 37000 + +payout_ss: + batch_size: 250 + +# ------------------------------------------------------------------ +# Network settings + + +web3_pp: + + sapphire-testnet: + address_file: "~/.ocean/ocean-contracts/artifacts/address.json" + rpc_url: https://testnet.sapphire.oasis.dev + subgraph_url: https://v4.subgraph.sapphire-testnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph + owner_addrs: "0xe02a421dfc549336d47efee85699bd0a3da7d6ff" # OPF deployer address + + sapphire-mainnet: + address_file: "~/.ocean/ocean-contracts/artifacts/address.json" + rpc_url: https://sapphire.oasis.io + subgraph_url: https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph + owner_addrs: "0x4ac2e51f9b1b0ca9e000dfe6032b24639b172703" # OPF deployer address + + development: + address_file: "~/.ocean/ocean-contracts/artifacts/address.json" + rpc_url: http://localhost:8545 + subgraph_url: http://localhost:9000/subgraphs/name/oceanprotocol/ocean-subgraph + owner_addrs: "0xe2DD09d719Da89e5a3D0F2549c7E24566e947260" # OPF deployer address. Taken from ocean.py setup-local.md FACTORY_DEPLOYER_PRIVATE_KEY + + barge-predictoor-bot: + address_file: "~/barge-predictoor-bot.address.json" + private_key: "0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" # address is 0xe2DD... + rpc_url: http://4.245.224.119:8545 # from VPS + subgraph_url: http://4.245.224.119:9000/subgraphs/name/oceanprotocol/ocean-subgraph # from VPS + owner_addrs: "0xe2DD09d719Da89e5a3D0F2549c7E24566e947260" # OPF deployer address. Taken from ocean.py setup-local.md FACTORY_DEPLOYER_PRIVATE_KEY + + barge-pytest: + address_file: "~/barge-pytest.address.json" + private_key: "0xc594c6e5def4bab63ac29eed19a134c130388f74f019bc74b8f4389df2837a58" + rpc_url: http://74.234.16.165:8545 + subgraph_url: http://74.234.16.165:9000/subgraphs/name/oceanprotocol/ocean-subgraph + owner_addrs: "0xe2DD09d719Da89e5a3D0F2549c7E24566e947260" # OPF deployer address. Taken from ocean.py setup-local.md FACTORY_DEPLOYER_PRIVATE_KEY + diff --git a/scripts/check_network.py b/scripts/check_network.py deleted file mode 100644 index f79f78f70..000000000 --- a/scripts/check_network.py +++ /dev/null @@ -1,168 +0,0 @@ -import math -import sys -import time -from addresses import get_opf_addresses -from pdr_backend.models.base_config import BaseConfig -from pdr_backend.models.token import Token -from pdr_backend.util.subgraph import get_consume_so_far_per_contract, query_subgraph - - -WEEK = 86400 * 7 - - -def seconds_to_text(seconds: int) -> str: - if seconds == 300: - return "5m" - if seconds == 3600: - return "1h" - return "" - - -def print_stats(contract_dict, field_name, threshold=0.9): - count = sum(1 for _ in contract_dict["slots"]) - with_field = sum(1 for slot in contract_dict["slots"] if len(slot[field_name]) > 0) - if count == 0: - count += 1 - status = "PASS" if with_field / count > threshold else "FAIL" - token_name = contract_dict["token"]["name"] - timeframe = seconds_to_text(int(contract_dict["secondsPerEpoch"])) - print(f"{token_name} {timeframe}: {with_field}/{count} {field_name} - {status}") - - -def check_dfbuyer(dfbuyer_addr, contract_query_result, subgraph_url, tokens): - ts_now = time.time() - ts_start_time = int((ts_now // WEEK) * WEEK) - contract_addresses = [ - i["id"] for i in contract_query_result["data"]["predictContracts"] - ] - sofar = get_consume_so_far_per_contract( - subgraph_url, - dfbuyer_addr, - ts_start_time, - contract_addresses, - ) - expected = get_expected_consume(int(ts_now), tokens) - print( - f"Checking consume amounts (dfbuyer), expecting {expected} consume per contract" - ) - for addr in contract_addresses: - x = sofar[addr] - log_text = "PASS" if x >= expected else "FAIL" - print( - f" {log_text}... got {x} consume for contract: {addr}, expected {expected}" - ) - - -def get_expected_consume(for_ts: int, tokens: int): - amount_per_feed_per_interval = tokens / 7 / 20 - week_start = (math.floor(for_ts / WEEK)) * WEEK - time_passed = for_ts - week_start - n_intervals = int(time_passed / 86400) + 1 - return n_intervals * amount_per_feed_per_interval - - -if __name__ == "__main__": - config = BaseConfig() - - lookback_hours = 24 - if len(sys.argv) > 1: - try: - lookback_hours = int(sys.argv[1]) - except ValueError: - print("Please provide a valid integer as the number of epochs to check!") - - addresses = get_opf_addresses(config.web3_config.w3.eth.chain_id) - - ts = int(time.time()) - ts_start = ts - lookback_hours * 60 * 60 - query = """ - { - predictContracts{ - id - token{ - name - } - subscriptions(orderBy: expireTime orderDirection:desc first:10){ - user { - id - } - expireTime - } - slots(where:{slot_lt:%s, slot_gt:%s} orderBy: slot orderDirection:desc first:1000){ - slot - roundSumStakesUp - roundSumStakes - predictions(orderBy: timestamp orderDirection:desc){ - stake - user { - id - } - timestamp - payout{ - payout - predictedValue - trueValue - } - } - trueValues{ - trueValue - } - } - secondsPerEpoch - } - } - """ % ( - ts, - ts_start, - ) - result = query_subgraph(config.subgraph_url, query, timeout=10.0) - # check no of contracts - no_of_contracts = len(result["data"]["predictContracts"]) - if no_of_contracts >= 11: - print(f"Number of Predictoor contracts: {no_of_contracts} - OK") - else: - print(f"Number of Predictoor contracts: {no_of_contracts} - FAILED") - - print("-" * 60) - # check number of predictions - print("Predictions:") - for contract in result["data"]["predictContracts"]: - print_stats(contract, "predictions") - - print() - - # Check number of truevals - print("True Values:") - for contract in result["data"]["predictContracts"]: - print_stats(contract, "trueValues") - print("\nChecking account balances") - # pylint: disable=line-too-long - ocean_address = ( - "0x39d22B78A7651A76Ffbde2aaAB5FD92666Aca520" - if config.web3_config.w3.eth.chain_id == 23294 - else "0x973e69303259B0c2543a38665122b773D28405fB" - ) - ocean_token = Token(config.web3_config, ocean_address) - - for name, value in addresses.items(): - ocean_bal_wei = ocean_token.balanceOf(value) - native_bal_wei = config.web3_config.w3.eth.get_balance(value) - - ocean_bal = ocean_bal_wei / 1e18 - native_bal = native_bal_wei / 1e18 - - ocean_warning = " WARNING LOW OCEAN BALANCE!" if ocean_bal < 10 else " OK " - native_warning = " WARNING LOW NATIVE BALANCE!" if native_bal < 10 else " OK " - - if name == "trueval": - ocean_warning = " OK " - - # pylint: disable=line-too-long - print( - f"{name}: OCEAN: {ocean_bal:.2f}{ocean_warning}, Native: {native_bal:.2f}{native_warning}" - ) - - # ---------------- dfbuyer ---------------- - - token_amt = 44460 - check_dfbuyer(addresses["dfbuyer"].lower(), result, config.subgraph_url, token_amt) diff --git a/scripts/get_opf_predictions.py b/scripts/get_opf_predictions.py deleted file mode 100644 index ad8603c1a..000000000 --- a/scripts/get_opf_predictions.py +++ /dev/null @@ -1,172 +0,0 @@ -import csv -from typing import List -from pdr_backend.util.subgraph import query_subgraph - -addresses = { - "predictoor1": "0x35Afee1168D1e1053298F368488F4eE95E891a6e", - "predictoor2": "0x1628BeA0Fb859D56Cd2388054c0bA395827e4374", - "predictoor3": "0x3f0825d0c0bbfbb86cd13C7E6c9dC778E3Bb44ec", - "predictoor4": "0x20704E4297B1b014d9dB3972cc63d185C6C00615", - "predictoor5": "0xf38426BF6c117e7C5A6e484Ed0C8b86d4Ae7Ff78", - "predictoor6": "0xFe4A9C5F3A4EA5B1BfC089713ed5fce4bB276683", - "predictoor7": "0x078F083525Ad1C0d75Bc7e329EE6656bb7C81b12", - "predictoor8": "0x4A15CC5C20c5C5F71A9EA6376356f72b2A760f12", - "predictoor9": "0xD2a24CB4ff2584bAD80FF5F109034a891c3d88dD", - "predictoor10": "0x8a64CF23B5BB16Fd7444B47f94673B90Cc0F75cE", - "predictoor11": "0xD15749B83Be987fEAFa1D310eCc642E0e24CadBA", - "predictoor12": "0xAAbDBaB266b31d6C263b110bA9BE4930e63ce817", - "predictoor13": "0xB6431778C00F44c179D8D53f0E3d13334C051bd3", - "predictoor14": "0x2c2C599EC040F47C518fa96D08A92c5df5f50951", - "predictoor15": "0x5C72F76F7dae16dD34Cb6183b73F4791aa4B3BC4", - "predictoor16": "0x19C0A543664F819C7F9fb6475CE5b90Bfb112d26", - "predictoor17": "0x8cC3E2649777d59809C8d3E2Dd6E90FDAbBed502", - "predictoor18": "0xF5F2a495E0bcB50bF6821a857c5d4a694F5C19b4", - "predictoor19": "0x4f17B06177D37E24158fec982D48563bCEF97Fe6", - "predictoor20": "0x784b52987A894d74E37d494F91eD03a5Ab37aB36", -} - - -predictoor_pairs = { - "predictoor1": {"pair": "BTC", "timeframe": "5m"}, - "predictoor2": {"pair": "BTC", "timeframe": "1h"}, - "predictoor3": {"pair": "ETH", "timeframe": "5m"}, - "predictoor4": {"pair": "ETH", "timeframe": "1h"}, - "predictoor5": {"pair": "BNB", "timeframe": "5m"}, - "predictoor6": {"pair": "BNB", "timeframe": "1h"}, - "predictoor7": {"pair": "XRP", "timeframe": "5m"}, - "predictoor8": {"pair": "XRP", "timeframe": "1h"}, - "predictoor9": {"pair": "ADA", "timeframe": "5m"}, - "predictoor10": {"pair": "ADA", "timeframe": "1h"}, - "predictoor11": {"pair": "DOGE", "timeframe": "5m"}, - "predictoor12": {"pair": "DOGE", "timeframe": "1h"}, - "predictoor13": {"pair": "SOL", "timeframe": "5m"}, - "predictoor14": {"pair": "SOL", "timeframe": "1h"}, - "predictoor15": {"pair": "LTC", "timeframe": "5m"}, - "predictoor16": {"pair": "LTC", "timeframe": "1h"}, - "predictoor17": {"pair": "TRX", "timeframe": "5m"}, - "predictoor18": {"pair": "TRX", "timeframe": "1h"}, - "predictoor19": {"pair": "DOT", "timeframe": "5m"}, - "predictoor20": {"pair": "DOT", "timeframe": "1h"}, -} - - -class Prediction: - def __init__(self, pair, timeframe, prediction, stake, trueval, timestamp) -> None: - self.pair = pair - self.timeframe = timeframe - self.prediction = prediction - self.stake = stake - self.trueval = trueval - self.timestamp = timestamp - - -def get_all_predictions(): - chunk_size = 1000 - offset = 0 - predictions: List[Prediction] = [] - - address_filter = [a.lower() for a in addresses.values()] - - while True: - query = """ - { - predictPredictions(skip: %s, first: %s, where: {user_: {id_in: %s}}) { - id - user { - id - } - stake - payout { - payout - trueValue - predictedValue - } - slot { - slot - } - } - } - """ % ( - offset, - chunk_size, - str(address_filter).replace("'", '"'), - ) - # pylint: disable=line-too-long - mainnet_subgraph = "https://v4.subgraph.sapphire-mainnet.oceanprotocol.com/subgraphs/name/oceanprotocol/ocean-subgraph" - result = query_subgraph( - mainnet_subgraph, - query, - timeout=10.0, - ) - - print(".") - - offset += 1000 - - if not "data" in result: - break - - data = result["data"]["predictPredictions"] - if len(data) == 0: - break - for prediction in data: - predictoor_key = [ - key - for key, value in addresses.items() - if value.lower() == prediction["user"]["id"] - ][0] - pair_info = predictoor_pairs[predictoor_key] - pair_name = pair_info["pair"] - timeframe = pair_info["timeframe"] - timestamp = prediction["slot"]["slot"] - - if prediction["payout"] is None: - continue - - trueval = prediction["payout"]["trueValue"] - - if trueval is None: - continue - - predictedValue = prediction["payout"]["predictedValue"] - stake = float(prediction["stake"]) - - if stake < 0.01: - continue - - prediction_obj = Prediction( - pair_name, timeframe, predictedValue, stake, trueval, timestamp - ) - predictions.append(prediction_obj) - - return predictions - - -def write_csv(all_predictions): - data = {} - for prediction in all_predictions: - key = prediction.pair + prediction.timeframe - if key not in data: - data[key] = [] - data[key].append(prediction) - for key, prediction in data.items(): - prediction.sort(key=lambda x: x.timestamp) - filename = key + ".csv" - with open(filename, "w", newline="") as file: - writer = csv.writer(file) - writer.writerow(["Predicted Value", "True Value", "Timestamp", "Stake"]) - for prediction in prediction: - writer.writerow( - [ - prediction.prediction, - prediction.trueval, - prediction.timestamp, - prediction.stake, - ] - ) - print(f"CSV file '{filename}' created successfully.") - - -if __name__ == "__main__": - _predictions = get_all_predictions() - write_csv(_predictions) diff --git a/scripts/get_predictoor_info.py b/scripts/get_predictoor_info.py deleted file mode 100644 index 3886492a2..000000000 --- a/scripts/get_predictoor_info.py +++ /dev/null @@ -1,45 +0,0 @@ -import sys -from pdr_backend.util.csvs import write_prediction_csv -from pdr_backend.util.predictoor_stats import get_cli_statistics -from pdr_backend.util.subgraph_predictions import fetch_filtered_predictions, FilterMode -from pdr_backend.util.timeutil import ms_to_seconds, timestr_to_ut - -if __name__ == "__main__": - if len(sys.argv) != 6: - # pylint: disable=line-too-long - print( - "Usage: python get_predictoor_info.py [predictoor_addr | str] [start_date | yyyy-mm-dd] [end_date | yyyy-mm-dd] [network | mainnet | testnet] [csv_output_dir | str]" - ) - sys.exit(1) - - # single address or multiple addresses separated my comma - predictoor_addrs = sys.argv[1] - - # yyyy-mm-dd - start_dt = sys.argv[2] - end_dt = sys.argv[3] - - # mainnet or tesnet - network_param = sys.argv[4] - - csv_output_dir_param = sys.argv[5] - - start_ts_param = ms_to_seconds(timestr_to_ut(start_dt)) - end_ts_param = ms_to_seconds(timestr_to_ut(end_dt)) - - if "," in predictoor_addrs: - address_filter = predictoor_addrs.lower().split(",") - else: - address_filter = [predictoor_addrs.lower()] - - _predictions = fetch_filtered_predictions( - start_ts_param, - end_ts_param, - address_filter, - network_param, - FilterMode.PREDICTOOR, - ) - - write_prediction_csv(_predictions, csv_output_dir_param) - - get_cli_statistics(_predictions) diff --git a/scripts/topup.py b/scripts/topup.py deleted file mode 100644 index 5c4bd8701..000000000 --- a/scripts/topup.py +++ /dev/null @@ -1,87 +0,0 @@ -import sys -from addresses import get_opf_addresses -from pdr_backend.models.base_config import BaseConfig -from pdr_backend.models.token import Token, NativeToken - - -if __name__ == "__main__": - config = BaseConfig() - failed = ( - False # if there is not enough balance, exit 1 so we know that script failed - ) - addresses = get_opf_addresses(config.web3_config.w3.eth.chain_id) - ocean_address = None - if config.web3_config.w3.eth.chain_id == 23294: # mainnet - ocean_address = "0x39d22B78A7651A76Ffbde2aaAB5FD92666Aca520" - if config.web3_config.w3.eth.chain_id == 23295: # testnet - ocean_address = "0x973e69303259B0c2543a38665122b773D28405fB" - if ocean_address is None: - print("Unknown network") - sys.exit(1) - - ocean_token = Token(config.web3_config, ocean_address) - rose = NativeToken(config.web3_config) - - owner_ocean_balance = int(ocean_token.balanceOf(config.web3_config.owner)) / 1e18 - owner_rose_balance = int(rose.balanceOf(config.web3_config.owner)) / 1e18 - print( - f"Topup address ({config.web3_config.owner}) has " - + f"{owner_ocean_balance:.2f} OCEAN and {owner_rose_balance:.2f} ROSE\n\n" - ) - total_ocean = 0 - total_rose = 0 - for name, value in addresses.items(): - ocean_bal_wei = ocean_token.balanceOf(value) - rose_bal_wei = rose.balanceOf(value) - - ocean_bal = ocean_bal_wei / 1e18 - rose_bal = rose_bal_wei / 1e18 - - minimum_ocean_bal = 20 - minimum_rose_bal = 30 - topup_ocean_bal = 20 - topup_rose_bal = 30 - - if name == "trueval": - minimum_ocean_bal = 0 - topup_ocean_bal = 0 - - if name == "dfbuyer": - minimum_ocean_bal = 0 - topup_ocean_bal = 0 - minimum_rose_bal = 250 - topup_rose_bal = 250 - - # pylint: disable=line-too-long - print(f"{name}: {ocean_bal:.2f} OCEAN, {rose_bal:.2f} ROSE") - # check if we need to transfer - if minimum_ocean_bal > 0 and ocean_bal < minimum_ocean_bal: - print(f"\t Transfering {topup_ocean_bal} OCEAN to {value}...") - if owner_ocean_balance > topup_ocean_bal: - ocean_token.transfer( - value, - config.web3_config.w3.to_wei(topup_ocean_bal, "ether"), - config.web3_config.owner, - True, - ) - owner_ocean_balance = owner_ocean_balance - topup_ocean_bal - else: - failed = True - print("Not enough OCEAN :(") - if minimum_rose_bal > 0 and rose_bal < minimum_rose_bal: - print(f"\t Transfering {topup_rose_bal} ROSE to {value}...") - if owner_rose_balance > topup_rose_bal: - rose.transfer( - value, - config.web3_config.w3.to_wei(topup_rose_bal, "ether"), - config.web3_config.owner, - True, - ) - owner_rose_balance = owner_rose_balance - topup_rose_bal - else: - failed = True - print("Not enough ROSE :(") - if failed: - sys.exit(1) - else: - sys.exit(0) diff --git a/setup.py b/setup.py index 249a2dfc7..877bdb137 100644 --- a/setup.py +++ b/setup.py @@ -15,18 +15,24 @@ "enforce_typing", "eth-account", "eth-keys", + "flask", "matplotlib", "mypy", "numpy", "pandas", "pathlib", + "polars", + "polars[timezone]", + "pyarrow", "pylint", "pytest", - "pytest-asyncio", + "pytest-asyncio==0.21.1", "pytest-env", + "pyyaml", "requests", "scikit-learn", "statsmodels", + "types-pyYAML", "types-requests", "web3", "sapphire.py", diff --git a/system_tests/conftest.py b/system_tests/conftest.py new file mode 100644 index 000000000..66eeb5068 --- /dev/null +++ b/system_tests/conftest.py @@ -0,0 +1,46 @@ +import os +import sys + +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed + +# Add the root directory to the path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from pdr_backend.conftest_ganache import * # pylint: disable=wildcard-import,wrong-import-position + + +@pytest.fixture(scope="session", autouse=True) +def set_test_env_var(): + previous = os.getenv("TEST") + os.environ["TEST"] = "true" + yield + if previous is None: + del os.environ["TEST"] + else: + os.environ["TEST"] = previous + + +@pytest.fixture(scope="session") +def mock_feeds(): + feeds = { + "0x1": SubgraphFeed( + "Feed: binance | BTC/USDT | 5m", + "0x1", + "BTC", + 100, + 300, + "0xf", + "BTC/USDT", + "5m", + "binance", + ) + } + return feeds + + +@pytest.fixture(scope="session") +def mock_predictoor_contract(): + mock_contract = Mock(spec=PredictoorContract) + mock_contract.payout_multiple.return_value = None + mock_contract.get_agg_predval.return_value = (12, 23) + mock_contract.get_current_epoch.return_value = 100 + return mock_contract diff --git a/system_tests/test_check_network_system.py b/system_tests/test_check_network_system.py new file mode 100644 index 000000000..6643c832c --- /dev/null +++ b/system_tests/test_check_network_system.py @@ -0,0 +1,78 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.constants_opf_addrs import get_opf_addresses +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.analytics.check_network.print_stats") +@patch("pdr_backend.analytics.check_network.check_dfbuyer") +def test_topup(mock_print_stats, mock_check_dfbuyer): + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth.get_balance.return_value = 100 + mock_web3_pp.web3_config = mock_web3_config + mock_web3_pp.web3_config.owner = "0xowner" + + mock_token = MagicMock() + mock_token.balanceOf.return_value = int(5e18) + mock_token.transfer.return_value = True + + mock_query_subgraph = Mock() + mock_query_subgraph.return_value = { + "data": { + "predictContracts": [ + {}, + {}, + {}, + ] + } + } + with patch("pdr_backend.contract.token.Token", return_value=mock_token), patch( + "pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp + ), patch( + "pdr_backend.analytics.check_network.Token", return_value=mock_token + ), patch( + "pdr_backend.analytics.check_network.get_address", return_value="0x1" + ), patch( + "sys.exit" + ), patch( + "pdr_backend.analytics.check_network.query_subgraph", mock_query_subgraph + ): + # Mock sys.argv + sys.argv = ["pdr", "check_network", "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + addresses = get_opf_addresses("sapphire-mainnet") + # Verifying outputs + mock_print.assert_any_call("pdr check_network: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call("Number of Predictoor contracts: 3 - FAILED") + mock_print.assert_any_call("Predictions:") + mock_print.assert_any_call("True Values:") + mock_print.assert_any_call("\nChecking account balances") + + for key in addresses.values(): + if key.startswith("pred"): + mock_print.assert_any_call( + # pylint: disable=line-too-long + f"{key}: OCEAN: 5.00 WARNING LOW OCEAN BALANCE!, Native: 0.00 WARNING LOW NATIVE BALANCE!" + ) + + # Additional assertions + mock_print_stats.assert_called() + mock_check_dfbuyer.assert_called() + mock_token.balanceOf.assert_called() diff --git a/system_tests/test_dfbuyer_agent_system.py b/system_tests/test_dfbuyer_agent_system.py new file mode 100644 index 000000000..9fab3ed55 --- /dev/null +++ b/system_tests/test_dfbuyer_agent_system.py @@ -0,0 +1,101 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.contract.predictoor_batcher import PredictoorBatcher +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.util.constants import SAPPHIRE_MAINNET_CHAINID +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.dfbuyer.dfbuyer_agent.wait_until_subgraph_syncs") +@patch("pdr_backend.dfbuyer.dfbuyer_agent.time.sleep") +def test_dfbuyer_agent(mock_wait_until_subgraph_syncs, mock_time_sleep): + _ = mock_wait_until_subgraph_syncs + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + feeds = { + "0x1": SubgraphFeed( + "Feed: binance | BTC/USDT | 5m", + "0x1", + "BTC", + 100, + 300, + "0xf", + "BTC/USDT", + "5m", + "binance", + ) + } + mock_web3_pp.query_feed_contracts.return_value = feeds + + mock_token = MagicMock() + mock_token.balanceOf.return_value = 100 * 1e18 + mock_token.allowance.return_value = 1e128 + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.get_block.return_value = {"timestamp": 100} + mock_web3_config.owner = "0x00000000000000000000000000000000000c0ffe" + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth.block_number = 100 + mock_web3_config.w3.eth.chain_id = SAPPHIRE_MAINNET_CHAINID + mock_web3_config.w3.to_checksum_address.return_value = "0x1" + mock_web3_config.get_current_timestamp.return_value = 100 + + mock_web3_pp.web3_config = mock_web3_config + + mock_get_consume_so_far_per_contract = Mock() + mock_get_consume_so_far_per_contract.return_value = {"0x1": 120} + + mock_predictoor_batcher = Mock(spec=PredictoorBatcher) + mock_predictoor_batcher.contract_address = "0xpredictoor_batcher" + mock_predictoor_batcher.consume_multiple.return_value = { + "transactionHash": b"0xbatch_submit_tx", + "status": 1, + } + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch( + "pdr_backend.cli.cli_module.get_address", return_value="0x1" + ), patch("pdr_backend.dfbuyer.dfbuyer_agent.Token", return_value=mock_token), patch( + "pdr_backend.dfbuyer.dfbuyer_agent.PredictoorBatcher", + return_value=mock_predictoor_batcher, + ), patch( + "pdr_backend.dfbuyer.dfbuyer_agent.get_address", return_value="0x1" + ), patch( + "pdr_backend.dfbuyer.dfbuyer_agent.get_consume_so_far_per_contract", + mock_get_consume_so_far_per_contract, + ), patch( + "pdr_backend.dfbuyer.dfbuyer_agent.DFBuyerAgent._get_prices", + return_value={"0x1": 100}, + ): + # Mock sys.argv + sys.argv = ["pdr", "dfbuyer", "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr dfbuyer: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call(" Feed: 5m binance BTC/USDT 0x1") + mock_print.assert_any_call("Checking allowance...") + mock_print.assert_any_call("Taking step for timestamp:", 100) + mock_print.assert_any_call( + "Missing consume amounts:", {"0x1": 5165.714285714285} + ) + mock_print.assert_any_call("Missing consume times:", {"0x1": 52}) + mock_print.assert_any_call("Processing 3 batches...") + mock_print.assert_any_call("Consuming contracts ['0x1'] for [20] times.") + + # Additional assertions + mock_web3_pp.query_feed_contracts.assert_called() + mock_predictoor_batcher.consume_multiple.assert_called() + mock_time_sleep.assert_called() + mock_get_consume_so_far_per_contract.assert_called() diff --git a/system_tests/test_get_predictions_info_system.py b/system_tests/test_get_predictions_info_system.py new file mode 100644 index 000000000..e3ca742e4 --- /dev/null +++ b/system_tests/test_get_predictions_info_system.py @@ -0,0 +1,85 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.subgraph.prediction import Prediction +from pdr_backend.subgraph.subgraph_predictions import FilterMode +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.analytics.get_predictions_info.get_cli_statistics") +@patch("pdr_backend.analytics.get_predictions_info.fetch_filtered_predictions") +@patch("pdr_backend.analytics.get_predictions_info.save_analysis_csv") +@patch("pdr_backend.analytics.get_predictions_info.get_all_contract_ids_by_owner") +def test_topup( + mock_get_all_contract_ids_by_owner, + mock_save_analysis_csv, + mock_fetch_filtered_predictions, + mock_get_cli_statistics, +): + mock_get_all_contract_ids_by_owner.return_value = ["0xfeed"] + mock_predictions = [ + Prediction( + "0xfeed", + "BTC", + "5m", + True, + 100.0, + False, + 100, + "binance", + 10.0, + 10, + "0xuser", + ) + ] + mock_fetch_filtered_predictions.return_value = mock_predictions + + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_pp.web3_config = mock_web3_config + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp): + # Mock sys.argv + sys.argv = [ + "pdr", + "get_predictions_info", + "2023-12-01", + "2023-12-31", + "./dir", + "ppss.yaml", + "development", + "--FEEDS", + "0xfeed", + ] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr get_predictions_info: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call("FEEDS=0xfeed") + + # Additional assertions + mock_save_analysis_csv.assert_called_with(mock_predictions, "./dir") + mock_get_cli_statistics.assert_called_with(mock_predictions) + mock_fetch_filtered_predictions.assert_called_with( + 1701388800, + 1703980800, + ["0xfeed"], + "mainnet", + FilterMode.CONTRACT, + payout_only=True, + trueval_only=True, + ) diff --git a/system_tests/test_get_predictoors_info_system.py b/system_tests/test_get_predictoors_info_system.py new file mode 100644 index 000000000..83408a40d --- /dev/null +++ b/system_tests/test_get_predictoors_info_system.py @@ -0,0 +1,71 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.analytics.get_predictoors_info.fetch_filtered_predictions") +@patch("pdr_backend.analytics.get_predictoors_info.get_cli_statistics") +@patch("pdr_backend.analytics.get_predictoors_info.save_prediction_csv") +def test_topup( + mock_fetch_filtered_predictions, mock_get_cli_statistics, mock_save_prediction_csv +): + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth.get_balance.return_value = 100 + mock_web3_pp.web3_config = mock_web3_config + mock_web3_pp.web3_config.owner = "0xowner" + + mock_token = MagicMock() + mock_token.balanceOf.return_value = int(5e18) + mock_token.transfer.return_value = True + + mock_query_subgraph = Mock() + mock_query_subgraph.return_value = { + "data": { + "predictContracts": [ + {}, + {}, + {}, + ] + } + } + with patch("pdr_backend.contract.token.Token", return_value=mock_token), patch( + "pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp + ): + # Mock sys.argv + sys.argv = [ + "pdr", + "get_predictoors_info", + "2023-12-01", + "2023-12-31", + "./dir", + "ppss.yaml", + "development", + "--PDRS", + "0xpredictoor", + ] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr get_predictoors_info: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call("PDRS=0xpredictoor") + + # Additional assertions + mock_fetch_filtered_predictions.assert_called() + mock_get_cli_statistics.assert_called() + mock_save_prediction_csv.assert_called() diff --git a/system_tests/test_get_traction_info_system.py b/system_tests/test_get_traction_info_system.py new file mode 100644 index 000000000..6a0d360e3 --- /dev/null +++ b/system_tests/test_get_traction_info_system.py @@ -0,0 +1,68 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.analytics.get_traction_info.plot_slot_daily_statistics") +@patch("pdr_backend.lake.gql_data_factory.get_all_contract_ids_by_owner") +@patch("pdr_backend.analytics.predictoor_stats.plt.savefig") +def test_topup( + mock_savefig, + mock_get_all_contract_ids_by_owner, + mock_plot_slot_daily_statistics, +): + mock_get_all_contract_ids_by_owner.return_value = ["0xfeed"] + + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.owner_addrs = "0xowner" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_config.w3.owner_address = "0xowner" + mock_web3_pp.web3_config = mock_web3_config + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp): + # Mock sys.argv + sys.argv = [ + "pdr", + "get_traction_info", + "2023-12-01", + "2023-12-31", + "./dir", + "ppss.yaml", + "development", + ] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr get_traction_info: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call( + "Get predictions data across many feeds and timeframes." + ) + mock_print.assert_any_call( + " Data start: timestamp=1701388800000, dt=2023-12-01_00:00:00.000" + ) + mock_print.assert_any_call( + " Data fin: timestamp=1703980800000, dt=2023-12-31_00:00:00.000" + ) + mock_print.assert_any_call( + "Chart created:", "./dir/plots/daily_unique_predictoors.png" + ) + + # Additional assertions + mock_get_all_contract_ids_by_owner.assert_called() + mock_plot_slot_daily_statistics.assert_called() + mock_savefig.assert_called_with("./dir/plots/daily_unique_predictoors.png") diff --git a/system_tests/test_ocean_payout.py b/system_tests/test_ocean_payout.py new file mode 100644 index 000000000..799084d82 --- /dev/null +++ b/system_tests/test_ocean_payout.py @@ -0,0 +1,70 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.contract.predictoor_contract import PredictoorContract +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.constants import SAPPHIRE_MAINNET_CHAINID +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.payout.payout.wait_until_subgraph_syncs") +def test_ocean_payout_test(mock_wait_until_subgraph_syncs): + _ = mock_wait_until_subgraph_syncs + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.owner = "0x00000000000000000000000000000000000c0ffe" + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth.chain_id = SAPPHIRE_MAINNET_CHAINID + mock_web3_pp.web3_config = mock_web3_config + + mock_token = MagicMock() + mock_token.balanceOf.return_value = 100 * 1e18 + + mock_query_pending_payouts = Mock() + mock_query_pending_payouts.return_value = {"0x1": [1, 2, 3], "0x2": [5, 6, 7]} + number_of_slots = 6 # 3 + 3 + + mock_predictoor_contract = Mock(spec=PredictoorContract) + mock_predictoor_contract.payout_multiple.return_value = None + + with patch( + "pdr_backend.payout.payout.query_pending_payouts", mock_query_pending_payouts + ), patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch( + "pdr_backend.contract.token.Token", return_value=mock_token + ), patch( + "pdr_backend.payout.payout.WrappedToken", return_value=mock_token + ), patch( + "pdr_backend.payout.payout.PredictoorContract", + return_value=mock_predictoor_contract, + ): + # Mock sys.argv + sys.argv = ["pdr", "claim_OCEAN", "ppss.yaml"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr claim_OCEAN: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("Starting payout") + mock_print.assert_any_call("Finding pending payouts") + mock_print.assert_any_call(f"Found {number_of_slots} slots") + mock_print.assert_any_call("Claiming payouts for 0x1") + mock_print.assert_any_call("Claiming payouts for 0x2") + mock_print.assert_any_call("Payout done") + + # Additional assertions + mock_query_pending_payouts.assert_called_with( + mock_web3_pp.subgraph_url, mock_web3_config.owner + ) + mock_predictoor_contract.payout_multiple.assert_any_call([1, 2, 3], True) + mock_predictoor_contract.payout_multiple.assert_any_call([5, 6, 7], True) + assert mock_predictoor_contract.payout_multiple.call_count == 2 diff --git a/system_tests/test_predictoor_system.py b/system_tests/test_predictoor_system.py new file mode 100644 index 000000000..a5e3cee19 --- /dev/null +++ b/system_tests/test_predictoor_system.py @@ -0,0 +1,71 @@ +import sys +from unittest.mock import Mock, patch, MagicMock +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.web3_config import Web3Config + + +def setup_mock_web3_pp(mock_feeds, mock_predictoor_contract): + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "development" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + mock_web3_pp.query_feed_contracts.return_value = mock_feeds + mock_web3_pp.get_contracts.return_value = {"0x1": mock_predictoor_contract} + mock_web3_pp.w3.eth.block_number = 100 + mock_predictoor_ss = Mock() + mock_predictoor_ss.get_feed_from_candidates.return_value = mock_feeds["0x1"] + mock_predictoor_ss.s_until_epoch_end = 100 + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_config.get_block.return_value = {"timestamp": 100} + mock_web3_pp.web3_config = mock_web3_config + + return mock_web3_pp, mock_predictoor_ss + + +def _test_predictoor_system(mock_feeds, mock_predictoor_contract, approach): + mock_web3_pp, mock_predictoor_ss = setup_mock_web3_pp( + mock_feeds, mock_predictoor_contract + ) + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch( + "pdr_backend.publisher.publish_assets.get_address", return_value="0x1" + ), patch("pdr_backend.ppss.ppss.PredictoorSS", return_value=mock_predictoor_ss): + # Mock sys.argv + sys.argv = ["pdr", "predictoor", str(approach), "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr predictoor: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call(f"APPROACH={approach}") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call(" Feed: 5m binance BTC/USDT 0x1") + mock_print.assert_any_call("Starting main loop.") + mock_print.assert_any_call("Waiting...", end="") + + # Additional assertions + mock_predictoor_ss.get_feed_from_candidates.assert_called_once() + mock_predictoor_contract.get_current_epoch.assert_called() + + +@patch("pdr_backend.ppss.ppss.PPSS.verify_feed_dependencies") +def test_predictoor_approach_1_system( + mock_verify_feed_dependencies, mock_feeds, mock_predictoor_contract +): + _ = mock_verify_feed_dependencies + _test_predictoor_system(mock_feeds, mock_predictoor_contract, 1) + + +@patch("pdr_backend.ppss.ppss.PPSS.verify_feed_dependencies") +def test_predictoor_approach_3_system( + mock_verify_feed_dependencies, mock_feeds, mock_predictoor_contract +): + _ = mock_verify_feed_dependencies + _test_predictoor_system(mock_feeds, mock_predictoor_contract, 3) diff --git a/system_tests/test_publisher_system.py b/system_tests/test_publisher_system.py new file mode 100644 index 000000000..fbaf1a5c8 --- /dev/null +++ b/system_tests/test_publisher_system.py @@ -0,0 +1,42 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.cli.cli_module.fund_accounts_with_OCEAN") +@patch("pdr_backend.publisher.publish_assets.publish_asset") +def test_dfbuyer_agent(mock_fund_accounts, mock_publish_asset): + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "development" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_pp.web3_config = mock_web3_config + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch( + "pdr_backend.publisher.publish_assets.get_address", return_value="0x1" + ): + # Mock sys.argv + sys.argv = ["pdr", "publisher", "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr publisher: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call("Publish on network = development") + mock_print.assert_any_call("Done publishing.") + + # Additional assertions + mock_fund_accounts.assert_called() + mock_publish_asset.assert_called() diff --git a/system_tests/test_rose_payout.py b/system_tests/test_rose_payout.py new file mode 100644 index 000000000..a96b1b427 --- /dev/null +++ b/system_tests/test_rose_payout.py @@ -0,0 +1,54 @@ +import sys +from unittest.mock import patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.constants import SAPPHIRE_MAINNET_CHAINID + + +def test_rose_payout_test(): + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.web3_config.w3.eth.chain_id = SAPPHIRE_MAINNET_CHAINID + + mock_dfrewards = MagicMock() + mock_dfrewards.get_claimable_rewards.return_value = 100 + mock_dfrewards.claim_rewards.return_value = None + + mock_token = MagicMock() + mock_token.balanceOf.return_value = 100 * 1e18 + + with patch( + "pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp + ) as MockPPSS, patch( + "pdr_backend.payout.payout.DFRewards", return_value=mock_dfrewards + ), patch( + "pdr_backend.contract.token.Token", return_value=mock_token + ), patch( + "pdr_backend.payout.payout.WrappedToken", return_value=mock_token + ): + mock_ppss_instance = MockPPSS.return_value + mock_ppss_instance.web3_pp = mock_web3_pp + + # Mock sys.argv + sys.argv = ["pdr", "claim_ROSE", "ppss.yaml"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("Found 100 wROSE available to claim") + mock_print.assert_any_call("Claiming wROSE rewards...") + mock_print.assert_any_call("Converting wROSE to ROSE") + mock_print.assert_any_call("Found 100.0 wROSE, converting to ROSE...") + mock_print.assert_any_call("ROSE reward claim done") + + # Additional assertions + mock_dfrewards.get_claimable_rewards.assert_called_with( + mock_web3_pp.web3_config.owner, + "0x8Bc2B030b299964eEfb5e1e0b36991352E56D2D3", + ) + mock_dfrewards.claim_rewards.assert_called_with( + mock_web3_pp.web3_config.owner, + "0x8Bc2B030b299964eEfb5e1e0b36991352E56D2D3", + ) diff --git a/system_tests/test_topup_system.py b/system_tests/test_topup_system.py new file mode 100644 index 000000000..d5b6f8eba --- /dev/null +++ b/system_tests/test_topup_system.py @@ -0,0 +1,55 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.constants_opf_addrs import get_opf_addresses +from pdr_backend.util.web3_config import Web3Config + + +def test_topup(): + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.w3 = Mock() + mock_web3_pp.web3_config = mock_web3_config + mock_web3_pp.web3_config.owner = "0xowner" + + mock_token = MagicMock() + balances_arr = [int(5000 * 1e18), int(5000 * 1e18)] + [int(5 * 1e18)] * 100 + mock_token.balanceOf.side_effect = balances_arr + mock_token.transfer.return_value = True + + with patch("pdr_backend.contract.token.Token", return_value=mock_token), patch( + "pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp + ), patch("pdr_backend.util.topup.Token", return_value=mock_token), patch( + "pdr_backend.util.topup.NativeToken", return_value=mock_token + ), patch( + "pdr_backend.util.topup.get_address", return_value="0x1" + ), patch( + "sys.exit" + ): + # Mock sys.argv + sys.argv = ["pdr", "topup", "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + addresses = get_opf_addresses("sapphire-mainnet") + # Verifying outputs + for key, value in addresses.items(): + mock_print.assert_any_call(f"{key}: 5.00 OCEAN, 5.00 ROSE") + if key.startswith("pred"): + mock_print.assert_any_call(f"\t Transferring 20 OCEAN to {value}...") + mock_print.assert_any_call(f"\t Transferring 30 ROSE to {value}...") + if key.startswith("dfbuyer"): + mock_print.assert_any_call(f"\t Transferring 250 ROSE to {value}...") + + # Additional assertions + mock_token.transfer.assert_called() + mock_token.balanceOf.assert_called() diff --git a/system_tests/test_trader_agent_system.py b/system_tests/test_trader_agent_system.py new file mode 100644 index 000000000..0691bc54e --- /dev/null +++ b/system_tests/test_trader_agent_system.py @@ -0,0 +1,89 @@ +import sys +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.util.constants import SAPPHIRE_MAINNET_CHAINID +from pdr_backend.util.web3_config import Web3Config + + +def setup_mock_objects(mock_web3_pp, mock_predictoor_contract, feeds): + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + mock_web3_pp.query_feed_contracts.return_value = feeds + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.get_block.return_value = {"timestamp": 100} + mock_web3_config.owner = "0x00000000000000000000000000000000000c0ffe" + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth.block_number = 100 + mock_web3_config.w3.eth.chain_id = SAPPHIRE_MAINNET_CHAINID + mock_web3_pp.web3_config = mock_web3_config + + mock_token = MagicMock() + mock_token.balanceOf.return_value = 100 * 1e18 + + mock_trader_ss = Mock() + mock_trader_ss.min_buffer = 1 + mock_trader_ss.get_feed_from_candidates.return_value = feeds["0x1"] + + mock_web3_pp.get_contracts.return_value = {"0x1": mock_predictoor_contract} + + return mock_web3_pp, mock_token, mock_trader_ss + + +def _test_trader( + mock_time_sleep, mock_run, mock_predictoor_contract, mock_feeds, approach +): + mock_web3_pp = MagicMock(spec=Web3PP) + + mock_web3_pp, mock_token, mock_trader_ss = setup_mock_objects( + mock_web3_pp, mock_predictoor_contract, mock_feeds + ) + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch( + "pdr_backend.contract.token.Token", return_value=mock_token + ), patch("pdr_backend.payout.payout.WrappedToken", return_value=mock_token), patch( + "pdr_backend.payout.payout.PredictoorContract", + return_value=mock_predictoor_contract, + ), patch( + "pdr_backend.ppss.ppss.TraderSS", + return_value=mock_trader_ss, + ): + # Mock sys.argv + sys.argv = ["pdr", "trader", str(approach), "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr trader: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call(f"APPROACH={approach}") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call(" Feed: 5m binance BTC/USDT 0x1") + + # Additional assertions + mock_web3_pp.query_feed_contracts.assert_called() + mock_trader_ss.get_feed_from_candidates.assert_called_with(mock_feeds) + mock_time_sleep.assert_called() + mock_run.assert_called() + + +@patch("pdr_backend.trader.base_trader_agent.BaseTraderAgent.run") +@patch("pdr_backend.trader.base_trader_agent.time.sleep") +def test_trader_approach_1( + mock_time_sleep, mock_run, mock_predictoor_contract, mock_feeds +): + _test_trader(mock_time_sleep, mock_run, mock_predictoor_contract, mock_feeds, 1) + + +@patch("pdr_backend.trader.base_trader_agent.BaseTraderAgent.run") +@patch("pdr_backend.trader.base_trader_agent.time.sleep") +def test_trader_approach_2( + mock_time_sleep, mock_run, mock_predictoor_contract, mock_feeds +): + _test_trader(mock_time_sleep, mock_run, mock_predictoor_contract, mock_feeds, 2) diff --git a/system_tests/test_trueval_agent_system.py b/system_tests/test_trueval_agent_system.py new file mode 100644 index 000000000..63281890a --- /dev/null +++ b/system_tests/test_trueval_agent_system.py @@ -0,0 +1,80 @@ +import sys + +from unittest.mock import Mock, patch, MagicMock + +from pdr_backend.cli import cli_module +from pdr_backend.contract.predictoor_batcher import PredictoorBatcher +from pdr_backend.contract.slot import Slot +from pdr_backend.ppss.web3_pp import Web3PP +from pdr_backend.subgraph.subgraph_feed import SubgraphFeed +from pdr_backend.util.constants import SAPPHIRE_MAINNET_CHAINID +from pdr_backend.util.web3_config import Web3Config + + +@patch("pdr_backend.trueval.trueval_agent.wait_until_subgraph_syncs") +@patch("pdr_backend.trueval.trueval_agent.time.sleep") +@patch("pdr_backend.trueval.trueval_agent.TruevalAgent.process_trueval_slot") +def test_trueval_batch(mock_wait_until_subgraph_syncs, mock_time_sleep, mock_process): + _ = mock_wait_until_subgraph_syncs + mock_web3_pp = MagicMock(spec=Web3PP) + mock_web3_pp.network = "sapphire-mainnet" + mock_web3_pp.subgraph_url = ( + "http://localhost:8000/subgraphs/name/oceanprotocol/ocean-subgraph" + ) + feeds = { + "0x1": SubgraphFeed( + "Feed: binance | BTC/USDT | 5m", + "0x1", + "BTC", + 100, + 300, + "0xf", + "BTC/USDT", + "5m", + "binance", + ) + } + mock_web3_pp.get_pending_slots.return_value = [Slot(1, feeds["0x1"])] + + mock_web3_config = Mock(spec=Web3Config) + mock_web3_config.get_block.return_value = {"timestamp": 100} + mock_web3_config.owner = "0x00000000000000000000000000000000000c0ffe" + mock_web3_config.w3 = Mock() + mock_web3_config.w3.eth.block_number = 100 + mock_web3_config.w3.eth.chain_id = SAPPHIRE_MAINNET_CHAINID + mock_web3_pp.web3_config = mock_web3_config + + mock_predictoor_batcher = Mock(spec=PredictoorBatcher) + + with patch("pdr_backend.ppss.ppss.Web3PP", return_value=mock_web3_pp), patch( + "pdr_backend.cli.cli_module.get_address", return_value="0x1" + ), patch( + "pdr_backend.trueval.trueval_agent.PredictoorBatcher", + return_value=mock_predictoor_batcher, + ), patch( + "pdr_backend.trueval.trueval_agent.TruevalAgent.process_trueval_slot" + ), patch( + "pdr_backend.trueval.trueval_agent.TruevalAgent.batch_submit_truevals", + return_value="0xbatch_submit_tx", + ): + # Mock sys.argv + sys.argv = ["pdr", "trueval", "ppss.yaml", "development"] + + with patch("builtins.print") as mock_print: + cli_module._do_main() + + # Verifying outputs + mock_print.assert_any_call("pdr trueval: Begin") + mock_print.assert_any_call("Arguments:") + mock_print.assert_any_call("PPSS_FILE=ppss.yaml") + mock_print.assert_any_call("NETWORK=development") + mock_print.assert_any_call("Found 1 pending slots, processing 30") + mock_print.assert_any_call("Submitting transaction...") + mock_print.assert_any_call( + "Tx sent: 0xbatch_submit_tx, sleeping for 30 seconds..." + ) + + # Additional assertions + mock_web3_pp.get_pending_slots.assert_called() + mock_time_sleep.assert_called() + mock_process.assert_called()