Skip to content

Commit

Permalink
Merge pull request #31 from rcholic/async_token_store
Browse files Browse the repository at this point in the history
use backoff to refresh tokens
  • Loading branch information
rcholic authored Jul 14, 2024
2 parents e4c8f8a + f371282 commit 205cb70
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 125 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ cython_debug/
notebooks
*.log
data-logs/
*tokens.json
*tokens*.json
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ pip install CSchwabPy
```python

# save these lines in a file named like cschwab.py
from cschwabpy.SchwabAsyncClient import SchwabAsyncClient
# NOTE: should use SchwabClient to get tokens manually after version 0.1.3
from cschwabpy.SchwabClient import SchwabClient

app_client_key = "---your-app-client-key-here-"
app_secret = "app-secret"

schwab_client = SchwabAsyncClient(app_client_id=app_client_key, app_secret=app_secret)
schwab_client = SchwabClient(app_client_id=app_client_key, app_secret=app_secret)
schwab_client.get_tokens_manually()

# run in your Terminal, follow the prompt to complete authentication:
Expand All @@ -47,13 +48,13 @@ schwab_client.get_tokens_manually()
#----------------
ticker = '$SPX'
# get option expirations:
expiration_list = await schwab_client.get_option_expirations_async(underlying_symbol = ticker)
expiration_list = schwab_client.get_option_expirations(underlying_symbol = ticker)

# download SPX option chains
from_date = 2024-07-01
to_date = 2024-07-01

opt_chain_result = await schwab_client.download_option_chain_async(ticker, from_date, to_date)
opt_chain_result = schwab_client.download_option_chain(ticker, from_date, to_date)

# get call-put dataframe pairs by expiration
opt_df_pairs = opt_chain_result.to_dataframe_pairs_by_expiration()
Expand Down
84 changes: 10 additions & 74 deletions cschwabpy/SchwabAsyncClient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cschwabpy.models.token import Tokens, ITokenStore, LocalTokenStore
from cschwabpy.models.token import Tokens, IAsyncTokenStore, AsyncLocalTokenStore
from cschwabpy.models import (
OptionChainQueryFilter,
OptionContractType,
Expand Down Expand Up @@ -28,7 +28,7 @@
SCHWAB_AUTH_PATH,
SCHWAB_TOKEN_PATH,
)

import backoff
import httpx
import re
import base64
Expand All @@ -42,7 +42,7 @@ def __init__(
self,
app_client_id: str,
app_secret: str,
token_store: ITokenStore = LocalTokenStore(),
token_store: IAsyncTokenStore = AsyncLocalTokenStore(),
tokens: Optional[Tokens] = None,
http_client: Optional[httpx.AsyncClient] = None,
) -> None:
Expand All @@ -51,28 +51,23 @@ def __init__(
self.__token_store = token_store
self.__client = http_client
self.__keep_client_alive = http_client is not None
if (
tokens is not None
and tokens.is_access_token_valid
and tokens.is_refresh_token_valid
):
token_store.save_tokens(tokens)

self.__tokens = token_store.get_tokens()
self.__tokens = tokens

@property
def token_url(self) -> str:
return f"{SCHWAB_API_BASE_URL}/{SCHWAB_TOKEN_PATH}"

@backoff.on_exception(backoff.expo, Exception, max_tries=3, max_time=10)
async def _ensure_valid_access_token(self, force_refresh: bool = False) -> bool:
if self.__tokens is None:
self.__tokens = await self.__token_store.get_tokens()

if self.__tokens is None:
raise Exception(
"Tokens are not available. Please use get_tokens_manually() to get tokens first."
)

if self.__tokens.is_access_token_valid and not force_refresh:
return True

client = httpx.AsyncClient() if self.__client is None else self.__client
try:
key_sec_encoded = self.__encode_app_key_secret()
Expand All @@ -91,7 +86,7 @@ async def _ensure_valid_access_token(self, force_refresh: bool = False) -> bool:
if response.status_code == 200:
json_res = response.json()
tokens = Tokens(**json_res)
self.__token_store.save_tokens(tokens)
await self.__token_store.save_tokens(tokens)
return True
else:
raise Exception(
Expand Down Expand Up @@ -385,67 +380,8 @@ async def download_option_chain_async(
url=target_url, params={}, headers=self.__auth_header()
)
json_res = response.json()
print("json_res: ", json_res)
return OptionChain(**json_res)
finally:
if not self.__keep_client_alive:
await client.aclose()

def get_tokens_manually(
self,
) -> None:
"""Manual steps to get tokens from Charles Schwab API."""
from prompt_toolkit import prompt
import urllib.parse as url_parser

redirect_uri = prompt("Enter your redirect uri> ").strip()
complete_auth_url = f"{SCHWAB_API_BASE_URL}/{SCHWAB_AUTH_PATH}?response_type=code&client_id={self.__client_id}&redirect_uri={redirect_uri}"
print(
f"Copy and open the following URL in browser. Complete Login & Authorization:\n {complete_auth_url}"
)
auth_code_response_url = prompt(
"Paste the entire authorization response URL here> "
).strip()

auth_code = ""
try:
auth_code_pattern = re.compile(r"code=(.+)&?")
d = re.search(auth_code_pattern, auth_code_response_url)
if d:
auth_code = d.group(1)
auth_code = url_parser.unquote(auth_code.split("&")[0])
else:
raise Exception(
"authorization response url does not contain authorization code"
)

if len(auth_code) == 0:
raise Exception("authorization code is empty")
except Exception as ex:
raise Exception(
"Failed to get authorization code. Please try again. Exception: ", ex
)

key_sec_encoded = self.__encode_app_key_secret()
with httpx.Client() as client:
response = client.post(
url=self.token_url,
headers={
"Authorization": f"Basic {key_sec_encoded}",
"Content-Type": "application/x-www-form-urlencoded",
},
data={
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": redirect_uri,
},
)

if response.status_code == 200:
json_res = response.json()
tokens = Tokens(**json_res)
self.__token_store.save_tokens(tokens)
print(
f"Tokens saved successfully at path: {self.__token_store.token_file_path}"
)
else:
print("Failed to get tokens. Please try again.")
14 changes: 5 additions & 9 deletions cschwabpy/SchwabClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
InstrumentProjection,
)
import cschwabpy.util as util

import backoff
from datetime import datetime, timedelta
from typing import Optional, List, Mapping
from cschwabpy.costants import (
Expand Down Expand Up @@ -53,20 +53,16 @@ def __init__(
self.__token_store = token_store
self.__client = http_client
self.__keep_client_alive = http_client is not None
if (
tokens is not None
and tokens.is_access_token_valid
and tokens.is_refresh_token_valid
):
token_store.save_tokens(tokens)

self.__tokens = token_store.get_tokens()
self.__tokens = tokens

@property
def token_url(self) -> str:
return f"{SCHWAB_API_BASE_URL}/{SCHWAB_TOKEN_PATH}"

@backoff.on_exception(backoff.expo, Exception, max_tries=3, max_time=10)
def _ensure_valid_access_token(self, force_refresh: bool = False) -> bool:
if self.__tokens is None:
self.__tokens = self.__token_store.get_tokens()
if self.__tokens is None:
raise Exception(
"Tokens are not available. Please use get_tokens_manually() to get tokens first."
Expand Down
46 changes: 46 additions & 0 deletions cschwabpy/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import json
import time
import aiofiles as af
from pathlib import Path

REFRESH_TOKEN_VALIDITY_SECONDS = 7 * 24 * 60 * 60 # 7 days
Expand Down Expand Up @@ -77,3 +78,48 @@ def get_tokens(self) -> Optional[Tokens]:
def save_tokens(self, tokens: Tokens) -> None:
with open(self.token_file_path, "w") as token_file:
token_file.write(json.dumps(tokens.to_json(), indent=4))


class IAsyncTokenStore(Protocol):
@property
def token_output_path(self) -> str:
"""Path for outputting tokens."""
return ""

async def get_tokens(self) -> Optional[Tokens]:
pass

async def save_tokens(self, tokens: Tokens) -> None:
pass


class AsyncLocalTokenStore(IAsyncTokenStore):
def __init__(
self, json_file_name: str = "tokens.json", file_path: Optional[str] = None
):
self.file_name = json_file_name
self.token_file_path = file_path
if file_path is None:
self.token_file_path = Path(Path(__file__).parent, json_file_name)
else:
self.token_file_path = Path(file_path)

if not os.path.exists(self.token_file_path.parent):
os.makedirs(self.token_file_path.parent)

@property
def token_output_path(self) -> str:
return str(self.token_file_path)

async def get_tokens(self) -> Optional[Tokens]:
try:
async with af.open(self.token_file_path, mode="r") as token_file:
token_json_str = await token_file.read()
tokens_json = json.loads(token_json_str)
return Tokens(**tokens_json)
except:
return None

async def save_tokens(self, tokens: Tokens) -> None:
async with af.open(self.token_file_path, mode="w") as token_file:
await token_file.write(json.dumps(tokens.to_json(), indent=4))
Loading

0 comments on commit 205cb70

Please sign in to comment.