Skip to content

Commit

Permalink
frank/update driftpy (#34)
Browse files Browse the repository at this point in the history
* chore: update driftpy 0.7.19 -> 0.7.20

* silence all the annoying mypy linter errors
  • Loading branch information
soundsonacid committed Jan 31, 2024
1 parent 82ccbdd commit e920f08
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 41 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python = "^3.10"
python-dotenv = "^1.0.0"
solana = "^0.30.1"
anchorpy = "^0.17.1"
driftpy = "^0.7.19"
driftpy = "^0.7.20"

[build-system]
requires = ["poetry-core"]
Expand Down
42 changes: 21 additions & 21 deletions python/sdk/jit_proxy/jit_proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Optional, cast

from borsh_construct.enum import _rust_enum
from sumtypes import constructor
from sumtypes import constructor # type: ignore

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from anchorpy import Context, Program

Expand Down Expand Up @@ -74,7 +74,7 @@ async def jit(self, params: JitIxParams):
await self.init()

sub_account_id = self.drift_client.get_sub_account_id_for_ix(
params.sub_account_id
params.sub_account_id # type: ignore
)

order = next(
Expand All @@ -90,11 +90,11 @@ async def jit(self, params: JitIxParams):
params.taker,
self.drift_client.get_user_account(sub_account_id),
],
writable_spot_market_indexes=[order.market_index, QUOTE_SPOT_MARKET_INDEX]
if is_variant(order.market_type, "Spot")
writable_spot_market_indexes=[order.market_index, QUOTE_SPOT_MARKET_INDEX] # type: ignore
if is_variant(order.market_type, "Spot") # type: ignore
else [],
writable_perp_market_indexes=[order.market_index]
if is_variant(order.market_type, "Perp")
writable_perp_market_indexes=[order.market_index] # type: ignore
if is_variant(order.market_type, "Perp") # type: ignore
else [],
)

Expand All @@ -114,35 +114,35 @@ async def jit(self, params: JitIxParams):
)
)

if is_variant(order.market_type, "Spot"):
if is_variant(order.market_type, "Spot"): # type: ignore
remaining_accounts.append(
AccountMeta(
pubkey=self.drift_client.get_spot_market_account(
order.market_index
pubkey=self.drift_client.get_spot_market_account( # type: ignore
order.market_index # type: ignore
).vault,
is_writable=False,
is_signer=False,
)
)
remaining_accounts.append(
AccountMeta(
pubkey=self.drift_client.get_quote_spot_market_account().vault,
pubkey=self.drift_client.get_quote_spot_market_account().vault, # type: ignore
is_writable=False,
is_signer=False,
)
)

jit_params = self.program.type["JitParams"](
jit_params = self.program.type["JitParams"]( # type: ignore
taker_order_id=params.taker_order_id,
max_position=cast(int, params.max_position),
min_position=cast(int, params.min_position),
bid=cast(int, params.bid),
ask=cast(int, params.ask),
price_type=self.get_price_type(params.price_type),
price_type=self.get_price_type(params.price_type), # type: ignore
post_only=self.get_post_only(params.post_only),
)

ix = self.program.instruction["jit"](
ix = self.program.instruction["jit"]( # type: ignore
jit_params,
ctx=Context(
accounts={
Expand All @@ -156,7 +156,7 @@ async def jit(self, params: JitIxParams):
"authority": self.drift_client.wallet.public_key,
"drift_program": self.drift_client.program_id,
},
signers={self.drift_client.wallet},
signers={self.drift_client.wallet}, # type: ignore
remaining_accounts=remaining_accounts,
),
)
Expand All @@ -167,16 +167,16 @@ async def jit(self, params: JitIxParams):

def get_price_type(self, price_type: PriceType):
if is_variant(price_type, "Oracle"):
return self.program.type["PriceType"].Oracle()
return self.program.type["PriceType"].Oracle() # type: ignore
elif is_variant(price_type, "Limit"):
return self.program.type["PriceType"].Limit()
else:
return self.program.type["PriceType"].Limit() # type: ignore
else:
raise ValueError(f"Unknown price type: {str(price_type)}")

def get_post_only(self, post_only: PostOnlyParams):
if is_variant(post_only, "MustPostOnly"):
return self.program.type["PostOnlyParam"].MustPostOnly()
return self.program.type["PostOnlyParam"].MustPostOnly() # type: ignore
elif is_variant(post_only, "TryPostOnly"):
return self.program.type["PostOnlyParam"].TryPostOnly()
return self.program.type["PostOnlyParam"].TryPostOnly() # type: ignore
elif is_variant(post_only, "Slide"):
return self.program.type["PostOnlyParam"].Slide()
return self.program.type["PostOnlyParam"].Slide() # type: ignore
10 changes: 5 additions & 5 deletions python/sdk/jit_proxy/jitter/base_jitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from driftpy.types import is_variant, UserAccount, Order, UserStatsAccount, ReferrerInfo
from driftpy.drift_client import DriftClient
Expand Down Expand Up @@ -72,7 +72,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i
taker_key_str = str(taker_key)

taker_stats_key = get_user_stats_account_public_key(
self.drift_client.program_id, taker.authority
self.drift_client.program_id, taker.authority # type: ignore
)

self.logger.info(f"Taker: {taker.authority}")
Expand Down Expand Up @@ -110,7 +110,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i

if (
order.base_asset_amount - order.base_asset_amount_filled
<= perp_market_account.amm.min_order_size
<= perp_market_account.amm.min_order_size # type: ignore
):
self.logger.info("Order filled within min_order_size")
self.logger.info("----------------------------")
Expand Down Expand Up @@ -138,7 +138,7 @@ async def on_account_update(self, taker: UserAccount, taker_key: Pubkey, slot: i

if (
order.base_asset_amount - order.base_asset_amount_filled
<= spot_market_account.min_order_size
<= spot_market_account.min_order_size # type: ignore
):
self.logger.info("Order filled within min_order_size")
self.logger.info("----------------------------")
Expand Down Expand Up @@ -177,7 +177,7 @@ async def create_try_fill(
order: Order,
order_sig: str,
):
future = asyncio.Future()
future = asyncio.Future() # type: ignore
future.set_result(None)
return future

Expand Down
2 changes: 1 addition & 1 deletion python/sdk/jit_proxy/jitter/jitter_shotgun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Coroutine

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from driftpy.drift_client import DriftClient
from driftpy.auction_subscriber.auction_subscriber import AuctionSubscriber
Expand Down
26 changes: 13 additions & 13 deletions python/sdk/jit_proxy/jitter/jitter_sniper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Any, Coroutine

from solders.pubkey import Pubkey
from solders.pubkey import Pubkey # type: ignore

from driftpy.drift_client import DriftClient
from driftpy.auction_subscriber.auction_subscriber import AuctionSubscriber
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:

auction_start_price = convert_to_number(
get_auction_price_for_oracle_offset_auction(
order, order.slot, oracle_price.price
order, order.slot, oracle_price.price # type: ignore
)
if is_variant(order.order_type, "Oracle")
else order.auction_start_price,
Expand All @@ -252,23 +252,23 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:

auction_end_price = convert_to_number(
get_auction_price_for_oracle_offset_auction(
order, order.slot + order.auction_duration - 1, oracle_price.price
order, order.slot + order.auction_duration - 1, oracle_price.price # type: ignore
)
if is_variant(order.order_type, "Oracle")
else order.auction_end_price,
PRICE_PRECISION,
)

bid = (
convert_to_number(oracle_price.price + params.bid, PRICE_PRECISION)
if is_variant(params.price_type, "Oracle")
else convert_to_number(params.bid, PRICE_PRECISION)
convert_to_number(oracle_price.price + params.bid, PRICE_PRECISION) # type: ignore
if is_variant(params.price_type, "Oracle") # type: ignore
else convert_to_number(params.bid, PRICE_PRECISION) # type: ignore
)

ask = (
convert_to_number(oracle_price.price + params.ask, PRICE_PRECISION)
if is_variant(params.price_type, "Oracle")
else convert_to_number(params.ask, PRICE_PRECISION)
convert_to_number(oracle_price.price + params.ask, PRICE_PRECISION) # type: ignore
if is_variant(params.price_type, "Oracle") # type: ignore
else convert_to_number(params.ask, PRICE_PRECISION) # type: ignore
)

slots_until_cross = 0
Expand All @@ -282,7 +282,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:
if (
convert_to_number(
get_auction_price(
order, order.slot + slots_until_cross, oracle_price.price
order, order.slot + slots_until_cross, oracle_price.price # type: ignore
),
PRICE_PRECISION,
)
Expand All @@ -294,7 +294,7 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:
if (
convert_to_number(
get_auction_price(
order, order.slot + slots_until_cross, oracle_price.price
order, order.slot + slots_until_cross, oracle_price.price # type: ignore
),
PRICE_PRECISION,
)
Expand All @@ -312,12 +312,12 @@ def get_auction_and_order_details(self, order: Order) -> AuctionAndOrderDetails:
auction_start_price,
auction_end_price,
step_size,
oracle_price,
oracle_price, # type: ignore
)

async def wait_for_slot_or_cross_or_expiry(
self, target_slot: int, order: Order, initial_details: AuctionAndOrderDetails
) -> (int, AuctionAndOrderDetails):
) -> (int, AuctionAndOrderDetails): # type: ignore
auction_end_slot = order.auction_duration + order.slot
current_details: AuctionAndOrderDetails = initial_details
will_cross = initial_details.will_cross
Expand Down

0 comments on commit e920f08

Please sign in to comment.