-
Notifications
You must be signed in to change notification settings - Fork 0
/
conftest.py
174 lines (129 loc) · 4.2 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Configuration tests module."""
import numpy as np
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from database import Base
from dependencies import (
get_current_user,
get_db_session,
get_gcp_client,
get_llm_embedding_client,
get_pinecone_index,
get_redis_client,
)
from main import create_app
from settings import get_settings
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_gcp_client():
"""Overridden gcp_client app dependency."""
class MockGCPClient:
project = "test-project"
def blob(self, name):
return self
@staticmethod
def upload_from_file(*args, **kwargs):
return None
def bucket(self, name):
return self
@staticmethod
def generate_signed_url(**kwargs):
return "signed-url"
return MockGCPClient()
def override_get_pinecone_index():
"""Overridden get_pinecone_index app dependency."""
class MockPineconeIndex:
@staticmethod
def upsert(*args, **kwargs):
return None
@staticmethod
def query(*args, **kwargs):
return {}
return MockPineconeIndex()
def override_get_llm_embedding_client():
"""Overridden get_llm_embedding_client app dependency."""
class MockOpenAIEmbeddings:
@staticmethod
async def aembed_documents(texts):
return [np.random.rand(1536)]
@staticmethod
def embed_query(text):
return np.random.rand(1536)
return MockOpenAIEmbeddings()
def override_get_redis_client():
"""Overridden get_redis_client app dependency."""
class MockRedisClient:
@staticmethod
def get(key):
return None
@staticmethod
def set(*args, **kwargs):
return None
return MockRedisClient()
def override_get_settings():
"""Overridden get_settings app dependency."""
class MockSettings:
bucket_name = "test-bucket"
token_exp_minutes = 15
jwt_secret_key = "test-secret"
jwt_algorithm = "HS256"
max_file_upload_count = 5
max_file_bytes_size = 25000000
openai_api_key = ""
db_url = ""
gcp_storage_exp_minutes = 15
pinecone_api_key = ""
embedding_chunk_size = 200
embedding_namespace = "sample-embedding-name-space"
redis_cache_exp = 15
return MockSettings()
@pytest.fixture
def db_session():
"""Database session fixture."""
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
try:
yield session
finally:
session.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture
def client(db_session):
"""
Not logged in client fixture
:param db_session: fixture dependency of mock database session
"""
def override_get_db():
yield db_session
app = create_app(disable_limiter=True)
app.dependency_overrides[get_db_session] = override_get_db
app.dependency_overrides[get_settings] = override_get_settings
with TestClient(app) as test_client:
yield test_client
@pytest.fixture
def login_client(db_session):
"""
Login client fixture.
:param db_session: fixture dependency of mock database session
"""
def override_get_db():
yield db_session
app = create_app(disable_limiter=True)
app.dependency_overrides[get_db_session] = override_get_db
app.dependency_overrides[get_current_user] = lambda: None
app.dependency_overrides[get_settings] = override_get_settings
app.dependency_overrides[get_gcp_client] = override_gcp_client
app.dependency_overrides[get_pinecone_index] = override_get_pinecone_index
app.dependency_overrides[get_redis_client] = override_get_redis_client
app.dependency_overrides[get_llm_embedding_client] = (
override_get_llm_embedding_client
)
with TestClient(app) as test_client:
yield test_client, app