Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CVAT v2.4 support #2903

Merged
merged 8 commits into from
May 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 118 additions & 45 deletions fiftyone/utils/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
import math
from collections import defaultdict
from copy import copy, deepcopy
from datetime import datetime
import itertools
import logging
import math
import multiprocessing
import multiprocessing.dummy
import os
from packaging.version import Version
import time
import warnings
import webbrowser

Expand Down Expand Up @@ -3525,7 +3527,7 @@ def base_url(self):

@property
def base_api_url(self):
if self._server_version == 1:
if self._server_version.major == 1:
return "%s/api/v1" % self.base_url

return "%s/api" % self.base_url
Expand All @@ -3534,6 +3536,10 @@ def base_api_url(self):
def login_url(self):
return "%s/auth/login" % self.base_api_url

@property
def about_url(self):
return "%s/server/about" % self.base_api_url

@property
def users_url(self):
return "%s/users" % self.base_api_url
Expand Down Expand Up @@ -3592,11 +3598,21 @@ def task_annotation_formatted_url(
anno_filepath,
)

def labels_url(self, task_id):
# server_version >= 2.4 only
return "%s/labels?task_id=%d" % (self.base_api_url, task_id)

def jobs_url(self, task_id):
return "%s/jobs" % self.task_url(task_id)
if self._server_version >= Version("2.4"):
return "%s/jobs?task_id=%d" % (self.base_api_url, task_id)
else:
return "%s/jobs" % self.task_url(task_id)

def job_url(self, task_id, job_id):
return "%s/%d" % (self.jobs_url(task_id), job_id)
if self._server_version >= Version("2.4"):
return self.taskless_job_url(job_id)
else:
return "%s/%d" % (self.jobs_url(task_id), job_id)

def taskless_job_url(self, job_id):
return "%s/jobs/%d" % (self.base_api_url, job_id)
Expand All @@ -3621,13 +3637,13 @@ def project_id_search_url(self, project_id):

@property
def assignee_key(self):
if self._server_version == 1:
if self._server_version.major == 1:
return "assignee_id"

return "assignee"

def _parse_reviewers(self, job_reviewers):
if self._server_version == 2 and job_reviewers is not None:
if self._server_version.major > 1 and job_reviewers is not None:
logger.warning("CVAT v2 servers do not support `job_reviewers`")
return None

Expand Down Expand Up @@ -3655,20 +3671,38 @@ def _setup(self):
# pylint: disable=too-many-function-args
self._session.headers.update(self._headers)

self._server_version = 2
self._server_version = Version("2")

try:
self._login(username, password)
except requests.exceptions.HTTPError as e:
if e.response.status_code != 404:
raise e

self._server_version = 1
self._server_version = Version("1")
self._login(username, password)

self._add_referer()
self._add_organization()

try:
response = self.get(self.about_url).json()
ver = Version(response["version"])
if ver.major != self._server_version.major:
logger.warning(
"CVAT server major versions don't match: %s vs %s",
ver.major,
self._server_version.major,
)

self._server_version = ver
except Exception as e:
logger.debug(
"Failed to access or parse CVAT server version: %s", e
)

logger.debug("CVAT server version: %s", self._server_version)

def _add_referer(self):
if "Referer" not in self._session.headers:
self._session.headers["Referer"] = self.login_url
Expand All @@ -3688,7 +3722,7 @@ def _login(self, username, password):
self._session.post,
self.login_url,
print_error_info=False,
data={"username": username, "password": password},
json={"username": username, "password": password},
)

if "csrftoken" in response.cookies:
Expand Down Expand Up @@ -3878,7 +3912,7 @@ def list_projects(self):
the list of project IDs
"""
return self._get_paginated_results(
self.projects_url, self.projects_page_url, value="id"
self.projects_url, get_page_url=self.projects_page_url, value="id"
)

def project_exists(self, project_id):
Expand Down Expand Up @@ -3934,16 +3968,23 @@ def get_project_tasks(self, project_id):
the list of task IDs
"""
resp = self.get(self.project_url(project_id)).json()
tasks = []
for task in resp.get("tasks", []):
if isinstance(task, int):
# For CVATv2 servers, task ids are stored directly as an array
# of integers
tasks.append(task)
else:
# For CVATv1 servers, project tasks are dictionaries we need to
# exctract "id" from
tasks.append(task["id"])
val = resp.get("tasks", [])

if self._server_version >= Version("2.4"):
tasks = self._get_paginated_results(val["url"])
tasks = [x["id"] for x in tasks]
else:
tasks = []
for task in val:
if isinstance(task, int):
# For v2 servers, task ids are stored directly as an array
# of integers
tasks.append(task)
else:
# For v1 servers, project tasks are dictionaries we need to
# exctract "id" from
tasks.append(task["id"])

return tasks

def create_task(
Expand Down Expand Up @@ -4002,12 +4043,12 @@ def create_task(
if issue_tracker is not None:
task_json["bug_tracker"] = issue_tracker

task_resp = self.post(self.tasks_url, json=task_json).json()
task_id = task_resp["id"]
task_id, labels = self._get_task_id_labels_json(task_json)

# @todo: see _get_attr_class_maps
class_id_map = {}
attr_id_map = {}
for label in task_resp["labels"]:
for label in labels:
class_id = label["id"]
class_id_map[label["name"]] = class_id
attr_id_map[class_id] = {}
Expand All @@ -4031,7 +4072,7 @@ def list_tasks(self):
the list of task IDs
"""
return self._get_paginated_results(
self.tasks_url, self.tasks_page_url, value="id"
self.tasks_url, get_page_url=self.tasks_page_url, value="id"
)

def task_exists(self, task_id):
Expand Down Expand Up @@ -4133,35 +4174,42 @@ def upload_data(
if len(paths) == 1 and fom.get_media_type(paths[0]) == fom.VIDEO:
# Video task
filename = os.path.basename(paths[0])
open_file = open(paths[0], "rb")
files["client_files[0]"] = (filename, open_file)
open_files.append(open_file)
f = open(paths[0], "rb")
files["client_files[0]"] = (filename, f)
open_files.append(f)
else:
# Image task
for idx, path in enumerate(paths):
# IMPORTANT: CVAT organizes media within a task alphabetically
# by filename, so we must give CVAT filenames whose
# alphabetical order matches the order of `paths`
filename = "%06d_%s" % (idx, os.path.basename(path))
open_file = open(path, "rb")
files["client_files[%d]" % idx] = (filename, open_file)
open_files.append(open_file)
with open(path, "rb") as f:
files["client_files[%d]" % idx] = (filename, f.read())

try:
self.post(self.task_data_url(task_id), data=data, files=files)
except Exception as e:
raise e
finally:
for f in open_files:
f.close()

# @todo is this loop really needed?
# It can take a bit for jobs to show up, so we poll
job_ids = []
while not job_ids:
job_resp = self.get(self.jobs_url(task_id))
job_resp_json = job_resp.json()
if "results" in job_resp_json:
job_resp_json = job_resp_json["results"]
url = self.jobs_url(task_id)
if self._server_version >= Version("2.4"):
job_resp_json = self._get_paginated_results(url)
else:
job_resp = self.get(url)
job_resp_json = job_resp.json()
if "results" in job_resp_json:
job_resp_json = job_resp_json["results"]

job_ids = [j["id"] for j in job_resp_json]
if not job_ids:
time.sleep(1)

if job_assignees is not None:
num_assignees = len(job_assignees)
Expand All @@ -4174,7 +4222,7 @@ def upload_data(
job_patch = {self.assignee_key: user_id}
self.patch(self.taskless_job_url(job_id), json=job_patch)

if self._server_version == 1 and job_reviewers is not None:
if self._server_version.major == 1 and job_reviewers is not None:
num_reviewers = len(job_reviewers)
for idx, job_id in enumerate(job_ids):
# Round robin strategy
Expand Down Expand Up @@ -4599,37 +4647,43 @@ def download_annotations(self, results):
return annotations

def _get_attr_class_maps(self, task_id):
task_json = self.get(self.task_url(task_id)).json()

labels = self._get_task_labels(task_id)
_class_map = {}
attr_id_map = {}
for label in task_json["labels"]:
for label in labels:
_class_map[label["id"]] = label["name"]
attr_id_map[label["id"]] = {
i["name"]: i["id"] for i in label["attributes"]
}

# AL: not sure why we didn't just reverse keys/vals initially
class_map_rev = {n: i for i, n in _class_map.items()}

return attr_id_map, class_map_rev

def _get_paginated_results(self, base_url, get_page_url, value=None):
def _get_paginated_results(self, base_url, get_page_url=None, value=None):
results = []
page_number = 1
page = base_url
while True:
response = self.get(page).json()
if "results" not in response:
break

for result in response["results"]:
if value is not None:
results.append(result[value])
else:
results.append(result)

if not response["next"]:
page = response.get("next", None)

if not page:
break

page_number += 1
page = get_page_url(page_number)
if get_page_url is not None:
page_number += 1
page = get_page_url(page_number)

return results

Expand Down Expand Up @@ -4665,14 +4719,33 @@ def _get_project_labels(self, project_id):
raise ValueError("Project '%s' not found" % project_id)

resp = self.get(self.project_url(project_id)).json()
return resp["labels"]
labels = resp["labels"]

if self._server_version >= Version("2.4"):
labels = self._get_paginated_results(labels["url"])

return labels

def _get_task_labels(self, task_id):
resp = self.get(self.task_url(task_id)).json()
if "labels" not in resp:
raise ValueError("Task '%s' not found" % task_id)

return resp["labels"]
labels = resp["labels"]
if self._server_version >= Version("2.4"):
labels = self._get_paginated_results(labels["url"])

return labels

def _get_task_id_labels_json(self, task_json):
resp = self.post(self.tasks_url, json=task_json).json()
task_id = resp["id"]

labels = resp["labels"]
if self._server_version >= Version("2.4"):
labels = self._get_paginated_results(labels["url"])

return task_id, labels

def _parse_project_details(self, project_name, project_id):
if project_id is not None:
Expand Down