Skip to content

Commit

Permalink
Add MSK IAM authentication capability and framework for other Kafka S…
Browse files Browse the repository at this point in the history
…ASL/OAuth mechanisms (#31)

* Add MSK IAM authentication capability and framework for other Kafka SASL/OAuth mechanisms

* Add AWS_ROLE_SESSION_NAME config to stop errors about changing principals when re-authing

* Add support for assuming a specified role
  • Loading branch information
woodlee authored Mar 14, 2024
1 parent 902d310 commit 6d480f3
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 27 deletions.
44 changes: 33 additions & 11 deletions cdc_kafka/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import confluent_kafka.avro
import fastavro

from . import constants
from . import constants, kafka_oauth

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand All @@ -40,8 +40,7 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots
'bootstrap.servers': bootstrap_servers,
'group.id': f'cdc_to_kafka_{socket.getfqdn()}',
'enable.partition.eof': True,
'enable.auto.commit': False,
'broker.address.family': 'v4'
'enable.auto.commit': False
}, **extra_kafka_consumer_config}
producer_config: Dict[str, Any] = {**{
'bootstrap.servers': bootstrap_servers,
Expand All @@ -50,14 +49,24 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots
'statistics.interval.ms': 30 * 60 * 1000,
'enable.gapless.guarantee': True,
'retry.backoff.ms': 250,
'compression.codec': 'snappy',
'broker.address.family': 'v4'
'compression.codec': 'snappy'
}, **extra_kafka_producer_config}
admin_config: Dict[str, Any] = {
'bootstrap.servers': bootstrap_servers,
'broker.address.family': 'v4'
'bootstrap.servers': bootstrap_servers
}

oauth_provider = kafka_oauth.get_kafka_oauth_provider()

if oauth_provider is not None:
logger.debug('Using Kafka OAuth provider class %s', type(oauth_provider).__name__)
for config_dict in (consumer_config, producer_config, admin_config):
if not config_dict.get('security.protocol'):
config_dict['security.protocol'] = 'SASL_SSL'
if not config_dict.get('sasl.mechanisms'):
config_dict['sasl.mechanisms'] = 'OAUTHBEARER'
if not config_dict.get('client.id'):
config_dict['client.id'] = socket.gethostname()

logger.debug('Kafka consumer configuration: %s', json.dumps(consumer_config))
logger.debug('Kafka producer configuration: %s', json.dumps(producer_config))
logger.debug('Kafka admin client configuration: %s', json.dumps(admin_config))
Expand All @@ -73,16 +82,17 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots
admin_config['throttle_cb'] = KafkaClient._log_kafka_throttle_event
admin_config['logger'] = logger

if oauth_provider is not None:
consumer_config['oauth_cb'] = oauth_provider.consumer_oauth_cb
producer_config['oauth_cb'] = oauth_provider.producer_oauth_cb
admin_config['oauth_cb'] = oauth_provider.admin_oauth_cb

self._use_transactions: bool = False
if transactional_id is not None:
producer_config['transactional.id'] = transactional_id
self._use_transactions = True

self._producer: confluent_kafka.Producer = confluent_kafka.Producer(producer_config)

if self._use_transactions:
self._producer.init_transactions()

self._schema_registry: confluent_kafka.avro.CachedSchemaRegistryClient = \
confluent_kafka.avro.CachedSchemaRegistryClient(schema_registry_url)
self._consumer: confluent_kafka.Consumer = confluent_kafka.Consumer(consumer_config)
Expand All @@ -96,6 +106,18 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots
]]] = collections.defaultdict(list)
self._disable_writing = disable_writing
self._creation_warned_topic_names: Set[str] = set()

if oauth_provider is not None:
# I dislike this, but it seems like these polls are needed to trigger initial invocations of the oauth_cb
# before we act further with the producer, admin client, or consumer:
oauth_cb_poll_timeout = 3
self._producer.poll(oauth_cb_poll_timeout)
self._consumer.poll(oauth_cb_poll_timeout)
self._admin.poll(oauth_cb_poll_timeout)

if self._use_transactions:
self._producer.init_transactions(constants.KAFKA_REQUEST_TIMEOUT_SECS)

self._cluster_metadata: confluent_kafka.admin.ClusterMetadata = self._get_cluster_metadata()

KafkaClient._instance = self
Expand Down
56 changes: 56 additions & 0 deletions cdc_kafka/kafka_oauth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
import importlib
import os
from abc import ABC, abstractmethod

from typing import TypeVar, Type, Tuple, Optional

KafkaOauthProviderAbstractType = TypeVar('KafkaOauthProviderAbstractType', bound='KafkaOauthProviderAbstract')


class KafkaOauthProviderAbstract(ABC):
@abstractmethod
def consumer_oauth_cb(self, config_str: str) -> Tuple[str, float]:
pass

@abstractmethod
def producer_oauth_cb(self, config_str: str) -> Tuple[str, float]:
pass

@abstractmethod
def admin_oauth_cb(self, config_str: str) -> Tuple[str, float]:
pass

@staticmethod
def add_arguments(parser: argparse.ArgumentParser) -> None:
pass

@classmethod
@abstractmethod
def construct_with_options(cls: Type[KafkaOauthProviderAbstractType],
opts: argparse.Namespace) -> KafkaOauthProviderAbstractType:
pass


def add_kafka_oauth_arg(parser: argparse.ArgumentParser) -> None:
parser.add_argument('--kafka-oauth-provider',
default=os.environ.get('KAFKA_OAUTH_PROVIDER'),
help="A string of form <module_name>.<class_name> indicating an implementation of "
"kafka_oauth.KafkaOauthProviderAbstract that provides OAuth callback functions specified "
"when instantiating Kafka consumers, producers, or admin clients.")


def get_kafka_oauth_provider() -> Optional[KafkaOauthProviderAbstract]:
parser = argparse.ArgumentParser()
add_kafka_oauth_arg(parser)
opts, _ = parser.parse_known_args()

if not opts.kafka_oauth_provider:
return None

package_module, class_name = opts.kafka_oauth_provider.rsplit('.', 1)
module = importlib.import_module(package_module)
oauth_class: KafkaOauthProviderAbstract = getattr(module, class_name)
oauth_class.add_arguments(parser)
opts, _ = parser.parse_known_args()
return oauth_class.construct_with_options(opts)
60 changes: 60 additions & 0 deletions cdc_kafka/kafka_oauth/aws_msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
import datetime
import logging
import os
from typing import Tuple, TypeVar, Type, Optional

from aws_msk_iam_sasl_signer import MSKAuthTokenProvider

from . import KafkaOauthProviderAbstract

logger = logging.getLogger(__name__)

AwsMskOauthCallbackProviderType = TypeVar('AwsMskOauthCallbackProviderType', bound='AwsMskOauthCallbackProvider')


class AwsMskOauthCallbackProvider(KafkaOauthProviderAbstract):
def __init__(self, aws_region: str, role_arn: Optional[str] = None):
self.aws_region: str = aws_region
self.role_arn: Optional[str] = role_arn
self._auth_token: str = ''
self._expiry_ts: float = datetime.datetime.now(datetime.timezone.utc).timestamp()

def consumer_oauth_cb(self, config_str: str) -> Tuple[str, float]:
return self._common_cb()

def producer_oauth_cb(self, config_str: str) -> Tuple[str, float]:
return self._common_cb()

def admin_oauth_cb(self, config_str: str) -> Tuple[str, float]:
return self._common_cb()

def _common_cb(self) -> Tuple[str, float]:
if not self._auth_token or datetime.datetime.now(datetime.timezone.utc).timestamp() > self._expiry_ts:
if self.role_arn:
self._auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token_from_role_arn(
self.aws_region, self.role_arn)
else:
self._auth_token, expiry_ms = MSKAuthTokenProvider.generate_auth_token(self.aws_region)
self._expiry_ts = expiry_ms / 1000
logger.debug('AwsMskOauthCallbackProvider generated an auth token that expires at %s',
datetime.datetime.fromtimestamp(self._expiry_ts, datetime.timezone.utc))
return self._auth_token, self._expiry_ts

@staticmethod
def add_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument('--msk-cluster-aws-region', default=os.environ.get('MSK_CLUSTER_AWS_REGION'),
help='AWS region name to use for IAM-based authentication to an AWS MSK cluster.')
parser.add_argument('--msk-cluster-access-role-arn', default=os.environ.get('MSK_CLUSTER_ACCESS_ROLE_ARN'),
help='Optional name of an AWS IAM role to assume for authentication to an AWS MSK cluster.')
parser.add_argument('--aws-role-session-name', default=os.environ.get('AWS_ROLE_SESSION_NAME'),
help='A session name for the process to maintain principal-name stability when'
're-authenticating for AWS IAM/SASL')

@classmethod
def construct_with_options(cls: Type[AwsMskOauthCallbackProviderType],
opts: argparse.Namespace) -> AwsMskOauthCallbackProviderType:
if not opts.msk_cluster_aws_region:
raise Exception('AwsMskOauthCallbackProvider cannot be used without specifying a value for '
'MSK_CLUSTER_AWS_REGION')
return cls(opts.msk_cluster_aws_region, opts.msk_cluster_access_role_arn)
8 changes: 5 additions & 3 deletions cdc_kafka/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import socket

from typing import Tuple, List
from . import constants
from . import constants, kafka_oauth
from .metric_reporting import reporter_base


Expand Down Expand Up @@ -285,7 +285,9 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report
default=os.environ.get('DB_ROW_BATCH_SIZE', 2000),
help="Maximum number of rows to retrieve in a single change data or snapshot query. Default 2000.")

opts = p.parse_args()
kafka_oauth.add_kafka_oauth_arg(p)

opts, _ = p.parse_known_args()

reporter_classes: List[reporter_base.ReporterBase] = []
reporters: List[reporter_base.ReporterBase] = []
Expand All @@ -298,7 +300,7 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List[report
reporter_classes.append(reporter_class)
reporter_class.add_arguments(p)

opts = p.parse_args()
opts, _ = p.parse_known_args()

for reporter_class in reporter_classes:
reporters.append(reporter_class.construct_with_options(opts))
Expand Down
15 changes: 11 additions & 4 deletions cdc_kafka/progress_reset_tool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
import json
import logging
import os
import socket

from cdc_kafka import kafka, constants, progress_tracking, options
from cdc_kafka import kafka, constants, progress_tracking, options, kafka_oauth
from .metric_reporting import accumulator

logger = logging.getLogger(__name__)
Expand All @@ -27,7 +28,12 @@ def main() -> None:
p.add_argument('--execute',
type=options.str2bool, nargs='?', const=True,
default=options.str2bool(os.environ.get('EXECUTE', '0')))
opts = p.parse_args()
p.add_argument('--extra-kafka-producer-config',
default=os.environ.get('EXTRA_KAFKA_PRODUCER_CONFIG', {}), type=json.loads)
p.add_argument('--extra-kafka-consumer-config',
default=os.environ.get('EXTRA_KAFKA_CONSUMER_CONFIG', {}), type=json.loads)
kafka_oauth.add_kafka_oauth_arg(p)
opts, _ = p.parse_known_args()

logger.info(f"""
Expand All @@ -42,8 +48,9 @@ def main() -> None:
""")

with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.schema_registry_url, {},
{}, disable_writing=True) as kafka_client:
with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.schema_registry_url,
opts.extra_kafka_consumer_config, opts.extra_kafka_producer_config,
disable_writing=True) as kafka_client:
progress_tracker = progress_tracking.ProgressTracker(kafka_client, opts.progress_topic_name, socket.getfqdn(),
opts.snapshot_logging_topic_name)
progress_entries = progress_tracker.get_prior_progress()
Expand Down
11 changes: 8 additions & 3 deletions cdc_kafka/progress_topic_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import collections
import copy
import datetime
import json
import logging
import os
import re
Expand All @@ -10,7 +11,7 @@
import confluent_kafka
from tabulate import tabulate

from cdc_kafka import kafka, constants, progress_tracking, options, helpers
from cdc_kafka import kafka, constants, progress_tracking, options, helpers, kafka_oauth
from .metric_reporting import accumulator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,16 +42,20 @@ def main() -> None:
default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS'))
p.add_argument('--progress-topic-name',
default=os.environ.get('PROGRESS_TOPIC_NAME', '_cdc_to_kafka_progress'))
p.add_argument('--extra-kafka-consumer-config',
default=os.environ.get('EXTRA_KAFKA_CONSUMER_CONFIG', {}), type=json.loads)
kafka_oauth.add_kafka_oauth_arg(p)
p.add_argument('--show-all',
type=options.str2bool, nargs='?', const=True,
default=options.str2bool(os.environ.get('SHOW_ALL', '0')))
opts = p.parse_args()
opts, _ = p.parse_known_args()

if not (opts.schema_registry_url and opts.kafka_bootstrap_servers):
raise Exception('Arguments schema_registry_url and kafka_bootstrap_servers are required.')

with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers,
opts.schema_registry_url, {}, {}, disable_writing=True) as kafka_client:
opts.schema_registry_url, opts.extra_kafka_consumer_config, {},
disable_writing=True) as kafka_client:
if kafka_client.get_topic_partition_count(opts.progress_topic_name) is None:
logger.error('Progress topic %s not found.', opts.progress_topic_name)
exit(1)
Expand Down
18 changes: 16 additions & 2 deletions cdc_kafka/replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import argparse
import json
import logging.config
import multiprocessing as mp
import os
Expand All @@ -52,6 +53,8 @@
from confluent_kafka.schema_registry.avro import AvroDeserializer
from faster_fifo import Queue

from cdc_kafka import kafka_oauth

log_level = os.getenv('LOG_LEVEL', 'INFO').upper()

logging.config.dictConfig({
Expand Down Expand Up @@ -224,6 +227,9 @@ def main() -> None:
default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS'))
p.add_argument('--schema-registry-url',
default=os.environ.get('SCHEMA_REGISTRY_URL'))
p.add_argument('--extra-kafka-consumer-config',
default=os.environ.get('EXTRA_KAFKA_CONSUMER_CONFIG', {}), type=json.loads)
kafka_oauth.add_kafka_oauth_arg(p)

# Config for data target / progress tracking
p.add_argument('--target-db-server',
Expand Down Expand Up @@ -257,7 +263,7 @@ def main() -> None:
p.add_argument('--consumed-messages-limit', type=int,
default=os.environ.get('CONSUMED_MESSAGES_LIMIT', 0))

opts = p.parse_args()
opts, _ = p.parse_known_args()

if not (opts.replay_topic and opts.kafka_bootstrap_servers and opts.schema_registry_url and
opts.target_db_server and opts.target_db_user and opts.target_db_password and opts.target_db_database and
Expand Down Expand Up @@ -537,7 +543,15 @@ def consumer_process(opts: argparse.Namespace, stop_event: EventClass, queue: Qu
'enable.auto.offset.store': False,
'enable.auto.commit': False, # We don't use Kafka for offset management in this code
'auto.offset.reset': "earliest",
'on_commit': commit_cb}
'on_commit': commit_cb,
**opts.extra_kafka_consumer_config}
oauth_provider = kafka_oauth.get_kafka_oauth_provider()
if oauth_provider is not None:
if not consumer_conf.get('security.protocol'):
consumer_conf['security.protocol'] = 'SASL_SSL'
if not consumer_conf.get('sasl.mechanisms'):
consumer_conf['sasl.mechanisms'] = 'OAUTHBEARER'
consumer_conf['oauth_cb'] = oauth_provider.consumer_oauth_cb
consumer: Consumer = Consumer(consumer_conf)
start_offset_by_partition: Dict[int, int] = {
p.source_topic_partition: p.last_handled_message_offset + 1 for p in progress
Expand Down
11 changes: 7 additions & 4 deletions cdc_kafka/show_snapshot_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import confluent_kafka
from tabulate import tabulate

from cdc_kafka import kafka
from cdc_kafka import kafka, kafka_oauth
from . import constants
from .metric_reporting import accumulator

Expand All @@ -31,10 +31,13 @@ def main() -> None:
default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS'))
p.add_argument('--snapshot-logging-topic-name', required=True,
default=os.environ.get('SNAPSHOT_LOGGING_TOPIC_NAME'))
opts = p.parse_args()
p.add_argument('--extra-kafka-consumer-config',
default=os.environ.get('EXTRA_KAFKA_CONSUMER_CONFIG', {}), type=json.loads)
kafka_oauth.add_kafka_oauth_arg(p)
opts, _ = p.parse_known_args()

with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.schema_registry_url, {},
{}, disable_writing=True) as kafka_client:
with kafka.KafkaClient(accumulator.NoopAccumulator(), opts.kafka_bootstrap_servers, opts.schema_registry_url,
opts.extra_kafka_consumer_config, {}, disable_writing=True) as kafka_client:
last_start: Optional[Dict[str, Any]] = None
consumed_count: int = 0
relevant_count: int = 0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
avro==1.11.3
aws-msk-iam-sasl-signer-python==1.0.1
bitarray==2.9.2
confluent-kafka==2.3.0
fastavro==1.9.4
Expand Down

0 comments on commit 6d480f3

Please sign in to comment.