Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default user and show warning message on startup when no users found #2481

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/argilla/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

API_KEY_HEADER_NAME = "X-Argilla-Api-Key"
WORKSPACE_HEADER_NAME = "X-Argilla-Workspace"

DEFAULT_USERNAME = "argilla"
DEFAULT_PASSWORD = "1234"
DEFAULT_API_KEY = "argilla.apikey" # Keep the same api key for now

# TODO: This constant will be drop out with issue
Expand Down
8 changes: 4 additions & 4 deletions src/argilla/server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from passlib.context import CryptContext
from sqlalchemy.orm import Session

_CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")


def get_workspace_user_by_workspace_id_and_user_id(db: Session, workspace_id: UUID, user_id: UUID):
Expand Down Expand Up @@ -101,7 +101,7 @@ def create_user(db: Session, user_create: UserCreate):
last_name=user_create.last_name,
username=user_create.username,
role=user_create.role,
password_hash=_CRYPT_CONTEXT.hash(user_create.password),
password_hash=CRYPT_CONTEXT.hash(user_create.password),
)

db.add(user)
Expand All @@ -121,9 +121,9 @@ def delete_user(db: Session, user: User):
def authenticate_user(db: Session, username: str, password: str):
user = get_user_by_username(db, username)

if user and _CRYPT_CONTEXT.verify(password, user.password_hash):
if user and CRYPT_CONTEXT.verify(password, user.password_hash):
return user
elif user:
return
else:
_CRYPT_CONTEXT.dummy_verify()
CRYPT_CONTEXT.dummy_verify()
4 changes: 0 additions & 4 deletions src/argilla/server/security/auth_provider/local/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from pydantic import BaseSettings

from argilla._constants import DEFAULT_API_KEY
from argilla.server import helpers
from argilla.server.settings import settings as server_settings

Expand All @@ -41,9 +40,6 @@ class Settings(BaseSettings):
algorithm: str = "HS256"
token_expiration_in_minutes: int = 30000
token_api_url: str = "/api/security/token"

default_apikey: str = DEFAULT_API_KEY
default_password: str = "$2y$12$MPcRR71ByqgSI8AaqgxrMeSdrD4BcxDIdYkr.ePQoKz7wsGK7SAca" # 1234
users_db_file: str = ".users.yml"

@property
Expand Down
50 changes: 49 additions & 1 deletion src/argilla/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
This module configures the global fastapi application
"""
import contextlib
import glob
import inspect
import logging
Expand All @@ -30,19 +31,24 @@
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pydantic import ConfigError
from sqlalchemy.orm import Session

from argilla import __version__ as argilla_version
from argilla._constants import DEFAULT_API_KEY, DEFAULT_PASSWORD, DEFAULT_USERNAME
from argilla.logging import configure_logging
from argilla.server import helpers
from argilla.server.contexts import accounts
from argilla.server.daos.backend import GenericElasticEngineBackend
from argilla.server.daos.backend.base import GenericSearchError
from argilla.server.daos.datasets import DatasetsDAO
from argilla.server.daos.records import DatasetRecordsDAO
from argilla.server.database import get_db
from argilla.server.errors import (
APIErrorHandler,
EntityNotFoundError,
UnauthorizedError,
)
from argilla.server.models import User, UserRole, Workspace
from argilla.server.routes import api_router
from argilla.server.security import auth
from argilla.server.settings import settings
Expand Down Expand Up @@ -173,7 +179,7 @@ def configure_app_logging(app: FastAPI):
app.on_event("startup")(configure_logging)


def configure_telemetry(app):
def configure_telemetry(app: FastAPI):
message = "\n"
message += inspect.cleandoc(
"""
Expand All @@ -197,6 +203,47 @@ async def check_telemetry():
print(message, flush=True)


def configure_database(app: FastAPI):
get_db_wrapper = contextlib.contextmanager(get_db)

def _create_default_user(db: Session):
user = User(
first_name="",
username=DEFAULT_USERNAME,
role=UserRole.admin,
api_key=DEFAULT_API_KEY,
password_hash=accounts.CRYPT_CONTEXT.hash(DEFAULT_PASSWORD),
workspaces=[Workspace(name=DEFAULT_USERNAME)],
)

db.add(user)
db.commit()
db.refresh(user)

return user

def _user_has_default_credentials(user: User):
return user.api_key == DEFAULT_API_KEY or accounts.CRYPT_CONTEXT.verify(DEFAULT_PASSWORD, user.password_hash)

def _log_default_user_warning():
_LOGGER.warning(
f"User {DEFAULT_USERNAME!r} with default credentials has been found in the database. "
"If you are using argilla in a production environment this can be a serious security problem. "
f"We recommend that you create a new admin user and then delete the default {DEFAULT_USERNAME!r} one."
)

@app.on_event("startup")
async def create_default_user_if_not_present():
with get_db_wrapper() as db:
if db.query(User).count() == 0:
_create_default_user(db)
_log_default_user_warning()
else:
default_user = accounts.get_user_by_username(db, DEFAULT_USERNAME)
if default_user and _user_has_default_credentials(default_user):
_log_default_user_warning()


argilla_app = FastAPI(
title="Argilla",
description="Argilla API",
Expand All @@ -211,6 +258,7 @@ async def check_telemetry():
app.mount(settings.base_url, argilla_app)

configure_app_logging(app)
configure_database(app)
configure_storage(app)
configure_telemetry(app)

Expand Down
12 changes: 8 additions & 4 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla._constants import API_KEY_HEADER_NAME, WORKSPACE_HEADER_NAME
from argilla.server.security.auth_provider.local.settings import settings
from argilla._constants import (
API_KEY_HEADER_NAME,
DEFAULT_API_KEY,
DEFAULT_USERNAME,
WORKSPACE_HEADER_NAME,
)
from fastapi import FastAPI
from starlette.testclient import TestClient

Expand All @@ -22,8 +26,8 @@ class SecuredClient:
def __init__(self, client: TestClient):
self._client = client
self._header = {
API_KEY_HEADER_NAME: settings.default_apikey,
WORKSPACE_HEADER_NAME: "argilla", # Hard-coded default workspace
API_KEY_HEADER_NAME: DEFAULT_API_KEY,
WORKSPACE_HEADER_NAME: DEFAULT_USERNAME,
}
self._current_user = None

Expand Down