Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle replication lag when authenticating with a Bearer Token #1922

Merged
merged 4 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions onadata/apps/api/viewsets/connect_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rest_framework.exceptions import ParseError
from rest_framework.response import Response
from rest_framework import mixins
from multidb.pinning import use_master

from onadata.apps.api.models.odk_token import ODKToken
from onadata.apps.api.models.temp_token import TempToken
Expand Down Expand Up @@ -47,11 +48,12 @@ def user_profile_w_token_response(request, status):
user_profile = cache.get(
f'{USER_PROFILE_PREFIX}{request.user.username}')
if not user_profile:
user_profile, __ = UserProfile.objects.get_or_create(
user=request.user)
cache.set(
f'{USER_PROFILE_PREFIX}{request.user.username}',
user_profile)
with use_master:
user_profile, _ = UserProfile.objects.get_or_create(
user=request.user)
cache.set(
f'{USER_PROFILE_PREFIX}{request.user.username}',
user_profile)

serializer = UserProfileWithTokenSerializer(
instance=user_profile, context={"request": request})
Expand Down
52 changes: 52 additions & 0 deletions onadata/libs/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jwt
from django_digest import HttpDigestAuthenticator
from multidb.pinning import use_master
from oauth2_provider.models import AccessToken
from rest_framework import exceptions
from rest_framework.authentication import (
BaseAuthentication,
Expand All @@ -24,6 +25,8 @@
)
from rest_framework.authtoken.models import Token
from rest_framework.exceptions import AuthenticationFailed
from oauth2_provider.oauth2_validators import OAuth2Validator
from oauth2_provider.settings import oauth2_settings

from onadata.apps.api.models.temp_token import TempToken
from onadata.apps.api.tasks import send_account_lockout_email
Expand Down Expand Up @@ -363,3 +366,52 @@ def send_lockout_email(username):
],
countdown=getattr(settings, "LOCKOUT_TIME", 1800) + 60,
)


class MasterReplicaOAuth2Validator(OAuth2Validator):
"""
Custom OAuth2Validator class that takes into account replication lag
between Master & Replica databases
https://github.com/jazzband/django-oauth-toolkit/blob/3bde632d5722f1f85ffcd8277504955321f00fff/oauth2_provider/oauth2_validators.py#L49
"""
def validate_bearer_token(self, token, scopes, request):
if not token:
return False

introspection_url = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL
introspection_token = oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN
introspection_credentials = oauth2_settings.\
RESOURCE_SERVER_INTROSPECTION_CREDENTIALS

try:
access_token = AccessToken.objects.select_related(
"application", "user").get(token=token)
except AccessToken.DoesNotExist:
# Try retrieving AccessToken from MasterDB if not available
# in Read replica
with use_master:
try:
access_token = AccessToken.objects.select_related(
"application", "user").get(token=token)
except AccessToken.DoesNotExist:
access_token = None

if not access_token or not access_token.is_valid(scopes):
if introspection_url and (
introspection_token or introspection_credentials):
access_token = self._get_token_from_authentication_server(
token,
introspection_url,
introspection_token,
introspection_credentials
)

if access_token and access_token.is_valid(scopes):
request.client = access_token.application
request.user = access_token.user
request.scopes = scopes
request.access_token = access_token
return True
else:
self._set_oauth2_error_on_request(request, access_token, scopes)
return False
30 changes: 30 additions & 0 deletions onadata/libs/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from django.conf import settings
from django.contrib.auth.models import User
from django.test import TestCase
from django.http.request import HttpRequest

from mock import patch, MagicMock
from oauth2_provider.models import AccessToken
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.test import APIRequestFactory

Expand All @@ -13,6 +16,7 @@
TempTokenAuthentication,
TempTokenURLParameterAuthentication,
check_lockout,
MasterReplicaOAuth2Validator
)


Expand Down Expand Up @@ -143,3 +147,29 @@ def test_exception_on_username_with_whitespaces(self):
# passed as a username
with self.assertRaises(AuthenticationFailed):
check_lockout(request)


class TestMasterReplicaOAuth2Validator(TestCase):

@patch('onadata.libs.authentication.AccessToken')
def test_reads_from_master(self, mock_token_class):
def is_valid_mock(*args, **kwargs):
return True
token = MagicMock()
token.is_valid = is_valid_mock
token.user = 'bob'
token.application = 'bob-test'
token.token = 'abc'
mock_token_class.DoesNotExist = AccessToken.DoesNotExist
mock_token_class.objects.select_related(
"application", "user").\
get.side_effect = [AccessToken.DoesNotExist, token]
req = HttpRequest()
self.assertTrue(
MasterReplicaOAuth2Validator().validate_bearer_token(
token, {}, req))
self.assertEqual(
mock_token_class.objects.select_related(
"application", "user").get.call_count, 2)
self.assertEqual(req.access_token, token)
self.assertEqual(req.user, token.user)
3 changes: 2 additions & 1 deletion onadata/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@
'SCOPES': {
'read': 'Read scope',
'write': 'Write scope',
'groups': 'Access to your groups'}
'groups': 'Access to your groups'},
'OAUTH2_VALIDATOR_CLASS': 'onadata.libs.authentication.MasterReplicaOAuth2Validator' # noqa
}

REST_FRAMEWORK = {
Expand Down