From f478a6af9b243678006e4802b0546a9c21740861 Mon Sep 17 00:00:00 2001 From: Adam Sachs Date: Wed, 7 Sep 2022 16:29:38 -0400 Subject: [PATCH] Refactor strategy instantiation for more extensitiliby (#1254) * Instantiate strategies via abstract Strategy base class A generalized Strategy abstract base class provides generalized getter methods that instantiate strategy subclasses (implementations). These methods rely on the builtin __subclasses__() method to identify Strategy subclasses, which allows for more dynamic and extensible strategy implementation, removing the need for a hardcoded enumeration of supported Strategy implementations. Abstract strategy types inherit from this new abstract base class, and strategy subclasses (implementations) must provide `name` and `configuration_model` attributes that are leveraged by new instantiation mechanism in the abstract base class. * Update get_description() to be a class rather than static method This allows the method to leverage the new `name` class variable rather than relying on a static constant variable. * Remove strategy factories and update references Strategy factories are no longer needed with refactored Strategy getters. Update the uses (references) of strategy factories throughout the codebase to now rely on the new Strategy getters. Strategy subclasses (implementations) now need to be imported explicitly in __init__.py's because they used to be imported in factory modules. Also remove the old MaskingStrategy registration/factory mechanisms. * Remove strategy name constants Now that the abstract Strategy base class enforces implementation subclasses to have a `name` class attribute, this attribute should be relied upon rather than the arbitrary name constants declared previously. The get_strategy_name() abstract method is also superfluous, as the `name` class attribute can be used as a standardized way to retrieve the strategy name. * Remove get_configuration_model() abstract method The generalized strategy getter now relies upon the `configuration_model` class variable that's on each Strategy. Therefore we no longer need the get_configuration_model() getter on each Strategy subclass. * Update MaskingStrategy docs with new Strategy functionality * Update changelog * Improve recursion in _find_all_strategy_subclasses * Fix recursion bug when finding all strategies Update associated tests to make sure the recursion is properly tested * Tweak conditional for falsy check * Make get_strategies endpoint test more robust * Fix typo in documentation Co-authored-by: Adam Sachs --- CHANGELOG.md | 1 + .../docs/guides/masking_strategies.md | 38 ++-- .../ops/api/v1/endpoints/masking_endpoints.py | 14 +- .../ops/api/v1/endpoints/oauth_endpoints.py | 6 +- .../api/v1/endpoints/saas_config_endpoints.py | 11 +- .../55d61eb8ed12_add_default_policies.py | 4 +- src/fidesops/ops/schemas/saas/saas_config.py | 6 +- .../ops/service/authentication/__init__.py | 7 + .../authentication/authentication_strategy.py | 11 +- .../authentication_strategy_basic.py | 8 +- .../authentication_strategy_bearer.py | 8 +- .../authentication_strategy_factory.py | 78 -------- ...tion_strategy_oauth2_authorization_code.py | 8 +- ...tion_strategy_oauth2_client_credentials.py | 12 +- .../authentication_strategy_query_param.py | 8 +- .../ops/service/connectors/query_config.py | 9 +- .../connectors/saas/authenticated_client.py | 6 +- .../ops/service/connectors/saas_connector.py | 10 +- .../masking/strategy/masking_strategy.py | 19 +- .../strategy/masking_strategy_aes_encrypt.py | 36 ++-- .../strategy/masking_strategy_factory.py | 68 ------- .../masking/strategy/masking_strategy_hash.py | 34 ++-- .../masking/strategy/masking_strategy_hmac.py | 36 ++-- .../strategy/masking_strategy_nullify.py | 26 +-- .../masking_strategy_random_string_rewrite.py | 22 +-- .../masking_strategy_string_rewrite.py | 22 +-- .../ops/service/pagination/__init__.py | 5 + .../service/pagination/pagination_strategy.py | 14 +- .../pagination/pagination_strategy_cursor.py | 14 +- .../pagination/pagination_strategy_factory.py | 72 -------- .../pagination/pagination_strategy_link.py | 14 +- .../pagination/pagination_strategy_offset.py | 14 +- .../privacy_request/request_service.py | 8 +- .../post_processor_strategy/__init__.py | 4 + .../post_processor_strategy.py | 10 +- .../post_processor_strategy_factory.py | 66 ------- .../post_processor_strategy_filter.py | 17 +- .../post_processor_strategy_unwrap.py | 17 +- src/fidesops/ops/service/strategy.py | 77 ++++++++ .../v1/endpoints/test_masking_endpoints.py | 37 ++-- .../api/v1/endpoints/test_policy_endpoints.py | 8 +- tests/ops/fixtures/application_fixtures.py | 18 +- tests/ops/models/test_policy.py | 12 +- .../test_authentication_strategy_basic.py | 12 +- .../test_authentication_strategy_bearer.py | 10 +- .../test_authentication_strategy_factory.py | 20 ++- ...tion_strategy_oauth2_authorization_code.py | 52 +++--- ...tion_strategy_oauth2_client_credentials.py | 44 +++-- ...est_authentication_strategy_query_param.py | 8 +- .../service/connectors/test_queryconfig.py | 5 +- .../test_masking_strategy_aes_encrypt.py | 7 +- .../strategy/test_masking_strategy_factory.py | 14 +- .../strategy/test_masking_strategy_hash.py | 24 ++- .../strategy/test_masking_strategy_hmac.py | 21 ++- .../test_pagination_strategy_factory.py | 20 ++- .../request_runner_service_test.py | 8 +- .../test_post_processor_strategy_factory.py | 18 +- tests/ops/service/test_strategy_retrieval.py | 169 ++++++++++++++++++ .../ops/util/encryption/test_secrets_util.py | 13 +- 59 files changed, 622 insertions(+), 738 deletions(-) delete mode 100644 src/fidesops/ops/service/authentication/authentication_strategy_factory.py delete mode 100644 src/fidesops/ops/service/masking/strategy/masking_strategy_factory.py delete mode 100644 src/fidesops/ops/service/pagination/pagination_strategy_factory.py delete mode 100644 src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_factory.py create mode 100644 src/fidesops/ops/service/strategy.py create mode 100644 tests/ops/service/test_strategy_retrieval.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 152f09821..d15c7d3dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The types of changes are: * Created a docker image for the privacy center [#1165](https://github.com/ethyca/fidesops/pull/1165) * Adds email scopes to postman collection [#1241](https://github.com/ethyca/fidesops/pull/1241) * Clean up docker build [#1252](https://github.com/ethyca/fidesops/pull/1252) +* Add `Strategy` abstract base class for more extensible strategy development [1254](https://github.com/ethyca/fidesops/pull/1254) ### Added diff --git a/docs/fidesops/docs/guides/masking_strategies.md b/docs/fidesops/docs/guides/masking_strategies.md index 71bb1bfc0..825b13aef 100644 --- a/docs/fidesops/docs/guides/masking_strategies.md +++ b/docs/fidesops/docs/guides/masking_strategies.md @@ -169,28 +169,31 @@ strategies available, along with their configuration options. ## Extensibility In fidesops, masking strategies are all built on top of an abstract base class - `MaskingStrategy`. -`MaskingStrategy` has five methods - `mask`, `secrets_required`, `get_configuration_model`, `get_description`, and `data_type_supported`. For more detail on these +`MaskingStrategy` has four methods - `mask`, `secrets_required`, `get_description`, and `data_type_supported`. For more detail on these methods, visit the class in the fidesops repository. For now, we will focus on the implementation of `RandomStringRewriteMaskingStrategy` below: ```python import string -from typing import Optional from secrets import choice +from typing import List, Optional, Type -from fidesops.ops.schemas.masking.masking_configuration import RandomStringMaskingConfiguration, MaskingConfiguration -from fidesops.ops.schemas.masking.masking_strategy_description import MaskingStrategyDescription +from fidesops.ops.schemas.masking.masking_configuration import ( + RandomStringMaskingConfiguration, +) +from fidesops.ops.schemas.masking.masking_strategy_description import ( + MaskingStrategyConfigurationDescription, + MaskingStrategyDescription, +) from fidesops.ops.service.masking.strategy.format_preservation import FormatPreservation from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) -RANDOM_STRING_REWRITE_STRATEGY_NAME = "random_string_rewrite" -@MaskingStrategyFactory.register(RANDOM_STRING_REWRITE_STRATEGY_NAME) class RandomStringRewriteMaskingStrategy(MaskingStrategy): - """Masks a value with a random string of the length specified in the configuration.""" + """Masks each provied value with a random string of the length specified in the configuration.""" + + name = "random_string_rewrite" + configuration_model = RandomStringMaskingConfiguration def __init__( self, @@ -199,7 +202,9 @@ class RandomStringRewriteMaskingStrategy(MaskingStrategy): self.length = configuration.length self.format_preservation = configuration.format_preservation - def mask(self, values: Optional[List[str]], privacy_request_id: Optional[str]) -> Optional[List[str]]: + def mask( + self, values: Optional[List[str]], request_id: Optional[str] + ) -> Optional[List[str]]: """Replaces the value with a random lowercase string of the configured length""" if values is None: return None @@ -217,12 +222,11 @@ class RandomStringRewriteMaskingStrategy(MaskingStrategy): masked_values.append(masked) return masked_values - @staticmethod - def get_configuration_model() -> MaskingConfiguration: + def secrets_required(self) -> bool: """Not covered in this example""" - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: """Not covered in this example""" @staticmethod @@ -241,6 +245,4 @@ any defaults that should be applied in their absence. All configuration classes ### Integrate the masking strategy factory -In order to leverage an implemented masking strategy, the `MaskingStrategy` subclass must be registered with the `MaskingStrategyFactory`. To register a new `MaskingStrategy`, use the `register` decorator on the `MaskingStrategy` subclass definition, as shown in the above example. - -The value passed as the argument to the decorator must be the registered name of the `MaskingStrategy` subclass. This is the same value defined by [callers](#using-fidesops-as-a-masking-service) in the `"masking_strategy"."strategy"` field. +In order to leverage an implemented masking strategy, the `MaskingStrategy` subclass must be imported into the application runtime. Also, the `MaskingStrategy` class must define two class variables: `name`, which is the unique, registered name that [callers](#using-fidesops-as-a-masking-service) will use in their `"masking_strategy"."strategy"` field to invoke the strategy; and `configuration_model`, which references the configuration class used to parameterize the strategy. diff --git a/src/fidesops/ops/api/v1/endpoints/masking_endpoints.py b/src/fidesops/ops/api/v1/endpoints/masking_endpoints.py index 1d8990c94..c9535b818 100644 --- a/src/fidesops/ops/api/v1/endpoints/masking_endpoints.py +++ b/src/fidesops/ops/api/v1/endpoints/masking_endpoints.py @@ -5,7 +5,7 @@ from starlette.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND from fidesops.ops.api.v1.urn_registry import MASKING, MASKING_STRATEGY, V1_URL_PREFIX -from fidesops.ops.common_exceptions import ValidationError +from fidesops.ops.common_exceptions import NoSuchStrategyException, ValidationError from fidesops.ops.schemas.masking.masking_api import ( MaskingAPIRequest, MaskingAPIResponse, @@ -13,10 +13,7 @@ from fidesops.ops.schemas.masking.masking_strategy_description import ( MaskingStrategyDescription, ) -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, - NoSuchStrategyException, -) +from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy from fidesops.ops.util.api_router import APIRouter router = APIRouter(tags=["Masking"], prefix=V1_URL_PREFIX) @@ -30,7 +27,7 @@ def mask_value(request: MaskingAPIRequest) -> MaskingAPIResponse: try: values = request.values masking_strategy = request.masking_strategy - strategy = MaskingStrategyFactory.get_strategy( + strategy = MaskingStrategy.get_strategy( masking_strategy.strategy, masking_strategy.configuration ) logger.info( @@ -52,7 +49,4 @@ def mask_value(request: MaskingAPIRequest) -> MaskingAPIResponse: def list_masking_strategies() -> List[MaskingStrategyDescription]: """Lists available masking strategies with instructions on how to use them""" logger.info("Getting available masking strategies") - return [ - strategy.get_description() - for strategy in MaskingStrategyFactory.get_strategies() - ] + return [strategy.get_description() for strategy in MaskingStrategy.get_strategies()] diff --git a/src/fidesops/ops/api/v1/endpoints/oauth_endpoints.py b/src/fidesops/ops/api/v1/endpoints/oauth_endpoints.py index 301d3dfd2..c13a7d25f 100644 --- a/src/fidesops/ops/api/v1/endpoints/oauth_endpoints.py +++ b/src/fidesops/ops/api/v1/endpoints/oauth_endpoints.py @@ -42,8 +42,8 @@ from fidesops.ops.models.authentication_request import AuthenticationRequest from fidesops.ops.models.connectionconfig import ConnectionConfig from fidesops.ops.schemas.client import ClientCreatedResponse -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) from fidesops.ops.service.authentication.authentication_strategy_oauth2_authorization_code import ( OAuth2AuthorizationCodeAuthenticationStrategy, @@ -214,7 +214,7 @@ def oauth_callback(code: str, state: str, db: Session = Depends(get_db)) -> None authentication = ( connection_config.get_saas_config().client_config.authentication # type: ignore ) - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( # type: ignore + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = AuthenticationStrategy.get_strategy( # type: ignore authentication.strategy, authentication.configuration # type: ignore ) connection_config.secrets = {**connection_config.secrets, "code": code} # type: ignore diff --git a/src/fidesops/ops/api/v1/endpoints/saas_config_endpoints.py b/src/fidesops/ops/api/v1/endpoints/saas_config_endpoints.py index c5e6a8df0..5cafb7e2a 100644 --- a/src/fidesops/ops/api/v1/endpoints/saas_config_endpoints.py +++ b/src/fidesops/ops/api/v1/endpoints/saas_config_endpoints.py @@ -44,8 +44,8 @@ ValidateSaaSConfigResponse, ) from fidesops.ops.schemas.shared_schemas import FidesOpsKey -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) from fidesops.ops.service.authentication.authentication_strategy_oauth2_authorization_code import ( OAuth2AuthorizationCodeAuthenticationStrategy, @@ -113,10 +113,7 @@ def verify_oauth_connection_config( detail="The connection config does not contain an authentication configuration.", ) - if ( - authentication.strategy - != OAuth2AuthorizationCodeAuthenticationStrategy.strategy_name - ): + if authentication.strategy != OAuth2AuthorizationCodeAuthenticationStrategy.name: raise HTTPException( status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail="The connection config does not use OAuth2 Authorization Code authentication.", @@ -262,7 +259,7 @@ def authorize_connection( authentication = connection_config.get_saas_config().client_config.authentication # type: ignore try: - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = AuthenticationStrategy.get_strategy( authentication.strategy, authentication.configuration # type: ignore ) return auth_strategy.get_authorization_url(db, connection_config) diff --git a/src/fidesops/ops/migrations/versions/55d61eb8ed12_add_default_policies.py b/src/fidesops/ops/migrations/versions/55d61eb8ed12_add_default_policies.py index 42476f266..6b3b0fd74 100644 --- a/src/fidesops/ops/migrations/versions/55d61eb8ed12_add_default_policies.py +++ b/src/fidesops/ops/migrations/versions/55d61eb8ed12_add_default_policies.py @@ -30,9 +30,6 @@ from fidesops.ops.db.base_class import JSONTypeOverride from fidesops.ops.models.policy import ActionType, DrpAction from fidesops.ops.schemas.storage.storage import StorageType -from fidesops.ops.service.masking.strategy.masking_strategy_string_rewrite import ( - STRING_REWRITE_STRATEGY_NAME, -) from fidesops.ops.util.data_category import DataCategory logging.basicConfig() @@ -48,6 +45,7 @@ FIDESOPS_AUTOGENERATED_STORAGE_KEY = "fidesops_autogenerated_storage_destination" AUTOGENERATED_ACCESS_KEY = "download" AUTOGENERATED_ERASURE_KEY = "delete" +STRING_REWRITE_STRATEGY_NAME = "string_rewrite" client_select_query: TextClause = text( """SELECT client.id FROM client WHERE fides_key = :fides_key""" diff --git a/src/fidesops/ops/schemas/saas/saas_config.py b/src/fidesops/ops/schemas/saas/saas_config.py index 67b94e94b..d59d80320 100644 --- a/src/fidesops/ops/schemas/saas/saas_config.py +++ b/src/fidesops/ops/schemas/saas/saas_config.py @@ -116,13 +116,13 @@ def validate_request_for_pagination(cls, values: Dict[str, Any]) -> Dict[str, An """ # delay import to avoid cyclic-dependency error - We still ignore the pylint error - from fidesops.ops.service.pagination.pagination_strategy_factory import ( # pylint: disable=R0401 - get_strategy, + from fidesops.ops.service.pagination.pagination_strategy import ( # pylint: disable=R0401 + PaginationStrategy, ) pagination = values.get("pagination") if pagination is not None: - pagination_strategy = get_strategy( + pagination_strategy = PaginationStrategy.get_strategy( pagination.get("strategy"), pagination.get("configuration") ) pagination_strategy.validate_request(values) diff --git a/src/fidesops/ops/service/authentication/__init__.py b/src/fidesops/ops/service/authentication/__init__.py index e69de29bb..14160fb96 100644 --- a/src/fidesops/ops/service/authentication/__init__.py +++ b/src/fidesops/ops/service/authentication/__init__.py @@ -0,0 +1,7 @@ +from fidesops.ops.service.authentication import ( + authentication_strategy_basic, + authentication_strategy_bearer, + authentication_strategy_oauth2_authorization_code, + authentication_strategy_oauth2_client_credentials, + authentication_strategy_query_param, +) diff --git a/src/fidesops/ops/service/authentication/authentication_strategy.py b/src/fidesops/ops/service/authentication/authentication_strategy.py index 7934171f0..6eed7bd7e 100644 --- a/src/fidesops/ops/service/authentication/authentication_strategy.py +++ b/src/fidesops/ops/service/authentication/authentication_strategy.py @@ -1,12 +1,12 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from requests import PreparedRequest from fidesops.ops.models.connectionconfig import ConnectionConfig -from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration +from fidesops.ops.service.strategy import Strategy -class AuthenticationStrategy(ABC): +class AuthenticationStrategy(Strategy): """Abstract base class for SaaS authentication strategies""" @abstractmethod @@ -14,8 +14,3 @@ def add_authentication( self, request: PreparedRequest, connection_config: ConnectionConfig ) -> PreparedRequest: """Add authentication to the request""" - - @staticmethod - @abstractmethod - def get_configuration_model() -> StrategyConfiguration: - """Used to get the configuration model to configure the strategy""" diff --git a/src/fidesops/ops/service/authentication/authentication_strategy_basic.py b/src/fidesops/ops/service/authentication/authentication_strategy_basic.py index 79f9716fb..dbfc5ca01 100644 --- a/src/fidesops/ops/service/authentication/authentication_strategy_basic.py +++ b/src/fidesops/ops/service/authentication/authentication_strategy_basic.py @@ -3,7 +3,6 @@ from fidesops.ops.models.connectionconfig import ConnectionConfig from fidesops.ops.schemas.saas.strategy_configuration import ( BasicAuthenticationConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.authentication.authentication_strategy import ( AuthenticationStrategy, @@ -17,7 +16,8 @@ class BasicAuthenticationStrategy(AuthenticationStrategy): and uses them to add a basic authentication header to the incoming request. """ - strategy_name = "basic" + name = "basic" + configuration_model = BasicAuthenticationConfiguration def __init__(self, configuration: BasicAuthenticationConfiguration): self.username = configuration.username @@ -36,7 +36,3 @@ def add_authentication( ) ) return request - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return BasicAuthenticationConfiguration # type: ignore diff --git a/src/fidesops/ops/service/authentication/authentication_strategy_bearer.py b/src/fidesops/ops/service/authentication/authentication_strategy_bearer.py index c932a6b0e..9f3c8e1df 100644 --- a/src/fidesops/ops/service/authentication/authentication_strategy_bearer.py +++ b/src/fidesops/ops/service/authentication/authentication_strategy_bearer.py @@ -3,7 +3,6 @@ from fidesops.ops.models.connectionconfig import ConnectionConfig from fidesops.ops.schemas.saas.strategy_configuration import ( BearerAuthenticationConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.authentication.authentication_strategy import ( AuthenticationStrategy, @@ -17,7 +16,8 @@ class BearerAuthenticationStrategy(AuthenticationStrategy): and uses it to add a bearer authentication header to the incoming request. """ - strategy_name = "bearer" + name = "bearer" + configuration_model = BearerAuthenticationConfiguration def __init__(self, configuration: BearerAuthenticationConfiguration): self.token = configuration.token @@ -30,7 +30,3 @@ def add_authentication( self.token, connection_config.secrets # type: ignore ) return request - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return BearerAuthenticationConfiguration # type: ignore diff --git a/src/fidesops/ops/service/authentication/authentication_strategy_factory.py b/src/fidesops/ops/service/authentication/authentication_strategy_factory.py deleted file mode 100644 index 2cd8ab624..000000000 --- a/src/fidesops/ops/service/authentication/authentication_strategy_factory.py +++ /dev/null @@ -1,78 +0,0 @@ -import logging -from enum import Enum -from typing import Any, Dict, List - -from pydantic import ValidationError - -from fidesops.ops.common_exceptions import NoSuchStrategyException -from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError -from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration -from fidesops.ops.service.authentication.authentication_strategy import ( - AuthenticationStrategy, -) -from fidesops.ops.service.authentication.authentication_strategy_basic import ( - BasicAuthenticationStrategy, -) -from fidesops.ops.service.authentication.authentication_strategy_bearer import ( - BearerAuthenticationStrategy, -) -from fidesops.ops.service.authentication.authentication_strategy_oauth2_authorization_code import ( - OAuth2AuthorizationCodeAuthenticationStrategy, -) -from fidesops.ops.service.authentication.authentication_strategy_oauth2_client_credentials import ( - OAuth2ClientCredentialsAuthenticationStrategy, -) -from fidesops.ops.service.authentication.authentication_strategy_query_param import ( - QueryParamAuthenticationStrategy, -) - -logger = logging.getLogger(__name__) - - -class SupportedAuthenticationStrategies(Enum): - """ - The supported strategies for authenticating against SaaS APIs. - """ - - basic = BasicAuthenticationStrategy - bearer = BearerAuthenticationStrategy - query_param = QueryParamAuthenticationStrategy - oauth2_authorization_code = OAuth2AuthorizationCodeAuthenticationStrategy - oauth2_client_credentials = OAuth2ClientCredentialsAuthenticationStrategy - - @classmethod - def __contains__(cls, item: str) -> bool: - try: - cls[item] - except KeyError: - return False - - return True - - -def get_strategy( - strategy_name: str, - configuration: Dict[str, Any], -) -> AuthenticationStrategy: - """ - Returns the strategy given the name and configuration. - Raises NoSuchStrategyException if the strategy does not exist - """ - if not SupportedAuthenticationStrategies.__contains__(strategy_name): - valid_strategies = ", ".join(get_strategy_names()) - raise NoSuchStrategyException( - f"Strategy '{strategy_name}' does not exist. Valid strategies are [{valid_strategies}]" - ) - strategy = SupportedAuthenticationStrategies[strategy_name].value - try: - strategy_config: StrategyConfiguration = strategy.get_configuration_model()( - **configuration - ) - return strategy(configuration=strategy_config) - except ValidationError as e: - raise FidesopsValidationError(message=str(e)) - - -def get_strategy_names() -> List[str]: - """Returns all supported authentication strategies""" - return [s.name for s in SupportedAuthenticationStrategies] diff --git a/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_authorization_code.py b/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_authorization_code.py index 2d1d7b8b3..408c7d981 100644 --- a/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_authorization_code.py +++ b/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_authorization_code.py @@ -12,7 +12,6 @@ from fidesops.ops.models.connectionconfig import ConnectionConfig from fidesops.ops.schemas.saas.strategy_configuration import ( OAuth2AuthorizationCodeConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.authentication.authentication_strategy_oauth2_base import ( OAuth2AuthenticationStrategyBase, @@ -28,7 +27,8 @@ class OAuth2AuthorizationCodeAuthenticationStrategy(OAuth2AuthenticationStrategy it if needed using the configured token refresh request. """ - strategy_name = "oauth2_authorization_code" + name = "oauth2_authorization_code" + configuration_model = OAuth2AuthorizationCodeConfiguration def __init__(self, configuration: OAuth2AuthorizationCodeConfiguration): super().__init__(configuration) @@ -126,7 +126,3 @@ def _generate_state() -> str: if config.oauth_instance: state = f"{config.oauth_instance}-{state}" return state - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return OAuth2AuthorizationCodeConfiguration # type: ignore diff --git a/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_client_credentials.py b/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_client_credentials.py index ef2d2bfee..3dea09062 100644 --- a/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_client_credentials.py +++ b/src/fidesops/ops/service/authentication/authentication_strategy_oauth2_client_credentials.py @@ -3,10 +3,7 @@ from requests import PreparedRequest from fidesops.ops.models.connectionconfig import ConnectionConfig -from fidesops.ops.schemas.saas.strategy_configuration import ( - OAuth2BaseConfiguration, - StrategyConfiguration, -) +from fidesops.ops.schemas.saas.strategy_configuration import OAuth2BaseConfiguration from fidesops.ops.service.authentication.authentication_strategy_oauth2_base import ( OAuth2AuthenticationStrategyBase, ) @@ -20,7 +17,8 @@ class OAuth2ClientCredentialsAuthenticationStrategy(OAuth2AuthenticationStrategy it if needed using the configured token refresh request. """ - strategy_name = "oauth2_client_credentials" + name = "oauth2_client_credentials" + configuration_model = OAuth2BaseConfiguration def add_authentication( self, request: PreparedRequest, connection_config: ConnectionConfig @@ -39,7 +37,3 @@ def add_authentication( # add access_token to request request.headers["Authorization"] = "Bearer " + access_token return request - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return OAuth2BaseConfiguration # type: ignore diff --git a/src/fidesops/ops/service/authentication/authentication_strategy_query_param.py b/src/fidesops/ops/service/authentication/authentication_strategy_query_param.py index 5aaf5d1c3..ee1f51ba1 100644 --- a/src/fidesops/ops/service/authentication/authentication_strategy_query_param.py +++ b/src/fidesops/ops/service/authentication/authentication_strategy_query_param.py @@ -3,7 +3,6 @@ from fidesops.ops.models.connectionconfig import ConnectionConfig from fidesops.ops.schemas.saas.strategy_configuration import ( QueryParamAuthenticationConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.authentication.authentication_strategy import ( AuthenticationStrategy, @@ -18,7 +17,8 @@ class QueryParamAuthenticationStrategy(AuthenticationStrategy): and adds it as a query param to the incoming request. """ - strategy_name = "query_param" + name = "query_param" + configuration_model = QueryParamAuthenticationConfiguration def __init__(self, configuration: QueryParamAuthenticationConfiguration): self.name = configuration.name @@ -34,7 +34,3 @@ def add_authentication( assign_placeholders(self.value, connection_config.secrets), # type: ignore ) return request - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return QueryParamAuthenticationConfiguration # type: ignore diff --git a/src/fidesops/ops/service/connectors/query_config.py b/src/fidesops/ops/service/connectors/query_config.py index 22b6fce49..0c2ff67f7 100644 --- a/src/fidesops/ops/service/connectors/query_config.py +++ b/src/fidesops/ops/service/connectors/query_config.py @@ -20,11 +20,8 @@ from fidesops.ops.models.policy import ActionType, Policy, Rule from fidesops.ops.models.privacy_request import ManualAction, PrivacyRequest from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) from fidesops.ops.service.masking.strategy.masking_strategy_nullify import ( - NULL_REWRITE_STRATEGY_NAME, + NullMaskingStrategy, ) from fidesops.ops.task.refine_target_path import ( build_refined_target_paths, @@ -147,7 +144,7 @@ def update_value_map( # pylint: disable=R0914 strategy_config = rule.masking_strategy if not strategy_config: continue - strategy: MaskingStrategy = MaskingStrategyFactory.get_strategy( + strategy: MaskingStrategy = MaskingStrategy.get_strategy( strategy_config["strategy"], strategy_config["configuration"] ) for rule_field_path in field_paths: @@ -157,7 +154,7 @@ def update_value_map( # pylint: disable=R0914 if field_path == rule_field_path ][0] null_masking: bool = ( - strategy_config.get("strategy") == NULL_REWRITE_STRATEGY_NAME + strategy_config.get("strategy") == NullMaskingStrategy.name ) if not self._supported_data_type( masking_override, null_masking, strategy diff --git a/src/fidesops/ops/service/connectors/saas/authenticated_client.py b/src/fidesops/ops/service/connectors/saas/authenticated_client.py index a87935c62..5ef6c5aa4 100644 --- a/src/fidesops/ops/service/connectors/saas/authenticated_client.py +++ b/src/fidesops/ops/service/connectors/saas/authenticated_client.py @@ -52,8 +52,8 @@ def get_authenticated_request( incoming path, headers, query, and body params. """ - from fidesops.ops.service.authentication.authentication_strategy_factory import ( # pylint: disable=R0401 - get_strategy, + from fidesops.ops.service.authentication.authentication_strategy import ( # pylint: disable=R0401 + AuthenticationStrategy, ) req: PreparedRequest = Request( @@ -66,7 +66,7 @@ def get_authenticated_request( # add authentication if provided if self.client_config.authentication: - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( self.client_config.authentication.strategy, self.client_config.authentication.configuration, ) diff --git a/src/fidesops/ops/service/connectors/saas_connector.py b/src/fidesops/ops/service/connectors/saas_connector.py index d96c68d35..9b456c26e 100644 --- a/src/fidesops/ops/service/connectors/saas_connector.py +++ b/src/fidesops/ops/service/connectors/saas_connector.py @@ -18,15 +18,9 @@ ) from fidesops.ops.service.connectors.saas_query_config import SaaSQueryConfig from fidesops.ops.service.pagination.pagination_strategy import PaginationStrategy -from fidesops.ops.service.pagination.pagination_strategy_factory import ( - get_strategy as get_pagination_strategy, -) from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy import ( PostProcessorStrategy, ) -from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy_factory import ( - get_strategy as get_postprocessor_strategy, -) from fidesops.ops.service.saas_request.saas_request_override_factory import ( SaaSRequestOverrideFactory, SaaSRequestType, @@ -189,7 +183,7 @@ def execute_prepared_request( # use the pagination strategy (if available) to get the next request next_request = None if saas_request.pagination: - strategy: PaginationStrategy = get_pagination_strategy( + strategy: PaginationStrategy = PaginationStrategy.get_strategy( saas_request.pagination.strategy, saas_request.pagination.configuration, ) @@ -222,7 +216,7 @@ def process_response_data( rows: List[Row] = [] processed_data = response_data for postprocessor in postprocessors or []: - strategy: PostProcessorStrategy = get_postprocessor_strategy( + strategy: PostProcessorStrategy = PostProcessorStrategy.get_strategy( postprocessor.strategy, postprocessor.configuration # type: ignore ) logger.info( diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy.py b/src/fidesops/ops/service/masking/strategy/masking_strategy.py index e40eb358b..5336076b7 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy.py @@ -1,15 +1,17 @@ # MR Note - It would be nice to enforce this at compile time -from abc import ABC, abstractmethod -from typing import Any, List, Optional +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, List, Optional, Type -from fidesops.ops.schemas.masking.masking_configuration import MaskingConfiguration from fidesops.ops.schemas.masking.masking_secrets import MaskingSecretCache from fidesops.ops.schemas.masking.masking_strategy_description import ( MaskingStrategyDescription, ) +from fidesops.ops.service.strategy import Strategy -class MaskingStrategy(ABC): +class MaskingStrategy(Strategy): """Abstract base class for masking strategies""" @abstractmethod @@ -25,14 +27,9 @@ def secrets_required(self) -> bool: def generate_secrets_for_cache(self) -> List[MaskingSecretCache]: """Generates secrets for strategy""" - @staticmethod - @abstractmethod - def get_configuration_model() -> MaskingConfiguration: - """Used to get the configuration model to configure the strategy""" - - @staticmethod + @classmethod @abstractmethod - def get_description() -> MaskingStrategyDescription: + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: """Returns the description used for documentation. In particular, used by the documentation endpoint in masking_endpoints.list_masking_strategies""" diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_aes_encrypt.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_aes_encrypt.py index fa1a34226..f2afe29ed 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_aes_encrypt.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy_aes_encrypt.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from fidesops.ops.schemas.masking.masking_configuration import ( AesEncryptionMaskingConfiguration, HmacMaskingConfiguration, - MaskingConfiguration, ) from fidesops.ops.schemas.masking.masking_secrets import ( MaskingSecretCache, @@ -18,20 +17,18 @@ ) from fidesops.ops.service.masking.strategy.format_preservation import FormatPreservation from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) from fidesops.ops.util.encryption.aes_gcm_encryption_scheme import encrypt from fidesops.ops.util.encryption.hmac_encryption_scheme import ( hmac_encrypt_return_bytes, ) from fidesops.ops.util.encryption.secrets_util import SecretsUtil -AES_ENCRYPT_STRATEGY_NAME = "aes_encrypt" - -@MaskingStrategyFactory.register(AES_ENCRYPT_STRATEGY_NAME) class AesEncryptionMaskingStrategy(MaskingStrategy): + + name = "aes_encrypt" + configuration_model = AesEncryptionMaskingConfiguration + def __init__(self, configuration: AesEncryptionMaskingConfiguration): self.mode = configuration.mode self.format_preservation = configuration.format_preservation @@ -81,17 +78,12 @@ def generate_secrets_for_cache(self) -> List[MaskingSecretCache]: ] = self._build_masking_secret_meta() return SecretsUtil.build_masking_secrets_for_cache(masking_meta) - @staticmethod - def get_configuration_model() -> MaskingConfiguration: - """Used to get the configuration model to configure the strategy""" - return AesEncryptionMaskingConfiguration # type: ignore - - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: """Returns the description used for documentation. In particular, used by the documentation endpoint in masking_endpoints.list_masking_strategies""" return MaskingStrategyDescription( - name=AES_ENCRYPT_STRATEGY_NAME, + name=cls.name, description="Masks by encrypting the value using AES", configurations=[ MaskingStrategyConfigurationDescription( @@ -127,19 +119,21 @@ def _generate_nonce( value, key, salt, HmacMaskingConfiguration.Algorithm.sha_256 # type: ignore )[:12] - @staticmethod - def _build_masking_secret_meta() -> Dict[SecretType, MaskingSecretMeta]: + @classmethod + def _build_masking_secret_meta( + cls: Type[MaskingStrategy], + ) -> Dict[SecretType, MaskingSecretMeta]: return { SecretType.key: MaskingSecretMeta[bytes]( - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=cls.name, generate_secret_func=SecretsUtil.generate_secret_bytes, ), SecretType.key_hmac: MaskingSecretMeta[str]( - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=cls.name, generate_secret_func=SecretsUtil.generate_secret_string, ), SecretType.salt_hmac: MaskingSecretMeta[str]( - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=cls.name, generate_secret_func=SecretsUtil.generate_secret_string, ), } diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_factory.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_factory.py deleted file mode 100644 index 22a110d78..000000000 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_factory.py +++ /dev/null @@ -1,68 +0,0 @@ -import logging -from typing import Callable, Dict, Type, Union, ValuesView - -from pydantic import ValidationError - -from fidesops.ops.common_exceptions import NoSuchStrategyException -from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError -from fidesops.ops.schemas.masking.masking_configuration import FormatPreservationConfig -from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy - -logger = logging.getLogger(__name__) - - -class MaskingStrategyFactory: - registry: Dict[str, Type[MaskingStrategy]] = {} - valid_strategies: str = "" - - @classmethod - def register( - cls, name: str - ) -> Callable[[Type[MaskingStrategy]], Type[MaskingStrategy]]: - def wrapper(strategy_class: Type[MaskingStrategy]) -> Type[MaskingStrategy]: - logger.debug( - "Registering new masking strategy '%s' under name '%s'", - strategy_class, - name, - ) - - if name in cls.registry: - logger.warning( - "Masking strategy with name '%s' already exists. It previously referred to class '%s', but will now refer to '%s'", - name, - cls.registry[name], - strategy_class, - ) - - cls.registry[name] = strategy_class - cls.valid_strategies = ", ".join(cls.registry.keys()) - return cls.registry[name] - - return wrapper - - @classmethod - def get_strategy( - cls, - strategy_name: str, - configuration: Dict[str, Union[str, FormatPreservationConfig]], - ) -> MaskingStrategy: - """ - Returns the strategy given the name and configuration. - Raises NoSuchStrategyException if the strategy does not exist - """ - try: - strategy = cls.registry[strategy_name] - except KeyError: - raise NoSuchStrategyException( - f"Strategy '{strategy_name}' does not exist. Valid strategies are [{cls.valid_strategies}]" - ) - try: - strategy_config = strategy.get_configuration_model()(**configuration) # type: ignore - except ValidationError as e: - raise FidesopsValidationError(message=str(e)) - return strategy(configuration=strategy_config) # type: ignore - - @classmethod - def get_strategies(cls) -> ValuesView[MaskingStrategy]: - """Returns all supported masking strategies""" - return cls.registry.values() # type: ignore diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_hash.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_hash.py index 3266577a5..4ffecd242 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_hash.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy_hash.py @@ -1,13 +1,10 @@ from __future__ import annotations import hashlib -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from fidesops.ops.core.config import config -from fidesops.ops.schemas.masking.masking_configuration import ( - HashMaskingConfiguration, - MaskingConfiguration, -) +from fidesops.ops.schemas.masking.masking_configuration import HashMaskingConfiguration from fidesops.ops.schemas.masking.masking_secrets import ( MaskingSecretCache, MaskingSecretMeta, @@ -19,18 +16,15 @@ ) from fidesops.ops.service.masking.strategy.format_preservation import FormatPreservation from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) from fidesops.ops.util.encryption.secrets_util import SecretsUtil -HASH_STRATEGY_NAME = "hash" - -@MaskingStrategyFactory.register(HASH_STRATEGY_NAME) class HashMaskingStrategy(MaskingStrategy): """Masks a value by hashing it""" + name = "hash" + configuration_model = HashMaskingConfiguration + def __init__( self, configuration: HashMaskingConfiguration, @@ -76,16 +70,12 @@ def generate_secrets_for_cache(self) -> List[MaskingSecretCache]: ] = self._build_masking_secret_meta() return SecretsUtil.build_masking_secrets_for_cache(masking_meta) - @staticmethod - def get_configuration_model() -> MaskingConfiguration: - return HashMaskingConfiguration # type: ignore - # MR Note - We will need a way to ensure that this does not fall out of date. Given that it # includes subjective instructions, this is not straightforward to automate - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: return MaskingStrategyDescription( - name=HASH_STRATEGY_NAME, + name=cls.name, description="Masks the input value by returning a hashed version of the input value", configurations=[ MaskingStrategyConfigurationDescription( @@ -120,11 +110,13 @@ def _hash_sha512(value: str, salt: str) -> str: (value + salt).encode(config.security.encoding) ).hexdigest() - @staticmethod - def _build_masking_secret_meta() -> Dict[SecretType, MaskingSecretMeta]: + @classmethod + def _build_masking_secret_meta( + cls: Type[MaskingStrategy], + ) -> Dict[SecretType, MaskingSecretMeta]: return { SecretType.salt: MaskingSecretMeta[str]( - masking_strategy=HASH_STRATEGY_NAME, + masking_strategy=cls.name, generate_secret_func=SecretsUtil.generate_secret_string, ) } diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_hmac.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_hmac.py index 4e8b8f128..0a2dfb660 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_hmac.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy_hmac.py @@ -1,11 +1,8 @@ from __future__ import annotations -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type -from fidesops.ops.schemas.masking.masking_configuration import ( - HmacMaskingConfiguration, - MaskingConfiguration, -) +from fidesops.ops.schemas.masking.masking_configuration import HmacMaskingConfiguration from fidesops.ops.schemas.masking.masking_secrets import ( MaskingSecretCache, MaskingSecretMeta, @@ -17,21 +14,18 @@ ) from fidesops.ops.service.masking.strategy.format_preservation import FormatPreservation from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) from fidesops.ops.util.encryption.hmac_encryption_scheme import hmac_encrypt_return_str from fidesops.ops.util.encryption.secrets_util import SecretsUtil -HMAC_STRATEGY_NAME = "hmac" - -@MaskingStrategyFactory.register(HMAC_STRATEGY_NAME) class HmacMaskingStrategy(MaskingStrategy): """ Masks a value by generating a hash using a hash algorithm and a required secret key. One of the differences between this and the HashMaskingStrategy is the required secret key.""" + name = "hmac" + configuration_model = HmacMaskingConfiguration + def __init__( self, configuration: HmacMaskingConfiguration, @@ -76,14 +70,10 @@ def generate_secrets_for_cache(self) -> List[MaskingSecretCache]: ] = self._build_masking_secret_meta() return SecretsUtil.build_masking_secrets_for_cache(masking_meta) - @staticmethod - def get_configuration_model() -> MaskingConfiguration: - return HmacMaskingConfiguration # type: ignore - - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: return MaskingStrategyDescription( - name=HMAC_STRATEGY_NAME, + name=cls.name, description="Masks the input value by using the HMAC algorithm along with a hashed version of the data " "and a secret key.", configurations=[ @@ -105,15 +95,17 @@ def data_type_supported(data_type: Optional[str]) -> bool: supported_data_types = {"string"} return data_type in supported_data_types - @staticmethod - def _build_masking_secret_meta() -> Dict[SecretType, MaskingSecretMeta]: + @classmethod + def _build_masking_secret_meta( + cls: Type[MaskingStrategy], + ) -> Dict[SecretType, MaskingSecretMeta]: return { SecretType.key: MaskingSecretMeta[str]( - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=cls.name, generate_secret_func=SecretsUtil.generate_secret_string, ), SecretType.salt: MaskingSecretMeta[str]( - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=cls.name, generate_secret_func=SecretsUtil.generate_secret_string, ), } diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_nullify.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_nullify.py index 7c8c28d09..517e77b89 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_nullify.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy_nullify.py @@ -1,24 +1,18 @@ -from typing import List, Optional +from typing import List, Optional, Type -from fidesops.ops.schemas.masking.masking_configuration import ( - MaskingConfiguration, - NullMaskingConfiguration, -) +from fidesops.ops.schemas.masking.masking_configuration import NullMaskingConfiguration from fidesops.ops.schemas.masking.masking_strategy_description import ( MaskingStrategyDescription, ) from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) -NULL_REWRITE_STRATEGY_NAME = "null_rewrite" - -@MaskingStrategyFactory.register(NULL_REWRITE_STRATEGY_NAME) class NullMaskingStrategy(MaskingStrategy): """Masks provided values each with a null value.""" + name = "null_rewrite" + configuration_model = NullMaskingConfiguration + def __init__( self, configuration: NullMaskingConfiguration, @@ -39,14 +33,10 @@ def mask( def secrets_required(self) -> bool: return False - @staticmethod - def get_configuration_model() -> MaskingConfiguration: - return NullMaskingConfiguration # type: ignore - - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: return MaskingStrategyDescription( - name=NULL_REWRITE_STRATEGY_NAME, + name=cls.name, description="Masks the input value with a null value", configurations=[], ) diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_random_string_rewrite.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_random_string_rewrite.py index ce4de2685..8d0a73bde 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_random_string_rewrite.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy_random_string_rewrite.py @@ -1,9 +1,8 @@ import string from secrets import choice -from typing import List, Optional +from typing import List, Optional, Type from fidesops.ops.schemas.masking.masking_configuration import ( - MaskingConfiguration, RandomStringMaskingConfiguration, ) from fidesops.ops.schemas.masking.masking_strategy_description import ( @@ -12,17 +11,14 @@ ) from fidesops.ops.service.masking.strategy.format_preservation import FormatPreservation from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) - -RANDOM_STRING_REWRITE_STRATEGY_NAME = "random_string_rewrite" -@MaskingStrategyFactory.register(RANDOM_STRING_REWRITE_STRATEGY_NAME) class RandomStringRewriteMaskingStrategy(MaskingStrategy): """Masks each provied value with a random string of the length specified in the configuration.""" + name = "random_string_rewrite" + configuration_model = RandomStringMaskingConfiguration + def __init__( self, configuration: RandomStringMaskingConfiguration, @@ -53,14 +49,10 @@ def mask( def secrets_required(self) -> bool: return False - @staticmethod - def get_configuration_model() -> MaskingConfiguration: - return RandomStringMaskingConfiguration # type: ignore - - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: return MaskingStrategyDescription( - name=RANDOM_STRING_REWRITE_STRATEGY_NAME, + name=cls.name, description="Masks the input value with a random string of a specified length", configurations=[ MaskingStrategyConfigurationDescription( diff --git a/src/fidesops/ops/service/masking/strategy/masking_strategy_string_rewrite.py b/src/fidesops/ops/service/masking/strategy/masking_strategy_string_rewrite.py index d851cad8e..8a352dec8 100644 --- a/src/fidesops/ops/service/masking/strategy/masking_strategy_string_rewrite.py +++ b/src/fidesops/ops/service/masking/strategy/masking_strategy_string_rewrite.py @@ -1,7 +1,6 @@ -from typing import List, Optional +from typing import List, Optional, Type from fidesops.ops.schemas.masking.masking_configuration import ( - MaskingConfiguration, StringRewriteMaskingConfiguration, ) from fidesops.ops.schemas.masking.masking_strategy_description import ( @@ -10,17 +9,14 @@ ) from fidesops.ops.service.masking.strategy.format_preservation import FormatPreservation from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) - -STRING_REWRITE_STRATEGY_NAME = "string_rewrite" -@MaskingStrategyFactory.register(STRING_REWRITE_STRATEGY_NAME) class StringRewriteMaskingStrategy(MaskingStrategy): """Masks the values with a pre-determined value""" + name = "string_rewrite" + configuration_model = StringRewriteMaskingConfiguration + def __init__( self, configuration: StringRewriteMaskingConfiguration, @@ -47,16 +43,12 @@ def mask( def secrets_required(self) -> bool: return False - @staticmethod - def get_configuration_model() -> MaskingConfiguration: - return StringRewriteMaskingConfiguration # type: ignore - # MR Note - We will need a way to ensure that this does not fall out of date. Given that it # includes subjective instructions, this is not straightforward to automate - @staticmethod - def get_description() -> MaskingStrategyDescription: + @classmethod + def get_description(cls: Type[MaskingStrategy]) -> MaskingStrategyDescription: return MaskingStrategyDescription( - name=STRING_REWRITE_STRATEGY_NAME, + name=cls.name, description="Masks the input value with a default string value", configurations=[ MaskingStrategyConfigurationDescription( diff --git a/src/fidesops/ops/service/pagination/__init__.py b/src/fidesops/ops/service/pagination/__init__.py index e69de29bb..378e7b2c2 100644 --- a/src/fidesops/ops/service/pagination/__init__.py +++ b/src/fidesops/ops/service/pagination/__init__.py @@ -0,0 +1,5 @@ +from fidesops.ops.service.pagination import ( + pagination_strategy_cursor, + pagination_strategy_link, + pagination_strategy_offset, +) diff --git a/src/fidesops/ops/service/pagination/pagination_strategy.py b/src/fidesops/ops/service/pagination/pagination_strategy.py index 783965d28..ad1b61910 100644 --- a/src/fidesops/ops/service/pagination/pagination_strategy.py +++ b/src/fidesops/ops/service/pagination/pagination_strategy.py @@ -1,23 +1,20 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, Optional from requests import Response from fidesops.ops.schemas.saas.shared_schemas import SaaSRequestParams +from fidesops.ops.service.strategy import Strategy if TYPE_CHECKING: from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration -class PaginationStrategy(ABC): +class PaginationStrategy(Strategy): """Abstract base class for SaaS pagination strategies""" - @abstractmethod - def get_strategy_name(self) -> str: - """Returns strategy name""" - @abstractmethod def get_next_request( self, @@ -28,11 +25,6 @@ def get_next_request( ) -> Optional[SaaSRequestParams]: """Build request for next page of data""" - @staticmethod - @abstractmethod - def get_configuration_model() -> StrategyConfiguration: - """Used to get the configuration model to configure the strategy""" - def validate_request(self, request: Dict[str, Any]) -> None: """ Accepts the raw SaaSRequest data and validates that the request diff --git a/src/fidesops/ops/service/pagination/pagination_strategy_cursor.py b/src/fidesops/ops/service/pagination/pagination_strategy_cursor.py index fa17036d4..38068bc0f 100644 --- a/src/fidesops/ops/service/pagination/pagination_strategy_cursor.py +++ b/src/fidesops/ops/service/pagination/pagination_strategy_cursor.py @@ -6,21 +6,19 @@ from fidesops.ops.schemas.saas.shared_schemas import SaaSRequestParams from fidesops.ops.schemas.saas.strategy_configuration import ( CursorPaginationConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.pagination.pagination_strategy import PaginationStrategy -STRATEGY_NAME = "cursor" - class CursorPaginationStrategy(PaginationStrategy): + + name = "cursor" + configuration_model = CursorPaginationConfiguration + def __init__(self, configuration: CursorPaginationConfiguration): self.cursor_param = configuration.cursor_param self.field = configuration.field - def get_strategy_name(self) -> str: - return STRATEGY_NAME - def get_next_request( self, request_params: SaaSRequestParams, @@ -52,7 +50,3 @@ def get_next_request( query_params=request_params.query_params, body=request_params.body, ) - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return CursorPaginationConfiguration # type: ignore diff --git a/src/fidesops/ops/service/pagination/pagination_strategy_factory.py b/src/fidesops/ops/service/pagination/pagination_strategy_factory.py deleted file mode 100644 index d68679461..000000000 --- a/src/fidesops/ops/service/pagination/pagination_strategy_factory.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import logging -from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List - -from pydantic import ValidationError - -from fidesops.ops.common_exceptions import NoSuchStrategyException -from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError -from fidesops.ops.service.pagination.pagination_strategy_cursor import ( - CursorPaginationStrategy, -) -from fidesops.ops.service.pagination.pagination_strategy_link import ( - LinkPaginationStrategy, -) -from fidesops.ops.service.pagination.pagination_strategy_offset import ( - OffsetPaginationStrategy, -) - -if TYPE_CHECKING: - from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration - from fidesops.ops.service.pagination.pagination_strategy import PaginationStrategy - -logger = logging.getLogger(__name__) - - -class SupportedPaginationStrategies(Enum): - """ - The supported methods by which Fidesops can post-process Saas connector data. - """ - - offset = OffsetPaginationStrategy - link = LinkPaginationStrategy - cursor = CursorPaginationStrategy - - @classmethod - def __contains__(cls, item: str) -> bool: - try: - cls[item] - except KeyError: - return False - - return True - - -def get_strategy( - strategy_name: str, - configuration: Dict[str, Any], -) -> PaginationStrategy: - """ - Returns the strategy given the name and configuration. - Raises NoSuchStrategyException if the strategy does not exist - """ - if not SupportedPaginationStrategies.__contains__(strategy_name): - valid_strategies = ", ".join([s.name for s in SupportedPaginationStrategies]) - raise NoSuchStrategyException( - f"Strategy '{strategy_name}' does not exist. Valid strategies are [{valid_strategies}]" - ) - strategy = SupportedPaginationStrategies[strategy_name].value - try: - strategy_config: StrategyConfiguration = strategy.get_configuration_model()( - **configuration - ) - return strategy(configuration=strategy_config) - except ValidationError as e: - raise FidesopsValidationError(message=str(e)) - - -def get_strategies() -> List[PaginationStrategy]: - """Returns all supported pagination strategies""" - return [e.value for e in SupportedPaginationStrategies] diff --git a/src/fidesops/ops/service/pagination/pagination_strategy_link.py b/src/fidesops/ops/service/pagination/pagination_strategy_link.py index 4cd970960..be3baf204 100644 --- a/src/fidesops/ops/service/pagination/pagination_strategy_link.py +++ b/src/fidesops/ops/service/pagination/pagination_strategy_link.py @@ -10,25 +10,23 @@ from fidesops.ops.schemas.saas.strategy_configuration import ( LinkPaginationConfiguration, LinkSource, - StrategyConfiguration, ) from fidesops.ops.service.pagination.pagination_strategy import PaginationStrategy from fidesops.ops.util.logger import Pii -STRATEGY_NAME = "link" - logger = logging.getLogger(__name__) class LinkPaginationStrategy(PaginationStrategy): + + name = "link" + configuration_model = LinkPaginationConfiguration + def __init__(self, configuration: LinkPaginationConfiguration): self.source = configuration.source self.rel = configuration.rel self.path = configuration.path - def get_strategy_name(self) -> str: - return STRATEGY_NAME - def get_next_request( self, request_params: SaaSRequestParams, @@ -71,7 +69,3 @@ def get_next_request( query_params=updated_query_params, body=request_params.body, ) - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return LinkPaginationConfiguration # type: ignore diff --git a/src/fidesops/ops/service/pagination/pagination_strategy_offset.py b/src/fidesops/ops/service/pagination/pagination_strategy_offset.py index 6aad7da2d..1ebf8b15d 100644 --- a/src/fidesops/ops/service/pagination/pagination_strategy_offset.py +++ b/src/fidesops/ops/service/pagination/pagination_strategy_offset.py @@ -8,22 +8,20 @@ from fidesops.ops.schemas.saas.strategy_configuration import ( ConnectorParamRef, OffsetPaginationConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.pagination.pagination_strategy import PaginationStrategy -STRATEGY_NAME = "offset" - class OffsetPaginationStrategy(PaginationStrategy): + + name = "offset" + configuration_model = OffsetPaginationConfiguration + def __init__(self, configuration: OffsetPaginationConfiguration): self.incremental_param = configuration.incremental_param self.increment_by = configuration.increment_by self.limit = configuration.limit - def get_strategy_name(self) -> str: - return STRATEGY_NAME - def get_next_request( self, request_params: SaaSRequestParams, @@ -69,10 +67,6 @@ def get_next_request( body=request_params.body, ) - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return OffsetPaginationConfiguration # type: ignore - def validate_request(self, request: Dict[str, Any]) -> None: """Ensures that the query param specified by 'incremental_param' exists in the request""" query_params = ( diff --git a/src/fidesops/ops/service/privacy_request/request_service.py b/src/fidesops/ops/service/privacy_request/request_service.py index d1b148ffd..b63a5e3ff 100644 --- a/src/fidesops/ops/service/privacy_request/request_service.py +++ b/src/fidesops/ops/service/privacy_request/request_service.py @@ -8,9 +8,7 @@ from fidesops.ops.schemas.drp_privacy_request import DrpPrivacyRequestCreate from fidesops.ops.schemas.masking.masking_secrets import MaskingSecretCache from fidesops.ops.schemas.redis_cache import PrivacyRequestIdentity -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) +from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy logger = logging.getLogger(__name__) @@ -58,9 +56,7 @@ def cache_data( if strategy_name in unique_masking_strategies_by_name: continue unique_masking_strategies_by_name.add(strategy_name) - masking_strategy = MaskingStrategyFactory.get_strategy( - strategy_name, configuration - ) + masking_strategy = MaskingStrategy.get_strategy(strategy_name, configuration) if masking_strategy.secrets_required(): masking_secrets: List[ MaskingSecretCache diff --git a/src/fidesops/ops/service/processors/post_processor_strategy/__init__.py b/src/fidesops/ops/service/processors/post_processor_strategy/__init__.py index e69de29bb..0d13c74a7 100644 --- a/src/fidesops/ops/service/processors/post_processor_strategy/__init__.py +++ b/src/fidesops/ops/service/processors/post_processor_strategy/__init__.py @@ -0,0 +1,4 @@ +from fidesops.ops.service.processors.post_processor_strategy import ( + post_processor_strategy_filter, + post_processor_strategy_unwrap, +) diff --git a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy.py b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy.py index 201b1fa1a..2796ce82a 100644 --- a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy.py +++ b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy.py @@ -1,13 +1,11 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Union +from fidesops.ops.service.strategy import Strategy -class PostProcessorStrategy(ABC): - """Abstract base class for SaaS post processor strategies""" - @abstractmethod - def get_strategy_name(self) -> str: - """Returns strategy name""" +class PostProcessorStrategy(Strategy): + """Abstract base class for SaaS post processor strategies""" @abstractmethod def process( diff --git a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_factory.py b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_factory.py deleted file mode 100644 index a0c61aaeb..000000000 --- a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_factory.py +++ /dev/null @@ -1,66 +0,0 @@ -import logging -from enum import Enum -from typing import Any, Dict, List - -from pydantic import ValidationError - -from fidesops.ops.common_exceptions import NoSuchStrategyException -from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError -from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration -from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy import ( - PostProcessorStrategy, -) -from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy_filter import ( - FilterPostProcessorStrategy, -) -from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy_unwrap import ( - UnwrapPostProcessorStrategy, -) - -logger = logging.getLogger(__name__) - - -class SupportedPostProcessorStrategies(Enum): - """ - The supported methods by which Fidesops can post-process Saas connector data. - """ - - unwrap = UnwrapPostProcessorStrategy - filter = FilterPostProcessorStrategy - - @classmethod - def __contains__(cls, item: str) -> bool: - try: - cls[item] - except KeyError: - return False - - return True - - -def get_strategy( - strategy_name: str, - configuration: Dict[str, Any], -) -> PostProcessorStrategy: - """ - Returns the strategy given the name and configuration. - Raises NoSuchStrategyException if the strategy does not exist - """ - if not SupportedPostProcessorStrategies.__contains__(strategy_name): - valid_strategies = ", ".join([s.name for s in SupportedPostProcessorStrategies]) - raise NoSuchStrategyException( - f"Strategy '{strategy_name}' does not exist. Valid strategies are [{valid_strategies}]" - ) - strategy = SupportedPostProcessorStrategies[strategy_name].value - try: - strategy_config: StrategyConfiguration = strategy.get_configuration_model()( - **configuration - ) - return strategy(configuration=strategy_config) - except ValidationError as e: - raise FidesopsValidationError(message=str(e)) - - -def get_strategies() -> List[PostProcessorStrategy]: - """Returns all supported postprocessor strategies""" - return [e.value for e in SupportedPostProcessorStrategies] diff --git a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_filter.py b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_filter.py index c8f89a64c..a46792228 100644 --- a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_filter.py +++ b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_filter.py @@ -7,14 +7,11 @@ from fidesops.ops.schemas.saas.shared_schemas import IdentityParamRef from fidesops.ops.schemas.saas.strategy_configuration import ( FilterPostProcessorConfiguration, - StrategyConfiguration, ) from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy import ( PostProcessorStrategy, ) -STRATEGY_NAME = "filter" - logger = logging.getLogger(__name__) @@ -44,15 +41,15 @@ class FilterPostProcessorStrategy(PostProcessorStrategy): } """ + name = "filter" + configuration_model = FilterPostProcessorConfiguration + def __init__(self, configuration: FilterPostProcessorConfiguration): self.field = configuration.field self.value = configuration.value self.exact = configuration.exact self.case_sensitive = configuration.case_sensitive - def get_strategy_name(self) -> str: - return STRATEGY_NAME - def process( self, data: Union[List[Dict[str, Any]], Dict[str, Any]], @@ -75,7 +72,7 @@ def process( logger.warning( "Could not retrieve identity reference '%s' due to missing identity data for the following post processing strategy: %s", self.value.identity, - self.get_strategy_name(), + self.name, ) return [] filter_value = identity_data.get(self.value.identity) # type: ignore @@ -106,7 +103,7 @@ def process( logger.warning( "%s could not be found on data for the following post processing strategy: %s", self.field, - self.get_strategy_name(), + self.name, ) return [] @@ -159,7 +156,3 @@ def _matches( # base case, compare filter_value to a single string return filter_value == target if exact else filter_value in target - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return FilterPostProcessorConfiguration # type: ignore diff --git a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_unwrap.py b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_unwrap.py index 1c9a315c6..526dd1a2d 100644 --- a/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_unwrap.py +++ b/src/fidesops/ops/service/processors/post_processor_strategy/post_processor_strategy_unwrap.py @@ -4,15 +4,12 @@ import pydash from fidesops.ops.schemas.saas.strategy_configuration import ( - StrategyConfiguration, UnwrapPostProcessorConfiguration, ) from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy import ( PostProcessorStrategy, ) -STRATEGY_NAME = "unwrap" - logger = logging.getLogger(__name__) @@ -37,12 +34,12 @@ class UnwrapPostProcessorStrategy(PostProcessorStrategy): If given a list, the unwrap will apply to the dicts inside the list. """ + name = "unwrap" + configuration_model = UnwrapPostProcessorConfiguration + def __init__(self, configuration: UnwrapPostProcessorConfiguration): self.data_path = configuration.data_path - def get_strategy_name(self) -> str: - return STRATEGY_NAME - def process( self, data: Union[List[Dict[str, Any]], Dict[str, Any]], @@ -61,7 +58,7 @@ def process( logger.warning( "%s could not be found for the following post processing strategy: %s", self.data_path, - self.get_strategy_name(), + self.name, ) else: result = unwrapped @@ -72,7 +69,7 @@ def process( logger.warning( "%s could not be found for the following post processing strategy: %s", self.data_path, - self.get_strategy_name(), + self.name, ) else: result.append(unwrapped) @@ -81,7 +78,3 @@ def process( result = pydash.flatten(result) return result - - @staticmethod - def get_configuration_model() -> StrategyConfiguration: - return UnwrapPostProcessorConfiguration # type: ignore diff --git a/src/fidesops/ops/service/strategy.py b/src/fidesops/ops/service/strategy.py new file mode 100644 index 000000000..43e9520bc --- /dev/null +++ b/src/fidesops/ops/service/strategy.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import logging +from abc import ABC +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +from pydantic import ValidationError + +from fidesops.ops.common_exceptions import NoSuchStrategyException +from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError +from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration + +logger = logging.getLogger(__name__) +T = TypeVar("T", bound="Strategy") +C = TypeVar("C", bound=StrategyConfiguration) + + +def _find_strategy_subclass( + cls: Type[Strategy], strategy_name: str +) -> Optional[Type[Strategy]]: + if hasattr(cls, "name") and cls.name == strategy_name: + return cls + for sub in cls.__subclasses__(): + found = _find_strategy_subclass(sub, strategy_name) + if found: + return found + return None + + +def _find_all_strategy_subclasses( + cls: Type[T], subs: List[Type[T]] = None +) -> List[Type[T]]: + if subs is None: + subs = [] + if hasattr(cls, "name"): + subs.append(cls) + for sub in cls.__subclasses__(): + _find_all_strategy_subclasses(sub, subs) + return subs + + +class Strategy(ABC, Generic[C]): + """Abstract base class for strategies""" + + name: str + configuration_model: Type[C] + + @classmethod + def get_strategy( + cls: Type[T], + strategy_name: str, + configuration: Dict[str, Any], + ) -> T: + """ + Returns the strategy given the name and configuration. + Raises NoSuchStrategyException if the strategy does not exist + """ + + strategy_class = _find_strategy_subclass(cls, strategy_name) + + if strategy_class is None: + valid_strategies = ", ".join( + [sub.name for sub in _find_all_strategy_subclasses(cls)] + ) + raise NoSuchStrategyException( + f"Strategy '{strategy_name}' does not exist. Valid strategies are [{valid_strategies}]" + ) + try: + strategy_config = strategy_class.configuration_model(**configuration) + except ValidationError as e: + raise FidesopsValidationError(message=str(e)) + return strategy_class(strategy_config) # type: ignore + + @classmethod + def get_strategies(cls: Type[T]) -> List[Type[T]]: + """Returns all supported strategies""" + return _find_all_strategy_subclasses(cls) diff --git a/tests/ops/api/v1/endpoints/test_masking_endpoints.py b/tests/ops/api/v1/endpoints/test_masking_endpoints.py index b8f924092..7f5bb820e 100644 --- a/tests/ops/api/v1/endpoints/test_masking_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_masking_endpoints.py @@ -7,39 +7,38 @@ from fidesops.ops.schemas.masking.masking_configuration import ( AesEncryptionMaskingConfiguration, ) +from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy from fidesops.ops.service.masking.strategy.masking_strategy_aes_encrypt import ( - AES_ENCRYPT_STRATEGY_NAME, -) -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, + AesEncryptionMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_hash import ( - HASH_STRATEGY_NAME, + HashMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_hmac import ( - HMAC_STRATEGY_NAME, + HmacMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_nullify import ( - NULL_REWRITE_STRATEGY_NAME, + NullMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_random_string_rewrite import ( - RANDOM_STRING_REWRITE_STRATEGY_NAME, + RandomStringRewriteMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_string_rewrite import ( - STRING_REWRITE_STRATEGY_NAME, + StringRewriteMaskingStrategy, ) class TestGetMaskingStrategies: def test_read_strategies(self, api_client: TestClient): expected_response = [] - for strategy in MaskingStrategyFactory.get_strategies(): + for strategy in MaskingStrategy.get_strategies(): expected_response.append(strategy.get_description()) response = api_client.get(V1_URL_PREFIX + MASKING_STRATEGY) response_body = json.loads(response.text) assert 200 == response.status_code + assert response_body assert expected_response == response_body @@ -50,7 +49,7 @@ def test_mask_value_string_rewrite(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": STRING_REWRITE_STRATEGY_NAME, + "strategy": StringRewriteMaskingStrategy.name, "configuration": {"rewrite_value": rewrite_val}, }, } @@ -69,7 +68,7 @@ def test_mask_value_random_string_rewrite(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": RANDOM_STRING_REWRITE_STRATEGY_NAME, + "strategy": RandomStringRewriteMaskingStrategy.name, "configuration": {"length": length}, }, } @@ -84,7 +83,7 @@ def test_mask_value_hmac(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": HMAC_STRATEGY_NAME, + "strategy": HmacMaskingStrategy.name, "configuration": {}, }, } @@ -99,7 +98,7 @@ def test_mask_value_hash(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": HASH_STRATEGY_NAME, + "strategy": HashMaskingStrategy.name, "configuration": {}, }, } @@ -115,7 +114,7 @@ def test_mask_value_hash_multi_value(self, api_client: TestClient): request = { "values": [value, value2], "masking_strategy": { - "strategy": HASH_STRATEGY_NAME, + "strategy": HashMaskingStrategy.name, "configuration": {}, }, } @@ -135,7 +134,7 @@ def test_mask_value_hash_multi_value_same_value(self, api_client: TestClient): request = { "values": [value, value], "masking_strategy": { - "strategy": HASH_STRATEGY_NAME, + "strategy": HashMaskingStrategy.name, "configuration": {}, }, } @@ -155,7 +154,7 @@ def test_mask_value_aes_encrypt(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": AES_ENCRYPT_STRATEGY_NAME, + "strategy": AesEncryptionMaskingStrategy.name, "configuration": { "mode": AesEncryptionMaskingConfiguration.Mode.GCM.value }, @@ -187,7 +186,7 @@ def test_mask_value_invalid_config(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": STRING_REWRITE_STRATEGY_NAME, + "strategy": StringRewriteMaskingStrategy.name, "configuration": {"wrong": "config"}, }, } @@ -201,7 +200,7 @@ def test_masking_value_null(self, api_client: TestClient): request = { "values": [value], "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, } diff --git a/tests/ops/api/v1/endpoints/test_policy_endpoints.py b/tests/ops/api/v1/endpoints/test_policy_endpoints.py index bbc70d8ed..0eba3ae63 100644 --- a/tests/ops/api/v1/endpoints/test_policy_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_policy_endpoints.py @@ -17,7 +17,7 @@ ) from fidesops.ops.models.policy import ActionType, DrpAction, Policy, Rule, RuleTarget from fidesops.ops.service.masking.strategy.masking_strategy_nullify import ( - NULL_REWRITE_STRATEGY_NAME, + NullMaskingStrategy, ) from fidesops.ops.util.data_category import DataCategory, generate_fides_data_categories @@ -711,7 +711,7 @@ def test_create_erasure_rule_for_policy( "name": "test erasure rule", "action_type": ActionType.erasure.value, "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, } @@ -729,7 +729,7 @@ def test_create_erasure_rule_for_policy( rule_data = response_data[0] assert "masking_strategy" in rule_data masking_strategy_data = rule_data["masking_strategy"] - assert masking_strategy_data["strategy"] == NULL_REWRITE_STRATEGY_NAME + assert masking_strategy_data["strategy"] == NullMaskingStrategy.name assert "configuration" not in masking_strategy_data def test_update_rule_policy_id_fails( @@ -1070,7 +1070,7 @@ def test_create_conflicting_rule_targets( "name": "Erasure Rule", "policy_id": policy.id, "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, }, diff --git a/tests/ops/fixtures/application_fixtures.py b/tests/ops/fixtures/application_fixtures.py index dfd6a563e..8f7a9da10 100644 --- a/tests/ops/fixtures/application_fixtures.py +++ b/tests/ops/fixtures/application_fixtures.py @@ -49,13 +49,13 @@ StorageType, ) from fidesops.ops.service.masking.strategy.masking_strategy_hmac import ( - HMAC_STRATEGY_NAME, + HmacMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_nullify import ( - NULL_REWRITE_STRATEGY_NAME, + NullMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_string_rewrite import ( - STRING_REWRITE_STRATEGY_NAME, + StringRewriteMaskingStrategy, ) from fidesops.ops.util.data_category import DataCategory @@ -417,7 +417,7 @@ def erasure_policy_string_rewrite_long( "name": "Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": STRING_REWRITE_STRATEGY_NAME, + "strategy": StringRewriteMaskingStrategy.name, "configuration": { "rewrite_value": "some rewrite value that is very long and goes on and on" }, @@ -461,7 +461,7 @@ def erasure_policy_two_rules( "name": "Second Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, }, @@ -469,7 +469,7 @@ def erasure_policy_two_rules( # TODO set masking strategy in Rule.create() call above, once more masking strategies beyond NULL_REWRITE are supported. second_erasure_rule.masking_strategy = { - "strategy": STRING_REWRITE_STRATEGY_NAME, + "strategy": StringRewriteMaskingStrategy.name, "configuration": {"rewrite_value": "*****"}, } @@ -616,7 +616,7 @@ def policy_drp_action_erasure(db: Session, oauth_client: ClientDetail) -> Genera "name": "Erasure Request Rule DRP", "policy_id": erasure_request_policy.id, "masking_strategy": { - "strategy": STRING_REWRITE_STRATEGY_NAME, + "strategy": StringRewriteMaskingStrategy.name, "configuration": {"rewrite_value": "MASKED"}, }, }, @@ -668,7 +668,7 @@ def erasure_policy_string_rewrite( "name": "string rewrite erasure rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": STRING_REWRITE_STRATEGY_NAME, + "strategy": StringRewriteMaskingStrategy.name, "configuration": {"rewrite_value": "MASKED"}, }, }, @@ -721,7 +721,7 @@ def erasure_policy_hmac( "name": "hmac erasure rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": HMAC_STRATEGY_NAME, + "strategy": HmacMaskingStrategy.name, "configuration": {}, }, }, diff --git a/tests/ops/models/test_policy.py b/tests/ops/models/test_policy.py index 87fa17097..2188270f9 100644 --- a/tests/ops/models/test_policy.py +++ b/tests/ops/models/test_policy.py @@ -15,10 +15,10 @@ _is_ancestor_of_contained_categories, ) from fidesops.ops.service.masking.strategy.masking_strategy_hash import ( - HASH_STRATEGY_NAME, + HashMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_nullify import ( - NULL_REWRITE_STRATEGY_NAME, + NullMaskingStrategy, ) from fidesops.ops.util.data_category import DataCategory from fidesops.ops.util.text import to_snake_case @@ -88,7 +88,7 @@ def test_create_erasure_rule_with_destination_is_invalid( "policy_id": policy.id, "storage_destination_id": policy.rules[0].storage_destination.id, "masking_strategy": { - "strategy": HASH_STRATEGY_NAME, + "strategy": HashMaskingStrategy.name, "configuration": { "algorithm": "SHA-512", "format_preservation": {"suffix": "@masked.com"}, @@ -216,7 +216,7 @@ def test_create_erasure_rule( "name": "Valid Erasure Rule", "policy_id": policy.id, "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, }, @@ -313,7 +313,7 @@ def test_validate_policy( "name": "Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, }, @@ -336,7 +336,7 @@ def test_validate_policy( "name": "Another Erasure Rule", "policy_id": erasure_policy.id, "masking_strategy": { - "strategy": NULL_REWRITE_STRATEGY_NAME, + "strategy": NullMaskingStrategy.name, "configuration": {}, }, }, diff --git a/tests/ops/service/authentication/test_authentication_strategy_basic.py b/tests/ops/service/authentication/test_authentication_strategy_basic.py index 06deec2c7..10ba90d3f 100644 --- a/tests/ops/service/authentication/test_authentication_strategy_basic.py +++ b/tests/ops/service/authentication/test_authentication_strategy_basic.py @@ -4,8 +4,8 @@ from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError from fidesops.ops.models.connectionconfig import ConnectionConfig -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) @@ -16,7 +16,7 @@ def test_basic_auth_with_username_and_password(): password = "sufficientlylongpassword" secrets = {"username": username, "password": password} - authenticated_request = get_strategy( + authenticated_request = AuthenticationStrategy.get_strategy( "basic", {"username": "", "password": ""} ).add_authentication(req, ConnectionConfig(secrets=secrets)) assert ( @@ -31,7 +31,7 @@ def test_basic_auth_with_username_only(): username = "admin" secrets = {"username": username} - authenticated_request = get_strategy( + authenticated_request = AuthenticationStrategy.get_strategy( "basic", {"username": ""} ).add_authentication(req, ConnectionConfig(secrets=secrets)) # The requests library still calls str(password) even if the password is None @@ -45,4 +45,6 @@ def test_basic_auth_with_no_credentials(): req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() with pytest.raises(FidesopsValidationError): - get_strategy("basic", {}).add_authentication(req, ConnectionConfig(secrets={})) + AuthenticationStrategy.get_strategy("basic", {}).add_authentication( + req, ConnectionConfig(secrets={}) + ) diff --git a/tests/ops/service/authentication/test_authentication_strategy_bearer.py b/tests/ops/service/authentication/test_authentication_strategy_bearer.py index 2bf53867e..00aec52cf 100644 --- a/tests/ops/service/authentication/test_authentication_strategy_bearer.py +++ b/tests/ops/service/authentication/test_authentication_strategy_bearer.py @@ -3,8 +3,8 @@ from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError from fidesops.ops.models.connectionconfig import ConnectionConfig -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) @@ -14,7 +14,7 @@ def test_bearer_auth_with_token(): api_key = "imnotasecretitsok" secrets = {"api_key": api_key} - authenticated_request = get_strategy( + authenticated_request = AuthenticationStrategy.get_strategy( "bearer", {"token": ""} ).add_authentication(req, ConnectionConfig(secrets=secrets)) assert authenticated_request.headers["Authorization"] == f"Bearer {api_key}" @@ -24,4 +24,6 @@ def test_bearer_auth_without_token(): req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() with pytest.raises(FidesopsValidationError): - get_strategy("bearer", {}).add_authentication(req, ConnectionConfig(secrets={})) + AuthenticationStrategy.get_strategy("bearer", {}).add_authentication( + req, ConnectionConfig(secrets={}) + ) diff --git a/tests/ops/service/authentication/test_authentication_strategy_factory.py b/tests/ops/service/authentication/test_authentication_strategy_factory.py index 0ed62657f..2824aa4b5 100644 --- a/tests/ops/service/authentication/test_authentication_strategy_factory.py +++ b/tests/ops/service/authentication/test_authentication_strategy_factory.py @@ -1,15 +1,15 @@ import pytest from fidesops.ops.common_exceptions import NoSuchStrategyException +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, +) from fidesops.ops.service.authentication.authentication_strategy_basic import ( BasicAuthenticationStrategy, ) from fidesops.ops.service.authentication.authentication_strategy_bearer import ( BearerAuthenticationStrategy, ) -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, -) from fidesops.ops.service.authentication.authentication_strategy_query_param import ( QueryParamAuthenticationStrategy, ) @@ -20,22 +20,28 @@ def test_get_strategy_basic(): "username": "", "password": "", } - strategy = get_strategy(strategy_name="basic", configuration=config) + strategy = AuthenticationStrategy.get_strategy( + strategy_name="basic", configuration=config + ) assert isinstance(strategy, BasicAuthenticationStrategy) def test_get_strategy_bearer(): config = {"token": ""} - strategy = get_strategy(strategy_name="bearer", configuration=config) + strategy = AuthenticationStrategy.get_strategy( + strategy_name="bearer", configuration=config + ) assert isinstance(strategy, BearerAuthenticationStrategy) def test_get_strategy_query_param(): config = {"name": "api_key", "value": ""} - strategy = get_strategy(strategy_name="query_param", configuration=config) + strategy = AuthenticationStrategy.get_strategy( + strategy_name="query_param", configuration=config + ) assert isinstance(strategy, QueryParamAuthenticationStrategy) def test_get_strategy_invalid_strategy(): with pytest.raises(NoSuchStrategyException): - get_strategy("invalid", {}) + AuthenticationStrategy.get_strategy("invalid", {}) diff --git a/tests/ops/service/authentication/test_authentication_strategy_oauth2_authorization_code.py b/tests/ops/service/authentication/test_authentication_strategy_oauth2_authorization_code.py index 0dbe8262b..d0a50692f 100644 --- a/tests/ops/service/authentication/test_authentication_strategy_oauth2_authorization_code.py +++ b/tests/ops/service/authentication/test_authentication_strategy_oauth2_authorization_code.py @@ -7,8 +7,8 @@ from sqlalchemy.orm import Session from fidesops.ops.common_exceptions import FidesopsException, OAuth2TokenException -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) from fidesops.ops.service.authentication.authentication_strategy_oauth2_authorization_code import ( OAuth2AuthorizationCodeAuthenticationStrategy, @@ -29,8 +29,10 @@ def test_oauth2_authentication( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( - "oauth2_authorization_code", oauth2_authorization_code_configuration + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_authorization_code", oauth2_authorization_code_configuration + ) ) authenticated_request = auth_strategy.add_authentication( req, oauth2_authorization_code_connection_config @@ -50,7 +52,7 @@ def test_oauth2_authentication_missing_access_token( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_authorization_code", oauth2_authorization_code_configuration ) with pytest.raises(FidesopsException) as exc: @@ -72,7 +74,7 @@ def test_oauth2_authentication_empty_access_token( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_authorization_code", oauth2_authorization_code_configuration ) with pytest.raises(FidesopsException) as exc: @@ -95,7 +97,7 @@ def test_oauth2_authentication_missing_secrets( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_authorization_code", oauth2_authorization_code_configuration ) with pytest.raises(FidesopsException) as exc: @@ -128,7 +130,7 @@ def test_oauth2_authentication_successful_refresh( # the request we want to authenticate req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_authorization_code", oauth2_authorization_code_configuration ) authenticated_request = auth_strategy.add_authentication( @@ -163,7 +165,7 @@ def test_oauth2_authentication_no_refresh( # the request we want to authenticate req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_authorization_code", oauth2_authorization_code_configuration ) authenticated_request = auth_strategy.add_authentication( @@ -193,7 +195,7 @@ def test_oauth2_authentication_failed_refresh( # the request we want to authenticate req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_authorization_code", oauth2_authorization_code_configuration ) with pytest.raises(OAuth2TokenException) as exc: @@ -223,8 +225,10 @@ def test_get_authorization_url( ): state = "unique_value" mock_state.return_value = state - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( - "oauth2_authorization_code", oauth2_authorization_code_configuration + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_authorization_code", oauth2_authorization_code_configuration + ) ) assert ( auth_strategy.get_authorization_url( @@ -251,8 +255,10 @@ def test_get_authorization_url_missing_secrets( oauth2_authorization_code_connection_config.secrets["client_id"] = None oauth2_authorization_code_connection_config.secrets["client_secret"] = "" - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( - "oauth2_authorization_code", oauth2_authorization_code_configuration + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_authorization_code", oauth2_authorization_code_configuration + ) ) with pytest.raises(FidesopsException) as exc: auth_strategy.get_authorization_url( @@ -290,8 +296,10 @@ def test_get_access_token( "expires_in": expires_in, } - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( - "oauth2_authorization_code", oauth2_authorization_code_configuration + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_authorization_code", oauth2_authorization_code_configuration + ) ) oauth2_authorization_code_connection_config.secrets = { **oauth2_authorization_code_connection_config.secrets, @@ -345,8 +353,10 @@ def test_get_access_token_no_expires_in( } oauth2_authorization_code_configuration["expires_in"] = 3600 - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( - "oauth2_authorization_code", oauth2_authorization_code_configuration + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_authorization_code", oauth2_authorization_code_configuration + ) ) oauth2_authorization_code_connection_config.secrets = { **oauth2_authorization_code_connection_config.secrets, @@ -382,8 +392,10 @@ def test_get_access_token_missing_secrets( oauth2_authorization_code_connection_config.secrets["client_id"] = None oauth2_authorization_code_connection_config.secrets["client_secret"] = "" - auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = get_strategy( - "oauth2_authorization_code", oauth2_authorization_code_configuration + auth_strategy: OAuth2AuthorizationCodeAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_authorization_code", oauth2_authorization_code_configuration + ) ) with pytest.raises(FidesopsException) as exc: oauth2_authorization_code_connection_config.secrets = { diff --git a/tests/ops/service/authentication/test_authentication_strategy_oauth2_client_credentials.py b/tests/ops/service/authentication/test_authentication_strategy_oauth2_client_credentials.py index 998affc9f..9bd7cfdf6 100644 --- a/tests/ops/service/authentication/test_authentication_strategy_oauth2_client_credentials.py +++ b/tests/ops/service/authentication/test_authentication_strategy_oauth2_client_credentials.py @@ -13,8 +13,8 @@ ConnectionConfig, ConnectionType, ) -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) from fidesops.ops.service.authentication.authentication_strategy_oauth2_client_credentials import ( OAuth2ClientCredentialsAuthenticationStrategy, @@ -118,7 +118,7 @@ def test_oauth2_authentication( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_client_credentials", oauth2_client_credentials_configuration ) authenticated_request = auth_strategy.add_authentication( @@ -146,8 +146,10 @@ def test_oauth2_authentication_missing_access_token( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = get_strategy( - "oauth2_client_credentials", oauth2_client_credentials_configuration + auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_client_credentials", oauth2_client_credentials_configuration + ) ) authenticated_request = auth_strategy.add_authentication( req, oauth2_client_credentials_connection_config @@ -173,7 +175,7 @@ def test_oauth2_authentication_empty_access_token( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_client_credentials", oauth2_client_credentials_configuration ) authenticated_request = auth_strategy.add_authentication( @@ -195,8 +197,10 @@ def test_oauth2_authentication_missing_secrets( req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = get_strategy( - "oauth2_client_credentials", oauth2_client_credentials_configuration + auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_client_credentials", oauth2_client_credentials_configuration + ) ) with pytest.raises(FidesopsException) as exc: auth_strategy.add_authentication( @@ -228,7 +232,7 @@ def test_oauth2_authentication_successful_refresh( # the request we want to authenticate req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_client_credentials", oauth2_client_credentials_configuration ) authenticated_request = auth_strategy.add_authentication( @@ -262,7 +266,7 @@ def test_oauth2_authentication_no_refresh( # the request we want to authenticate req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_client_credentials", oauth2_client_credentials_configuration ) authenticated_request = auth_strategy.add_authentication( @@ -292,7 +296,7 @@ def test_oauth2_authentication_failed_refresh( # the request we want to authenticate req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() - auth_strategy = get_strategy( + auth_strategy = AuthenticationStrategy.get_strategy( "oauth2_client_credentials", oauth2_client_credentials_configuration ) with pytest.raises(OAuth2TokenException) as exc: @@ -331,8 +335,10 @@ def test_get_access_token( "expires_in": expires_in, } - auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = get_strategy( - "oauth2_client_credentials", oauth2_client_credentials_configuration + auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_client_credentials", oauth2_client_credentials_configuration + ) ) auth_strategy.get_access_token(oauth2_client_credentials_connection_config, db) @@ -380,8 +386,10 @@ def test_get_access_token_no_expires_in( } oauth2_client_credentials_configuration["expires_in"] = 3600 - auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = get_strategy( - "oauth2_client_credentials", oauth2_client_credentials_configuration + auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_client_credentials", oauth2_client_credentials_configuration + ) ) auth_strategy.get_access_token(oauth2_client_credentials_connection_config, db) @@ -411,8 +419,10 @@ def test_get_access_token_missing_secrets( oauth2_client_credentials_connection_config.secrets["client_id"] = None oauth2_client_credentials_connection_config.secrets["client_secret"] = "" - auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = get_strategy( - "oauth2_client_credentials", oauth2_client_credentials_configuration + auth_strategy: OAuth2ClientCredentialsAuthenticationStrategy = ( + AuthenticationStrategy.get_strategy( + "oauth2_client_credentials", oauth2_client_credentials_configuration + ) ) with pytest.raises(FidesopsException) as exc: auth_strategy.get_access_token( diff --git a/tests/ops/service/authentication/test_authentication_strategy_query_param.py b/tests/ops/service/authentication/test_authentication_strategy_query_param.py index 5e4b7ed01..84c4b6d21 100644 --- a/tests/ops/service/authentication/test_authentication_strategy_query_param.py +++ b/tests/ops/service/authentication/test_authentication_strategy_query_param.py @@ -3,8 +3,8 @@ from fidesops.ops.common_exceptions import ValidationError as FidesopsValidationError from fidesops.ops.models.connectionconfig import ConnectionConfig -from fidesops.ops.service.authentication.authentication_strategy_factory import ( - get_strategy, +from fidesops.ops.service.authentication.authentication_strategy import ( + AuthenticationStrategy, ) @@ -15,7 +15,7 @@ def test_query_param_auth(): api_key = "imakeyblademaster" secrets = {"api_key": api_key} - authenticated_request = get_strategy( + authenticated_request = AuthenticationStrategy.get_strategy( "query_param", {"name": "account", "value": ""} ).add_authentication(req, ConnectionConfig(secrets=secrets)) assert authenticated_request.url == f"https://localhost/?{name}={api_key}" @@ -25,6 +25,6 @@ def test_query_param_auth_without_config(): req: PreparedRequest = Request(method="POST", url="https://localhost").prepare() with pytest.raises(FidesopsValidationError): - get_strategy("query_param", {}).add_authentication( + AuthenticationStrategy.get_strategy("query_param", {}).add_authentication( req, ConnectionConfig(secrets={}) ) diff --git a/tests/ops/service/connectors/test_queryconfig.py b/tests/ops/service/connectors/test_queryconfig.py index a6b5cd1aa..075db4e0d 100644 --- a/tests/ops/service/connectors/test_queryconfig.py +++ b/tests/ops/service/connectors/test_queryconfig.py @@ -26,7 +26,6 @@ ) from fidesops.ops.service.connectors.saas_query_config import SaaSQueryConfig from fidesops.ops.service.masking.strategy.masking_strategy_hash import ( - HASH_STRATEGY_NAME, HashMaskingStrategy, ) from fidesops.ops.util.data_category import DataCategory @@ -297,7 +296,7 @@ def test_generate_update_stmt_multiple_fields_same_rule( # cache secrets for hash strategy secret = MaskingSecretCache[str]( secret="adobo", - masking_strategy=HASH_STRATEGY_NAME, + masking_strategy=HashMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret, privacy_request.id) @@ -596,7 +595,7 @@ def test_generate_update_stmt_multiple_rules( # cache secrets for hash strategy secret = MaskingSecretCache[str]( secret="adobo", - masking_strategy=HASH_STRATEGY_NAME, + masking_strategy=HashMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret, privacy_request.id) diff --git a/tests/ops/service/masking/strategy/test_masking_strategy_aes_encrypt.py b/tests/ops/service/masking/strategy/test_masking_strategy_aes_encrypt.py index 7ddf98d4b..5cdd4070f 100644 --- a/tests/ops/service/masking/strategy/test_masking_strategy_aes_encrypt.py +++ b/tests/ops/service/masking/strategy/test_masking_strategy_aes_encrypt.py @@ -6,7 +6,6 @@ ) from fidesops.ops.schemas.masking.masking_secrets import MaskingSecretCache, SecretType from fidesops.ops.service.masking.strategy.masking_strategy_aes_encrypt import ( - AES_ENCRYPT_STRATEGY_NAME, AesEncryptionMaskingStrategy, ) @@ -51,19 +50,19 @@ def test_mask_all_aes_modes(mock_encrypt: Mock): def cache_secrets() -> None: secret_key = MaskingSecretCache[bytes]( secret=b"\x94Y\xa8Z", - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=AesEncryptionMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) secret_hmac_key = MaskingSecretCache[str]( secret="other_key", - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=AesEncryptionMaskingStrategy.name, secret_type=SecretType.key_hmac, ) cache_secret(secret_hmac_key, request_id) secret_hmac_salt = MaskingSecretCache[str]( secret="some_salt", - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=AesEncryptionMaskingStrategy.name, secret_type=SecretType.salt_hmac, ) cache_secret(secret_hmac_salt, request_id) diff --git a/tests/ops/service/masking/strategy/test_masking_strategy_factory.py b/tests/ops/service/masking/strategy/test_masking_strategy_factory.py index 67d4628e2..21926f282 100644 --- a/tests/ops/service/masking/strategy/test_masking_strategy_factory.py +++ b/tests/ops/service/masking/strategy/test_masking_strategy_factory.py @@ -1,12 +1,10 @@ import pytest +from fidesops.ops.common_exceptions import NoSuchStrategyException +from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy from fidesops.ops.service.masking.strategy.masking_strategy_aes_encrypt import ( AesEncryptionMaskingStrategy, ) -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, - NoSuchStrategyException, -) from fidesops.ops.service.masking.strategy.masking_strategy_hash import ( HashMaskingStrategy, ) @@ -16,22 +14,22 @@ def test_get_strategy_hash(): - strategy = MaskingStrategyFactory.get_strategy("hash", {}) + strategy = MaskingStrategy.get_strategy("hash", {}) assert isinstance(strategy, HashMaskingStrategy) def test_get_strategy_rewrite(): config = {"rewrite_value": "val"} - strategy = MaskingStrategyFactory.get_strategy("string_rewrite", config) + strategy = MaskingStrategy.get_strategy("string_rewrite", config) assert isinstance(strategy, StringRewriteMaskingStrategy) def test_get_strategy_aes_encrypt(): config = {"mode": "GCM", "key": "keycard", "nonce": "none"} - strategy = MaskingStrategyFactory.get_strategy("aes_encrypt", config) + strategy = MaskingStrategy.get_strategy("aes_encrypt", config) assert isinstance(strategy, AesEncryptionMaskingStrategy) def test_get_strategy_invalid(): with pytest.raises(NoSuchStrategyException): - MaskingStrategyFactory.get_strategy("invalid", {}) + MaskingStrategy.get_strategy("invalid", {}) diff --git a/tests/ops/service/masking/strategy/test_masking_strategy_hash.py b/tests/ops/service/masking/strategy/test_masking_strategy_hash.py index eda1805a2..bafc4be8f 100644 --- a/tests/ops/service/masking/strategy/test_masking_strategy_hash.py +++ b/tests/ops/service/masking/strategy/test_masking_strategy_hash.py @@ -1,7 +1,9 @@ from fidesops.ops.schemas.masking.masking_configuration import HashMaskingConfiguration from fidesops.ops.schemas.masking.masking_secrets import MaskingSecretCache, SecretType +from fidesops.ops.service.masking.strategy.masking_strategy_aes_encrypt import ( + AesEncryptionMaskingStrategy, +) from fidesops.ops.service.masking.strategy.masking_strategy_hash import ( - HASH_STRATEGY_NAME, HashMaskingStrategy, ) @@ -16,7 +18,9 @@ def test_mask_sha256(): expected = "1c015e801323afa54bde5e4d510809e6b5f14ad9b9961c48cbd7143106b6e596" secret = MaskingSecretCache[str]( - secret="adobo", masking_strategy=HASH_STRATEGY_NAME, secret_type=SecretType.salt + secret="adobo", + masking_strategy=HashMaskingStrategy.name, + secret_type=SecretType.salt, ) cache_secret(secret, request_id) @@ -31,7 +35,9 @@ def test_mask_sha512(): expected = "527ca44f5c95400d161c503e6ddad7be01941ec9e7a03c2201338a16ba8a36bb765a430bd6b276a590661154f3f743a3a91efecd056645b4ea13b4b8cf39e8e3" secret = MaskingSecretCache[str]( - secret="adobo", masking_strategy=HASH_STRATEGY_NAME, secret_type=SecretType.salt + secret="adobo", + masking_strategy=HashMaskingStrategy.name, + secret_type=SecretType.salt, ) cache_secret(secret, request_id) @@ -46,7 +52,9 @@ def test_mask_sha256_default(): expected = "1c015e801323afa54bde5e4d510809e6b5f14ad9b9961c48cbd7143106b6e596" secret = MaskingSecretCache[str]( - secret="adobo", masking_strategy=HASH_STRATEGY_NAME, secret_type=SecretType.salt + secret="adobo", + masking_strategy=HashMaskingStrategy.name, + secret_type=SecretType.salt, ) cache_secret(secret, request_id) @@ -62,7 +70,9 @@ def test_mask_sha256_default_multi_value(): expected2 = "f37d3290343da298f2471fa8cff444d242052529e4fa27a1b9361bd1fdc02fd4" secret = MaskingSecretCache[str]( - secret="adobo", masking_strategy=HASH_STRATEGY_NAME, secret_type=SecretType.salt + secret="adobo", + masking_strategy=HashMaskingStrategy.name, + secret_type=SecretType.salt, ) cache_secret(secret, request_id) @@ -78,7 +88,9 @@ def test_mask_arguments_null(): expected = None secret = MaskingSecretCache[str]( - secret="adobo", masking_strategy=HASH_STRATEGY_NAME, secret_type=SecretType.salt + secret="adobo", + masking_strategy=HashMaskingStrategy.name, + secret_type=SecretType.salt, ) cache_secret(secret, request_id) diff --git a/tests/ops/service/masking/strategy/test_masking_strategy_hmac.py b/tests/ops/service/masking/strategy/test_masking_strategy_hmac.py index 502d6cfb4..239626959 100644 --- a/tests/ops/service/masking/strategy/test_masking_strategy_hmac.py +++ b/tests/ops/service/masking/strategy/test_masking_strategy_hmac.py @@ -1,7 +1,6 @@ from fidesops.ops.schemas.masking.masking_configuration import HmacMaskingConfiguration from fidesops.ops.schemas.masking.masking_secrets import MaskingSecretCache, SecretType from fidesops.ops.service.masking.strategy.masking_strategy_hmac import ( - HMAC_STRATEGY_NAME, HmacMaskingStrategy, ) @@ -17,13 +16,13 @@ def test_hmac_sha_256(): secret_key = MaskingSecretCache[str]( secret="test_key", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) secret_salt = MaskingSecretCache[str]( secret="test_salt", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret_salt, request_id) @@ -40,13 +39,13 @@ def test_mask_sha512(): secret_key = MaskingSecretCache[str]( secret="test_key", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) secret_salt = MaskingSecretCache[str]( secret="test_salt", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret_salt, request_id) @@ -63,13 +62,13 @@ def test_mask_sha256_default(): secret_key = MaskingSecretCache[str]( secret="test_key", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) secret_salt = MaskingSecretCache[str]( secret="test_salt", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret_salt, request_id) @@ -87,13 +86,13 @@ def test_mask_sha256_default_multi_value(): secret_key = MaskingSecretCache[str]( secret="test_key", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) secret_salt = MaskingSecretCache[str]( secret="test_salt", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret_salt, request_id) @@ -111,13 +110,13 @@ def test_mask_arguments_null(): secret_key = MaskingSecretCache[str]( secret="test_key", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) secret_salt = MaskingSecretCache[str]( secret="test_salt", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.salt, ) cache_secret(secret_salt, request_id) diff --git a/tests/ops/service/pagination/test_pagination_strategy_factory.py b/tests/ops/service/pagination/test_pagination_strategy_factory.py index f8466ce32..aa7ed6261 100644 --- a/tests/ops/service/pagination/test_pagination_strategy_factory.py +++ b/tests/ops/service/pagination/test_pagination_strategy_factory.py @@ -1,10 +1,10 @@ import pytest from fidesops.ops.common_exceptions import NoSuchStrategyException, ValidationError +from fidesops.ops.service.pagination.pagination_strategy import PaginationStrategy from fidesops.ops.service.pagination.pagination_strategy_cursor import ( CursorPaginationStrategy, ) -from fidesops.ops.service.pagination.pagination_strategy_factory import get_strategy from fidesops.ops.service.pagination.pagination_strategy_link import ( LinkPaginationStrategy, ) @@ -19,27 +19,35 @@ def test_get_strategy_offset(): "increment_by": 1, "limit": 100, } - strategy = get_strategy(strategy_name="offset", configuration=config) + strategy = PaginationStrategy.get_strategy( + strategy_name="offset", configuration=config + ) assert isinstance(strategy, OffsetPaginationStrategy) def test_get_strategy_link(): config = {"source": "body", "path": "body.next_link"} - strategy = get_strategy(strategy_name="link", configuration=config) + strategy = PaginationStrategy.get_strategy( + strategy_name="link", configuration=config + ) assert isinstance(strategy, LinkPaginationStrategy) def test_get_strategy_cursor(): config = {"cursor_param": "after", "field": "id"} - strategy = get_strategy(strategy_name="cursor", configuration=config) + strategy = PaginationStrategy.get_strategy( + strategy_name="cursor", configuration=config + ) assert isinstance(strategy, CursorPaginationStrategy) def test_get_strategy_invalid_config(): with pytest.raises(ValidationError): - get_strategy(strategy_name="offset", configuration={"invalid": "thing"}) + PaginationStrategy.get_strategy( + strategy_name="offset", configuration={"invalid": "thing"} + ) def test_get_strategy_invalid_strategy(): with pytest.raises(NoSuchStrategyException): - get_strategy("invalid", {}) + PaginationStrategy.get_strategy("invalid", {}) diff --git a/tests/ops/service/privacy_request/request_runner_service_test.py b/tests/ops/service/privacy_request/request_runner_service_test.py index c8337b27b..a9a22fba9 100644 --- a/tests/ops/service/privacy_request/request_runner_service_test.py +++ b/tests/ops/service/privacy_request/request_runner_service_test.py @@ -42,9 +42,7 @@ RedshiftConnector, SnowflakeConnector, ) -from fidesops.ops.service.masking.strategy.masking_strategy_factory import ( - MaskingStrategyFactory, -) +from fidesops.ops.service.masking.strategy.masking_strategy import MaskingStrategy from fidesops.ops.service.masking.strategy.masking_strategy_hmac import ( HmacMaskingStrategy, ) @@ -239,9 +237,7 @@ def get_privacy_request_results( if strategy_name in unique_masking_strategies_by_name: continue unique_masking_strategies_by_name.add(strategy_name) - masking_strategy = MaskingStrategyFactory.get_strategy( - strategy_name, configuration - ) + masking_strategy = MaskingStrategy.get_strategy(strategy_name, configuration) if masking_strategy.secrets_required(): masking_secrets: List[ MaskingSecretCache diff --git a/tests/ops/service/processors/post_processor_strategy/test_post_processor_strategy_factory.py b/tests/ops/service/processors/post_processor_strategy/test_post_processor_strategy_factory.py index d2bd60ee2..2b258f694 100644 --- a/tests/ops/service/processors/post_processor_strategy/test_post_processor_strategy_factory.py +++ b/tests/ops/service/processors/post_processor_strategy/test_post_processor_strategy_factory.py @@ -1,8 +1,8 @@ import pytest from fidesops.ops.common_exceptions import NoSuchStrategyException, ValidationError -from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy_factory import ( - get_strategy, +from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy import ( + PostProcessorStrategy, ) from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy_filter import ( FilterPostProcessorStrategy, @@ -14,21 +14,27 @@ def test_get_strategy_filter(): config = {"field": "email_contact", "value": "somebody@email.com"} - strategy = get_strategy(strategy_name="filter", configuration=config) + strategy = PostProcessorStrategy.get_strategy( + strategy_name="filter", configuration=config + ) assert isinstance(strategy, FilterPostProcessorStrategy) def test_get_strategy_unwrap(): config = {"data_path": "exact_matches.members"} - strategy = get_strategy(strategy_name="unwrap", configuration=config) + strategy = PostProcessorStrategy.get_strategy( + strategy_name="unwrap", configuration=config + ) assert isinstance(strategy, UnwrapPostProcessorStrategy) def test_get_strategy_invalid_config(): with pytest.raises(ValidationError): - get_strategy(strategy_name="unwrap", configuration={"invalid": "thing"}) + PostProcessorStrategy.get_strategy( + strategy_name="unwrap", configuration={"invalid": "thing"} + ) def test_get_strategy_invalid_strategy(): with pytest.raises(NoSuchStrategyException): - get_strategy("invalid", {}) + PostProcessorStrategy.get_strategy("invalid", {}) diff --git a/tests/ops/service/test_strategy_retrieval.py b/tests/ops/service/test_strategy_retrieval.py new file mode 100644 index 000000000..7d827bc69 --- /dev/null +++ b/tests/ops/service/test_strategy_retrieval.py @@ -0,0 +1,169 @@ +from abc import abstractmethod +from typing import Any, Dict, List, Union + +import pytest + +from fidesops.ops.common_exceptions import NoSuchStrategyException +from fidesops.ops.schemas.saas.strategy_configuration import StrategyConfiguration +from fidesops.ops.service.processors.post_processor_strategy.post_processor_strategy import ( + PostProcessorStrategy, +) + + +class SomeStrategyConfiguration(StrategyConfiguration): + some_key: str = "default value" + + +class SomeStrategy(PostProcessorStrategy): + name = "some postprocessor strategy" + configuration_model = SomeStrategyConfiguration + + def __init__(self, configuration: SomeStrategyConfiguration): + self.some_config = configuration.some_key + + def process( + self, data: Any, identity_data: Dict[str, Any] = None + ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: + pass + + +class SomeSubStrategy(SomeStrategy): + """ + A strategy class that subclasses another strategy class + Its parent class is also a valid strategy, i.e. it has a name + """ + + name = "some subclassed strategy" + + +class AnotherSubStrategy(SomeStrategy): + """ + A strategy class that subclasses another strategy class + This is to test two subclasses at the same level + in the strategy class hierarchy + """ + + name = "another subclassed strategy" + + +class SomeSubSubStrategy(SomeSubStrategy): + """ + A strategy class that subclasses another strategy subclass + This is to test a 3-level strategy hierarchy + """ + + name = "some sub-subclassed strategy" + + +class SomeAbstractStrategyClass(PostProcessorStrategy): + """ + This class does not provide a name, which indicates + that it's "abstract", i.e. it should not be retrievable + """ + + @abstractmethod + def some_abstract_method(self): + """Placeholder for an abstract method""" + + +class DifferentStrategySubClass(SomeAbstractStrategyClass): + """ + This strategy class subclasses an abstract strategy class + that does not provide a name and is not a strategy + """ + + name = "different subclassed strategy" + configuration_model = SomeStrategyConfiguration + + def some_abstract_method(self): + pass + + def __init__(self, configuration: SomeStrategyConfiguration): + self.some_config = configuration.some_key + + def process( + self, data: Any, identity_data: Dict[str, Any] = None + ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: + pass + + +class TestStrategyRetrieval: + """ + Unit tests for abstract strategy retrieval functionality. + Uses PostProcessorStrategy as an example + """ + + def test_valid_strategy(self): + """ + Test registering a valid Strategy + """ + + config = SomeStrategyConfiguration(some_key="non default value") + retrieved_strategy = PostProcessorStrategy.get_strategy( + SomeStrategy.name, config.dict() + ) + assert isinstance(retrieved_strategy, SomeStrategy) + assert retrieved_strategy.some_config == "non default value" + + def test_multi_level_inheritance_strategy(self): + """ + Test that strategy classes with multiple levels + of inheritance can be properly retrieved + """ + + config = SomeStrategyConfiguration(some_key="non default value") + retrieved_strategy = PostProcessorStrategy.get_strategy( + SomeStrategy.name, config.dict() + ) + assert isinstance(retrieved_strategy, SomeStrategy) + + retrieved_strategy = PostProcessorStrategy.get_strategy( + SomeSubStrategy.name, config.dict() + ) + assert isinstance(retrieved_strategy, SomeSubStrategy) + assert issubclass(type(retrieved_strategy), SomeStrategy) + + retrieved_strategy = PostProcessorStrategy.get_strategy( + AnotherSubStrategy.name, config.dict() + ) + assert isinstance(retrieved_strategy, AnotherSubStrategy) + assert issubclass(type(retrieved_strategy), SomeStrategy) + + retrieved_strategy = PostProcessorStrategy.get_strategy( + SomeSubSubStrategy.name, config.dict() + ) + assert isinstance(retrieved_strategy, SomeSubSubStrategy) + assert issubclass(type(retrieved_strategy), SomeStrategy) + assert issubclass(type(retrieved_strategy), SomeSubStrategy) + + retrieved_strategy = PostProcessorStrategy.get_strategy( + DifferentStrategySubClass.name, config.dict() + ) + assert isinstance(retrieved_strategy, DifferentStrategySubClass) + assert issubclass(type(retrieved_strategy), SomeAbstractStrategyClass) + + def test_retrieve_nonexistent_strategy(self): + """ + Test attempt to retrieve a nonexistent strategy + """ + + with pytest.raises(NoSuchStrategyException) as exc: + PostProcessorStrategy.get_strategy("a nonexistent strategy", {}) + assert "'a nonexistent strategy'" in str(exc.value) + assert "some postprocessor strategy" in str(exc.value) + + def test_get_strategies(self): + """ + Test `get_strategies` method returns expected list of strategies + """ + strats = PostProcessorStrategy.get_strategies() + expected_strats = [ + SomeStrategy, + SomeSubStrategy, + SomeSubSubStrategy, + DifferentStrategySubClass, + ] + for expected_strat in expected_strats: + assert expected_strat in strats + + assert SomeAbstractStrategyClass not in strats diff --git a/tests/ops/util/encryption/test_secrets_util.py b/tests/ops/util/encryption/test_secrets_util.py index 0ce0acdbb..a98b76d89 100644 --- a/tests/ops/util/encryption/test_secrets_util.py +++ b/tests/ops/util/encryption/test_secrets_util.py @@ -6,10 +6,9 @@ SecretType, ) from fidesops.ops.service.masking.strategy.masking_strategy_aes_encrypt import ( - AES_ENCRYPT_STRATEGY_NAME, + AesEncryptionMaskingStrategy, ) from fidesops.ops.service.masking.strategy.masking_strategy_hmac import ( - HMAC_STRATEGY_NAME, HmacMaskingStrategy, ) from fidesops.ops.util.encryption.secrets_util import SecretsUtil @@ -23,7 +22,7 @@ def test_get_secret_from_cache_str() -> None: # build masking secret meta for HMAC key masking_meta_key: Dict[SecretType, MaskingSecretMeta] = { SecretType.key: MaskingSecretMeta[str]( - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, generate_secret_func=SecretsUtil.generate_secret_string, ) } @@ -31,7 +30,7 @@ def test_get_secret_from_cache_str() -> None: # cache secrets for HMAC secret_key = MaskingSecretCache[str]( secret="test_key", - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) @@ -47,7 +46,7 @@ def test_get_secret_from_cache_bytes() -> None: # build masking secret meta for AES key masking_meta_key: Dict[SecretType, MaskingSecretMeta] = { SecretType.key: MaskingSecretMeta[bytes]( - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=AesEncryptionMaskingStrategy.name, generate_secret_func=SecretsUtil.generate_secret_bytes, ) } @@ -55,7 +54,7 @@ def test_get_secret_from_cache_bytes() -> None: # cache secret AES key secret_key = MaskingSecretCache[str]( secret=b"\x94Y\xa8Z", - masking_strategy=AES_ENCRYPT_STRATEGY_NAME, + masking_strategy=AesEncryptionMaskingStrategy.name, secret_type=SecretType.key, ) cache_secret(secret_key, request_id) @@ -71,7 +70,7 @@ def test_generate_secret() -> None: # build masking secret meta for HMAC key masking_meta_key: Dict[SecretType, MaskingSecretMeta] = { SecretType.key: MaskingSecretMeta[str]( - masking_strategy=HMAC_STRATEGY_NAME, + masking_strategy=HmacMaskingStrategy.name, generate_secret_func=SecretsUtil.generate_secret_string, ) }