Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revalidate stages upon view reload #2890

Merged
merged 3 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions fiftyone/core/clips.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def keep_fields(self):
super().keep_fields()

def reload(self):
"""Reloads this view from the source collection in the database.
"""Reloads the view.

Note that :class:`ClipView` instances are not singletons, so any
in-memory clips extracted from this view will not be updated by calling
Expand All @@ -298,11 +298,14 @@ def reload(self):
# This assumes that calling `load_view()` when the current clips
# dataset has been deleted will cause a new one to be generated
#

self._clips_dataset.delete()
_view = self._clips_stage.load_view(self._source_collection)
self._clips_dataset = _view._clips_dataset

_view = self._base_view
for stage in self._stages:
_view = _view.add_stage(stage)

def _sync_source_sample(self, sample):
if not self._classification_field:
return
Expand Down
7 changes: 5 additions & 2 deletions fiftyone/core/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def keep_fields(self):
super().keep_fields()

def reload(self):
"""Reloads this view from the source collection in the database.
"""Reloads the view.

Note that :class:`PatchView` instances are not singletons, so any
in-memory patches extracted from this view will not be updated by
Expand All @@ -311,11 +311,14 @@ def reload(self):
# This assumes that calling `load_view()` when the current patches
# dataset has been deleted will cause a new one to be generated
#

self._patches_dataset.delete()
_view = self._patches_stage.load_view(self._source_collection)
self._patches_dataset = _view._patches_dataset

_view = self._base_view
for stage in self._stages:
_view = _view.add_stage(stage)

def _sync_source_sample(self, sample):
for field in self._label_fields:
self._sync_source_sample_field(sample, field)
Expand Down
7 changes: 5 additions & 2 deletions fiftyone/core/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def keep_fields(self):
super().keep_fields()

def reload(self):
"""Reloads this view from the source collection in the database.
"""Reloads the view.

Note that :class:`FrameView` instances are not singletons, so any
in-memory frames extracted from this view will not be updated by
Expand All @@ -272,11 +272,14 @@ def reload(self):
# This assumes that calling `load_view()` when the current patches
# dataset has been deleted will cause a new one to be generated
#

self._frames_dataset.delete()
_view = self._frames_stage.load_view(self._source_collection)
self._frames_dataset = _view._frames_dataset

_view = self._base_view
for stage in self._stages:
_view = _view.add_stage(stage)

def _set_labels(self, field_name, sample_ids, label_docs):
super()._set_labels(field_name, sample_ids, label_docs)

Expand Down
6 changes: 5 additions & 1 deletion fiftyone/core/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,14 +1100,18 @@ def clone(self, name=None, persistent=False):
)

def reload(self):
"""Reloads the underlying dataset from the database.
"""Reloads the view.

Note that :class:`fiftyone.core.sample.SampleView` instances are not
singletons, so any in-memory samples extracted from this view will not
be updated by calling this method.
"""
self._dataset.reload()

_view = self._base_view
for stage in self._stages:
_view = _view.add_stage(stage)

def to_dict(
self,
rel_dir=None,
Expand Down
46 changes: 36 additions & 10 deletions tests/unittests/similarity_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,32 @@ def test_image_similarity(self):
query_id = dataset.first().id

view1 = dataset.sort_by_similarity(query_id)
values1 = view1.values("id")

view2 = dataset.sort_by_similarity(query_id, reverse=True)
values2 = view2.values("id")

self.assertEqual(
view1.values("id"), list(reversed(view2.values("id")))
)
self.assertListEqual(values2, values1[::-1])

view3 = dataset.sort_by_similarity(query_id, k=4)
values3 = view3.values("id")

self.assertEqual(len(view3), 4)
self.assertListEqual(values3, values1[:4])

view4 = dataset.sort_by_similarity(query_id, brain_key="img_sim")
values4 = view4.values("id")

self.assertListEqual(values4, values1)

view5 = view4.limit(2)
values5 = view5.values("id")

self.assertListEqual(values5, values1[:2])

view5.reload()
values5 = view5.values("id")

self.assertEqual(view1.values("id"), view4.values("id"))
self.assertListEqual(values5, values1[:2])

@drop_datasets
def test_object_similarity(self):
Expand Down Expand Up @@ -163,19 +176,32 @@ def test_object_similarity(self):
patches = dataset.to_patches("ground_truth")

view1 = patches.sort_by_similarity(query_id)
values1 = view1.values("id")

view2 = patches.sort_by_similarity(query_id, reverse=True)
values2 = view2.values("id")

self.assertEqual(
view1.values("id"), list(reversed(view2.values("id")))
)
self.assertListEqual(values2, values1[::-1])

view3 = patches.sort_by_similarity(query_id, k=4)
values3 = view3.values("id")

self.assertEqual(len(view3), 4)
self.assertListEqual(values3, values1[:4])

view4 = patches.sort_by_similarity(query_id, brain_key="obj_sim")
values4 = view4.values("id")

self.assertEqual(values4, values1)

view5 = view4.limit(2)
values5 = view5.values("id")

self.assertListEqual(values5, values1[:2])

view5.reload()
values5 = view5.values("id")

self.assertEqual(view1.values("id"), view4.values("id"))
self.assertListEqual(values5, values1[:2])


if __name__ == "__main__":
Expand Down
29 changes: 29 additions & 0 deletions tests/unittests/view_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,35 @@ def test_view_name_readonly(self):
with self.assertRaises(AttributeError):
view.name = "new_name"

@drop_datasets
def test_reload(self):
dataset = fo.Dataset()
dataset.add_samples(
[
fo.Sample(filepath="image1.jpg", foo="bar"),
fo.Sample(filepath="image2.jpg", spam="eggs"),
fo.Sample(filepath="image3.jpg"),
fo.Sample(filepath="image4.jpg"),
fo.Sample(filepath="image5.jpg"),
]
)

view = dataset.take(3).sort_by("filepath").select_fields("foo")
sample_ids = view.values("id")

# Reloading should not cause dataset-independent view stage parameters
# like Take's internal random seed to be changed
view.reload()
same_sample_ids = view.values("id")

self.assertListEqual(sample_ids, same_sample_ids)

dataset.delete_sample_field("foo")

# Field `foo` no longer exists, so validation should fail on reload
with self.assertRaises(ValueError):
view.reload()


class ViewFieldTests(unittest.TestCase):
@skip_windows # TODO: don't skip on Windows
Expand Down