Skip to content

Commit

Permalink
feat: add workspace_id param to GET /api/v1/me/datasets (#3727)
Browse files Browse the repository at this point in the history
# Description

This PR adds the `workspace_id` param to `GET /api/v1/me/datasets` so
that the workspace filtering when listing `FeedbackTask` datasets is
applied in the API-side, as well as making sure that no local filters
are applied e.g. `FeedbackDataset.list(workspace=...)`

Closes #3726

**Type of change**

- [X] New feature (non-breaking change which adds functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

- [x] Add tests for `GET /api/v1/me/datasets` using the `workspace_id`
param, including also the updated policies for non-owner users
- [x] Add tests for `list_datasets` in the Python SDK using the
`workspace_id` arg
- [x] Add tests for `FeedbackDataset.list` in the Python client using
the `workspace` arg

**Checklist**

- [X] I added relevant documentation
- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
alvarobartt and gabrielmbmb authored Sep 7, 2023
1 parent 4be9294 commit fb92ca2
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 25 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ These are the section headers that we use:
- Added `created_at` and `updated_at` properties to `RemoteFeedbackDataset` and `FilteredRemoteFeedbackDataset` ([#3709](https://github.com/argilla-io/argilla/pull/3709)).
- Added handling `PermissionError` when executing a command with a logged in user with not enough permissions ([#3717](https://github.com/argilla-io/argilla/pull/3717)).
- Added `workspaces add-user` command to add a user to workspace ([#3712](https://github.com/argilla-io/argilla/pull/3712)).
- Added `workspace_id` param to `GET /api/v1/me/datasets` endpoint ([#3727](https://github.com/argilla-io/argilla/pull/3727)).
- Added `workspace_id` arg to `list_datasets` in the Python SDK ([#3727](https://github.com/argilla-io/argilla/pull/3727)).

### Changed

Expand Down
10 changes: 6 additions & 4 deletions src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,17 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[
if workspace is not None:
workspace = Workspace.from_name(workspace)

# TODO(alvarobartt or gabrielmbmb): add `workspace_id` in `GET /api/v1/datasets`
# and in `GET /api/v1/me/datasets` to filter by workspace
try:
datasets = datasets_api_v1.list_datasets(client=httpx_client).parsed
datasets = datasets_api_v1.list_datasets(
client=httpx_client, workspace_id=workspace.id if workspace is not None else None
).parsed
except Exception as e:
raise RuntimeError(
f"Failed while listing the `FeedbackDataset` datasets in Argilla with exception: {e}"
) from e

if len(datasets) == 0:
return []
return [
RemoteFeedbackDataset(
client=httpx_client,
Expand All @@ -335,5 +338,4 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[
guidelines=dataset.guidelines or None,
)
for dataset in datasets
if workspace is None or dataset.workspace_id == workspace.id
]
17 changes: 13 additions & 4 deletions src/argilla/client/sdk/v1/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,28 @@ def publish_dataset(

def list_datasets(
client: httpx.Client,
) -> Response[Union[List[FeedbackDatasetModel], ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets` endpoint to retrieve a list of `FeedbackTask` datasets.
workspace_id: Optional[UUID] = None,
) -> Response[Union[list, List[FeedbackDatasetModel], ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/me/datasets` endpoint to retrieve a list of
`FeedbackTask` datasets filtered by `workspace_id` if applicable.
Args:
client: the authenticated Argilla client to be used to send the request to the API.
workspace_id: the id of the workspace to filter the datasets by. Note that the user
should either be owner or have access to the workspace. Defaults to None.
Returns:
A `Response` object containing a `parsed` attribute with the parsed response if the
request was successful, which is a list of `FeedbackDatasetModel`.
request was successful, which is a list of `FeedbackDatasetModel` if any, otherwise
it will contain an empty list.
"""
url = "/api/v1/me/datasets"

response = client.get(url=url)
params = {}
if workspace_id is not None:
params["workspace_id"] = str(workspace_id)

response = client.get(url=url, params=params)

if response.status_code == 200:
response_obj = Response.from_httpx_response(response)
Expand Down
26 changes: 16 additions & 10 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Optional
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
Expand Down Expand Up @@ -69,16 +69,22 @@ async def _get_dataset(

@router.get("/me/datasets", response_model=Datasets)
async def list_current_user_datasets(
*, db: AsyncSession = Depends(get_async_db), current_user: User = Security(auth.get_current_user)
*,
db: AsyncSession = Depends(get_async_db),
workspace_id: Optional[UUID] = None,
current_user: User = Security(auth.get_current_user),
):
await authorize(current_user, DatasetPolicyV1.list)

if current_user.is_owner:
dataset_list = await datasets.list_datasets(db)
return Datasets(items=dataset_list)

await current_user.awaitable_attrs.datasets
return Datasets(items=current_user.datasets)
await authorize(current_user, DatasetPolicyV1.list(workspace_id))

if not workspace_id:
if current_user.is_owner:
dataset_list = await datasets.list_datasets(db)
else:
await current_user.awaitable_attrs.datasets
dataset_list = current_user.datasets
else:
dataset_list = await datasets.list_datasets_by_workspace_id(db, workspace_id)
return Datasets(items=dataset_list)


@router.get("/datasets/{dataset_id}/fields", response_model=Fields)
Expand Down
11 changes: 8 additions & 3 deletions src/argilla/server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Awaitable, Callable
from typing import Awaitable, Callable, Optional
from uuid import UUID

from sqlalchemy.ext.asyncio import async_object_session
Expand Down Expand Up @@ -210,8 +210,13 @@ async def is_allowed(actor: User) -> bool:

class DatasetPolicyV1:
@classmethod
async def list(cls, actor: User) -> bool:
return True
def list(cls, workspace_id: Optional[UUID] = None) -> PolicyAction:
async def is_allowed(actor: User) -> bool:
if actor.is_owner or workspace_id is None:
return True
return await _exists_workspace_user_by_user_and_workspace_id(actor, workspace_id)

return is_allowed

@classmethod
def get(cls, dataset: Dataset) -> PolicyAction:
Expand Down
34 changes: 33 additions & 1 deletion tests/integration/client/sdk/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)


@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner, UserRole.annotator])
@pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin, UserRole.annotator])
@pytest.mark.asyncio
async def test_list_datasets(role: UserRole) -> None:
dataset = await DatasetFactory.create()
Expand All @@ -65,6 +65,38 @@ async def test_list_datasets(role: UserRole) -> None:
assert isinstance(response.parsed[0], FeedbackDatasetModel)


@pytest.mark.parametrize(
"role, with_workspace, expected_length",
[
(UserRole.owner, False, 2),
(UserRole.owner, True, 1),
(UserRole.admin, False, 0),
(UserRole.admin, True, 1),
(UserRole.annotator, False, 0),
(UserRole.annotator, True, 1),
],
)
@pytest.mark.asyncio
async def test_list_datasets_by_workspace_id(role: UserRole, with_workspace: bool, expected_length: int) -> None:
workspace = await WorkspaceFactory.create()
dataset = await DatasetFactory.create(workspace=workspace)
user = await UserFactory.create(role=role, workspaces=[dataset.workspace] if with_workspace else [])

another_workspace = await WorkspaceFactory.create()
await DatasetFactory.create(workspace=another_workspace)

api = Argilla(api_key=user.api_key)
response = list_datasets(
client=api.client.httpx, workspace_id=str(dataset.workspace.id) if with_workspace else None
)

assert response.status_code == 200
assert isinstance(response.parsed, list)
assert len(response.parsed) == expected_length
if expected_length > 0:
assert isinstance(response.parsed[0], FeedbackDatasetModel)


@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner, UserRole.annotator])
@pytest.mark.asyncio
async def test_get_datasets(role: UserRole) -> None:
Expand Down
31 changes: 28 additions & 3 deletions tests/unit/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@

@pytest.mark.asyncio
class TestSuiteDatasets:
async def test_list_current_user_datasets(self, async_client: "AsyncClient", owner_auth_header: dict):
async def test_list_current_user_datasets(self, async_client: "AsyncClient", owner_auth_header: dict) -> None:
dataset_a = await DatasetFactory.create(name="dataset-a")
dataset_b = await DatasetFactory.create(name="dataset-b", guidelines="guidelines")
dataset_c = await DatasetFactory.create(name="dataset-c", status=DatasetStatus.ready)
Expand Down Expand Up @@ -125,15 +125,15 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own
]
}

async def test_list_current_user_datasets_without_authentication(self, async_client: "AsyncClient"):
async def test_list_current_user_datasets_without_authentication(self, async_client: "AsyncClient") -> None:
response = await async_client.get("/api/v1/me/datasets")

assert response.status_code == 401

@pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin])
async def test_list_current_user_datasets_as_restricted_user_role(
self, async_client: "AsyncClient", role: UserRole
):
) -> None:
workspace = await WorkspaceFactory.create()
user = await UserFactory.create(workspaces=[workspace], role=role)

Expand All @@ -148,6 +148,31 @@ async def test_list_current_user_datasets_as_restricted_user_role(
response_body = response.json()
assert [dataset["name"] for dataset in response_body["items"]] == ["dataset-a", "dataset-b"]

@pytest.mark.parametrize("role", [UserRole.owner, UserRole.annotator, UserRole.admin])
async def test_list_current_user_datasets_by_workspace_id(
self, async_client: "AsyncClient", role: UserRole
) -> None:
workspace = await WorkspaceFactory.create()
another_workspace = await WorkspaceFactory.create()

user = (
await UserFactory.create(role=role)
if role == UserRole.owner
else await UserFactory.create(workspaces=[workspace], role=role)
)

await DatasetFactory.create(name="dataset-a", workspace=workspace)
await DatasetFactory.create(name="dataset-b", workspace=another_workspace)

response = await async_client.get(
"/api/v1/me/datasets", params={"workspace_id": workspace.id}, headers={API_KEY_HEADER_NAME: user.api_key}
)

assert response.status_code == 200

response_body = response.json()
assert [dataset["name"] for dataset in response_body["items"]] == ["dataset-a"]

async def test_list_dataset_fields(self, async_client: "AsyncClient", owner_auth_header: dict):
dataset = await DatasetFactory.create()
text_field_a = await TextFieldFactory.create(
Expand Down

0 comments on commit fb92ca2

Please sign in to comment.