diff --git a/src/argilla/_constants.py b/src/argilla/_constants.py index b721488e1b..3f7af54304 100644 --- a/src/argilla/_constants.py +++ b/src/argilla/_constants.py @@ -26,4 +26,4 @@ _OLD_API_KEY_HEADER_NAME = "X-Rubrix-Api-Key" _OLD_WORKSPACE_HEADER_NAME = "X-Rubrix-Workspace" -DATASET_NAME_REGEX_PATTERN = r"^(?!-|_)[a-z0-9-_]+$" +ES_INDEX_REGEX_PATTERN = r"^(?!-|_)[a-z0-9-_]+$" diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 66003c8729..1df03aca60 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -18,15 +18,15 @@ import re import warnings from asyncio import Future -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from rich import print as rprint from rich.progress import Progress from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, - DATASET_NAME_REGEX_PATTERN, DEFAULT_API_KEY, + ES_INDEX_REGEX_PATTERN, WORKSPACE_HEADER_NAME, ) from argilla.client.apis.datasets import Datasets @@ -208,6 +208,15 @@ def set_workspace(self, workspace: str): if not workspace: raise Exception("Must provide a workspace") + if not re.match(ES_INDEX_REGEX_PATTERN, workspace): + raise InputValueError( + f"Provided workspace name {workspace} does not match the pattern" + f" {ES_INDEX_REGEX_PATTERN}. Please, use a valid name for your" + " workspace. This limitation is caused by naming conventions for indexes" + " in Elasticsearch. If applicable, you can try to lowercase the name of your workspace." + " https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html" + ) + if workspace != self.get_workspace(): if workspace == self._user.username: self._client.headers.pop(WORKSPACE_HEADER_NAME, workspace) @@ -326,10 +335,10 @@ async def log_async( if not name: raise InputValueError("Empty dataset name has been passed as argument.") - if not re.match(DATASET_NAME_REGEX_PATTERN, name): + if not re.match(ES_INDEX_REGEX_PATTERN, name): raise InputValueError( f"Provided dataset name {name} does not match the pattern" - f" {DATASET_NAME_REGEX_PATTERN}. Please, use a valid name for your" + f" {ES_INDEX_REGEX_PATTERN}. Please, use a valid name for your" " dataset. This limitation is caused by naming conventions for indexes" " in Elasticsearch." " https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html" diff --git a/src/argilla/server/apis/v0/models/commons/params.py b/src/argilla/server/apis/v0/models/commons/params.py index 4bd92aa4e2..68b5842649 100644 --- a/src/argilla/server/apis/v0/models/commons/params.py +++ b/src/argilla/server/apis/v0/models/commons/params.py @@ -18,12 +18,12 @@ from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, - DATASET_NAME_REGEX_PATTERN, + ES_INDEX_REGEX_PATTERN, WORKSPACE_HEADER_NAME, ) from argilla.server.security.model import WORKSPACE_NAME_PATTERN -DATASET_NAME_PATH_PARAM = Path(..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name") +DATASET_NAME_PATH_PARAM = Path(..., regex=ES_INDEX_REGEX_PATTERN, description="The dataset name") @dataclass diff --git a/src/argilla/server/apis/v0/models/datasets.py b/src/argilla/server/apis/v0/models/datasets.py index 7590f4ed13..65e4a2cc30 100644 --- a/src/argilla/server/apis/v0/models/datasets.py +++ b/src/argilla/server/apis/v0/models/datasets.py @@ -21,7 +21,7 @@ from pydantic import BaseModel, Field -from argilla._constants import DATASET_NAME_REGEX_PATTERN +from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.commons.models import TaskType from argilla.server.services.datasets import ServiceBaseDataset @@ -43,7 +43,7 @@ class UpdateDatasetRequest(BaseModel): class _BaseDatasetRequest(UpdateDatasetRequest): - name: str = Field(regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name") + name: str = Field(regex=ES_INDEX_REGEX_PATTERN, description="The dataset name") class CreateDatasetRequest(_BaseDatasetRequest): diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py index 03f690f777..6b3c55f6cf 100644 --- a/src/argilla/server/daos/models/datasets.py +++ b/src/argilla/server/daos/models/datasets.py @@ -17,12 +17,12 @@ from pydantic import BaseModel, Field, validator -from argilla._constants import DATASET_NAME_REGEX_PATTERN +from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.commons.models import TaskType class BaseDatasetDB(BaseModel): - name: str = Field(regex=DATASET_NAME_REGEX_PATTERN) + name: str = Field(regex=ES_INDEX_REGEX_PATTERN) task: TaskType owner: Optional[str] = Field(description="Deprecated. Use `workspace` instead. Will be removed in v1.5.0") workspace: Optional[str] = None diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py index 007f043507..5e7f77a491 100644 --- a/src/argilla/server/security/model.py +++ b/src/argilla/server/security/model.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, root_validator, validator -from argilla._constants import DATASET_NAME_REGEX_PATTERN +from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.errors import EntityNotFoundError WORKSPACE_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_\-]*$") @@ -37,9 +37,9 @@ class User(BaseModel): @validator("username") def check_username(cls, value): - if not re.compile(DATASET_NAME_REGEX_PATTERN).match(value): + if not re.compile(ES_INDEX_REGEX_PATTERN).match(value): raise ValueError( - "Wrong username. " f"The username {value} does not match the pattern {DATASET_NAME_REGEX_PATTERN}" + "Wrong username. " f"The username {value} does not match the pattern {ES_INDEX_REGEX_PATTERN}" ) return value diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 822a8e6607..d9de6616cb 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -94,6 +94,11 @@ def mock_get(*args, **kwargs): monkeypatch.setattr(users_api, "whoami", mock_get) +def test_init_uppercase_workspace(mocked_client): + with pytest.raises(InputValueError): + api.init(workspace="UPPERCASE_WORKSPACE") + + def test_init_correct(mock_response_200): """Testing correct default initialization