Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: OTEL context lost in subthreads of parallel bulk calls #2616

Merged
merged 22 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions elasticsearch/_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from typing import Generator, Literal, Mapping

try:
from opentelemetry import context as otel_context
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)

_tracer: trace.Tracer | None = trace.get_tracer("elasticsearch-api")
except ModuleNotFoundError:
Expand All @@ -41,6 +45,8 @@


class OpenTelemetry:
context_carrier: dict[str, str] = {}

def __init__(
self,
enabled: bool | None = None,
Expand Down Expand Up @@ -86,3 +92,32 @@ def span(
endpoint_id=endpoint_id,
body_strategy=self.body_strategy,
)

@contextlib.contextmanager
def helpers_span(self, span_name: str) -> Generator[None, None, None]:
if not self.enabled or self.tracer is None:
yield
return

with self.tracer.start_as_current_span(span_name) as otel_span:
TraceContextTextMapPropagator().inject(self.context_carrier)
otel_span.set_attribute("db.system", "elasticsearch")
otel_span.set_attribute("db.operation", span_name)
# Without a request method, Elastic APM does not display the traces
otel_span.set_attribute("http.request.method", "null")
yield

@contextlib.contextmanager
def recover_parent_context(self) -> Generator[None, None, None]:
if not self.enabled or self.tracer is None:
yield
return

otel_parent_ctx = TraceContextTextMapPropagator().extract(
carrier=self.context_carrier
)
token = otel_context.attach(otel_parent_ctx)
try:
yield
finally:
otel_context.detach(token)
221 changes: 115 additions & 106 deletions elasticsearch/helpers/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,28 +331,29 @@ def _process_bulk_chunk(
"""
Send a bulk request to elasticsearch and process the output.
"""
if isinstance(ignore_status, int):
ignore_status = (ignore_status,)

try:
# send the actual request
resp = client.bulk(*args, operations=bulk_actions, **kwargs) # type: ignore[arg-type]
except ApiError as e:
gen = _process_bulk_chunk_error(
error=e,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_exception=raise_on_exception,
raise_on_error=raise_on_error,
)
else:
gen = _process_bulk_chunk_success(
resp=resp.body,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_error=raise_on_error,
)
yield from gen
with client._otel.recover_parent_context():
if isinstance(ignore_status, int):
ignore_status = (ignore_status,)

try:
# send the actual request
resp = client.bulk(*args, operations=bulk_actions, **kwargs) # type: ignore[arg-type]
except ApiError as e:
gen = _process_bulk_chunk_error(
error=e,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_exception=raise_on_exception,
raise_on_error=raise_on_error,
)
else:
gen = _process_bulk_chunk_success(
resp=resp.body,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_error=raise_on_error,
)
yield from gen


def streaming_bulk(
Expand All @@ -370,6 +371,7 @@ def streaming_bulk(
max_backoff: float = 600,
yield_ok: bool = True,
ignore_status: Union[int, Collection[int]] = (),
span_name: str = "helpers.streaming_bulk",
*args: Any,
**kwargs: Any,
) -> Iterable[Tuple[bool, Dict[str, Any]]]:
Expand Down Expand Up @@ -406,73 +408,77 @@ def streaming_bulk(
:arg yield_ok: if set to False will skip successful documents in the output
:arg ignore_status: list of HTTP status code that you want to ignore
"""
client = client.options()
client._client_meta = (("h", "bp"),)
with client._otel.helpers_span(span_name):
client = client.options()
client._client_meta = (("h", "bp"),)

serializer = client.transport.serializers.get_serializer("application/json")
serializer = client.transport.serializers.get_serializer("application/json")

bulk_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
bulk_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
]
]
bulk_actions: List[bytes]
for bulk_data, bulk_actions in _chunk_actions(
map(expand_action_callback, actions), chunk_size, max_chunk_bytes, serializer
):
for attempt in range(max_retries + 1):
to_retry: List[bytes] = []
to_retry_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
] = []
if attempt:
time.sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1)))

try:
for data, (ok, info) in zip(
bulk_data,
_process_bulk_chunk(
client,
bulk_actions,
bulk_actions: List[bytes]
for bulk_data, bulk_actions in _chunk_actions(
map(expand_action_callback, actions),
chunk_size,
max_chunk_bytes,
serializer,
):
for attempt in range(max_retries + 1):
to_retry: List[bytes] = []
to_retry_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
] = []
if attempt:
time.sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1)))

try:
for data, (ok, info) in zip(
bulk_data,
raise_on_exception,
raise_on_error,
ignore_status,
*args,
**kwargs,
),
):
if not ok:
action, info = info.popitem()
# retry if retries enabled, we get 429, and we are not
# in the last attempt
if (
max_retries
and info["status"] == 429
and (attempt + 1) <= max_retries
):
# _process_bulk_chunk expects bytes so we need to
# re-serialize the data
to_retry.extend(map(serializer.dumps, data))
to_retry_data.append(data)
else:
yield ok, {action: info}
elif yield_ok:
yield ok, info

except ApiError as e:
# suppress 429 errors since we will retry them
if attempt == max_retries or e.status_code != 429:
raise
else:
if not to_retry:
break
# retry only subset of documents that didn't succeed
bulk_actions, bulk_data = to_retry, to_retry_data
_process_bulk_chunk(
client,
bulk_actions,
bulk_data,
raise_on_exception,
raise_on_error,
ignore_status,
*args,
**kwargs,
),
):
if not ok:
action, info = info.popitem()
# retry if retries enabled, we get 429, and we are not
# in the last attempt
if (
max_retries
and info["status"] == 429
and (attempt + 1) <= max_retries
):
# _process_bulk_chunk expects bytes so we need to
# re-serialize the data
to_retry.extend(map(serializer.dumps, data))
to_retry_data.append(data)
else:
yield ok, {action: info}
elif yield_ok:
yield ok, info

except ApiError as e:
# suppress 429 errors since we will retry them
if attempt == max_retries or e.status_code != 429:
raise
else:
if not to_retry:
break
# retry only subset of documents that didn't succeed
bulk_actions, bulk_data = to_retry, to_retry_data


def bulk(
Expand Down Expand Up @@ -519,7 +525,7 @@ def bulk(
# make streaming_bulk yield successful results so we can count them
kwargs["yield_ok"] = True
for ok, item in streaming_bulk(
client, actions, ignore_status=ignore_status, *args, **kwargs # type: ignore[misc]
client, actions, ignore_status=ignore_status, span_name="helpers.bulk", *args, **kwargs # type: ignore[misc]
):
# go through request-response pairs and detect failures
if not ok:
Expand Down Expand Up @@ -589,27 +595,30 @@ def _setup_queues(self) -> None:
] = Queue(max(queue_size, thread_count))
self._quick_put = self._inqueue.put

pool = BlockingPool(thread_count)
with client._otel.helpers_span("helpers.parallel_bulk"):
pool = BlockingPool(thread_count)

try:
for result in pool.imap(
lambda bulk_chunk: list(
_process_bulk_chunk(
client,
bulk_chunk[1],
bulk_chunk[0],
ignore_status=ignore_status, # type: ignore[misc]
*args,
**kwargs,
)
),
_chunk_actions(expanded_actions, chunk_size, max_chunk_bytes, serializer),
):
yield from result

finally:
pool.close()
pool.join()
try:
for result in pool.imap(
lambda bulk_chunk: list(
_process_bulk_chunk(
client,
bulk_chunk[1],
bulk_chunk[0],
ignore_status=ignore_status, # type: ignore[misc]
*args,
**kwargs,
)
),
_chunk_actions(
expanded_actions, chunk_size, max_chunk_bytes, serializer
),
):
yield from result

finally:
pool.close()
pool.join()


def scan(
Expand Down
25 changes: 25 additions & 0 deletions test_elasticsearch/test_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
# under the License.

import os
from unittest import mock

import pytest

from elasticsearch import Elasticsearch, helpers

try:
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
Expand Down Expand Up @@ -95,3 +98,25 @@ def test_detailed_span():
"db.elasticsearch.cluster.name": "e9106fc68e3044f0b1475b04bf4ffd5f",
"db.elasticsearch.node.name": "instance-0000000001",
}


@mock.patch("elasticsearch._otel.OpenTelemetry.recover_parent_context")
@mock.patch("elasticsearch._otel.OpenTelemetry.helpers_span")
@mock.patch("elasticsearch.helpers.actions._process_bulk_chunk_success")
@mock.patch("elasticsearch.Elasticsearch.bulk")
def test_forward_otel_context_to_subthreads(
_call_bulk_mock,
_process_bulk_success_mock,
_mock_otel_helpers_span,
_mock_otel_recv_context,
):
tracer, memory_exporter = setup_tracing()
es_client = Elasticsearch("http://localhost:9200")
es_client._otel = OpenTelemetry(enabled=True, tracer=tracer)

_call_bulk_mock.return_value = mock.Mock()
actions = ({"x": i} for i in range(100))
list(helpers.parallel_bulk(es_client, actions, chunk_size=4))
# Ensures that the OTEL context has been forwarded to all chunks
assert es_client._otel.helpers_span.call_count == 1
assert es_client._otel.recover_parent_context.call_count == 25
1 change: 1 addition & 0 deletions test_elasticsearch/test_server/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
),
):
self.client = client
self._otel = client._otel
self._called = 0
self._fail_at = fail_at
self.transport = client.transport
Expand Down
Loading
Loading