-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #215 from sfmqrb/historical_intraday
- Loading branch information
Showing
7 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import datetime | ||
import os | ||
import logging | ||
import pandas as pd | ||
import asyncio | ||
import aiohttp | ||
from pathlib import Path | ||
from typing import Dict, Optional, Union | ||
from pytse_client.tse_settings import TICKER_TRADE_DETAILS | ||
from pytse_client.utils.trade_dates import get_valid_dates | ||
from pytse_client.ticker.ticker import Ticker | ||
from pytse_client.config import LOGGER_NAME, TRADE_DETAILS_HIST_PATH | ||
from pytse_client.utils.logging_generator import get_logger | ||
|
||
logger = get_logger(f"{LOGGER_NAME}_trade_details", logging.INFO) | ||
ERROR_MSG = "{date} is not a valid trade day. Make sure it is a trade day." | ||
TRADE_DETAILS_HEADER = { | ||
"Accept": "application/json, text/plain, */*", | ||
"Accept-Language": "en-US,en;q=0.9,fa-IR;q=0.8,fa;q=0.7", | ||
"Cache-Control": "no-cache", | ||
"Connection": "keep-alive", | ||
"DNT": "1", | ||
"Pragma": "no-cache", | ||
"User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Mobile Safari/537.36", | ||
"Accept-Encoding": "gzip, deflate", | ||
} | ||
mapping_api_col = {"pTran": "price", "qTitTran": "volume", "hEven": "datetime"} | ||
valid_time_frames_mapping = { | ||
"30s": "30S", | ||
"1m": "1T", | ||
"5m": "5T", | ||
"10m": "10T", | ||
"15m": "15T", | ||
"30m": "30T", | ||
"1h": "1H", | ||
} | ||
reversed_keys = {val: key for key, val in mapping_api_col.items()} | ||
|
||
|
||
def get_trade_details( | ||
symbol_name: str, | ||
start_date: datetime.date, | ||
end_date: Optional[datetime.date] = None, | ||
to_csv: bool = False, | ||
base_path: Optional[str] = None, | ||
timeframe: Optional[str] = None, | ||
aggregate: bool = False, | ||
) -> Dict[str, pd.DataFrame]: | ||
if ( | ||
timeframe is not None | ||
and timeframe not in valid_time_frames_mapping.keys() | ||
): | ||
raise ValueError( | ||
f"The provided timeframe is not valid. It should be among {valid_time_frames_mapping.keys()}" | ||
) | ||
|
||
result = {} | ||
end_date = start_date if not end_date else end_date | ||
ticker = Ticker(symbol_name) | ||
|
||
all_valid_dates = get_valid_dates(ticker, start_date, end_date) | ||
|
||
date_df_list = [] | ||
date_df_list.extend(get_df_valid_dates(ticker, all_valid_dates)) | ||
for date_df in date_df_list: | ||
date, df = date_df | ||
df = common_process(df, date.strftime("%Y%m%d")) | ||
if df.empty: | ||
continue | ||
if timeframe: | ||
ohlcv_df = df.resample(valid_time_frames_mapping[timeframe]).agg( | ||
{"price": "ohlc", "volume": "sum"} | ||
) | ||
ohlcv_df.columns = ["open", "high", "low", "close", "volume"] | ||
ohlcv_df = ohlcv_df.dropna() | ||
result[date] = ohlcv_df | ||
else: | ||
result[date] = df | ||
|
||
if aggregate: | ||
result = {"aggregate": pd.concat(result.values())} | ||
result["aggregate"] = result["aggregate"].sort_values( | ||
["datetime"], ascending=[True] | ||
) | ||
|
||
if to_csv: | ||
for date in result: | ||
write_to_csv(result[date], base_path, date) | ||
return result | ||
|
||
|
||
def write_to_csv( | ||
df: pd.DataFrame, | ||
base_path: Union[str, None], | ||
date: Union[datetime.date, str], | ||
): | ||
base_path = base_path or TRADE_DETAILS_HIST_PATH | ||
Path(base_path).mkdir(parents=True, exist_ok=True) | ||
extension = ( | ||
date.strftime("%Y-%m-%d") if type(date) == datetime.date else date | ||
) | ||
file_name = f"trade_details_{extension}.csv" | ||
path = os.path.join(base_path, file_name) | ||
df.to_csv(path) | ||
|
||
|
||
def common_process(df: pd.DataFrame, date: str): | ||
if len(df) == 0: | ||
return pd.DataFrame(columns=list(mapping_api_col.values())) | ||
df.rename(columns=mapping_api_col, inplace=True) | ||
df = df.loc[:, list(mapping_api_col.values())] | ||
df["datetime"] = pd.to_datetime( | ||
date + " " + df["datetime"].astype(str), format="%Y%m%d %H%M%S" | ||
) | ||
df = df.sort_values(["datetime"], ascending=[True]) | ||
df.set_index("datetime", inplace=True) | ||
return df | ||
|
||
|
||
def get_df_valid_dates( | ||
ticker: Ticker, | ||
valid_dates: list, | ||
): | ||
return asyncio.run( | ||
get_df_valid_dates_async( | ||
ticker, | ||
valid_dates, | ||
), | ||
) | ||
|
||
|
||
async def get_df_valid_dates_async(ticker, valid_dates): | ||
conn = aiohttp.TCPConnector(limit=25) | ||
async with aiohttp.ClientSession(connector=conn) as session: | ||
tasks = [] | ||
for date in valid_dates: | ||
tasks.append(_get_trade_details(ticker, date, session)) | ||
results = await asyncio.gather(*tasks) | ||
|
||
return results | ||
|
||
|
||
async def _get_trade_details(ticker: Ticker, date_obj: datetime.date, session): | ||
index = ticker.index | ||
date = date_obj.strftime("%Y%m%d") | ||
url = TICKER_TRADE_DETAILS.format(index=index, date=date) | ||
max_retries = 9 | ||
retry_count = 0 | ||
|
||
while retry_count < max_retries: | ||
try: | ||
async with session.get( | ||
url, headers=TRADE_DETAILS_HEADER, timeout=100 | ||
) as response: | ||
if response.status == 503: | ||
logger.info( | ||
f"Received 503 Service Unavailable on {date_obj}. Retrying..." | ||
) | ||
retry_count += 1 | ||
await asyncio.sleep(1) | ||
else: | ||
response.raise_for_status() | ||
data = await response.json() | ||
logger.info( | ||
f"Successfully fetched trade details on {date_obj} from tse" | ||
) | ||
return [date_obj, pd.json_normalize(data["tradeHistory"])] | ||
except (aiohttp.ClientError, asyncio.TimeoutError): | ||
logger.error(f"Request failed for {date_obj}. Retrying...") | ||
retry_count += 1 | ||
await asyncio.sleep(1) | ||
|
||
raise Exception( | ||
f"Failed to fetch trade details for {ticker} on {date_obj} after {max_retries} retries" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import datetime | ||
from pytse_client.ticker.ticker import Ticker | ||
|
||
|
||
def get_valid_dates( | ||
ticker: Ticker, | ||
start_date: datetime.date, | ||
end_date: datetime.date, | ||
): | ||
all_valid_dates = [] | ||
for n in range((end_date - start_date).days + 1): | ||
date = start_date + datetime.timedelta(n) | ||
if date in ticker.trade_dates: | ||
all_valid_dates.append(date) | ||
return all_valid_dates |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import shutil | ||
import unittest | ||
from os.path import exists | ||
from pathlib import Path | ||
from datetime import date | ||
from parameterized import parameterized | ||
from pytse_client import get_trade_details | ||
from pytse_client.historical_intraday.trade_details import ( | ||
valid_time_frames_mapping, | ||
) | ||
|
||
|
||
class TestTradeDetails(unittest.TestCase): | ||
params = [ | ||
("خودرو", (6374, 2)), | ||
("وغدیر", (1915, 2)), | ||
] | ||
single_ticker = "اهرم" | ||
valid_timeframes = [ | ||
(timeframe,) for timeframe in valid_time_frames_mapping.keys() | ||
] | ||
|
||
def setUp(self) -> None: | ||
self.write_csv_path = "test_dir" | ||
self.valid_start_date = date(2023, 2, 27) | ||
self.valid_end_date = date(2023, 2, 28) | ||
return super().setUp() | ||
|
||
def tearDown(self) -> None: | ||
shutil.rmtree(self.write_csv_path) | ||
return super().tearDown() | ||
|
||
@parameterized.expand(params) | ||
def test_diff_trade_details(self, symbol_name: str, shape: tuple): | ||
dict_df = get_trade_details( | ||
symbol_name=symbol_name, | ||
start_date=self.valid_start_date, | ||
end_date=self.valid_end_date, | ||
to_csv=True, | ||
base_path=self.write_csv_path, | ||
) | ||
self.assertTrue(exists(Path(f"{self.write_csv_path}"))) | ||
self.assertGreater(len(dict_df), 0) | ||
self.assertEqual(dict_df[self.valid_start_date].shape, shape) | ||
|
||
@parameterized.expand(valid_timeframes) | ||
def test_timeframes_aggregate(self, timeframe: str): | ||
dict_df = get_trade_details( | ||
symbol_name=self.single_ticker, | ||
start_date=self.valid_start_date, | ||
end_date=self.valid_end_date, | ||
to_csv=True, | ||
base_path=self.write_csv_path, | ||
timeframe=timeframe, | ||
aggregate=True, | ||
) | ||
self.assertTrue(exists(Path(f"{self.write_csv_path}"))) | ||
self.assertFalse(dict_df["aggregate"].empty) | ||
|
||
|
||
if __name__ == "__main__": | ||
suite = unittest.TestLoader().loadTestsFromTestCase(TestOrderBook) | ||
unittest.TextTestRunner(verbosity=3).run(suite) |