From e920f08da06b507bb9d3da44203dde2c99254be5 Mon Sep 17 00:00:00 2001 From: frank <98238480+soundsonacid@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:25:27 -0800 Subject: [PATCH] frank/update driftpy (#34) * chore: update driftpy 0.7.19 -> 0.7.20 * silence all the annoying mypy linter errors --- python/pyproject.toml | 2 +- python/sdk/jit_proxy/jit_proxy_client.py | 42 +++++++++---------- python/sdk/jit_proxy/jitter/base_jitter.py | 10 ++--- python/sdk/jit_proxy/jitter/jitter_shotgun.py | 2 +- python/sdk/jit_proxy/jitter/jitter_sniper.py | 26 ++++++------ 5 files changed, 41 insertions(+), 41 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 3904cbe..06d6909 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -13,7 +13,7 @@ python = "^3.10" python-dotenv = "^1.0.0" solana = "^0.30.1" anchorpy = "^0.17.1" -driftpy = "^0.7.19" +driftpy = "^0.7.20" [build-system] requires = ["poetry-core"] diff --git a/python/sdk/jit_proxy/jit_proxy_client.py b/python/sdk/jit_proxy/jit_proxy_client.py index 7041af7..158a5bc 100644 --- a/python/sdk/jit_proxy/jit_proxy_client.py +++ b/python/sdk/jit_proxy/jit_proxy_client.py @@ -2,9 +2,9 @@ from typing import Optional, cast from borsh_construct.enum import _rust_enum -from sumtypes import constructor +from sumtypes import constructor # type: ignore -from solders.pubkey import Pubkey +from solders.pubkey import Pubkey # type: ignore from anchorpy import Context, Program @@ -74,7 +74,7 @@ async def jit(self, params: JitIxParams): await self.init() sub_account_id = self.drift_client.get_sub_account_id_for_ix( - params.sub_account_id + params.sub_account_id # type: ignore ) order = next( @@ -90,11 +90,11 @@ async def jit(self, params: JitIxParams): params.taker, self.drift_client.get_user_account(sub_account_id), ], - writable_spot_market_indexes=[order.market_index, QUOTE_SPOT_MARKET_INDEX] - if is_variant(order.market_type, "Spot") + writable_spot_market_indexes=[order.market_index, QUOTE_SPOT_MARKET_INDEX] # type: ignore + if is_variant(order.market_type, "Spot") # type: ignore else [], - writable_perp_market_indexes=[order.market_index] - if is_variant(order.market_type, "Perp") + writable_perp_market_indexes=[order.market_index] # type: ignore + if is_variant(order.market_type, "Perp") # type: ignore else [], ) @@ -114,11 +114,11 @@ async def jit(self, params: JitIxParams): ) ) - if is_variant(order.market_type, "Spot"): + if is_variant(order.market_type, "Spot"): # type: ignore remaining_accounts.append( AccountMeta( - pubkey=self.drift_client.get_spot_market_account( - order.market_index + pubkey=self.drift_client.get_spot_market_account( # type: ignore + order.market_index # type: ignore ).vault, is_writable=False, is_signer=False, @@ -126,23 +126,23 @@ async def jit(self, params: JitIxParams): ) remaining_accounts.append( AccountMeta( - pubkey=self.drift_client.get_quote_spot_market_account().vault, + pubkey=self.drift_client.get_quote_spot_market_account().vault, # type: ignore is_writable=False, is_signer=False, ) ) - jit_params = self.program.type["JitParams"]( + jit_params = self.program.type["JitParams"]( # type: ignore taker_order_id=params.taker_order_id, max_position=cast(int, params.max_position), min_position=cast(int, params.min_position), bid=cast(int, params.bid), ask=cast(int, params.ask), - price_type=self.get_price_type(params.price_type), + price_type=self.get_price_type(params.price_type), # type: ignore post_only=self.get_post_only(params.post_only), ) - ix = self.program.instruction["jit"]( + ix = self.program.instruction["jit"]( # type: ignore jit_params, ctx=Context( accounts={ @@ -156,7 +156,7 @@ async def jit(self, params: JitIxParams): "authority": self.drift_client.wallet.public_key, "drift_program": self.drift_client.program_id, }, - signers={self.drift_client.wallet}, + signers={self.drift_client.wallet}, # type: ignore remaining_accounts=remaining_accounts, ), ) @@ -167,16 +167,16 @@ async def jit(self, params: JitIxParams): def get_price_type(self, price_type: PriceType): if is_variant(price_type, "Oracle"): - return self.program.type["PriceType"].Oracle() + return self.program.type["PriceType"].Oracle() # type: ignore elif is_variant(price_type, "Limit"): - return self.program.type["PriceType"].Limit() - else: + return self.program.type["PriceType"].Limit() # type: ignore + else: raise ValueError(f"Unknown price type: {str(price_type)}") def get_post_only(self, post_only: PostOnlyParams): if is_variant(post_only, "MustPostOnly"): - return self.program.type["PostOnlyParam"].MustPostOnly() + return self.program.type["PostOnlyParam"].MustPostOnly() # type: ignore elif is_variant(post_only, "TryPostOnly"): - return self.program.type["PostOnlyParam"].TryPostOnly() + return self.program.type["PostOnlyParam"].TryPostOnly() # type: ignore elif is_variant(post_only, "Slide"): - return self.program.type["PostOnlyParam"].Slide() + return self.program.type["PostOnlyParam"].Slide() # type: ignore diff --git a/python/sdk/jit_proxy/jitter/base_jitter.py b/python/sdk/jit_proxy/jitter/base_jitter.py index bb4038a..59214ed 100644 --- a/python/sdk/jit_proxy/jitter/base_jitter.py +++ b/python/sdk/jit_proxy/jitter/base_jitter.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from solders.pubkey import Pubkey +from solders.pubkey import Pubkey # type: ignore from driftpy.types import is_variant, UserAccount, Order, UserStatsAccount, ReferrerInfo from driftpy.drift_client import DriftClient @@ -72,7 +72,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i taker_key_str = str(taker_key) taker_stats_key = get_user_stats_account_public_key( - self.drift_client.program_id, taker.authority + self.drift_client.program_id, taker.authority # type: ignore ) self.logger.info(f"Taker: {taker.authority}") @@ -110,7 +110,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i if ( order.base_asset_amount - order.base_asset_amount_filled - <= perp_market_account.amm.min_order_size + <= perp_market_account.amm.min_order_size # type: ignore ): self.logger.info("Order filled within min_order_size") self.logger.info("----------------------------") @@ -138,7 +138,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i if ( order.base_asset_amount - order.base_asset_amount_filled - <= spot_market_account.min_order_size + <= spot_market_account.min_order_size # type: ignore ): self.logger.info("Order filled within min_order_size") self.logger.info("----------------------------") @@ -177,7 +177,7 @@ async def create_try_fill( order: Order, order_sig: str, ): - future = asyncio.Future() + future = asyncio.Future() # type: ignore future.set_result(None) return future diff --git a/python/sdk/jit_proxy/jitter/jitter_shotgun.py b/python/sdk/jit_proxy/jitter/jitter_shotgun.py index 1a701f0..4adc493 100644 --- a/python/sdk/jit_proxy/jitter/jitter_shotgun.py +++ b/python/sdk/jit_proxy/jitter/jitter_shotgun.py @@ -2,7 +2,7 @@ from typing import Any, Coroutine -from solders.pubkey import Pubkey +from solders.pubkey import Pubkey # type: ignore from driftpy.drift_client import DriftClient from driftpy.auction_subscriber.auction_subscriber import AuctionSubscriber diff --git a/python/sdk/jit_proxy/jitter/jitter_sniper.py b/python/sdk/jit_proxy/jitter/jitter_sniper.py index 4fe4dfb..1168312 100644 --- a/python/sdk/jit_proxy/jitter/jitter_sniper.py +++ b/python/sdk/jit_proxy/jitter/jitter_sniper.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any, Coroutine -from solders.pubkey import Pubkey +from solders.pubkey import Pubkey # type: ignore from driftpy.drift_client import DriftClient from driftpy.auction_subscriber.auction_subscriber import AuctionSubscriber @@ -243,7 +243,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails: auction_start_price = convert_to_number( get_auction_price_for_oracle_offset_auction( - order, order.slot, oracle_price.price + order, order.slot, oracle_price.price # type: ignore ) if is_variant(order.order_type, "Oracle") else order.auction_start_price, @@ -252,7 +252,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails: auction_end_price = convert_to_number( get_auction_price_for_oracle_offset_auction( - order, order.slot + order.auction_duration - 1, oracle_price.price + order, order.slot + order.auction_duration - 1, oracle_price.price # type: ignore ) if is_variant(order.order_type, "Oracle") else order.auction_end_price, @@ -260,15 +260,15 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails: ) bid = ( - convert_to_number(oracle_price.price + params.bid, PRICE_PRECISION) - if is_variant(params.price_type, "Oracle") - else convert_to_number(params.bid, PRICE_PRECISION) + convert_to_number(oracle_price.price + params.bid, PRICE_PRECISION) # type: ignore + if is_variant(params.price_type, "Oracle") # type: ignore + else convert_to_number(params.bid, PRICE_PRECISION) # type: ignore ) ask = ( - convert_to_number(oracle_price.price + params.ask, PRICE_PRECISION) - if is_variant(params.price_type, "Oracle") - else convert_to_number(params.ask, PRICE_PRECISION) + convert_to_number(oracle_price.price + params.ask, PRICE_PRECISION) # type: ignore + if is_variant(params.price_type, "Oracle") # type: ignore + else convert_to_number(params.ask, PRICE_PRECISION) # type: ignore ) slots_until_cross = 0 @@ -282,7 +282,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails: if ( convert_to_number( get_auction_price( - order, order.slot + slots_until_cross, oracle_price.price + order, order.slot + slots_until_cross, oracle_price.price # type: ignore ), PRICE_PRECISION, ) @@ -294,7 +294,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails: if ( convert_to_number( get_auction_price( - order, order.slot + slots_until_cross, oracle_price.price + order, order.slot + slots_until_cross, oracle_price.price # type: ignore ), PRICE_PRECISION, ) @@ -312,12 +312,12 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails: auction_start_price, auction_end_price, step_size, - oracle_price, + oracle_price, # type: ignore ) async def wait_for_slot_or_cross_or_expiry( self, target_slot: int, order: Order, initial_details: AuctionAndOrderDetails - ) -> (int, AuctionAndOrderDetails): + ) -> (int, AuctionAndOrderDetails): # type: ignore auction_end_slot = order.auction_duration + order.slot current_details: AuctionAndOrderDetails = initial_details will_cross = initial_details.will_cross