Skip to content

Commit

Permalink
refactor: use mixin class for tagging endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
jorwoods committed Aug 2, 2024
1 parent 3759248 commit 9897eed
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 47 deletions.
61 changes: 58 additions & 3 deletions tableauserverclient/server/endpoint/resource_tagger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import abc
import copy
from typing import Generic, Iterable, Set, TypeVar, Union
import urllib.parse

from .endpoint import Endpoint
from .exceptions import ServerResponseError
from ..exceptions import EndpointUnavailableError
from tableauserverclient.server.endpoint.endpoint import Endpoint
from tableauserverclient.server.endpoint.exceptions import ServerResponseError
from tableauserverclient.server.exceptions import EndpointUnavailableError
from tableauserverclient.server import RequestFactory
from tableauserverclient.models import TagItem

Expand Down Expand Up @@ -49,3 +51,56 @@ def update_tags(self, baseurl, resource_item):
resource_item.tags = self._add_tags(baseurl, resource_item.id, add_set)
resource_item._initial_tags = copy.copy(resource_item.tags)
logger.info("Updated tags to {0}".format(resource_item.tags))


T = TypeVar("T")


class TaggingMixin(Generic[T]):
@abc.abstractmethod
def baseurl(self) -> str:
raise NotImplementedError("baseurl must be implemented.")

def add_tags(self, item: Union[T, str], tags: Union[Iterable[str], str]) -> Set[str]:
item_id = getattr(item, "id", item)

if not isinstance(item_id, str):
raise ValueError("ID not found.")

if isinstance(tags, str):
tag_set = set([tags])
else:
tag_set = set(tags)

url = f"{self.baseurl}/{item_id}/tags"
add_req = RequestFactory.Tag.add_req(tag_set)
server_response = self.put_request(url, add_req)
return TagItem.from_response(server_response.content, self.parent_srv.namespace)

def delete_tags(self, item: Union[T, str], tags: Union[Iterable[str], str]) -> None:
item_id = getattr(item, "id", item)

if not isinstance(item_id, str):
raise ValueError("ID not found.")

if isinstance(tags, str):
tag_set = set([tags])
else:
tag_set = set(tags)

for tag in tag_set:
encoded_tag_name = urllib.parse.quote(tag)
url = f"{self.baseurl}/{item_id}/tags/{encoded_tag_name}"
self.delete_request(url)

def update_tags(self, item: T) -> None:
if item.tags == item._initial_tags:
return

add_set = item.tags - item._initial_tags
remove_set = item._initial_tags - item.tags
self.delete_tags(item, remove_set)
if add_set:
item.tags = self.add_tags(item, add_set)
item._initial_tags = copy.copy(item.tags)
logger.info(f"Updated tags to {item.tags}")
34 changes: 5 additions & 29 deletions tableauserverclient/server/endpoint/workbooks_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api, parameter_added_in
from tableauserverclient.server.endpoint.exceptions import InternalServerError, MissingRequiredFieldError
from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint
from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger
from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger, TaggingMixin

from tableauserverclient.filesys_helpers import (
to_filename,
Expand Down Expand Up @@ -58,7 +58,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Workbooks(QuerysetEndpoint[WorkbookItem]):
class Workbooks(QuerysetEndpoint[WorkbookItem], TaggingMixin[WorkbookItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Workbooks, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down Expand Up @@ -501,31 +501,7 @@ def schedule_extract_refresh(
) -> List["AddResponse"]: # actually should return a task
return self.parent_srv.schedules.add_to_schedule(schedule_id, workbook=item)

@api(version="1.0")
def add_tags(self, workbook: Union[WorkbookItem, str], tags: Union[Iterable[str], str]) -> Set[str]:
workbook = getattr(workbook, "id", workbook)

if not isinstance(workbook, str):
raise ValueError("Workbook ID not found.")

if isinstance(tags, str):
tag_set = set([tags])
else:
tag_set = set(tags)

return self._resource_tagger._add_tags(self.baseurl, workbook, tag_set)

@api(version="1.0")
def delete_tags(self, workbook: Union[WorkbookItem, str], tags: Union[Iterable[str], str]) -> None:
workbook = getattr(workbook, "id", workbook)

if not isinstance(workbook, str):
raise ValueError("Workbook ID not found.")

if isinstance(tags, str):
tag_set = set([tags])
else:
tag_set = set(tags)

for tag in tag_set:
self._resource_tagger._delete_tag(self.baseurl, workbook, tag)
Workbooks.add_tags = api(version="1.0")(Workbooks.add_tags)
Workbooks.delete_tags = api(version="1.0")(Workbooks.delete_tags)
Workbooks.update_tags = api(version="1.0")(Workbooks.update_tags)
6 changes: 0 additions & 6 deletions test/assets/workbook_add_tag.xml

This file was deleted.

9 changes: 0 additions & 9 deletions test/assets/workbook_add_tags.xml

This file was deleted.

102 changes: 102 additions & 0 deletions test/test_tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import re
from typing import Iterable
from xml.etree import ElementTree as ET

import pytest
import requests_mock
import tableauserverclient as TSC


@pytest.fixture
def get_server() -> TSC.Server:
server = TSC.Server("http://test", False)

# Fake sign in
server._site_id = "dad65087-b08b-4603-af4e-2887b8aafc67"
server._auth_token = "j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM"
server.version = "3.28"
return server


def xml_response_factory(tags: Iterable[str]) -> str:
root = ET.Element("tsResponse")
tags_element = ET.SubElement(root, "tags")
for tag in tags:
tag_element = ET.SubElement(tags_element, "tag")
tag_element.attrib["label"] = tag
root.attrib["xmlns"] = "http://tableau.com/api"
return ET.tostring(root, encoding="utf-8").decode("utf-8")


def make_workbook() -> TSC.WorkbookItem:
workbook = TSC.WorkbookItem("project", "test")
workbook._id = "06b944d2-959d-4604-9305-12323c95e70e"
return workbook


@pytest.mark.parametrize(
"endpoint_type, item",
[
("workbooks", make_workbook()),
],
)
@pytest.mark.parametrize(
"tags",
[
"a",
["a", "b"],
],
)
def test_add_tags(get_server, endpoint_type, item, tags) -> None:
add_tags_xml = xml_response_factory(tags)
endpoint = getattr(get_server, endpoint_type)
id_ = getattr(item, "id", item)

with requests_mock.mock() as m:
m.put(
f"{endpoint.baseurl}/{id_}/tags",
status_code=200,
text=add_tags_xml,
)
tag_result = endpoint.add_tags(item, tags)

if isinstance(tags, str):
tags = [tags]
assert set(tag_result) == set(tags)


@pytest.mark.parametrize(
"endpoint_type, item",
[
("workbooks", make_workbook()),
],
)
@pytest.mark.parametrize(
"tags",
[
"a",
["a", "b"],
],
)
def test_delete_tags(get_server, endpoint_type, item, tags) -> None:
add_tags_xml = xml_response_factory(tags)
endpoint = getattr(get_server, endpoint_type)
id_ = getattr(item, "id", item)

if isinstance(tags, str):
tags = [tags]
tag_paths = "|".join(tags)
tag_paths = f"({tag_paths})"
matcher = re.compile(rf"{endpoint.baseurl}\/{id_}\/tags\/{tag_paths}")
with requests_mock.mock() as m:
m.delete(
matcher,
status_code=200,
text=add_tags_xml,
)
endpoint.delete_tags(item, tags)
history = m.request_history

assert len(history) == len(tags)
urls = sorted([r.url.split("/")[-1] for r in history])
assert set(urls) == set(tags)

0 comments on commit 9897eed

Please sign in to comment.