Skip to content

Commit

Permalink
feat(aws): add tags to Global Accelerator (#5233)
Browse files Browse the repository at this point in the history
  • Loading branch information
puchy22 authored Sep 27, 2024
1 parent b402ced commit 13e40eb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional

from pydantic import BaseModel

from prowler.lib.logger import logger
from prowler.lib.scan_filters.scan_filters import is_resource_filtered
from prowler.providers.aws.lib.service.service import AWSService


################### GlobalAccelerator
class GlobalAccelerator(AWSService):
def __init__(self, provider):
# Call AWSService's __init__
Expand All @@ -18,6 +19,7 @@ def __init__(self, provider):
self.region = "us-west-2"
self.client = self.session.client(self.service, self.region)
self._list_accelerators()
self.__threading_call__(self._list_tags, self.accelerators.values())

def _list_accelerators(self):
logger.info("GlobalAccelerator - Listing Accelerators...")
Expand Down Expand Up @@ -46,9 +48,23 @@ def _list_accelerators(self):
f"{self.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

def _list_tags(self, resource: any):
try:
resource.tags = (
self.regional_clients[resource.region]
.list_tags_for_resource(ResourceArn=resource.arn)
.get("Tags", [])
)

except Exception as error:
logger.error(
f"{resource.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)


class Accelerator(BaseModel):
arn: str
name: str
region: str
enabled: bool
tags: Optional[list]
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def execute(self):
report.region = shield_client.region
report.resource_id = accelerator.name
report.resource_arn = accelerator.arn
report.resource_tags = accelerator.tags
report.status = "FAIL"
report.status_extended = f"Global Accelerator {accelerator.name} is not protected by AWS Shield Advanced."

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def mock_make_api_call(self, operation_name, kwarg):
}
if operation_name == "GetSubscriptionState":
return {"SubscriptionState": "ACTIVE"}
if operation_name == "ListTagsForResource":
return {"Tags": [{"Key": "Name", "Value": "TestAccelerator"}]}

return make_api_call(self, operation_name, kwarg)

Expand Down Expand Up @@ -94,3 +96,18 @@ def test_list_accelerators(self):
== AWS_REGION_US_WEST_2
)
assert globalaccelerator.accelerators[TEST_ACCELERATOR_ARN].enabled

def test_list_tags(self):
# GlobalAccelerator client for this test class
aws_provider = set_mocked_aws_provider()
globalaccelerator = GlobalAccelerator(aws_provider)

assert len(globalaccelerator.accelerators) == 1
assert (
globalaccelerator.accelerators[TEST_ACCELERATOR_ARN].tags[0]["Key"]
== "Name"
)
assert (
globalaccelerator.accelerators[TEST_ACCELERATOR_ARN].tags[0]["Value"]
== "TestAccelerator"
)
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_shield_enabled_globalaccelerator_protected(self):
name=accelerator_name,
region=AWS_REGION_EU_WEST_1,
enabled=True,
tags=[{"Key": "Name", "Value": "TestAccelerator"}],
)
}

Expand Down Expand Up @@ -85,6 +86,9 @@ def test_shield_enabled_globalaccelerator_protected(self):
result[0].status_extended
== f"Global Accelerator {accelerator_id} is protected by AWS Shield Advanced."
)
assert result[0].resource_tags == [
{"Key": "Name", "Value": "TestAccelerator"}
]

def test_shield_enabled_globalaccelerator_not_protected(self):
# GlobalAccelerator Client
Expand All @@ -98,6 +102,7 @@ def test_shield_enabled_globalaccelerator_not_protected(self):
name=accelerator_name,
region=AWS_REGION_EU_WEST_1,
enabled=True,
tags=[{"Key": "Name", "Value": "TestAccelerator"}],
)
}

Expand Down Expand Up @@ -131,6 +136,9 @@ def test_shield_enabled_globalaccelerator_not_protected(self):
result[0].status_extended
== f"Global Accelerator {accelerator_id} is not protected by AWS Shield Advanced."
)
assert result[0].resource_tags == [
{"Key": "Name", "Value": "TestAccelerator"}
]

def test_shield_disabled_globalaccelerator_not_protected(self):
# GlobalAccelerator Client
Expand Down

0 comments on commit 13e40eb

Please sign in to comment.