Skip to content

Commit

Permalink
Make propagators conform to spec (#488)
Browse files Browse the repository at this point in the history
* Make propagators conform to spec

* do not modify / set an invalid span in the passed context in case
  a propagator did not manage to extract
* in case no context is passed to propagator.extract assume the root
  context as default so that a new trace is started instead of continung
  the current active trace in case extraction fails
* fix also ot-trace propagator which compared int with str trace/span ids
  when checking for validity in extract
  • Loading branch information
mariojonke authored May 31, 2021
1 parent 4a8b32b commit 3a7eb53
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 216 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#504](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/504))
- `opentelemetry-instrumentation-asgi` Fix instrumentation default span name.
([#418](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/418))
- Propagators use the root context as default for `extract` and do not modify
the context if extracting from carrier does not work.
([#488](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/488))

### Added
- `opentelemetry-instrumentation-botocore` now supports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def extract(
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()

trace_id = extract_first_element(
getter.get(carrier, self.TRACE_ID_KEY)
)
Expand All @@ -64,7 +67,7 @@ def extract(
trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)

if trace_id is None or span_id is None:
return set_span_in_context(trace.INVALID_SPAN, context)
return context

trace_state = []
if origin is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import Mock, patch

from opentelemetry import trace as trace_api
from opentelemetry.context import Context
from opentelemetry.exporter.datadog import constants, propagator
from opentelemetry.sdk import trace
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
Expand All @@ -36,42 +37,58 @@ def setUpClass(cls):
)
cls.serialized_origin = "origin-service"

def test_malformed_headers(self):
def test_extract_malformed_headers_to_explicit_ctx(self):
"""Test with no Datadog headers"""
orig_ctx = Context({"k1": "v1"})
malformed_trace_id_key = FORMAT.TRACE_ID_KEY + "-x"
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = get_current_span(
FORMAT.extract(
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
},
)
).get_span_context()
context = FORMAT.extract(
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
},
orig_ctx,
)
self.assertDictEqual(orig_ctx, context)

self.assertNotEqual(context.trace_id, int(self.serialized_trace_id))
self.assertNotEqual(context.span_id, int(self.serialized_parent_id))
self.assertFalse(context.is_remote)
def test_extract_malformed_headers_to_implicit_ctx(self):
malformed_trace_id_key = FORMAT.TRACE_ID_KEY + "-x"
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = FORMAT.extract(
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
}
)
self.assertDictEqual(Context(), context)

def test_missing_trace_id(self):
def test_extract_missing_trace_id_to_explicit_ctx(self):
"""If a trace id is missing, populate an invalid trace id."""
carrier = {
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
}
orig_ctx = Context({"k1": "v1"})
carrier = {FORMAT.PARENT_ID_KEY: self.serialized_parent_id}

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_missing_trace_id_to_implicit_ctx(self):
carrier = {FORMAT.PARENT_ID_KEY: self.serialized_parent_id}

ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)
self.assertDictEqual(Context(), ctx)

def test_missing_parent_id(self):
def test_extract_missing_parent_id_to_explicit_ctx(self):
"""If a parent id is missing, populate an invalid trace id."""
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
}
orig_ctx = Context({"k1": "v1"})
carrier = {FORMAT.TRACE_ID_KEY: self.serialized_trace_id}

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_missing_parent_id_to_implicit_ctx(self):
carrier = {FORMAT.TRACE_ID_KEY: self.serialized_trace_id}

ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)
self.assertDictEqual(Context(), ctx)

def test_context_propagation(self):
"""Test the propagation of Datadog headers."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ def extract(
context: Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()

traceid = _extract_first_element(
getter.get(carrier, OT_TRACE_ID_HEADER), INVALID_TRACE_ID
traceid = _extract_identifier(
getter.get(carrier, OT_TRACE_ID_HEADER),
_valid_extract_traceid,
INVALID_TRACE_ID,
)

spanid = _extract_first_element(
getter.get(carrier, OT_SPAN_ID_HEADER), INVALID_SPAN_ID
spanid = _extract_identifier(
getter.get(carrier, OT_SPAN_ID_HEADER),
_valid_extract_spanid,
INVALID_SPAN_ID,
)

sampled = _extract_first_element(
Expand All @@ -73,17 +79,12 @@ def extract(
else:
traceflags = TraceFlags.DEFAULT

if (
traceid != INVALID_TRACE_ID
and _valid_extract_traceid.fullmatch(traceid) is not None
and spanid != INVALID_SPAN_ID
and _valid_extract_spanid.fullmatch(spanid) is not None
):
if traceid != INVALID_TRACE_ID and spanid != INVALID_SPAN_ID:
context = set_span_in_context(
NonRecordingSpan(
SpanContext(
trace_id=int(traceid, 16),
span_id=int(spanid, 16),
trace_id=traceid,
span_id=spanid,
is_remote=True,
trace_flags=TraceFlags(traceflags),
)
Expand Down Expand Up @@ -172,3 +173,16 @@ def _extract_first_element(
if items is None:
return default
return next(iter(items), None)


def _extract_identifier(
items: Iterable[CarrierT], validator_pattern, default: int
) -> int:
header = _extract_first_element(items)
if header is None or validator_pattern.fullmatch(header) is None:
return default

try:
return int(header, 16)
except ValueError:
return default
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest import TestCase

from opentelemetry.baggage import get_all, set_baggage
from opentelemetry.context import Context
from opentelemetry.propagators.ot_trace import (
OT_BAGGAGE_PREFIX,
OT_SAMPLED_HEADER,
Expand All @@ -24,8 +25,6 @@
)
from opentelemetry.sdk.trace import _Span
from opentelemetry.trace import (
INVALID_SPAN_CONTEXT,
INVALID_SPAN_ID,
INVALID_TRACE_ID,
SpanContext,
TraceFlags,
Expand Down Expand Up @@ -275,65 +274,44 @@ def test_extract_trace_id_span_id_sampled_false(self):
get_current_span().get_span_context().trace_flags, TraceFlags
)

def test_extract_malformed_trace_id(self):
"""Test extraction with malformed trace_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: "abc123!",
OT_SPAN_ID_HEADER: "e457b5a2e4d86bd1",
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)

def test_extract_malformed_span_id(self):
"""Test extraction with malformed span_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: "64fe8b2a57d3eff7",
OT_SPAN_ID_HEADER: "abc123!",
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)

def test_extract_invalid_trace_id(self):
"""Test extraction with invalid trace_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: INVALID_TRACE_ID,
OT_SPAN_ID_HEADER: "e457b5a2e4d86bd1",
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)

def test_extract_invalid_span_id(self):
"""Test extraction with invalid span_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: "64fe8b2a57d3eff7",
OT_SPAN_ID_HEADER: INVALID_SPAN_ID,
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)
def test_extract_invalid_trace_header_to_explict_ctx(self):
invalid_headers = [
("abc123!", "e457b5a2e4d86bd1"), # malformed trace id
("64fe8b2a57d3eff7", "abc123!"), # malformed span id
("0" * 32, "e457b5a2e4d86bd1"), # invalid trace id
("64fe8b2a57d3eff7", "0" * 16), # invalid span id
]
for trace_id, span_id in invalid_headers:
with self.subTest(trace_id=trace_id, span_id=span_id):
orig_ctx = Context({"k1": "v1"})

ctx = self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: trace_id,
OT_SPAN_ID_HEADER: span_id,
OT_SAMPLED_HEADER: "false",
},
orig_ctx,
)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_invalid_trace_header_to_implicit_ctx(self):
invalid_headers = [
("abc123!", "e457b5a2e4d86bd1"), # malformed trace id
("64fe8b2a57d3eff7", "abc123!"), # malformed span id
("0" * 32, "e457b5a2e4d86bd1"), # invalid trace id
("64fe8b2a57d3eff7", "0" * 16), # invalid span id
]
for trace_id, span_id in invalid_headers:
with self.subTest(trace_id=trace_id, span_id=span_id):
ctx = self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: trace_id,
OT_SPAN_ID_HEADER: span_id,
OT_SAMPLED_HEADER: "false",
}
)
self.assertDictEqual(Context(), ctx)

def test_extract_baggage(self):
"""Test baggage extraction"""
Expand All @@ -359,11 +337,13 @@ def test_extract_baggage(self):
self.assertEqual(baggage["abc"], "abc")
self.assertEqual(baggage["def"], "def")

def test_extract_empty(self):
"Test extraction when no headers are present"
def test_extract_empty_to_explicit_ctx(self):
"""Test extraction when no headers are present"""
orig_ctx = Context({"k1": "v1"})
ctx = self.ot_trace_propagator.extract({}, orig_ctx)

span_context = get_current_span(
self.ot_trace_propagator.extract({})
).get_span_context()
self.assertDictEqual(orig_ctx, ctx)

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)
def test_extract_empty_to_implicit_ctx(self):
ctx = self.ot_trace_propagator.extract({})
self.assertDictEqual(Context(), ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,18 @@ def extract(
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()

trace_header_list = getter.get(carrier, TRACE_HEADER_KEY)

if not trace_header_list or len(trace_header_list) != 1:
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

trace_header = trace_header_list[0]

if not trace_header:
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

try:
(
Expand All @@ -128,9 +127,7 @@ def extract(
) = AwsXRayFormat._extract_span_properties(trace_header)
except AwsParseTraceHeaderError as err:
_logger.debug(err.message)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

options = 0
if sampled:
Expand All @@ -148,9 +145,7 @@ def extract(
_logger.debug(
"Invalid Span Extracted. Insertting INVALID span into provided context."
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

return trace.set_span_in_context(
trace.NonRecordingSpan(span_context), context=context
Expand Down
Loading

0 comments on commit 3a7eb53

Please sign in to comment.