-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sql_codegen.py
82 lines (68 loc) · 2.22 KB
/
sql_codegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import re
from typing import Iterator
import shutil
OUT_DIR = './foxfeed/gen/'
OUT_FILE = 'db.py'
INPUT = [
('score_posts', 'ScorePostsOutputModel'),
('score_by_interactions', 'ScoreByInteractionOutputModel'),
('find_unlinks', 'FindUnlinksOutputModel'),
]
HEADDER = '''
# This is kinda weird and really bad sorry
from datetime import datetime
import foxfeed.database
from typing import List, Union
Arg = Union[str, int, float, bool, datetime]
def escape(a: Arg) -> str:
if isinstance(a, bool):
return 'TRUE' if a else 'FALSE'
if isinstance(a, str):
assert "'" not in a
return "'" + a + "'"
if isinstance(a, int):
return str(a)
if isinstance(a, float):
return str(a)
if isinstance(a, datetime):
return "'" + a.isoformat().split('.')[0] + "'::timestamp"
'''
from datetime import datetime
d = datetime.now()
d.isoformat
def codegen_for_query(function_name: str, output_model: str, sql: str) -> Iterator[str]:
arguments = [i[1:] for i in re.findall(r'(?::)\w+', sql)]
# Put the sql in a big global variable
yield f'{function_name}_sql_query = """'
yield re.sub(r':\w+', lambda m: f'{{{m[0][1:]}}}', sql)
yield '"""'
yield ''
# Function signature
yield f'async def {function_name}('
yield ' db: foxfeed.database.Database,'
if arguments:
yield ' *,'
for i in sorted(set(arguments)):
yield f' {i}: Arg,'
yield f') -> List[foxfeed.database.{output_model}]:'
# Function body
yield f' query = {function_name}_sql_query.format('
for i in sorted(set(arguments)):
yield f' {i} = escape({i}),'
yield ' )'
yield f' result = await db.query_raw(query, model=foxfeed.database.{output_model}) # type: ignore'
yield ' return result'
yield ''
def codegen() -> Iterator[str]:
yield HEADDER
for name, output_model in INPUT:
with open('./sql/' + name + '.sql') as f:
yield from codegen_for_query(name, output_model, f.read())
if __name__ == '__main__':
try:
shutil.rmtree(OUT_DIR + '__pycache__')
except FileNotFoundError:
pass
code = '\n'.join(codegen())
with open(OUT_DIR + OUT_FILE, 'w') as f:
f.write(code)