Skip to content

Commit

Permalink
Fix OTel context loss in parallel bulk helper (#2616)
Browse files Browse the repository at this point in the history
Co-authored-by: Quentin Pradet <[email protected]>
(cherry picked from commit d4df09f)
  • Loading branch information
claudinoac authored and github-actions[bot] committed Sep 3, 2024
1 parent 86be9d1 commit 9ad6298
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 130 deletions.
22 changes: 22 additions & 0 deletions elasticsearch/_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,25 @@ def span(
endpoint_id=endpoint_id,
body_strategy=self.body_strategy,
)

@contextlib.contextmanager
def helpers_span(self, span_name: str) -> Generator[OpenTelemetrySpan, None, None]:
if not self.enabled or self.tracer is None:
yield OpenTelemetrySpan(None)
return

with self.tracer.start_as_current_span(span_name) as otel_span:
otel_span.set_attribute("db.system", "elasticsearch")
otel_span.set_attribute("db.operation", span_name)
# Without a request method, Elastic APM does not display the traces
otel_span.set_attribute("http.request.method", "null")
yield otel_span

@contextlib.contextmanager
def use_span(self, span: OpenTelemetrySpan) -> Generator[None, None, None]:
if not self.enabled or self.tracer is None:
yield
return

with trace.use_span(span):
yield
226 changes: 120 additions & 106 deletions elasticsearch/helpers/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
Union,
)

from elastic_transport import OpenTelemetrySpan

from .. import Elasticsearch
from ..compat import to_bytes
from ..exceptions import ApiError, NotFoundError, TransportError
Expand Down Expand Up @@ -322,6 +324,7 @@ def _process_bulk_chunk(
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
],
otel_span: OpenTelemetrySpan,
raise_on_exception: bool = True,
raise_on_error: bool = True,
ignore_status: Union[int, Collection[int]] = (),
Expand All @@ -331,28 +334,29 @@ def _process_bulk_chunk(
"""
Send a bulk request to elasticsearch and process the output.
"""
if isinstance(ignore_status, int):
ignore_status = (ignore_status,)

try:
# send the actual request
resp = client.bulk(*args, operations=bulk_actions, **kwargs) # type: ignore[arg-type]
except ApiError as e:
gen = _process_bulk_chunk_error(
error=e,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_exception=raise_on_exception,
raise_on_error=raise_on_error,
)
else:
gen = _process_bulk_chunk_success(
resp=resp.body,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_error=raise_on_error,
)
yield from gen
with client._otel.use_span(otel_span):
if isinstance(ignore_status, int):
ignore_status = (ignore_status,)

try:
# send the actual request
resp = client.bulk(*args, operations=bulk_actions, **kwargs) # type: ignore[arg-type]
except ApiError as e:
gen = _process_bulk_chunk_error(
error=e,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_exception=raise_on_exception,
raise_on_error=raise_on_error,
)
else:
gen = _process_bulk_chunk_success(
resp=resp.body,
bulk_data=bulk_data,
ignore_status=ignore_status,
raise_on_error=raise_on_error,
)
yield from gen


def streaming_bulk(
Expand All @@ -370,6 +374,7 @@ def streaming_bulk(
max_backoff: float = 600,
yield_ok: bool = True,
ignore_status: Union[int, Collection[int]] = (),
span_name: str = "helpers.streaming_bulk",
*args: Any,
**kwargs: Any,
) -> Iterable[Tuple[bool, Dict[str, Any]]]:
Expand Down Expand Up @@ -406,73 +411,78 @@ def streaming_bulk(
:arg yield_ok: if set to False will skip successful documents in the output
:arg ignore_status: list of HTTP status code that you want to ignore
"""
client = client.options()
client._client_meta = (("h", "bp"),)
with client._otel.helpers_span(span_name) as otel_span:
client = client.options()
client._client_meta = (("h", "bp"),)

serializer = client.transport.serializers.get_serializer("application/json")
serializer = client.transport.serializers.get_serializer("application/json")

bulk_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
bulk_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
]
]
bulk_actions: List[bytes]
for bulk_data, bulk_actions in _chunk_actions(
map(expand_action_callback, actions), chunk_size, max_chunk_bytes, serializer
):
for attempt in range(max_retries + 1):
to_retry: List[bytes] = []
to_retry_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
] = []
if attempt:
time.sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1)))

try:
for data, (ok, info) in zip(
bulk_data,
_process_bulk_chunk(
client,
bulk_actions,
bulk_actions: List[bytes]
for bulk_data, bulk_actions in _chunk_actions(
map(expand_action_callback, actions),
chunk_size,
max_chunk_bytes,
serializer,
):
for attempt in range(max_retries + 1):
to_retry: List[bytes] = []
to_retry_data: List[
Union[
Tuple[_TYPE_BULK_ACTION_HEADER],
Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY],
]
] = []
if attempt:
time.sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1)))

try:
for data, (ok, info) in zip(
bulk_data,
raise_on_exception,
raise_on_error,
ignore_status,
*args,
**kwargs,
),
):
if not ok:
action, info = info.popitem()
# retry if retries enabled, we get 429, and we are not
# in the last attempt
if (
max_retries
and info["status"] == 429
and (attempt + 1) <= max_retries
):
# _process_bulk_chunk expects bytes so we need to
# re-serialize the data
to_retry.extend(map(serializer.dumps, data))
to_retry_data.append(data)
else:
yield ok, {action: info}
elif yield_ok:
yield ok, info

except ApiError as e:
# suppress 429 errors since we will retry them
if attempt == max_retries or e.status_code != 429:
raise
else:
if not to_retry:
break
# retry only subset of documents that didn't succeed
bulk_actions, bulk_data = to_retry, to_retry_data
_process_bulk_chunk(
client,
bulk_actions,
bulk_data,
otel_span,
raise_on_exception,
raise_on_error,
ignore_status,
*args,
**kwargs,
),
):
if not ok:
action, info = info.popitem()
# retry if retries enabled, we get 429, and we are not
# in the last attempt
if (
max_retries
and info["status"] == 429
and (attempt + 1) <= max_retries
):
# _process_bulk_chunk expects bytes so we need to
# re-serialize the data
to_retry.extend(map(serializer.dumps, data))
to_retry_data.append(data)
else:
yield ok, {action: info}
elif yield_ok:
yield ok, info

except ApiError as e:
# suppress 429 errors since we will retry them
if attempt == max_retries or e.status_code != 429:
raise
else:
if not to_retry:
break
# retry only subset of documents that didn't succeed
bulk_actions, bulk_data = to_retry, to_retry_data


def bulk(
Expand Down Expand Up @@ -519,7 +529,7 @@ def bulk(
# make streaming_bulk yield successful results so we can count them
kwargs["yield_ok"] = True
for ok, item in streaming_bulk(
client, actions, ignore_status=ignore_status, *args, **kwargs # type: ignore[misc]
client, actions, ignore_status=ignore_status, span_name="helpers.bulk", *args, **kwargs # type: ignore[misc]
):
# go through request-response pairs and detect failures
if not ok:
Expand Down Expand Up @@ -589,27 +599,31 @@ def _setup_queues(self) -> None:
] = Queue(max(queue_size, thread_count))
self._quick_put = self._inqueue.put

pool = BlockingPool(thread_count)
with client._otel.helpers_span("helpers.parallel_bulk") as otel_span:
pool = BlockingPool(thread_count)

try:
for result in pool.imap(
lambda bulk_chunk: list(
_process_bulk_chunk(
client,
bulk_chunk[1],
bulk_chunk[0],
ignore_status=ignore_status, # type: ignore[misc]
*args,
**kwargs,
)
),
_chunk_actions(expanded_actions, chunk_size, max_chunk_bytes, serializer),
):
yield from result

finally:
pool.close()
pool.join()
try:
for result in pool.imap(
lambda bulk_chunk: list(
_process_bulk_chunk(
client,
bulk_chunk[1],
bulk_chunk[0],
otel_span=otel_span,
ignore_status=ignore_status, # type: ignore[misc]
*args,
**kwargs,
)
),
_chunk_actions(
expanded_actions, chunk_size, max_chunk_bytes, serializer
),
):
yield from result

finally:
pool.close()
pool.join()


def scan(
Expand Down
25 changes: 25 additions & 0 deletions test_elasticsearch/test_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
# under the License.

import os
from unittest import mock

import pytest

from elasticsearch import Elasticsearch, helpers

try:
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
Expand Down Expand Up @@ -95,3 +98,25 @@ def test_detailed_span():
"db.elasticsearch.cluster.name": "e9106fc68e3044f0b1475b04bf4ffd5f",
"db.elasticsearch.node.name": "instance-0000000001",
}


@mock.patch("elasticsearch._otel.OpenTelemetry.use_span")
@mock.patch("elasticsearch._otel.OpenTelemetry.helpers_span")
@mock.patch("elasticsearch.helpers.actions._process_bulk_chunk_success")
@mock.patch("elasticsearch.Elasticsearch.bulk")
def test_forward_otel_context_to_subthreads(
_call_bulk_mock,
_process_bulk_success_mock,
_mock_otel_helpers_span,
_mock_otel_use_span,
):
tracer, memory_exporter = setup_tracing()
es_client = Elasticsearch("http://localhost:9200")
es_client._otel = OpenTelemetry(enabled=True, tracer=tracer)

_call_bulk_mock.return_value = mock.Mock()
actions = ({"x": i} for i in range(100))
list(helpers.parallel_bulk(es_client, actions, chunk_size=4))
# Ensures that the OTEL context has been forwarded to all chunks
assert es_client._otel.helpers_span.call_count == 1
assert es_client._otel.use_span.call_count == 25
1 change: 1 addition & 0 deletions test_elasticsearch/test_server/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
),
):
self.client = client
self._otel = client._otel
self._called = 0
self._fail_at = fail_at
self.transport = client.transport
Expand Down
Loading

0 comments on commit 9ad6298

Please sign in to comment.