Skip to content

Commit

Permalink
Refactor PublicTask into a decorator task (#4656)
Browse files Browse the repository at this point in the history
* Refactor PublicTask into a decorator task

* Refactor and docs
  • Loading branch information
stsewd authored and agjohnson committed Oct 2, 2018
1 parent c669284 commit a8bd00a
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 78 deletions.
1 change: 0 additions & 1 deletion readthedocs/core/utils/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .permission_checks import user_id_matches # noqa for unused import
from .public import PublicTask # noqa
from .public import TaskNoPermission # noqa
from .public import permission_check # noqa
from .public import get_public_task_data # noqa
from .retrieve import TaskNotFound # noqa
from .retrieve import get_task_data # noqa
75 changes: 43 additions & 32 deletions readthedocs/core/utils/tasks/public.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""Celery tasks with publicly viewable status"""

from __future__ import absolute_import
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

from celery import Task, states
from django.conf import settings

from .retrieve import TaskNotFound
from .retrieve import get_task_data

from .retrieve import TaskNotFound, get_task_data

__all__ = (
'PublicTask', 'TaskNoPermission', 'permission_check',
'get_public_task_data')
'PublicTask', 'TaskNoPermission', 'get_public_task_data'
)


STATUS_UPDATES_ENABLED = not getattr(settings, 'CELERY_ALWAYS_EAGER', False)
Expand All @@ -19,22 +23,20 @@
class PublicTask(Task):

"""
See oauth.tasks for usage example.
Encapsulates common behaviour to expose a task publicly.
Subclasses need to define a ``run_public`` method.
"""
Tasks should use this class as ``base``. And define a ``check_permission``
property or use the ``permission_check`` decorator.
public_name = 'unknown'
The check_permission should be a function like:
function(request, state, context), and needs to return a boolean value.
@classmethod
def check_permission(cls, request, state, context):
"""Override this method to define who can monitor this task."""
# pylint: disable=unused-argument
return False
See oauth.tasks for usage example.
"""

def get_task_data(self):
"""Return tuple with state to be set next and results task."""
state = 'STARTED'
state = states.STARTED
info = {
'task_name': self.name,
'context': self.request.get('permission_context', {}),
Expand Down Expand Up @@ -66,12 +68,13 @@ def set_public_data(self, data):
self.request.update(public_data=data)
self.update_progress_data()

def run(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
# We override __call__ to let tasks use the run method.
error = False
exception_raised = None
self.set_permission_context(kwargs)
try:
result = self.run_public(*args, **kwargs)
result = self.run(*args, **kwargs)
except Exception as e:
# With Celery 4 we lost the ability to keep our data dictionary into
# ``AsyncResult.info`` when an exception was raised inside the
Expand All @@ -90,22 +93,26 @@ def run(self, *args, **kwargs):

return info

@staticmethod
def permission_check(check):
"""
Decorator for tasks that have PublicTask as base.
def permission_check(check):
"""
Class decorator for subclasses of PublicTask to sprinkle in re-usable
.. note::
The decorator should be on top of the task decorator.
permission checks::
permission checks::
@permission_check(user_id_matches)
class MyTask(PublicTask):
def run_public(self, user_id):
@PublicTask.permission_check(user_id_matches)
@celery.task(base=PublicTask)
def my_public_task(user_id):
pass
"""
def decorator(cls):
cls.check_permission = staticmethod(check)
return cls
return decorator
"""
def decorator(func):
func.check_permission = check
return func
return decorator


class TaskNoPermission(Exception):
Expand Down Expand Up @@ -139,5 +146,9 @@ def get_public_task_data(request, task_id):
context = info.get('context', {})
if not task.check_permission(request, state, context):
raise TaskNoPermission(task_id)
public_name = task.public_name
return public_name, state, info.get('public_data', {}), info.get('error', None)
return (
task.name,
state,
info.get('public_data', {}),
info.get('error', None),
)
12 changes: 9 additions & 3 deletions readthedocs/core/utils/tasks/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Utilities for retrieving task data."""

from __future__ import absolute_import
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

from celery import states
from celery.result import AsyncResult


__all__ = ('TaskNotFound', 'get_task_data')


Expand All @@ -23,7 +29,7 @@ def get_task_data(task_id):

result = AsyncResult(task_id)
state, info = result.state, result.info
if state == 'PENDING':
if state == states.PENDING:
raise TaskNotFound(task_id)
if 'task_name' not in info:
raise TaskNotFound(task_id)
Expand Down
5 changes: 0 additions & 5 deletions readthedocs/oauth/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,3 @@

class OAuthConfig(AppConfig):
name = 'readthedocs.oauth'

def ready(self):
from .tasks import SyncRemoteRepositories
from readthedocs.worker import app
app.tasks.register(SyncRemoteRepositories)
35 changes: 16 additions & 19 deletions readthedocs/oauth/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
"""Tasks for OAuth services."""

from __future__ import (
absolute_import, division, print_function, unicode_literals)
absolute_import,
division,
print_function,
unicode_literals,
)

import logging

from allauth.socialaccount.providers import registry as allauth_registry
from django.contrib.auth.models import User

from readthedocs.core.utils.tasks import (
PublicTask, permission_check, user_id_matches)
from readthedocs.core.utils.tasks import PublicTask, user_id_matches
from readthedocs.oauth.notifications import (
AttachWebhookNotification, InvalidProjectWebhookNotification)
AttachWebhookNotification,
InvalidProjectWebhookNotification,
)
from readthedocs.projects.models import Project
from readthedocs.worker import app

Expand All @@ -21,21 +26,13 @@
log = logging.getLogger(__name__)


@permission_check(user_id_matches)
class SyncRemoteRepositories(PublicTask):

name = __name__ + '.sync_remote_repositories'
public_name = 'sync_remote_repositories'
queue = 'web'

def run_public(self, user_id):
user = User.objects.get(pk=user_id)
for service_cls in registry:
for service in service_cls.for_user(user):
service.sync()


sync_remote_repositories = SyncRemoteRepositories()
@PublicTask.permission_check(user_id_matches)
@app.task(queue='web', base=PublicTask)
def sync_remote_repositories(user_id):
user = User.objects.get(pk=user_id)
for service_cls in registry:
for service in service_cls.for_user(user):
service.sync()


@app.task(queue='web')
Expand Down
29 changes: 19 additions & 10 deletions readthedocs/restapi/views/task_views.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""Endpoints relating to task/job status, etc."""

from __future__ import absolute_import
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

import logging

from django.core.urlresolvers import reverse
from redis import ConnectionError
from rest_framework import decorators, permissions
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
from redis import ConnectionError

from readthedocs.core.utils.tasks import TaskNoPermission
from readthedocs.core.utils.tasks import get_public_task_data
from readthedocs.core.utils.tasks import TaskNoPermission, get_public_task_data
from readthedocs.oauth import tasks


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -43,20 +47,25 @@ def get_status_data(task_name, state, data, error=None):
@decorators.renderer_classes((JSONRenderer,))
def job_status(request, task_id):
try:
task_name, state, public_data, error = get_public_task_data(request, task_id)
task_name, state, public_data, error = get_public_task_data(
request, task_id
)
except (TaskNoPermission, ConnectionError):
return Response(
get_status_data('unknown', 'PENDING', {}))
get_status_data('unknown', 'PENDING', {})
)
return Response(
get_status_data(task_name, state, public_data, error))
get_status_data(task_name, state, public_data, error)
)


@decorators.api_view(['POST'])
@decorators.permission_classes((permissions.IsAuthenticated,))
@decorators.renderer_classes((JSONRenderer,))
def sync_remote_repositories(request):
result = tasks.SyncRemoteRepositories().delay(
user_id=request.user.id)
result = tasks.sync_remote_repositories.delay(
user_id=request.user.id
)
task_id = result.task_id
return Response({
'task_id': task_id,
Expand Down
12 changes: 4 additions & 8 deletions readthedocs/rtd_tests/tests/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,11 @@ def test_public_task_exception(self):
from readthedocs.core.utils.tasks import PublicTask
from readthedocs.worker import app

class PublicTaskException(PublicTask):
name = 'public_task_exception'
@app.task(name='public_task_exception', base=PublicTask)
def public_task_exception():
raise Exception('Something bad happened')

def run_public(self):
raise Exception('Something bad happened')

app.tasks.register(PublicTaskException)
exception_task = PublicTaskException()
result = exception_task.delay()
result = public_task_exception.delay()

# although the task risen an exception, it's success since we add the
# exception into the ``info`` attributes
Expand Down

0 comments on commit a8bd00a

Please sign in to comment.