From 648359ca04e454897dcffd4b17d6300ca89cc59c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 Mar 2023 12:02:29 +0100 Subject: [PATCH 1/2] Add default user and show warning message on startup when no users found --- src/argilla/_constants.py | 3 ++ src/argilla/server/contexts/accounts.py | 8 ++-- src/argilla/server/server.py | 55 ++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/argilla/_constants.py b/src/argilla/_constants.py index e1336e5148..4d480d266b 100644 --- a/src/argilla/_constants.py +++ b/src/argilla/_constants.py @@ -17,6 +17,9 @@ API_KEY_HEADER_NAME = "X-Argilla-Api-Key" WORKSPACE_HEADER_NAME = "X-Argilla-Workspace" + +DEFAULT_USERNAME = "argilla" +DEFAULT_PASSWORD = "1234" DEFAULT_API_KEY = "argilla.apikey" # Keep the same api key for now # TODO: This constant will be drop out with issue diff --git a/src/argilla/server/contexts/accounts.py b/src/argilla/server/contexts/accounts.py index b2314d21a3..c712cb3999 100644 --- a/src/argilla/server/contexts/accounts.py +++ b/src/argilla/server/contexts/accounts.py @@ -23,7 +23,7 @@ from passlib.context import CryptContext from sqlalchemy.orm import Session -_CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") +CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") def get_workspace_user_by_workspace_id_and_user_id(db: Session, workspace_id: UUID, user_id: UUID): @@ -101,7 +101,7 @@ def create_user(db: Session, user_create: UserCreate): last_name=user_create.last_name, username=user_create.username, role=user_create.role, - password_hash=_CRYPT_CONTEXT.hash(user_create.password), + password_hash=CRYPT_CONTEXT.hash(user_create.password), ) db.add(user) @@ -121,9 +121,9 @@ def delete_user(db: Session, user: User): def authenticate_user(db: Session, username: str, password: str): user = get_user_by_username(db, username) - if user and _CRYPT_CONTEXT.verify(password, user.password_hash): + if user and CRYPT_CONTEXT.verify(password, user.password_hash): return user elif user: return else: - _CRYPT_CONTEXT.dummy_verify() + CRYPT_CONTEXT.dummy_verify() diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index b8256c1a70..e966fda2d2 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -16,6 +16,7 @@ """ This module configures the global fastapi application """ +import contextlib import glob import inspect import logging @@ -30,21 +31,29 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from pydantic import ConfigError +from sqlalchemy.orm import Session from argilla import __version__ as argilla_version +from argilla._constants import DEFAULT_PASSWORD, DEFAULT_USERNAME from argilla.logging import configure_logging from argilla.server import helpers +from argilla.server.contexts import accounts from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.backend.base import GenericSearchError from argilla.server.daos.datasets import DatasetsDAO from argilla.server.daos.records import DatasetRecordsDAO +from argilla.server.database import get_db from argilla.server.errors import ( APIErrorHandler, EntityNotFoundError, UnauthorizedError, ) +from argilla.server.models import User, UserRole, Workspace from argilla.server.routes import api_router from argilla.server.security import auth +from argilla.server.security.auth_provider.local.settings import ( + settings as auth_settings, +) from argilla.server.settings import settings from argilla.server.static_rewrite import RewriteStaticFiles @@ -173,7 +182,7 @@ def configure_app_logging(app: FastAPI): app.on_event("startup")(configure_logging) -def configure_telemetry(app): +def configure_telemetry(app: FastAPI): message = "\n" message += inspect.cleandoc( """ @@ -197,6 +206,49 @@ async def check_telemetry(): print(message, flush=True) +def configure_database(app: FastAPI): + get_db_wrapper = contextlib.contextmanager(get_db) + + def _create_default_user(db: Session): + user = User( + first_name="", + username=DEFAULT_USERNAME, + role=UserRole.admin, + api_key=auth_settings.default_apikey, + password_hash=auth_settings.default_password, + workspaces=[Workspace(name=DEFAULT_USERNAME)], + ) + + db.add(user) + db.commit() + db.refresh(user) + + return user + + def _user_has_default_credentials(user: User): + return user.api_key == auth_settings.default_apikey or accounts.CRYPT_CONTEXT.verify( + DEFAULT_PASSWORD, user.password_hash + ) + + def _log_default_user_warning(): + _LOGGER.warning( + f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. " + "If you are using argilla in a production environment this can be a serious security problem. " + f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one." + ) + + @app.on_event("startup") + async def create_default_user_if_not_present(): + with get_db_wrapper() as db: + if db.query(User).count() == 0: + _create_default_user(db) + _log_default_user_warning() + else: + default_user = accounts.get_user_by_username(db, DEFAULT_USERNAME) + if default_user and _user_has_default_credentials(default_user): + _log_default_user_warning() + + argilla_app = FastAPI( title="Argilla", description="Argilla API", @@ -211,6 +263,7 @@ async def check_telemetry(): app.mount(settings.base_url, argilla_app) configure_app_logging(app) +configure_database(app) configure_storage(app) configure_telemetry(app) From 57c958b6785ca71f588256d2ec8fa3d631c63aec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 Mar 2023 14:01:45 +0100 Subject: [PATCH 2/2] Remove ARGILLA_DEFAULT_APIKEY and ARGILLA_DEFAULT_PASSWORD environment variables from settings --- .../server/security/auth_provider/local/settings.py | 4 ---- src/argilla/server/server.py | 13 ++++--------- tests/helpers.py | 4 ++-- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/argilla/server/security/auth_provider/local/settings.py b/src/argilla/server/security/auth_provider/local/settings.py index b2157e1e61..758ffc51b4 100644 --- a/src/argilla/server/security/auth_provider/local/settings.py +++ b/src/argilla/server/security/auth_provider/local/settings.py @@ -15,7 +15,6 @@ from pydantic import BaseSettings -from argilla._constants import DEFAULT_API_KEY from argilla.server import helpers from argilla.server.settings import settings as server_settings @@ -41,9 +40,6 @@ class Settings(BaseSettings): algorithm: str = "HS256" token_expiration_in_minutes: int = 30000 token_api_url: str = "/api/security/token" - - default_apikey: str = DEFAULT_API_KEY - default_password: str = "$2y$12$MPcRR71ByqgSI8AaqgxrMeSdrD4BcxDIdYkr.ePQoKz7wsGK7SAca" # 1234 users_db_file: str = ".users.yml" @property diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index e966fda2d2..63bd5de1d3 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -34,7 +34,7 @@ from sqlalchemy.orm import Session from argilla import __version__ as argilla_version -from argilla._constants import DEFAULT_PASSWORD, DEFAULT_USERNAME +from argilla._constants import DEFAULT_API_KEY, DEFAULT_PASSWORD, DEFAULT_USERNAME from argilla.logging import configure_logging from argilla.server import helpers from argilla.server.contexts import accounts @@ -51,9 +51,6 @@ from argilla.server.models import User, UserRole, Workspace from argilla.server.routes import api_router from argilla.server.security import auth -from argilla.server.security.auth_provider.local.settings import ( - settings as auth_settings, -) from argilla.server.settings import settings from argilla.server.static_rewrite import RewriteStaticFiles @@ -214,8 +211,8 @@ def _create_default_user(db: Session): first_name="", username=DEFAULT_USERNAME, role=UserRole.admin, - api_key=auth_settings.default_apikey, - password_hash=auth_settings.default_password, + api_key=DEFAULT_API_KEY, + password_hash=accounts.CRYPT_CONTEXT.hash(DEFAULT_PASSWORD), workspaces=[Workspace(name=DEFAULT_USERNAME)], ) @@ -226,9 +223,7 @@ def _create_default_user(db: Session): return user def _user_has_default_credentials(user: User): - return user.api_key == auth_settings.default_apikey or accounts.CRYPT_CONTEXT.verify( - DEFAULT_PASSWORD, user.password_hash - ) + return user.api_key == DEFAULT_API_KEY or accounts.CRYPT_CONTEXT.verify(DEFAULT_PASSWORD, user.password_hash) def _log_default_user_warning(): _LOGGER.warning( diff --git a/tests/helpers.py b/tests/helpers.py index 92d12b88ed..a4e8d1420f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla._constants import API_KEY_HEADER_NAME +from argilla._constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY from argilla.server.security.auth_provider.local.settings import settings from fastapi import FastAPI from starlette.testclient import TestClient @@ -21,7 +21,7 @@ class SecuredClient: def __init__(self, client: TestClient): self._client = client - self._header = {API_KEY_HEADER_NAME: settings.default_apikey} + self._header = {API_KEY_HEADER_NAME: DEFAULT_API_KEY} self._current_user = None def update_api_key(self, api_key):