Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(logger): Adding thread safe logging keys #5141

Open
wants to merge 10 commits into
base: v2
Choose a base branch
from
75 changes: 73 additions & 2 deletions aws_lambda_powertools/logging/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import traceback
from abc import ABCMeta, abstractmethod
from contextvars import ContextVar
from datetime import datetime, timezone
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -48,17 +49,30 @@ class BasePowertoolsFormatter(logging.Formatter, metaclass=ABCMeta):
def append_keys(self, **additional_keys) -> None:
raise NotImplementedError()

def append_thread_local_keys(self, **additional_keys) -> None:
raise NotImplementedError()

def get_current_keys(self) -> Dict[str, Any]:
return {}

def get_current_thread_keys(self) -> Dict[str, Any]:
return {}

def remove_keys(self, keys: Iterable[str]) -> None:
raise NotImplementedError()

def remove_thread_local_keys(self, keys: Iterable[str]) -> None:
raise NotImplementedError()

@abstractmethod
def clear_state(self) -> None:
"""Removes any previously added logging keys"""
raise NotImplementedError()

def clear_thread_local_keys(self) -> None:
"""Removes any previously added logging keys in a specific thread"""
raise NotImplementedError()


class LambdaPowertoolsFormatter(BasePowertoolsFormatter):
"""Powertools for AWS Lambda (Python) Logging formatter.
Expand Down Expand Up @@ -234,17 +248,29 @@ def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -
def append_keys(self, **additional_keys) -> None:
self.log_format.update(additional_keys)

def append_thread_local_keys(self, **additional_keys) -> None:
set_context_keys(**additional_keys)

def get_current_keys(self) -> Dict[str, Any]:
return self.log_format

def get_current_thread_keys(self) -> Dict[str, Any]:
return _get_context().get()

def remove_keys(self, keys: Iterable[str]) -> None:
for key in keys:
self.log_format.pop(key, None)

def remove_thread_local_keys(self, keys: Iterable[str]) -> None:
remove_context_keys(keys)

def clear_state(self) -> None:
self.log_format = dict.fromkeys(self.log_record_order)
self.log_format.update(**self.keys_combined)

def clear_thread_local_keys(self) -> None:
clear_context_keys()

@staticmethod
def _build_default_keys() -> Dict[str, str]:
return {
Expand Down Expand Up @@ -343,14 +369,33 @@ def _extract_log_keys(self, log_record: logging.LogRecord) -> Dict[str, Any]:
record_dict["asctime"] = self.formatTime(record=log_record)
extras = {k: v for k, v in record_dict.items() if k not in RESERVED_LOG_ATTRS}

formatted_log = {}
formatted_log: Dict[str, Any] = {}

# Iterate over a default or existing log structure
# then replace any std log attribute e.g. '%(level)s' to 'INFO', '%(process)d to '4773'
# check if the value is a str if the key is a reserved attribute, the modulo operator only supports string
# lastly add or replace incoming keys (those added within the constructor or .structure_logs method)
for key, value in self.log_format.items():
if value and key in RESERVED_LOG_ATTRS:
formatted_log[key] = value % record_dict
if isinstance(value, str):
formatted_log[key] = value % record_dict
else:
raise ValueError(
"Logging keys that override reserved log attributes need to be type 'str', "
f"instead got '{type(value).__name__}'",
)
else:
formatted_log[key] = value

for key, value in _get_context().get().items():
if value and key in RESERVED_LOG_ATTRS:
if isinstance(value, str):
formatted_log[key] = value % record_dict
else:
raise ValueError(
"Logging keys that override reserved log attributes need to be type 'str', "
f"instead got '{type(value).__name__}'",
)
else:
formatted_log[key] = value

Expand All @@ -368,3 +413,29 @@ def _strip_none_records(records: Dict[str, Any]) -> Dict[str, Any]:

# Fetch current and future parameters from PowertoolsFormatter that should be reserved
RESERVED_FORMATTER_CUSTOM_KEYS: List[str] = inspect.getfullargspec(LambdaPowertoolsFormatter).args[1:]

# ContextVar for thread local keys
THREAD_LOCAL_KEYS: ContextVar[dict[str, Any]] = ContextVar("THREAD_LOCAL_KEYS", default={})


def _get_context() -> ContextVar[dict[str, Any]]:
return THREAD_LOCAL_KEYS


def clear_context_keys() -> None:
_get_context().set({})


def set_context_keys(**kwargs: Dict[str, Any]) -> None:
context = _get_context()
context.set({**context.get(), **kwargs})


def remove_context_keys(keys: Iterable[str]) -> None:
context = _get_context()
context_values = context.get()

for k in keys:
context_values.pop(k, None)

context.set(context_values)
13 changes: 13 additions & 0 deletions aws_lambda_powertools/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,24 @@ def debug(
def append_keys(self, **additional_keys: object) -> None:
self.registered_formatter.append_keys(**additional_keys)

def append_thread_local_keys(self, **additional_keys: object) -> None:
self.registered_formatter.append_thread_local_keys(**additional_keys)

def get_current_keys(self) -> Dict[str, Any]:
return self.registered_formatter.get_current_keys()

def get_current_thread_keys(self) -> Dict[str, Any]:
return self.registered_formatter.get_current_thread_keys()

def remove_keys(self, keys: Iterable[str]) -> None:
self.registered_formatter.remove_keys(keys)

def remove_thread_local_keys(self, keys: Iterable[str]) -> None:
self.registered_formatter.remove_thread_local_keys(keys)

def clear_thread_local_keys(self) -> None:
self.registered_formatter.clear_thread_local_keys()

def structure_logs(self, append: bool = False, formatter_options: Optional[Dict] = None, **keys) -> None:
"""Sets logging formatting to JSON.

Expand Down Expand Up @@ -636,6 +648,7 @@ def structure_logs(self, append: bool = False, formatter_options: Optional[Dict]

# Mode 3
self.registered_formatter.clear_state()
self.registered_formatter.clear_thread_local_keys()
self.registered_formatter.append_keys(**log_keys)

def set_correlation_id(self, value: Optional[str]) -> None:
Expand Down
69 changes: 67 additions & 2 deletions docs/core/logger.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ To ease routine tasks like extracting correlation ID from popular event sources,

You can append additional keys using either mechanism:

* Persist new keys across all future log messages via `append_keys` method
* New keys persist across all future log messages via `append_keys` method
* New keys persist across all future logs in a specific thread via `append_thread_local_keys` method
* Add additional keys on a per log message basis as a keyword=value, or via `extra` parameter

#### append_keys method

???+ warning
`append_keys` is not thread-safe, please see [RFC](https://github.com/aws-powertools/powertools-lambda-python/issues/991){target="_blank"}.
`append_keys` is not thread-safe, use `append_thread_local_keys` instead

You can append your own keys to your existing Logger via `append_keys(**additional_key_values)` method.

Expand All @@ -186,6 +187,22 @@ You can append your own keys to your existing Logger via `append_keys(**addition

This example will add `order_id` if its value is not empty, and in subsequent invocations where `order_id` might not be present it'll remove it from the Logger.

#### append_thread_local_keys method

You can append your own thread-local keys in your existing Logger via the `append_thread_local_keys` method

=== "append_thread_local_keys.py"

```python hl_lines="11"
--8<-- "examples/logger/src/append_thread_local_keys.py"
```

=== "append_thread_local_keys_output.json"

```json hl_lines="8 9 17 18"
--8<-- "examples/logger/src/append_thread_local_keys_output.json"
```

#### ephemeral metadata

You can pass an arbitrary number of keyword arguments (kwargs) to all log level's methods, e.g. `logger.info, logger.warning`.
Expand Down Expand Up @@ -228,6 +245,17 @@ It accepts any dictionary, and all keyword arguments will be added as part of th

### Removing additional keys

You can remove additional keys using either mechanism:

* Remove new keys across all future log messages via `remove_keys` method
* Remove new keys across all future logs in a specific thread via `remove_thread_local_keys` method
* Remove **all** new keys across all future logs in a specific thread via `clear_thread_local_keys` method

???+ danger
Keys added by `append_keys` can only be removed by `remove_keys` and thread-local keys added by `append_thread_local_key` can only be removed by `remove_thread_local_keys` or `clear_thread_local_keys`. Thread-local and normal logger keys are distinct values and can't be manipulated interchangably.

#### remove_keys method

You can remove any additional key from Logger state using `remove_keys`.

=== "remove_keys.py"
Expand All @@ -242,6 +270,40 @@ You can remove any additional key from Logger state using `remove_keys`.
--8<-- "examples/logger/src/remove_keys_output.json"
```

#### remove_thread_local_keys method

You can remove any additional thread-local keys from Logger using either `remove_thread_local_keys` or `clear_thread_local_keys`.

Use the `remove_thread_local_keys` method to remove a list of thread-local keys that were previously added using the `append_thread_local_keys` method.

=== "remove_thread_local_keys.py"

```python hl_lines="13"
--8<-- "examples/logger/src/remove_thread_local_keys.py"
```

=== "remove_thread_local_keys_output.json"

```json hl_lines="8 9 17 18 26 34"
--8<-- "examples/logger/src/remove_thread_local_keys_output.json"
```

#### clear_thread_local_keys method

Use the `clear_thread_local_keys` method to remove all thread-local keys that were previously added using the `append_thread_local_keys` method.

=== "clear_thread_local_keys.py"

```python hl_lines="13"
--8<-- "examples/logger/src/clear_thread_local_keys.py"
```

=== "clear_thread_local_keys_output.json"

```json hl_lines="8 9 17 18"
--8<-- "examples/logger/src/clear_thread_local_keys_output.json"
```

#### Clearing all state

Logger is commonly initialized in the global scope. Due to [Lambda Execution Context reuse](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-context.html){target="_blank"}, this means that custom keys can be persisted across invocations. If you want all custom keys to be deleted, you can use `clear_state=True` param in `inject_lambda_context` decorator.
Expand Down Expand Up @@ -284,6 +346,9 @@ You can view all currently configured keys from the Logger state using the `get_
--8<-- "examples/logger/src/get_current_keys.py"
```

???+ info
For thread-local additional logging keys, use `get_current_thread_keys` instead

### Log levels

The default log level is `INFO`. It can be set using the `level` constructor option, `setLevel()` method or by using the `POWERTOOLS_LOG_LEVEL` environment variable.
Expand Down
21 changes: 21 additions & 0 deletions examples/logger/src/append_thread_local_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import threading
from typing import List

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()


def threaded_func(order_id: str):
logger.append_thread_local_keys(order_id=order_id, thread_id=threading.get_ident())
logger.info("Collecting payment")


def lambda_handler(event: dict, context: LambdaContext) -> str:
order_ids: List[str] = event["order_ids"]

threading.Thread(target=threaded_func, args=(order_ids[0],)).start()
threading.Thread(target=threaded_func, args=(order_ids[1],)).start()

return "hello world"
20 changes: 20 additions & 0 deletions examples/logger/src/append_thread_local_keys_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 03:04:11,316-0400",
"service": "payment",
"order_id": "order_id_value_1",
"thread_id": "3507187776085958"
},
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 03:04:11,316-0400",
"service": "payment",
"order_id": "order_id_value_2",
"thread_id": "140718447808512"
}
]
23 changes: 23 additions & 0 deletions examples/logger/src/clear_thread_local_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import threading
from typing import List

from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()


def threaded_func(order_id: str):
logger.append_thread_local_keys(order_id=order_id, thread_id=threading.get_ident())
logger.info("Collecting payment")
logger.clear_thread_local_keys()
logger.info("Exiting thread")


def lambda_handler(event: dict, context: LambdaContext) -> str:
order_ids: List[str] = event["order_ids"]

threading.Thread(target=threaded_func, args=(order_ids[0],)).start()
threading.Thread(target=threaded_func, args=(order_ids[1],)).start()

return "hello world"
34 changes: 34 additions & 0 deletions examples/logger/src/clear_thread_local_keys_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 12:26:10,648-0400",
"service": "payment",
"order_id": "order_id_value_1",
"thread_id": 140077070292544
},
{
"level": "INFO",
"location": "threaded_func:11",
"message": "Collecting payment",
"timestamp": "2024-09-08 12:26:10,649-0400",
"service": "payment",
"order_id": "order_id_value_2",
"thread_id": 140077061899840
},
{
"level": "INFO",
"location": "threaded_func:13",
"message": "Exiting thread",
"timestamp": "2024-09-08 12:26:10,649-0400",
"service": "payment"
},
{
"level": "INFO",
"location": "threaded_func:13",
"message": "Exiting thread",
"timestamp": "2024-09-08 12:26:10,649-0400",
"service": "payment"
}
]
Loading