diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py
index 4ef914f9..d05cb38c 100644
--- a/src/syrupy/assertion.py
+++ b/src/syrupy/assertion.py
@@ -43,6 +43,7 @@ class SnapshotAssertion:
_extension_class: Type["AbstractSyrupyExtension"] = attr.ib(kw_only=True)
_test_location: "TestLocation" = attr.ib(kw_only=True)
_update_snapshots: bool = attr.ib(kw_only=True)
+ _extension: Optional["AbstractSyrupyExtension"] = attr.ib(init=False, default=None)
_executions: int = attr.ib(init=False, default=0, kw_only=True)
_execution_results: Dict[int, "AssertionResult"] = attr.ib(
init=False, factory=dict, kw_only=True
@@ -51,12 +52,15 @@ class SnapshotAssertion:
def __attrs_post_init__(self) -> None:
self._session.register_request(self)
+ def __init_extension(
+ self, extension_class: Type["AbstractSyrupyExtension"]
+ ) -> "AbstractSyrupyExtension":
+ return extension_class(test_location=self._test_location)
+
@property
def extension(self) -> "AbstractSyrupyExtension":
- if not getattr(self, "_extension", None):
- self._extension: "AbstractSyrupyExtension" = self._extension_class(
- test_location=self._test_location
- )
+ if not self._extension:
+ self._extension = self.__init_extension(self._extension_class)
return self._extension
@property
@@ -70,6 +74,10 @@ def executions(self) -> Dict[int, AssertionResult]:
def use_extension(
self, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
) -> "SnapshotAssertion":
+ """
+ Creates a new snapshot assertion fixture with the same options but using
+ specified extension class. This does not preserve assertion index or state.
+ """
return self.__class__(
update_snapshots=self._update_snapshots,
test_location=self._test_location,
@@ -92,6 +100,16 @@ def get_assert_diff(self, data: "SerializableData") -> List[str]:
diff.extend(self.extension.diff_lines(serialized_data, snapshot_data))
return diff
+ def __call__(
+ self, *, extension_class: Optional[Type["AbstractSyrupyExtension"]]
+ ) -> "SnapshotAssertion":
+ """
+ Modifies assertion instance options
+ """
+ if extension_class:
+ self._extension = self.__init_extension(extension_class)
+ return self
+
def __repr__(self) -> str:
attrs_to_repr = ["name", "num_executions"]
attrs_repr = ", ".join(f"{a}={repr(getattr(self, a))}" for a in attrs_to_repr)
@@ -131,6 +149,13 @@ def _assert(self, data: "SerializableData") -> bool:
updated=snapshot_updated,
)
self._executions += 1
+ self._post_assert()
+
+ def _post_assert(self) -> None:
+ """
+ Restores assertion instance options
+ """
+ self._extension = None
def _recall_data(self, index: int) -> Optional["SerializableData"]:
try:
diff --git a/tests/__snapshots__/test_extension_image.ambr b/tests/__snapshots__/test_extension_image.ambr
new file mode 100644
index 00000000..6a118b13
--- /dev/null
+++ b/tests/__snapshots__/test_extension_image.ambr
@@ -0,0 +1,3 @@
+# name: test_multiple_snapshot_extensions.1
+ ''
+---
diff --git a/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.2.png b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.2.png
new file mode 100644
index 00000000..7eb2b9ad
Binary files /dev/null and b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.2.png differ
diff --git a/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.3.svg b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.3.svg
new file mode 100644
index 00000000..90ebb8dd
--- /dev/null
+++ b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.3.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.svg b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.svg
new file mode 100644
index 00000000..90ebb8dd
--- /dev/null
+++ b/tests/__snapshots__/test_extension_image/test_multiple_snapshot_extensions.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/tests/test_extension_image.py b/tests/test_extension_image.py
index 329d0742..af67e103 100644
--- a/tests/test_extension_image.py
+++ b/tests/test_extension_image.py
@@ -8,40 +8,48 @@
)
+actual_png = base64.b64decode(
+ b"iVBORw0KGgoAAAANSUhEUgAAADIAAAAyBAMAAADsEZWCAAAAG1BMVEXMzMy"
+ b"Wlpaqqqq3t7exsbGcnJy+vr6jo6PFxcUFpPI/AAAACXBIWXMAAA7EAAAOxA"
+ b"GVKw4bAAAAQUlEQVQ4jWNgGAWjgP6ASdncAEaiAhaGiACmFhCJLsMaIiDAE"
+ b"QEi0WXYEiMCOCJAJIY9KuYGTC0gknpuHwXDGwAA5fsIZw0iYWYAAAAASUVO"
+ b"RK5CYII="
+)
+actual_svg = (
+ ''
+ ''
+)
+
+
@pytest.fixture
def snapshot_png(snapshot):
return snapshot.use_extension(PNGImageSnapshotExtension)
-def test_image(snapshot_png, snapshot_svg):
- actual_png = base64.b64decode(
- b"iVBORw0KGgoAAAANSUhEUgAAADIAAAAyBAMAAADsEZWCAAAAG1BMVEXMzMy"
- b"Wlpaqqqq3t7exsbGcnJy+vr6jo6PFxcUFpPI/AAAACXBIWXMAAA7EAAAOxA"
- b"GVKw4bAAAAQUlEQVQ4jWNgGAWjgP6ASdncAEaiAhaGiACmFhCJLsMaIiDAE"
- b"QEi0WXYEiMCOCJAJIY9KuYGTC0gknpuHwXDGwAA5fsIZw0iYWYAAAAASUVO"
- b"RK5CYII="
- )
+def test_image(snapshot_png):
assert actual_png == snapshot_png
-@pytest.fixture
-def snapshot_svg(snapshot):
- return snapshot.use_extension(SVGImageSnapshotExtension)
+def test_image_vector(snapshot):
+ """
+ Example of creating a previewable svg snapshot
+ """
+ assert snapshot(extension_class=SVGImageSnapshotExtension) == actual_svg
-def test_image_vector(snapshot_svg):
+def test_multiple_snapshot_extensions(snapshot):
"""
- Example of creating a previewable svg snapshot
+ Example of switching extension classes on the fly.
+ These should be indexed in order of assertion.
"""
- actual_svg = (
- ''
- ''
- )
- assert snapshot_svg == actual_svg
+ assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)
+ assert actual_svg == snapshot # uses initial extension class
+ assert actual_png == snapshot(extension_class=PNGImageSnapshotExtension)
+ assert actual_svg == snapshot(extension_class=SVGImageSnapshotExtension)