Skip to content

Commit

Permalink
fixing #1945
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Jul 15, 2022
1 parent 0b72305 commit 6eca40f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
16 changes: 16 additions & 0 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7468,6 +7468,22 @@ def _get_frame_label_fields(self):
).keys()
]

def _get_root_fields(self, fields):
root_fields = []
for field in fields:
if self.media_type == fom.VIDEO and field.startswith(
self._FRAMES_PREFIX
):
# Converts `frames.root[.x.y]` to `frames.root`
root = ".".join(field.split(".", 2)[:2])
else:
# Converts `root[.x.y]` to `root`
root = field.split(".", 1)[0]

root_fields.append(root)

return root_fields

def _validate_root_field(self, field_name, include_private=False):
_ = self._get_root_field_type(
field_name, include_private=include_private
Expand Down
20 changes: 7 additions & 13 deletions fiftyone/core/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,31 +527,25 @@ def load_run_view(cls, samples, key, select_fields=False):
#

fields = cls._get_run_fields(samples, key)
root_fields = [f for f in fields if "." not in f]
_select_fields = root_fields
for field in fields:
if not any(f.startswith(field) for f in root_fields):
_select_fields.append(field)
root_fields = samples._get_root_fields(fields)

view = view.select_fields(_select_fields)
view = view.select_fields(root_fields)

#
# Hide any ancillary info on the same fields
#

_exclude_fields = []
exclude_fields = []
for _key in cls.list_runs(samples):
if _key == key:
continue

for field in cls._get_run_fields(samples, _key):
if "." in field and any(
field.startswith(f) for f in root_fields
):
_exclude_fields.append(field)
if any(field.startswith(r + ".") for r in root_fields):
exclude_fields.append(field)

if _exclude_fields:
view = view.exclude_fields(_exclude_fields)
if exclude_fields:
view = view.exclude_fields(exclude_fields)

return view

Expand Down
28 changes: 28 additions & 0 deletions tests/unittests/evaluation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,34 @@ def test_evaluate_polylines_open_images(self):

self._evaluate_open_images(dataset, kwargs)

@drop_datasets
def test_load_evaluation_view_select_fields(self):
dataset = self._make_detections_dataset()

dataset.clone_sample_field("predictions", "predictions2")

dataset.evaluate_detections(
"predictions", gt_field="ground_truth", eval_key="eval"
)
dataset.evaluate_detections(
"predictions2", gt_field="ground_truth", eval_key="eval2"
)

view = dataset.load_evaluation_view("eval", select_fields=True)

schema = view.get_field_schema()
self.assertNotIn("predictions2", schema)
self.assertNotIn("eval2_tp", schema)
self.assertNotIn("eval2_fp", schema)
self.assertNotIn("eval2_fn", schema)

sample = view.last()
detection = sample["ground_truth"].detections[0]

self.assertIsNotNone(detection["eval"])
with self.assertRaises(KeyError):
detection["eval2"]


class VideoDetectionsTests(unittest.TestCase):
def _make_video_detections_dataset(self):
Expand Down

0 comments on commit 6eca40f

Please sign in to comment.