diff --git a/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py b/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py index aa9217c00e..b30b42d3fd 100644 --- a/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py +++ b/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py @@ -6,8 +6,9 @@ from flask import request as flask_request import opentelemetry.ext.wsgi as otel_wsgi -from opentelemetry import propagators, trace +from opentelemetry import context, propagators, trace from opentelemetry.ext.flask.version import __version__ +from opentelemetry.trace.propagation import get_span_from_context from opentelemetry.util import time_ns logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ _ENVIRON_STARTTIME_KEY = "opentelemetry-flask.starttime_key" _ENVIRON_SPAN_KEY = "opentelemetry-flask.span_key" _ENVIRON_ACTIVATION_KEY = "opentelemetry-flask.activation_key" +_ENVIRON_TOKEN = "opentelemetry-flask.token" def instrument_app(flask): @@ -57,8 +59,8 @@ def _before_flask_request(): span_name = flask_request.endpoint or otel_wsgi.get_default_span_name( environ ) - parent_span = propagators.extract( - otel_wsgi.get_header_from_environ, environ + token = context.attach( + propagators.extract(otel_wsgi.get_header_from_environ, environ) ) tracer = trace.get_tracer(__name__, __version__) @@ -69,7 +71,6 @@ def _before_flask_request(): attributes["http.route"] = flask_request.url_rule.rule span = tracer.start_span( span_name, - parent_span, kind=trace.SpanKind.SERVER, attributes=attributes, start_time=environ.get(_ENVIRON_STARTTIME_KEY), @@ -78,6 +79,7 @@ def _before_flask_request(): activation.__enter__() environ[_ENVIRON_ACTIVATION_KEY] = activation environ[_ENVIRON_SPAN_KEY] = span + environ[_ENVIRON_TOKEN] = token def _teardown_flask_request(exc): @@ -95,3 +97,4 @@ def _teardown_flask_request(exc): activation.__exit__( type(exc), exc, getattr(exc, "__traceback__", None) ) + context.detach(flask_request.environ.get(_ENVIRON_TOKEN)) diff --git a/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py b/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py index d21ca8258c..8e4b3e2cc0 100644 --- a/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py +++ b/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py @@ -76,7 +76,7 @@ def instrumented_request(self, method, url, *args, **kwargs): # to access propagators. headers = kwargs.setdefault("headers", {}) - propagators.inject(tracer, type(headers).__setitem__, headers) + propagators.inject(type(headers).__setitem__, headers) result = wrapped(self, method, url, *args, **kwargs) # *** PROCEED span.set_attribute("http.status_code", result.status_code) diff --git a/ext/opentelemetry-ext-http-requests/tests/test_requests_integration.py b/ext/opentelemetry-ext-http-requests/tests/test_requests_integration.py index 0a61016c77..ea37cbbf1b 100644 --- a/ext/opentelemetry-ext-http-requests/tests/test_requests_integration.py +++ b/ext/opentelemetry-ext-http-requests/tests/test_requests_integration.py @@ -41,6 +41,7 @@ def setUp(self): self.get_tracer = self.get_tracer_patcher.start() self.span_context_manager = mock.MagicMock() self.span = mock.create_autospec(trace.Span, spec_set=True) + self.span.get_context.return_value = trace.INVALID_SPAN_CONTEXT self.span_context_manager.__enter__.return_value = self.span def setspanattr(key, value): diff --git a/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py b/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py index 1ba196d9e0..bd9d22678e 100644 --- a/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py +++ b/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py @@ -93,6 +93,10 @@ from opentelemetry.ext.opentracing_shim import util from opentelemetry.ext.opentracing_shim.version import __version__ from opentelemetry.trace import DefaultSpan +from opentelemetry.trace.propagation import ( + get_span_from_context, + set_span_in_context, +) logger = logging.getLogger(__name__) @@ -677,11 +681,8 @@ def inject(self, span_context, format, carrier): propagator = propagators.get_global_httptextformat() - propagator.inject( - DefaultSpan(span_context.unwrap()), - type(carrier).__setitem__, - carrier, - ) + ctx = set_span_in_context(DefaultSpan(span_context.unwrap())) + propagator.inject(type(carrier).__setitem__, carrier, context=ctx) def extract(self, format, carrier): """Implements the ``extract`` method from the base class.""" @@ -700,6 +701,7 @@ def get_as_list(dict_object, key): return [value] if value is not None else [] propagator = propagators.get_global_httptextformat() - otel_context = propagator.extract(get_as_list, carrier) + ctx = propagator.extract(get_as_list, carrier) + otel_context = get_span_from_context(ctx).get_context() return SpanContextShim(otel_context) diff --git a/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py b/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py index 0d099340ec..2a3fe819c9 100644 --- a/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py +++ b/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py @@ -16,15 +16,26 @@ # pylint:disable=no-member import time +import typing from unittest import TestCase import opentracing import opentelemetry.ext.opentracing_shim as opentracingshim from opentelemetry import propagators, trace -from opentelemetry.context.propagation.httptextformat import HTTPTextFormat +from opentelemetry.context import Context from opentelemetry.ext.opentracing_shim import util from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace.propagation import ( + get_span_from_context, + set_span_in_context, +) +from opentelemetry.trace.propagation.httptextformat import ( + Getter, + HTTPTextFormat, + HTTPTextFormatT, + Setter, +) class TestShim(TestCase): @@ -49,7 +60,7 @@ def setUpClass(cls): cls._previous_propagator = propagators.get_global_httptextformat() # Set mock propagator for testing. - propagators.set_global_httptextformat(MockHTTPTextFormat) + propagators.set_global_httptextformat(MockHTTPTextFormat()) @classmethod def tearDownClass(cls): @@ -541,23 +552,37 @@ class MockHTTPTextFormat(HTTPTextFormat): TRACE_ID_KEY = "mock-traceid" SPAN_ID_KEY = "mock-spanid" - @classmethod - def extract(cls, get_from_carrier, carrier): - trace_id_list = get_from_carrier(carrier, cls.TRACE_ID_KEY) - span_id_list = get_from_carrier(carrier, cls.SPAN_ID_KEY) + def extract( + self, + get_from_carrier: Getter[HTTPTextFormatT], + carrier: HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> Context: + trace_id_list = get_from_carrier(carrier, self.TRACE_ID_KEY) + span_id_list = get_from_carrier(carrier, self.SPAN_ID_KEY) if not trace_id_list or not span_id_list: - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN) - return trace.SpanContext( - trace_id=int(trace_id_list[0]), span_id=int(span_id_list[0]) + return set_span_in_context( + trace.DefaultSpan( + trace.SpanContext( + trace_id=int(trace_id_list[0]), + span_id=int(span_id_list[0]), + ) + ) ) - @classmethod - def inject(cls, span, set_in_carrier, carrier): + def inject( + self, + set_in_carrier: Setter[HTTPTextFormatT], + carrier: HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> None: + span = get_span_from_context(context) set_in_carrier( - carrier, cls.TRACE_ID_KEY, str(span.get_context().trace_id) + carrier, self.TRACE_ID_KEY, str(span.get_context().trace_id) ) set_in_carrier( - carrier, cls.SPAN_ID_KEY, str(span.get_context().span_id) + carrier, self.SPAN_ID_KEY, str(span.get_context().span_id) ) diff --git a/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py b/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py index 37a3a0e9e0..b96fc057d1 100644 --- a/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py +++ b/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py @@ -22,8 +22,9 @@ import typing import wsgiref.util as wsgiref_util -from opentelemetry import propagators, trace +from opentelemetry import context, propagators, trace from opentelemetry.ext.wsgi.version import __version__ +from opentelemetry.trace.propagation import get_span_from_context from opentelemetry.trace.status import Status, StatusCanonicalCode _HTTP_VERSION_PREFIX = "HTTP/" @@ -181,12 +182,13 @@ def __call__(self, environ, start_response): start_response: The WSGI start_response callable. """ - parent_span = propagators.extract(get_header_from_environ, environ) + token = context.attach( + propagators.extract(get_header_from_environ, environ) + ) span_name = get_default_span_name(environ) span = self.tracer.start_span( span_name, - parent_span, kind=trace.SpanKind.SERVER, attributes=collect_request_attributes(environ), ) @@ -197,17 +199,20 @@ def __call__(self, environ, start_response): span, start_response ) iterable = self.wsgi(environ, start_response) - return _end_span_after_iterating(iterable, span, self.tracer) + return _end_span_after_iterating( + iterable, span, self.tracer, token + ) except: # noqa # TODO Set span status (cf. https://github.com/open-telemetry/opentelemetry-python/issues/292) span.end() + context.detach(token) raise # Put this in a subfunction to not delay the call to the wrapped # WSGI application (instrumentation should change the application # behavior as little as possible). -def _end_span_after_iterating(iterable, span, tracer): +def _end_span_after_iterating(iterable, span, tracer, token): try: with tracer.use_span(span): for yielded in iterable: @@ -217,3 +222,4 @@ def _end_span_after_iterating(iterable, span, tracer): if close: close() span.end() + context.detach(token) diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py deleted file mode 100644 index c8706281ad..0000000000 --- a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .binaryformat import BinaryFormat -from .httptextformat import HTTPTextFormat - -__all__ = ["BinaryFormat", "HTTPTextFormat"] diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py deleted file mode 100644 index 7f1a65882f..0000000000 --- a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.trace import SpanContext - - -class BinaryFormat(abc.ABC): - """API for serialization of span context into binary formats. - - This class provides an interface that enables converting span contexts - to and from a binary format. - """ - - @staticmethod - @abc.abstractmethod - def to_bytes(context: SpanContext) -> bytes: - """Creates a byte representation of a SpanContext. - - to_bytes should read values from a SpanContext and return a data - format to represent it, in bytes. - - Args: - context: the SpanContext to serialize - - Returns: - A bytes representation of the SpanContext. - - """ - - @staticmethod - @abc.abstractmethod - def from_bytes(byte_representation: bytes) -> typing.Optional[SpanContext]: - """Return a SpanContext that was represented by bytes. - - from_bytes should return back a SpanContext that was constructed from - the data serialized in the byte_representation passed. If it is not - possible to read in a proper SpanContext, return None. - - Args: - byte_representation: the bytes to deserialize - - Returns: - A bytes representation of the SpanContext if it is valid. - Otherwise return None. - - """ diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/__init__.py deleted file mode 100644 index c8706281ad..0000000000 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .binaryformat import BinaryFormat -from .httptextformat import HTTPTextFormat - -__all__ = ["BinaryFormat", "HTTPTextFormat"] diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/binaryformat.py deleted file mode 100644 index d6d083c0da..0000000000 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/binaryformat.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.distributedcontext import DistributedContext - - -class BinaryFormat(abc.ABC): - """API for serialization of span context into binary formats. - - This class provides an interface that enables converting span contexts - to and from a binary format. - """ - - @staticmethod - @abc.abstractmethod - def to_bytes(context: DistributedContext) -> bytes: - """Creates a byte representation of a DistributedContext. - - to_bytes should read values from a DistributedContext and return a data - format to represent it, in bytes. - - Args: - context: the DistributedContext to serialize - - Returns: - A bytes representation of the DistributedContext. - - """ - - @staticmethod - @abc.abstractmethod - def from_bytes( - byte_representation: bytes, - ) -> typing.Optional[DistributedContext]: - """Return a DistributedContext that was represented by bytes. - - from_bytes should return back a DistributedContext that was constructed - from the data serialized in the byte_representation passed. If it is - not possible to read in a proper DistributedContext, return None. - - Args: - byte_representation: the bytes to deserialize - - Returns: - A bytes representation of the DistributedContext if it is valid. - Otherwise return None. - - """ diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/httptextformat.py deleted file mode 100644 index 3e2c186283..0000000000 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/httptextformat.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.distributedcontext import DistributedContext - -Setter = typing.Callable[[object, str, str], None] -Getter = typing.Callable[[object, str], typing.List[str]] - - -class HTTPTextFormat(abc.ABC): - """API for propagation of span context via headers. - - This class provides an interface that enables extracting and injecting - span context into headers of HTTP requests. HTTP frameworks and clients - can integrate with HTTPTextFormat by providing the object containing the - headers, and a getter and setter function for the extraction and - injection of values, respectively. - - Example:: - - import flask - import requests - from opentelemetry.context.propagation import HTTPTextFormat - - PROPAGATOR = HTTPTextFormat() - - def get_header_from_flask_request(request, key): - return request.headers.get_all(key) - - def set_header_into_requests_request(request: requests.Request, - key: str, value: str): - request.headers[key] = value - - def example_route(): - distributed_context = PROPAGATOR.extract( - get_header_from_flask_request, - flask.request - ) - request_to_downstream = requests.Request( - "GET", "http://httpbin.org/get" - ) - PROPAGATOR.inject( - distributed_context, - set_header_into_requests_request, - request_to_downstream - ) - session = requests.Session() - session.send(request_to_downstream.prepare()) - - - .. _Propagation API Specification: - https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md - """ - - @abc.abstractmethod - def extract( - self, get_from_carrier: Getter, carrier: object - ) -> DistributedContext: - """Create a DistributedContext from values in the carrier. - - The extract function should retrieve values from the carrier - object using get_from_carrier, and use values to populate a - DistributedContext value and return it. - - Args: - get_from_carrier: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. - carrier: and object which contains values that are - used to construct a DistributedContext. This object - must be paired with an appropriate get_from_carrier - which understands how to extract a value from it. - Returns: - A DistributedContext with configuration found in the carrier. - - """ - - @abc.abstractmethod - def inject( - self, - context: DistributedContext, - set_in_carrier: Setter, - carrier: object, - ) -> None: - """Inject values from a DistributedContext into a carrier. - - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. - - Args: - context: The DistributedContext to read values from. - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that a place to define HTTP headers. - Should be paired with set_in_carrier, which should - know how to set header values on the carrier. - - """ diff --git a/opentelemetry-api/src/opentelemetry/propagators/__init__.py b/opentelemetry-api/src/opentelemetry/propagators/__init__.py index 3974a4cb03..f9b537cd86 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagators/__init__.py @@ -12,49 +12,87 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +API for propagation of context. + +Example:: + + import flask + import requests + from opentelemetry import propagators + + + PROPAGATOR = propagators.get_global_httptextformat() + + + def get_header_from_flask_request(request, key): + return request.headers.get_all(key) + + def set_header_into_requests_request(request: requests.Request, + key: str, value: str): + request.headers[key] = value + + def example_route(): + context = PROPAGATOR.extract( + get_header_from_flask_request, + flask.request + ) + request_to_downstream = requests.Request( + "GET", "http://httpbin.org/get" + ) + PROPAGATOR.inject( + set_header_into_requests_request, + request_to_downstream, + context=context + ) + session = requests.Session() + session.send(request_to_downstream.prepare()) + + +.. _Propagation API Specification: + https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md +""" + import typing -import opentelemetry.context.propagation.httptextformat as httptextformat import opentelemetry.trace as trace -from opentelemetry.context.propagation.tracecontexthttptextformat import ( +from opentelemetry.context import get_current +from opentelemetry.context.context import Context +from opentelemetry.trace.propagation import httptextformat +from opentelemetry.trace.propagation.tracecontexthttptextformat import ( TraceContextHTTPTextFormat, ) -_T = typing.TypeVar("_T") - def extract( - get_from_carrier: httptextformat.Getter[_T], carrier: _T -) -> trace.SpanContext: - """Load the parent SpanContext from values in the carrier. - - Using the specified HTTPTextFormatter, the propagator will - extract a SpanContext from the carrier. If one is found, - it will be set as the parent context of the current span. + get_from_carrier: httptextformat.Getter[httptextformat.HTTPTextFormatT], + carrier: httptextformat.HTTPTextFormatT, + context: typing.Optional[Context] = None, +) -> Context: + """ Uses the configured propagator to extract a Context from the carrier. Args: get_from_carrier: a function that can retrieve zero or more values from the carrier. In the case that the value does not exist, return an empty list. carrier: and object which contains values that are - used to construct a SpanContext. This object + used to construct a Context. This object must be paired with an appropriate get_from_carrier which understands how to extract a value from it. + context: an optional Context to use. Defaults to current + context if not set. """ - return get_global_httptextformat().extract(get_from_carrier, carrier) + return get_global_httptextformat().extract( + get_from_carrier, carrier, context + ) def inject( - tracer: trace.Tracer, - set_in_carrier: httptextformat.Setter[_T], - carrier: _T, + set_in_carrier: httptextformat.Setter[httptextformat.HTTPTextFormatT], + carrier: httptextformat.HTTPTextFormatT, + context: typing.Optional[Context] = None, ) -> None: - """Inject values from the current context into the carrier. - - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. + """ Uses the configured propagator to inject a Context into the carrier. Args: set_in_carrier: A setter function that can set values @@ -62,10 +100,10 @@ def inject( carrier: An object that contains a representation of HTTP headers. Should be paired with set_in_carrier, which should know how to set header values on the carrier. + context: an optional Context to use. Defaults to current + context if not set. """ - get_global_httptextformat().inject( - tracer.get_current_span(), set_in_carrier, carrier - ) + get_global_httptextformat().inject(set_in_carrier, carrier, context) _HTTP_TEXT_FORMAT = ( diff --git a/opentelemetry-api/src/opentelemetry/propagators/composite.py b/opentelemetry-api/src/opentelemetry/propagators/composite.py new file mode 100644 index 0000000000..4ec953c839 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/propagators/composite.py @@ -0,0 +1,69 @@ +# Copyright 2020, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import typing + +from opentelemetry.context.context import Context +from opentelemetry.trace.propagation import httptextformat + +logger = logging.getLogger(__name__) + + +class CompositeHTTPPropagator(httptextformat.HTTPTextFormat): + """ CompositeHTTPPropagator provides a mechanism for combining multiple + propagators into a single one. + + Args: + propagators: the list of propagators to use + """ + + def __init__( + self, propagators: typing.Sequence[httptextformat.HTTPTextFormat] + ) -> None: + self._propagators = propagators + + def extract( + self, + get_from_carrier: httptextformat.Getter[ + httptextformat.HTTPTextFormatT + ], + carrier: httptextformat.HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> Context: + """ Run each of the configured propagators with the given context and carrier. + Propagators are run in the order they are configured, if multiple + propagators write the same context key, the propagator later in the list + will override previous propagators. + + See `opentelemetry.trace.propagation.httptextformat.HTTPTextFormat.extract` + """ + for propagator in self._propagators: + context = propagator.extract(get_from_carrier, carrier, context) + return context # type: ignore + + def inject( + self, + set_in_carrier: httptextformat.Setter[httptextformat.HTTPTextFormatT], + carrier: httptextformat.HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> None: + """ Run each of the configured propagators with the given context and carrier. + Propagators are run in the order they are configured, if multiple + propagators write the same carrier key, the propagator later in the list + will override previous propagators. + + See `opentelemetry.trace.propagation.httptextformat.HTTPTextFormat.inject` + """ + for propagator in self._propagators: + propagator.inject(set_in_carrier, carrier, context) diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py index 881a74287a..90e7f9dcb3 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py @@ -13,7 +13,22 @@ # limitations under the License. from typing import Optional -from opentelemetry.trace import INVALID_SPAN_CONTEXT, Span, SpanContext +from opentelemetry import trace as trace_api +from opentelemetry.context import get_value, set_value +from opentelemetry.context.context import Context -_SPAN_CONTEXT_KEY = "extracted-span-context" SPAN_KEY = "current-span" + + +def set_span_in_context( + span: trace_api.Span, context: Optional[Context] = None +) -> Context: + ctx = set_value(SPAN_KEY, span, context=context) + return ctx + + +def get_span_from_context(context: Optional[Context] = None) -> trace_api.Span: + span = get_value(SPAN_KEY, context=context) + if not isinstance(span, trace_api.Span): + return trace_api.INVALID_SPAN + return span diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/trace/propagation/httptextformat.py similarity index 50% rename from opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py rename to opentelemetry-api/src/opentelemetry/trace/propagation/httptextformat.py index b64a298c41..500014d738 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/httptextformat.py @@ -15,89 +15,59 @@ import abc import typing -from opentelemetry.trace import Span, SpanContext +from opentelemetry.context.context import Context -_T = typing.TypeVar("_T") +HTTPTextFormatT = typing.TypeVar("HTTPTextFormatT") -Setter = typing.Callable[[_T, str, str], None] -Getter = typing.Callable[[_T, str], typing.List[str]] +Setter = typing.Callable[[HTTPTextFormatT, str, str], None] +Getter = typing.Callable[[HTTPTextFormatT, str], typing.List[str]] class HTTPTextFormat(abc.ABC): - """API for propagation of span context via headers. - - This class provides an interface that enables extracting and injecting - span context into headers of HTTP requests. HTTP frameworks and clients + """This class provides an interface that enables extracting and injecting + context into headers of HTTP requests. HTTP frameworks and clients can integrate with HTTPTextFormat by providing the object containing the headers, and a getter and setter function for the extraction and injection of values, respectively. - Example:: - - import flask - import requests - from opentelemetry.context.propagation import HTTPTextFormat - - PROPAGATOR = HTTPTextFormat() - - - - def get_header_from_flask_request(request, key): - return request.headers.get_all(key) - - def set_header_into_requests_request(request: requests.Request, - key: str, value: str): - request.headers[key] = value - - def example_route(): - span_context = PROPAGATOR.extract( - get_header_from_flask_request, - flask.request - ) - request_to_downstream = requests.Request( - "GET", "http://httpbin.org/get" - ) - PROPAGATOR.inject( - span_context, - set_header_into_requests_request, - request_to_downstream - ) - session = requests.Session() - session.send(request_to_downstream.prepare()) - - - .. _Propagation API Specification: - https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md """ @abc.abstractmethod def extract( - self, get_from_carrier: Getter[_T], carrier: _T - ) -> SpanContext: - """Create a SpanContext from values in the carrier. + self, + get_from_carrier: Getter[HTTPTextFormatT], + carrier: HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> Context: + """Create a Context from values in the carrier. The extract function should retrieve values from the carrier object using get_from_carrier, and use values to populate a - SpanContext value and return it. + Context value and return it. Args: get_from_carrier: a function that can retrieve zero or more values from the carrier. In the case that the value does not exist, return an empty list. carrier: and object which contains values that are - used to construct a SpanContext. This object + used to construct a Context. This object must be paired with an appropriate get_from_carrier which understands how to extract a value from it. + context: an optional Context to use. Defaults to current + context if not set. Returns: - A SpanContext with configuration found in the carrier. + A Context with configuration found in the carrier. """ @abc.abstractmethod def inject( - self, span: Span, set_in_carrier: Setter[_T], carrier: _T + self, + set_in_carrier: Setter[HTTPTextFormatT], + carrier: HTTPTextFormatT, + context: typing.Optional[Context] = None, ) -> None: - """Inject values from a Span into a carrier. + """Inject values from a Context into a carrier. inject enables the propagation of values into HTTP clients or other objects which perform an HTTP request. Implementations @@ -105,11 +75,12 @@ def inject( carrier. Args: - context: The SpanContext to read values from. set_in_carrier: A setter function that can set values on the carrier. carrier: An object that a place to define HTTP headers. Should be paired with set_in_carrier, which should know how to set header values on the carrier. + context: an optional Context to use. Defaults to current + context if not set. """ diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/tracecontexthttptextformat.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontexthttptextformat.py similarity index 71% rename from opentelemetry-api/src/opentelemetry/context/propagation/tracecontexthttptextformat.py rename to opentelemetry-api/src/opentelemetry/trace/propagation/tracecontexthttptextformat.py index 0f07841eb7..28db4e4557 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/tracecontexthttptextformat.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontexthttptextformat.py @@ -16,9 +16,12 @@ import typing import opentelemetry.trace as trace -from opentelemetry.context.propagation import httptextformat - -_T = typing.TypeVar("_T") +from opentelemetry.context.context import Context +from opentelemetry.trace.propagation import ( + get_span_from_context, + httptextformat, + set_span_in_context, +) # Keys and values are strings of up to 256 printable US-ASCII characters. # Implementations should conform to the `W3C Trace Context - Tracestate`_ @@ -59,20 +62,26 @@ class TraceContextHTTPTextFormat(httptextformat.HTTPTextFormat): ) _TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT) - @classmethod def extract( - cls, get_from_carrier: httptextformat.Getter[_T], carrier: _T - ) -> trace.SpanContext: - """Extracts a valid SpanContext from the carrier. + self, + get_from_carrier: httptextformat.Getter[ + httptextformat.HTTPTextFormatT + ], + carrier: httptextformat.HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> Context: + """Extracts SpanContext from the carrier. + + See `opentelemetry.trace.propagation.httptextformat.HTTPTextFormat.extract` """ - header = get_from_carrier(carrier, cls._TRACEPARENT_HEADER_NAME) + header = get_from_carrier(carrier, self._TRACEPARENT_HEADER_NAME) if not header: - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN, context) - match = re.search(cls._TRACEPARENT_HEADER_FORMAT_RE, header[0]) + match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header[0]) if not match: - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN, context) version = match.group(1) trace_id = match.group(2) @@ -80,16 +89,16 @@ def extract( trace_flags = match.group(4) if trace_id == "0" * 32 or span_id == "0" * 16: - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN, context) if version == "00": if match.group(5): - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN, context) if version == "ff": - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN, context) tracestate_headers = get_from_carrier( - carrier, cls._TRACESTATE_HEADER_NAME + carrier, self._TRACESTATE_HEADER_NAME ) tracestate = _parse_tracestate(tracestate_headers) @@ -99,31 +108,34 @@ def extract( trace_flags=trace.TraceFlags(trace_flags), trace_state=tracestate, ) + return set_span_in_context(trace.DefaultSpan(span_context), context) - return span_context - - @classmethod def inject( - cls, - span: trace.Span, - set_in_carrier: httptextformat.Setter[_T], - carrier: _T, + self, + set_in_carrier: httptextformat.Setter[httptextformat.HTTPTextFormatT], + carrier: httptextformat.HTTPTextFormatT, + context: typing.Optional[Context] = None, ) -> None: + """Injects SpanContext into the carrier. - context = span.get_context() + See `opentelemetry.trace.propagation.httptextformat.HTTPTextFormat.inject` + """ + span_context = get_span_from_context(context).get_context() - if context == trace.INVALID_SPAN_CONTEXT: + if span_context == trace.INVALID_SPAN_CONTEXT: return traceparent_string = "00-{:032x}-{:016x}-{:02x}".format( - context.trace_id, context.span_id, context.trace_flags + span_context.trace_id, + span_context.span_id, + span_context.trace_flags, ) set_in_carrier( - carrier, cls._TRACEPARENT_HEADER_NAME, traceparent_string + carrier, self._TRACEPARENT_HEADER_NAME, traceparent_string ) - if context.trace_state: - tracestate_string = _format_tracestate(context.trace_state) + if span_context.trace_state: + tracestate_string = _format_tracestate(span_context.trace_state) set_in_carrier( - carrier, cls._TRACESTATE_HEADER_NAME, tracestate_string + carrier, self._TRACESTATE_HEADER_NAME, tracestate_string ) diff --git a/opentelemetry-api/tests/propagators/test_composite.py b/opentelemetry-api/tests/propagators/test_composite.py new file mode 100644 index 0000000000..09ac0ecf68 --- /dev/null +++ b/opentelemetry-api/tests/propagators/test_composite.py @@ -0,0 +1,107 @@ +# Copyright 2020, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock + +from opentelemetry.propagators.composite import CompositeHTTPPropagator + + +def get_as_list(dict_object, key): + value = dict_object.get(key) + return [value] if value is not None else [] + + +def mock_inject(name, value="data"): + def wrapped(setter, carrier=None, context=None): + carrier[name] = value + + return wrapped + + +def mock_extract(name, value="context"): + def wrapped(getter, carrier=None, context=None): + new_context = context.copy() + new_context[name] = value + return new_context + + return wrapped + + +class TestCompositePropagator(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mock_propagator_0 = Mock( + inject=mock_inject("mock-0"), extract=mock_extract("mock-0") + ) + cls.mock_propagator_1 = Mock( + inject=mock_inject("mock-1"), extract=mock_extract("mock-1") + ) + cls.mock_propagator_2 = Mock( + inject=mock_inject("mock-0", value="data2"), + extract=mock_extract("mock-0", value="context2"), + ) + + def test_no_propagators(self): + propagator = CompositeHTTPPropagator([]) + new_carrier = {} + propagator.inject(dict.__setitem__, carrier=new_carrier) + self.assertEqual(new_carrier, {}) + + context = propagator.extract( + get_as_list, carrier=new_carrier, context={} + ) + self.assertEqual(context, {}) + + def test_single_propagator(self): + propagator = CompositeHTTPPropagator([self.mock_propagator_0]) + + new_carrier = {} + propagator.inject(dict.__setitem__, carrier=new_carrier) + self.assertEqual(new_carrier, {"mock-0": "data"}) + + context = propagator.extract( + get_as_list, carrier=new_carrier, context={} + ) + self.assertEqual(context, {"mock-0": "context"}) + + def test_multiple_propagators(self): + propagator = CompositeHTTPPropagator( + [self.mock_propagator_0, self.mock_propagator_1] + ) + + new_carrier = {} + propagator.inject(dict.__setitem__, carrier=new_carrier) + self.assertEqual(new_carrier, {"mock-0": "data", "mock-1": "data"}) + + context = propagator.extract( + get_as_list, carrier=new_carrier, context={} + ) + self.assertEqual(context, {"mock-0": "context", "mock-1": "context"}) + + def test_multiple_propagators_same_key(self): + # test that when multiple propagators extract/inject the same + # key, the later propagator values are extracted/injected + propagator = CompositeHTTPPropagator( + [self.mock_propagator_0, self.mock_propagator_2] + ) + + new_carrier = {} + propagator.inject(dict.__setitem__, carrier=new_carrier) + self.assertEqual(new_carrier, {"mock-0": "data2"}) + + context = propagator.extract( + get_as_list, carrier=new_carrier, context={} + ) + self.assertEqual(context, {"mock-0": "context2"}) diff --git a/opentelemetry-api/tests/context/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py similarity index 60% rename from opentelemetry-api/tests/context/propagation/test_tracecontexthttptextformat.py rename to opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index 8f283ef881..6ee4a957d2 100644 --- a/opentelemetry-api/tests/context/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -14,10 +14,13 @@ import typing import unittest -from unittest.mock import Mock from opentelemetry import trace -from opentelemetry.context.propagation import tracecontexthttptextformat +from opentelemetry.trace.propagation import ( + get_span_from_context, + set_span_in_context, + tracecontexthttptextformat, +) FORMAT = tracecontexthttptextformat.TraceContextHTTPTextFormat() @@ -43,8 +46,8 @@ def test_no_traceparent_header(self): trace-id and parent-id that represents the current request. """ output = {} # type:typing.Dict[str, typing.List[str]] - span_context = FORMAT.extract(get_as_list, output) - self.assertTrue(isinstance(span_context, trace.SpanContext)) + span = get_span_from_context(FORMAT.extract(get_as_list, output)) + self.assertIsInstance(span.get_context(), trace.SpanContext) def test_headers_with_tracestate(self): """When there is a traceparent and tracestate header, data from @@ -55,23 +58,25 @@ def test_headers_with_tracestate(self): span_id=format(self.SPAN_ID, "016x"), ) tracestate_value = "foo=1,bar=2,baz=3" - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [traceparent_value], - "tracestate": [tracestate_value], - }, - ) + span_context = get_span_from_context( + FORMAT.extract( + get_as_list, + { + "traceparent": [traceparent_value], + "tracestate": [tracestate_value], + }, + ) + ).get_context() self.assertEqual(span_context.trace_id, self.TRACE_ID) self.assertEqual(span_context.span_id, self.SPAN_ID) self.assertEqual( span_context.trace_state, {"foo": "1", "bar": "2", "baz": "3"} ) - - mock_span = Mock() - mock_span.configure_mock(**{"get_context.return_value": span_context}) output = {} # type:typing.Dict[str, str] - FORMAT.inject(mock_span, dict.__setitem__, output) + span = trace.DefaultSpan(span_context) + + ctx = set_span_in_context(span) + FORMAT.inject(dict.__setitem__, output, ctx) self.assertEqual(output["traceparent"], traceparent_value) for pair in ["foo=1", "bar=2", "baz=3"]: self.assertIn(pair, output["tracestate"]) @@ -96,16 +101,18 @@ def test_invalid_trace_id(self): Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-00000000000000000000000000000000-1234567890123456-00" - ], - "tracestate": ["foo=1,bar=2,foo=3"], - }, + span = get_span_from_context( + FORMAT.extract( + get_as_list, + { + "traceparent": [ + "00-00000000000000000000000000000000-1234567890123456-00" + ], + "tracestate": ["foo=1,bar=2,foo=3"], + }, + ) ) - self.assertEqual(span_context, trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_context(), trace.INVALID_SPAN_CONTEXT) def test_invalid_parent_id(self): """If the parent id is invalid, we must ignore the full traceparent @@ -125,16 +132,18 @@ def test_invalid_parent_id(self): Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-00000000000000000000000000000000-0000000000000000-00" - ], - "tracestate": ["foo=1,bar=2,foo=3"], - }, + span = get_span_from_context( + FORMAT.extract( + get_as_list, + { + "traceparent": [ + "00-00000000000000000000000000000000-0000000000000000-00" + ], + "tracestate": ["foo=1,bar=2,foo=3"], + }, + ) ) - self.assertEqual(span_context, trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_context(), trace.INVALID_SPAN_CONTEXT) def test_no_send_empty_tracestate(self): """If the tracestate is empty, do not set the header. @@ -145,15 +154,11 @@ def test_no_send_empty_tracestate(self): empty tracestate headers but SHOULD avoid sending them. """ output = {} # type:typing.Dict[str, str] - mock_span = Mock() - mock_span.configure_mock( - **{ - "get_context.return_value": trace.SpanContext( - self.TRACE_ID, self.SPAN_ID - ) - } + span = trace.DefaultSpan( + trace.SpanContext(self.TRACE_ID, self.SPAN_ID) ) - FORMAT.inject(mock_span, dict.__setitem__, output) + ctx = set_span_in_context(span) + FORMAT.inject(dict.__setitem__, output, ctx) self.assertTrue("traceparent" in output) self.assertFalse("tracestate" in output) @@ -165,48 +170,55 @@ def test_format_not_supported(self): If the version cannot be parsed, return an invalid trace header. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-12345678901234567890123456789012-" - "1234567890123456-00-residue" - ], - "tracestate": ["foo=1,bar=2,foo=3"], - }, + span = get_span_from_context( + FORMAT.extract( + get_as_list, + { + "traceparent": [ + "00-12345678901234567890123456789012-" + "1234567890123456-00-residue" + ], + "tracestate": ["foo=1,bar=2,foo=3"], + }, + ) ) - self.assertEqual(span_context, trace.INVALID_SPAN_CONTEXT) + self.assertEqual(span.get_context(), trace.INVALID_SPAN_CONTEXT) def test_propagate_invalid_context(self): """Do not propagate invalid trace context.""" output = {} # type:typing.Dict[str, str] - FORMAT.inject(trace.INVALID_SPAN, dict.__setitem__, output) + ctx = set_span_in_context(trace.INVALID_SPAN) + FORMAT.inject(dict.__setitem__, output, context=ctx) self.assertFalse("traceparent" in output) def test_tracestate_empty_header(self): """Test tracestate with an additional empty header (should be ignored) """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00" - ], - "tracestate": ["foo=1", ""], - }, + span = get_span_from_context( + FORMAT.extract( + get_as_list, + { + "traceparent": [ + "00-12345678901234567890123456789012-1234567890123456-00" + ], + "tracestate": ["foo=1", ""], + }, + ) ) - self.assertEqual(span_context.trace_state["foo"], "1") + self.assertEqual(span.get_context().trace_state["foo"], "1") def test_tracestate_header_with_trailing_comma(self): """Do not propagate invalid trace context. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00" - ], - "tracestate": ["foo=1,"], - }, + span = get_span_from_context( + FORMAT.extract( + get_as_list, + { + "traceparent": [ + "00-12345678901234567890123456789012-1234567890123456-00" + ], + "tracestate": ["foo=1,"], + }, + ) ) - self.assertEqual(span_context.trace_state["foo"], "1") + self.assertEqual(span.get_context().trace_state["foo"], "1") diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index 4da487618b..3e03c9aa02 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -15,7 +15,17 @@ import typing import opentelemetry.trace as trace -from opentelemetry.context.propagation.httptextformat import HTTPTextFormat +from opentelemetry.context import Context +from opentelemetry.trace.propagation import ( + get_span_from_context, + set_span_in_context, +) +from opentelemetry.trace.propagation.httptextformat import ( + Getter, + HTTPTextFormat, + HTTPTextFormatT, + Setter, +) class B3Format(HTTPTextFormat): @@ -32,15 +42,19 @@ class B3Format(HTTPTextFormat): FLAGS_KEY = "x-b3-flags" _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) - @classmethod - def extract(cls, get_from_carrier, carrier): + def extract( + self, + get_from_carrier: Getter[HTTPTextFormatT], + carrier: HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> Context: trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) sampled = "0" flags = None single_header = _extract_first_element( - get_from_carrier(carrier, cls.SINGLE_HEADER_KEY) + get_from_carrier(carrier, self.SINGLE_HEADER_KEY) ) if single_header: # The b3 spec calls for the sampling state to be @@ -58,29 +72,29 @@ def extract(cls, get_from_carrier, carrier): elif len(fields) == 4: trace_id, span_id, sampled, _ = fields else: - return trace.INVALID_SPAN_CONTEXT + return set_span_in_context(trace.INVALID_SPAN) else: trace_id = ( _extract_first_element( - get_from_carrier(carrier, cls.TRACE_ID_KEY) + get_from_carrier(carrier, self.TRACE_ID_KEY) ) or trace_id ) span_id = ( _extract_first_element( - get_from_carrier(carrier, cls.SPAN_ID_KEY) + get_from_carrier(carrier, self.SPAN_ID_KEY) ) or span_id ) sampled = ( _extract_first_element( - get_from_carrier(carrier, cls.SAMPLED_KEY) + get_from_carrier(carrier, self.SAMPLED_KEY) ) or sampled ) flags = ( _extract_first_element( - get_from_carrier(carrier, cls.FLAGS_KEY) + get_from_carrier(carrier, self.FLAGS_KEY) ) or flags ) @@ -90,32 +104,41 @@ def extract(cls, get_from_carrier, carrier): # flag values set. Since the setting of at least one implies # the desire for some form of sampling, propagate if either # header is set to allow. - if sampled in cls._SAMPLE_PROPAGATE_VALUES or flags == "1": + if sampled in self._SAMPLE_PROPAGATE_VALUES or flags == "1": options |= trace.TraceFlags.SAMPLED - return trace.SpanContext( - # trace an span ids are encoded in hex, so must be converted - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), - trace_flags=trace.TraceFlags(options), - trace_state=trace.TraceState(), + return set_span_in_context( + trace.DefaultSpan( + trace.SpanContext( + # trace an span ids are encoded in hex, so must be converted + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), + trace_flags=trace.TraceFlags(options), + trace_state=trace.TraceState(), + ) + ) ) - @classmethod - def inject(cls, span, set_in_carrier, carrier): + def inject( + self, + set_in_carrier: Setter[HTTPTextFormatT], + carrier: HTTPTextFormatT, + context: typing.Optional[Context] = None, + ) -> None: + span = get_span_from_context(context=context) sampled = (trace.TraceFlags.SAMPLED & span.context.trace_flags) != 0 set_in_carrier( - carrier, cls.TRACE_ID_KEY, format_trace_id(span.context.trace_id) + carrier, self.TRACE_ID_KEY, format_trace_id(span.context.trace_id), ) set_in_carrier( - carrier, cls.SPAN_ID_KEY, format_span_id(span.context.span_id) + carrier, self.SPAN_ID_KEY, format_span_id(span.context.span_id) ) if span.parent is not None: set_in_carrier( carrier, - cls.PARENT_SPAN_ID_KEY, + self.PARENT_SPAN_ID_KEY, format_span_id(span.parent.context.span_id), ) - set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0") + set_in_carrier(carrier, self.SAMPLED_KEY, "1" if sampled else "0") def format_trace_id(trace_id: int) -> str: @@ -128,10 +151,9 @@ def format_span_id(span_id: int) -> str: return format(span_id, "016x") -_T = typing.TypeVar("_T") - - -def _extract_first_element(items: typing.Iterable[_T]) -> typing.Optional[_T]: +def _extract_first_element( + items: typing.Iterable[HTTPTextFormatT], +) -> typing.Optional[HTTPTextFormatT]: if items is None: return None return next(iter(items), None) diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py index ae55b02bfd..0cdda1bcd0 100644 --- a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -17,6 +17,10 @@ import opentelemetry.sdk.context.propagation.b3_format as b3_format import opentelemetry.sdk.trace as trace import opentelemetry.trace as trace_api +from opentelemetry.trace.propagation import ( + get_span_from_context, + set_span_in_context, +) FORMAT = b3_format.B3Format() @@ -28,7 +32,8 @@ def get_as_list(dict_object, key): def get_child_parent_new_carrier(old_carrier): - parent_context = FORMAT.extract(get_as_list, old_carrier) + ctx = FORMAT.extract(get_as_list, old_carrier) + parent_context = get_span_from_context(ctx).get_context() parent = trace.Span("parent", parent_context) child = trace.Span( @@ -43,7 +48,8 @@ def get_child_parent_new_carrier(old_carrier): ) new_carrier = {} - FORMAT.inject(child, dict.__setitem__, new_carrier) + ctx = set_span_in_context(child) + FORMAT.inject(dict.__setitem__, new_carrier, context=ctx) return child, parent, new_carrier @@ -222,7 +228,8 @@ def test_invalid_single_header(self): invalid SpanContext. """ carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - span_context = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(get_as_list, carrier) + span_context = get_span_from_context(ctx).get_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -232,7 +239,9 @@ def test_missing_trace_id(self): FORMAT.SPAN_ID_KEY: self.serialized_span_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + + ctx = FORMAT.extract(get_as_list, carrier) + span_context = get_span_from_context(ctx).get_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) def test_missing_span_id(self): @@ -241,5 +250,7 @@ def test_missing_span_id(self): FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + + ctx = FORMAT.extract(get_as_list, carrier) + span_context = get_span_from_context(ctx).get_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)