Skip to content

Commit

Permalink
CVAT v2.4 support (#2903)
Browse files Browse the repository at this point in the history
* wip: tests passing with cvat v2.4, backwards compat with v2.3 todo

* wip: added branching to support cvat server versions < v2.4

* CVAT v2.4 updates (#2959)

* merge _get_paginated_results_2, parse versions

* version

* remove print

* fix v2.4 login issue

* trigger job id check less frequently

* nit

* Read file contents before passing to requests.post
to avoid hitting limit on number of open filehandles

* linting

* more linting

---------

Co-authored-by: Eric Hofesmann <[email protected]>
Co-authored-by: brimoor <[email protected]>
  • Loading branch information
3 people authored and lanzhenw committed May 10, 2023
1 parent f40ac2a commit 703476d
Showing 1 changed file with 118 additions and 45 deletions.
163 changes: 118 additions & 45 deletions fiftyone/utils/cvat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
import math
from collections import defaultdict
from copy import copy, deepcopy
from datetime import datetime
import itertools
import logging
import math
import multiprocessing
import multiprocessing.dummy
import os
from packaging.version import Version
import time
import warnings
import webbrowser

Expand Down Expand Up @@ -3525,7 +3527,7 @@ def base_url(self):

@property
def base_api_url(self):
if self._server_version == 1:
if self._server_version.major == 1:
return "%s/api/v1" % self.base_url

return "%s/api" % self.base_url
Expand All @@ -3534,6 +3536,10 @@ def base_api_url(self):
def login_url(self):
return "%s/auth/login" % self.base_api_url

@property
def about_url(self):
return "%s/server/about" % self.base_api_url

@property
def users_url(self):
return "%s/users" % self.base_api_url
Expand Down Expand Up @@ -3592,11 +3598,21 @@ def task_annotation_formatted_url(
anno_filepath,
)

def labels_url(self, task_id):
# server_version >= 2.4 only
return "%s/labels?task_id=%d" % (self.base_api_url, task_id)

def jobs_url(self, task_id):
return "%s/jobs" % self.task_url(task_id)
if self._server_version >= Version("2.4"):
return "%s/jobs?task_id=%d" % (self.base_api_url, task_id)
else:
return "%s/jobs" % self.task_url(task_id)

def job_url(self, task_id, job_id):
return "%s/%d" % (self.jobs_url(task_id), job_id)
if self._server_version >= Version("2.4"):
return self.taskless_job_url(job_id)
else:
return "%s/%d" % (self.jobs_url(task_id), job_id)

def taskless_job_url(self, job_id):
return "%s/jobs/%d" % (self.base_api_url, job_id)
Expand All @@ -3621,13 +3637,13 @@ def project_id_search_url(self, project_id):

@property
def assignee_key(self):
if self._server_version == 1:
if self._server_version.major == 1:
return "assignee_id"

return "assignee"

def _parse_reviewers(self, job_reviewers):
if self._server_version == 2 and job_reviewers is not None:
if self._server_version.major > 1 and job_reviewers is not None:
logger.warning("CVAT v2 servers do not support `job_reviewers`")
return None

Expand Down Expand Up @@ -3655,20 +3671,38 @@ def _setup(self):
# pylint: disable=too-many-function-args
self._session.headers.update(self._headers)

self._server_version = 2
self._server_version = Version("2")

try:
self._login(username, password)
except requests.exceptions.HTTPError as e:
if e.response.status_code != 404:
raise e

self._server_version = 1
self._server_version = Version("1")
self._login(username, password)

self._add_referer()
self._add_organization()

try:
response = self.get(self.about_url).json()
ver = Version(response["version"])
if ver.major != self._server_version.major:
logger.warning(
"CVAT server major versions don't match: %s vs %s",
ver.major,
self._server_version.major,
)

self._server_version = ver
except Exception as e:
logger.debug(
"Failed to access or parse CVAT server version: %s", e
)

logger.debug("CVAT server version: %s", self._server_version)

def _add_referer(self):
if "Referer" not in self._session.headers:
self._session.headers["Referer"] = self.login_url
Expand All @@ -3688,7 +3722,7 @@ def _login(self, username, password):
self._session.post,
self.login_url,
print_error_info=False,
data={"username": username, "password": password},
json={"username": username, "password": password},
)

if "csrftoken" in response.cookies:
Expand Down Expand Up @@ -3878,7 +3912,7 @@ def list_projects(self):
the list of project IDs
"""
return self._get_paginated_results(
self.projects_url, self.projects_page_url, value="id"
self.projects_url, get_page_url=self.projects_page_url, value="id"
)

def project_exists(self, project_id):
Expand Down Expand Up @@ -3934,16 +3968,23 @@ def get_project_tasks(self, project_id):
the list of task IDs
"""
resp = self.get(self.project_url(project_id)).json()
tasks = []
for task in resp.get("tasks", []):
if isinstance(task, int):
# For CVATv2 servers, task ids are stored directly as an array
# of integers
tasks.append(task)
else:
# For CVATv1 servers, project tasks are dictionaries we need to
# exctract "id" from
tasks.append(task["id"])
val = resp.get("tasks", [])

if self._server_version >= Version("2.4"):
tasks = self._get_paginated_results(val["url"])
tasks = [x["id"] for x in tasks]
else:
tasks = []
for task in val:
if isinstance(task, int):
# For v2 servers, task ids are stored directly as an array
# of integers
tasks.append(task)
else:
# For v1 servers, project tasks are dictionaries we need to
# exctract "id" from
tasks.append(task["id"])

return tasks

def create_task(
Expand Down Expand Up @@ -4002,12 +4043,12 @@ def create_task(
if issue_tracker is not None:
task_json["bug_tracker"] = issue_tracker

task_resp = self.post(self.tasks_url, json=task_json).json()
task_id = task_resp["id"]
task_id, labels = self._get_task_id_labels_json(task_json)

# @todo: see _get_attr_class_maps
class_id_map = {}
attr_id_map = {}
for label in task_resp["labels"]:
for label in labels:
class_id = label["id"]
class_id_map[label["name"]] = class_id
attr_id_map[class_id] = {}
Expand All @@ -4031,7 +4072,7 @@ def list_tasks(self):
the list of task IDs
"""
return self._get_paginated_results(
self.tasks_url, self.tasks_page_url, value="id"
self.tasks_url, get_page_url=self.tasks_page_url, value="id"
)

def task_exists(self, task_id):
Expand Down Expand Up @@ -4133,35 +4174,42 @@ def upload_data(
if len(paths) == 1 and fom.get_media_type(paths[0]) == fom.VIDEO:
# Video task
filename = os.path.basename(paths[0])
open_file = open(paths[0], "rb")
files["client_files[0]"] = (filename, open_file)
open_files.append(open_file)
f = open(paths[0], "rb")
files["client_files[0]"] = (filename, f)
open_files.append(f)
else:
# Image task
for idx, path in enumerate(paths):
# IMPORTANT: CVAT organizes media within a task alphabetically
# by filename, so we must give CVAT filenames whose
# alphabetical order matches the order of `paths`
filename = "%06d_%s" % (idx, os.path.basename(path))
open_file = open(path, "rb")
files["client_files[%d]" % idx] = (filename, open_file)
open_files.append(open_file)
with open(path, "rb") as f:
files["client_files[%d]" % idx] = (filename, f.read())

try:
self.post(self.task_data_url(task_id), data=data, files=files)
except Exception as e:
raise e
finally:
for f in open_files:
f.close()

# @todo is this loop really needed?
# It can take a bit for jobs to show up, so we poll
job_ids = []
while not job_ids:
job_resp = self.get(self.jobs_url(task_id))
job_resp_json = job_resp.json()
if "results" in job_resp_json:
job_resp_json = job_resp_json["results"]
url = self.jobs_url(task_id)
if self._server_version >= Version("2.4"):
job_resp_json = self._get_paginated_results(url)
else:
job_resp = self.get(url)
job_resp_json = job_resp.json()
if "results" in job_resp_json:
job_resp_json = job_resp_json["results"]

job_ids = [j["id"] for j in job_resp_json]
if not job_ids:
time.sleep(1)

if job_assignees is not None:
num_assignees = len(job_assignees)
Expand All @@ -4174,7 +4222,7 @@ def upload_data(
job_patch = {self.assignee_key: user_id}
self.patch(self.taskless_job_url(job_id), json=job_patch)

if self._server_version == 1 and job_reviewers is not None:
if self._server_version.major == 1 and job_reviewers is not None:
num_reviewers = len(job_reviewers)
for idx, job_id in enumerate(job_ids):
# Round robin strategy
Expand Down Expand Up @@ -4599,37 +4647,43 @@ def download_annotations(self, results):
return annotations

def _get_attr_class_maps(self, task_id):
task_json = self.get(self.task_url(task_id)).json()

labels = self._get_task_labels(task_id)
_class_map = {}
attr_id_map = {}
for label in task_json["labels"]:
for label in labels:
_class_map[label["id"]] = label["name"]
attr_id_map[label["id"]] = {
i["name"]: i["id"] for i in label["attributes"]
}

# AL: not sure why we didn't just reverse keys/vals initially
class_map_rev = {n: i for i, n in _class_map.items()}

return attr_id_map, class_map_rev

def _get_paginated_results(self, base_url, get_page_url, value=None):
def _get_paginated_results(self, base_url, get_page_url=None, value=None):
results = []
page_number = 1
page = base_url
while True:
response = self.get(page).json()
if "results" not in response:
break

for result in response["results"]:
if value is not None:
results.append(result[value])
else:
results.append(result)

if not response["next"]:
page = response.get("next", None)

if not page:
break

page_number += 1
page = get_page_url(page_number)
if get_page_url is not None:
page_number += 1
page = get_page_url(page_number)

return results

Expand Down Expand Up @@ -4665,14 +4719,33 @@ def _get_project_labels(self, project_id):
raise ValueError("Project '%s' not found" % project_id)

resp = self.get(self.project_url(project_id)).json()
return resp["labels"]
labels = resp["labels"]

if self._server_version >= Version("2.4"):
labels = self._get_paginated_results(labels["url"])

return labels

def _get_task_labels(self, task_id):
resp = self.get(self.task_url(task_id)).json()
if "labels" not in resp:
raise ValueError("Task '%s' not found" % task_id)

return resp["labels"]
labels = resp["labels"]
if self._server_version >= Version("2.4"):
labels = self._get_paginated_results(labels["url"])

return labels

def _get_task_id_labels_json(self, task_json):
resp = self.post(self.tasks_url, json=task_json).json()
task_id = resp["id"]

labels = resp["labels"]
if self._server_version >= Version("2.4"):
labels = self._get_paginated_results(labels["url"])

return task_id, labels

def _parse_project_details(self, project_name, project_id):
if project_id is not None:
Expand Down

0 comments on commit 703476d

Please sign in to comment.