Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix abstract method not implemented error; Fix parameter checking; #12

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions rqalpha_mod_incremental/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,15 @@ def start_up(self, env, mod_config):
self._recorder = None
self._mod_config = mod_config

if not self._mod_config.persist_folder:
return
if mod_config.recorder == "CsvRecorder":
if mod_config.persist_folder is None:
raise RuntimeError(_(u"You need to set persist_folder to use CsvRecorder!"))
elif mod_config.recorder == "MongodbRecorder":
if mod_config.strategy_id is None or mod_config.mongo_url is None or mod_config.mongo_dbname is None:
raise RuntimeError(_(u"MongodbRecorder requires strategy_id, mongo_url and mongo_dbname! "
u"But got {}").format(mod_config))
else:
raise RuntimeError(_(u"unknown recorder {}").format(mod_config.recorder))

config = self._env.config
if not env.data_source:
Expand All @@ -60,14 +67,10 @@ def _set_env_and_data_source(self):
mod_config = self._mod_config
system_log.info("use recorder {}", mod_config.recorder)
if mod_config.recorder == "CsvRecorder":
if not mod_config.persist_folder:
raise RuntimeError(_(u"You need to set persist_folder to use CsvRecorder"))
persist_folder = os.path.join(mod_config.persist_folder, "persist", str(mod_config.strategy_id))
persist_provider = DiskPersistProvider(persist_folder)
self._recorder = recorders.CsvRecorder(persist_folder)
elif mod_config.recorder == "MongodbRecorder":
if mod_config.strategy_id is None:
raise RuntimeError(_(u"You need to set strategy_id"))
persist_provider = persist_providers.MongodbPersistProvider(mod_config.strategy_id, mod_config.mongo_url,
mod_config.mongo_dbname)
self._recorder = recorders.MongodbRecorder(mod_config.strategy_id,
Expand All @@ -92,12 +95,14 @@ def _set_env_and_data_source(self):
if persist_meta:
# 不修改回测开始时间
self._env.config.base.start_date = datetime.datetime.strptime(persist_meta['start_date'], '%Y-%m-%d').date()
event_start_time = datetime.datetime.strptime(persist_meta['last_end_time'], '%Y-%m-%d').date() + datetime.timedelta(days=1)
event_start_time = datetime.datetime.strptime(persist_meta['last_end_time'],
'%Y-%m-%d').date() + datetime.timedelta(days=1)
# 代表历史有运行过,根据历史上次运行的end_date下一天设为事件发送的start_time
self._meta["origin_start_date"] = persist_meta["origin_start_date"]
self._meta["start_date"] = persist_meta["start_date"]
if self._meta["last_end_time"] <= persist_meta["last_end_time"]:
raise ValueError('The end_date should after end_date({}) last time'.format(persist_meta["last_end_time"]))
raise ValueError(
'The end_date should after end_date({}) last time'.format(persist_meta["last_end_time"]))
self._last_end_date = datetime.datetime.strptime(persist_meta["last_end_time"], "%Y-%m-%d").date()
self._event_start_time = event_start_time
self._overwrite_event_data_source_func()
Expand Down Expand Up @@ -137,8 +142,6 @@ def on_settlement(self, event):
return True

def tear_down(self, success, exception=None):
if not self._mod_config.persist_folder:
return
if exception is None:
self._recorder.store_meta(self._meta)
self._recorder.flush()
Expand Down
43 changes: 39 additions & 4 deletions rqalpha_mod_incremental/persist_providers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,52 @@
import os
import datetime

import jsonpickle
import pandas as pd
from rqrisk import Risk
from rqalpha.interface import AbstractPersistProvider


def get_performance(strategy_id, analysis_data):
daily_returns = analysis_data['portfolio_daily_returns']
benchmark = analysis_data['benchmark_daily_returns']
dates = [p['date'] for p in analysis_data['total_portfolios']]
assert len(daily_returns) == len(benchmark) == len(dates), 'unmatched length'
daily_returns = pd.Series(daily_returns, index=dates)
benchmark = pd.Series(benchmark, index=dates)
risk = Risk(daily_returns, benchmark, 0.)
perf = risk.all()
perf['strategy_id'] = strategy_id
perf['update_time'] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
perf['start_date'] = analysis_data['total_portfolios'][0]['date'].strftime('%Y-%m-%d')
perf['end_date'] = analysis_data['total_portfolios'][-1]['date'].strftime('%Y-%m-%d')
return perf


class MongodbPersistProvider(AbstractPersistProvider):
def __init__(self, strategy_id, mongo_url, mongo_db):
import pymongo
import gridfs

persist_db = pymongo.MongoClient(mongo_url)[mongo_db]
self.persist_db = pymongo.MongoClient(mongo_url)[mongo_db]
self._strategy_id = strategy_id
self._fs = gridfs.GridFS(persist_db)
self._fs = gridfs.GridFS(self.persist_db)

def store(self, key, value):
update_time = datetime.datetime.now()
self._fs.put(value, strategy_id=self._strategy_id, key=key, update_time=update_time)
for grid_out in self._fs.find({"strategy_id": self._strategy_id, "key": key, "update_time": {"$lt": update_time}}):
for grid_out in self._fs.find(
{"strategy_id": self._strategy_id, "key": key, "update_time": {"$lt": update_time}}):
self._fs.delete(grid_out._id)
if key == "mod_sys_analyser":
self._store_performance(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要做这个操作


def _store_performance(self, analysis_data):
try:
perf = get_performance(self._strategy_id,
jsonpickle.loads(analysis_data.decode("utf-8")))
self.persist_db['performance'].update({"strategy_id": self._strategy_id}, perf, upsert=True)
except Exception as e:
print(e)

def load(self, key, large_file=False):
import gridfs
Expand All @@ -27,6 +56,12 @@ def load(self, key, large_file=False):
except gridfs.errors.NoFile:
return None

def should_resume(self):
return False

def should_run_init(self):
return False


class DiskPersistProvider(AbstractPersistProvider):
def __init__(self, path="./persist"):
Expand Down