Skip to content

Commit

Permalink
feat: enable bulk add and remove users
Browse files Browse the repository at this point in the history
  • Loading branch information
jorwoods committed Jun 29, 2024
1 parent 7a8e54e commit a535e46
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 15 deletions.
55 changes: 41 additions & 14 deletions tableauserverclient/server/endpoint/groups_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
import logging

from .endpoint import QuerysetEndpoint, api
from .exceptions import MissingRequiredFieldError
from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api
from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError
from tableauserverclient.server import RequestFactory
from tableauserverclient.models import GroupItem, UserItem, PaginationItem, JobItem
from ..pager import Pager
from tableauserverclient.server.pager import Pager

from tableauserverclient.helpers.logging import logger

from typing import List, Optional, TYPE_CHECKING, Tuple, Union
from typing import Iterable, List, Optional, TYPE_CHECKING, Tuple, Union

if TYPE_CHECKING:
from ..request_options import RequestOptions
from tableauserverclient.server.request_options import RequestOptions


class Groups(QuerysetEndpoint[GroupItem]):
@property
def baseurl(self) -> str:
return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id)

# Gets all groups
@api(version="2.0")
def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[GroupItem], PaginationItem]:
"""Gets all groups"""
logger.info("Querying all groups on site")
url = self.baseurl
server_response = self.get_request(url, req_options)
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
all_group_items = GroupItem.from_response(server_response.content, self.parent_srv.namespace)
return all_group_items, pagination_item

# Gets all users in a given group
@api(version="2.0")
def populate_users(self, group_item, req_options: Optional["RequestOptions"] = None) -> None:
def populate_users(self, group_item: GroupItem, req_options: Optional["RequestOptions"] = None) -> None:
"""Gets all users in a given group"""
if not group_item.id:
error = "Group item missing ID. Group must be retrieved from server first."
raise MissingRequiredFieldError(error)
Expand All @@ -47,7 +47,7 @@ def user_pager():
group_item._set_users(user_pager)

def _get_users_for_group(
self, group_item, req_options: Optional["RequestOptions"] = None
self, group_item: GroupItem, req_options: Optional["RequestOptions"] = None
) -> Tuple[List[UserItem], PaginationItem]:
url = "{0}/{1}/users".format(self.baseurl, group_item.id)
server_response = self.get_request(url, req_options)
Expand All @@ -56,9 +56,9 @@ def _get_users_for_group(
logger.info("Populated users for group (ID: {0})".format(group_item.id))
return user_item, pagination_item

# Deletes 1 group by id
@api(version="2.0")
def delete(self, group_id: str) -> None:
"""Deletes 1 group by id"""
if not group_id:
error = "Group ID undefined."
raise ValueError(error)
Expand Down Expand Up @@ -87,17 +87,17 @@ def update(self, group_item: GroupItem, as_job: bool = False) -> Union[GroupItem
else:
return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0]

# Create a 'local' Tableau group
@api(version="2.0")
def create(self, group_item: GroupItem) -> GroupItem:
"""Create a 'local' Tableau group"""
url = self.baseurl
create_req = RequestFactory.Group.create_local_req(group_item)
server_response = self.post_request(url, create_req)
return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0]

# Create a group based on Active Directory
@api(version="2.0")
def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[GroupItem, JobItem]:
"""Create a group based on Active Directory"""
asJobparameter = "?asJob=true" if asJob else ""
url = self.baseurl + asJobparameter
create_req = RequestFactory.Group.create_ad_req(group_item)
Expand All @@ -107,9 +107,9 @@ def create_AD_group(self, group_item: GroupItem, asJob: bool = False) -> Union[G
else:
return GroupItem.from_response(server_response.content, self.parent_srv.namespace)[0]

# Removes 1 user from 1 group
@api(version="2.0")
def remove_user(self, group_item: GroupItem, user_id: str) -> None:
"""Removes 1 user from 1 group"""
if not group_item.id:
error = "Group item missing ID."
raise MissingRequiredFieldError(error)
Expand All @@ -120,9 +120,22 @@ def remove_user(self, group_item: GroupItem, user_id: str) -> None:
self.delete_request(url)
logger.info("Removed user (id: {0}) from group (ID: {1})".format(user_id, group_item.id))

# Adds 1 user to 1 group
@api(version="3.21")
def remove_users(self, group_item: GroupItem, users: Iterable[Union[str, UserItem]]) -> None:
"""Removes multiple users from 1 group"""
group_id = group_item.id if hasattr(group_item, "id") else group_item
if not isinstance(group_id, str):
raise ValueError(f"Invalid group provided: {group_item}")

url = f"{self.baseurl}/{group_id}/users/remove"
add_req = RequestFactory.Group.remove_users_req(users)
_ = self.put_request(url, add_req)
logger.info("Removed users to group (ID: {0})".format(group_item.id))
return None

@api(version="2.0")
def add_user(self, group_item: GroupItem, user_id: str) -> UserItem:
"""Adds 1 user to 1 group"""
if not group_item.id:
error = "Group item missing ID."
raise MissingRequiredFieldError(error)
Expand All @@ -135,3 +148,17 @@ def add_user(self, group_item: GroupItem, user_id: str) -> UserItem:
user = UserItem.from_response(server_response.content, self.parent_srv.namespace).pop()
logger.info("Added user (id: {0}) to group (ID: {1})".format(user_id, group_item.id))
return user

@api(version="3.21")
def add_users(self, group_item: GroupItem, users: Iterable[Union[str, UserItem]]) -> List[UserItem]:
"""Adds multiple users to 1 group"""
group_id = group_item.id if hasattr(group_item, "id") else group_item
if not isinstance(group_id, str):
raise ValueError(f"Invalid group provided: {group_item}")

url = f"{self.baseurl}/{group_id}/users"
add_req = RequestFactory.Group.add_users_req(users)
server_response = self.post_request(url, add_req)
users = UserItem.from_response(server_response.content, self.parent_srv.namespace)
logger.info("Added users to group (ID: {0})".format(group_item.id))
return users
24 changes: 23 additions & 1 deletion tableauserverclient/server/request_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xml.etree.ElementTree as ET
from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union

from requests.packages.urllib3.fields import RequestField
from requests.packages.urllib3.filepost import encode_multipart_formdata
Expand Down Expand Up @@ -387,6 +387,28 @@ def add_user_req(self, user_id: str) -> bytes:
user_element.attrib["id"] = user_id
return ET.tostring(xml_request)

@_tsrequest_wrapped
def add_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes:
users_element = ET.SubElement(xml_request, "users")
for user in users:
user_element = ET.SubElement(users_element, "user")
if not (user_id := user.id if isinstance(user, UserItem) else user):
raise ValueError("User ID must be populated")
user_element.attrib["id"] = user_id

return ET.tostring(xml_request)

@_tsrequest_wrapped
def remove_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes:
users_element = ET.SubElement(xml_request, "users")
for user in users:
user_element = ET.SubElement(users_element, "user")
if not (user_id := user.id if isinstance(user, UserItem) else user):
raise ValueError("User ID must be populated")
user_element.attrib["id"] = user_id

return ET.tostring(xml_request)

def create_local_req(self, group_item: GroupItem) -> bytes:
xml_request = ET.Element("tsRequest")
group_element = ET.SubElement(xml_request, "group")
Expand Down
8 changes: 8 additions & 0 deletions test/assets/group_add_users.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<?xml version='1.0' encoding='UTF-8'?>
<tsResponse xmlns="http://tableau.com/api" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://tableau.com/api http://tableau.com/api/ts-api-2.3.xsd">
<users>
<user id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7" name="Alice" siteRole="ServerAdministrator" />
<user id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8" name="Bob" siteRole="Explorer" />
<user id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8" name="Charlie" siteRole="Viewer" />
</users>
</tsResponse>
49 changes: 49 additions & 0 deletions test/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
POPULATE_USERS = os.path.join(TEST_ASSET_DIR, "group_populate_users.xml")
POPULATE_USERS_EMPTY = os.path.join(TEST_ASSET_DIR, "group_populate_users_empty.xml")
ADD_USER = os.path.join(TEST_ASSET_DIR, "group_add_user.xml")
ADD_USERS = TEST_ASSET_DIR / "group_add_users.xml"
ADD_USER_POPULATE = os.path.join(TEST_ASSET_DIR, "group_users_added.xml")
CREATE_GROUP = os.path.join(TEST_ASSET_DIR, "group_create.xml")
CREATE_GROUP_AD = os.path.join(TEST_ASSET_DIR, "group_create_ad.xml")
Expand Down Expand Up @@ -123,6 +124,54 @@ def test_add_user(self) -> None:
self.assertEqual("testuser", user.name)
self.assertEqual("ServerAdministrator", user.site_role)

def test_add_users(self) -> None:
self.server.version = "3.21"
self.baseurl = self.server.groups.baseurl

def make_user(id: str, name: str, siteRole: str) -> TSC.UserItem:
user = TSC.UserItem(name, siteRole)
user._id = id
return user

users = [
make_user(id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7", name="Alice", siteRole="ServerAdministrator"),
make_user(id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8", name="Bob", siteRole="Explorer"),
make_user(id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8", name="Charlie", siteRole="Viewer"),
]
group = TSC.GroupItem("test")
group._id = "e7833b48-c6f7-47b5-a2a7-36e7dd232758"

with requests_mock.mock() as m:
m.post(f"{self.baseurl}/{group.id}/users", text=ADD_USERS.read_text())
resp_users = self.server.groups.add_users(group, users)

for user, resp_user in zip(users, resp_users):
with self.subTest(user=user, resp_user=resp_user):
assert user.id == resp_user.id
assert user.name == resp_user.name
assert user.site_role == resp_user.site_role

def test_remove_users(self) -> None:
self.server.version = "3.21"
self.baseurl = self.server.groups.baseurl

def make_user(id: str, name: str, siteRole: str) -> TSC.UserItem:
user = TSC.UserItem(name, siteRole)
user._id = id
return user

users = [
make_user(id="5de011f8-4aa9-4d5b-b991-f464c8dd6bb7", name="Alice", siteRole="ServerAdministrator"),
make_user(id="5de011f8-3aa9-4d5b-b991-f467c8dd6bb8", name="Bob", siteRole="Explorer"),
make_user(id="5de011f8-2aa9-4d5b-b991-f466c8dd6bb8", name="Charlie", siteRole="Viewer"),
]
group = TSC.GroupItem("test")
group._id = "e7833b48-c6f7-47b5-a2a7-36e7dd232758"

with requests_mock.mock() as m:
m.put(f"{self.baseurl}/{group.id}/users/remove")
self.server.groups.remove_users(group, users)

def test_add_user_before_populating(self) -> None:
with open(GET_XML, "rb") as f:
get_xml_response = f.read().decode("utf-8")
Expand Down

0 comments on commit a535e46

Please sign in to comment.