Skip to content

Commit

Permalink
changes made in response to feedback from review
Browse files Browse the repository at this point in the history
  • Loading branch information
travishathaway committed Aug 31, 2023
1 parent 67ff4e3 commit 98763b4
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 153 deletions.
6 changes: 2 additions & 4 deletions conda_auth/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

PLUGIN_NAME = "conda-auth"

# move to the handlers module
OAUTH2_NAME = f"{PLUGIN_NAME}-oauth2"

# move to the handlers module
HTTP_BASIC_AUTH_NAME = f"{PLUGIN_NAME}-basic-auth"

# Error messages
LOGOUT_ERROR_MESSAGE = "Unable to logout."

INVALID_CREDENTIALS_ERROR_MESSAGE = "Provided credentials are not correct."

USERNAME_AND_PASSWORD_NOT_SET_ERROR_MESSAGE = "Username and password not set."
8 changes: 0 additions & 8 deletions conda_auth/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
from conda.base.context import context

from .base import AuthManager # noqa: F401
from .oauth2 import OAuth2Manager, OAuth2Handler # noqa: F401
from .basic_auth import BasicAuthManager, BasicAuthHandler # noqa: F401

oauth2 = OAuth2Manager(context)
basic_auth = BasicAuthManager(context)

OAuth2Handler.set_cache(oauth2.cache)
BasicAuthHandler.set_cache(basic_auth.cache)
161 changes: 72 additions & 89 deletions conda_auth/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import Any

import conda.base.context
import keyring
import requests
from conda.gateways.connection.session import get_session
from conda.models.channel import Channel
from conda.plugins.types import ChannelAuthBase

from ..constants import (
INVALID_CREDENTIALS_ERROR_MESSAGE,
USERNAME_AND_PASSWORD_NOT_SET_ERROR_MESSAGE,
)
from ..exceptions import CondaAuthError

INVALID_CREDENTIALS_ERROR_MESSAGE = "Provided credentials are not correct."


class AuthManager(ABC):
"""
Expand All @@ -30,7 +25,7 @@ def __init__(self, context: conda.base.context.Context, cache: dict | None = Non
``conda.base.context.context.channel_settings``.
"""
self._context = context
self.cache = {} if cache is None else cache
self._cache = {} if cache is None else cache

def get_action_func(self) -> Callable[[str], None]:
"""Return a callable to be used as the action function for the pre-command plugin hook"""
Expand All @@ -40,26 +35,69 @@ def action(command: str):
if channel := settings.get("channel"):
channel = Channel(channel)
# Only attempt to authenticate for actively used channels
if channel.canonical_name in self._context.channels:
if (
channel.canonical_name in self._context.channels
and settings.get("auth") == self.get_auth_type()
):
self.authenticate(channel, settings)

return action

def authenticate(self, channel: Channel, settings: dict[str, str]) -> None:
"""Used to retrieve credentials and store them on the ``cache`` property"""
if settings.get("auth") == self.get_auth_type():
extra_params = {
param: settings.get(param) for param in self.get_config_parameters()
}
self.set_secrets(channel, extra_params)
extra_params = {
param: settings.get(param) for param in self.get_config_parameters()
}
username, secret = self.fetch_secret(channel, extra_params)

verify_credentials(channel)
self.save_credentials(channel, username, secret)

def save_credentials(self, channel: Channel, username: str, secret: str) -> None:
"""
Saves the provided credentials to our credential store.
TODO: Method may be expanded in the future to allow the use of other storage
mechanisms.
"""
keyring.set_password(
self.get_keyring_id(channel.canonical_name), username, secret
)

def fetch_secret(
self, channel: Channel, settings: dict[str, str | None]
) -> tuple[str, str]:
"""
Fetch secrets and handle updating cache.
"""
if secrets := self._cache.get(channel.canonical_name):
return secrets

secrets = self._fetch_secret(channel, settings)
self._cache[channel.canonical_name] = secrets

return secrets

def get_secret(self, channel_name: str) -> tuple[str | None, str | None]:
"""
Get the secret that is currently cached for the channel
"""
secrets = self._cache.get(channel_name)

if secrets is None:
return None, None

return secrets

@abstractmethod
def set_secrets(self, channel: Channel, settings: dict[str, str | None]) -> None:
"""Implementations should include routine for fetching and storing secrets"""
def _fetch_secret(
self, channel: Channel, settings: dict[str, str | None]
) -> tuple[str, str]:
"""Implementations should include routine for fetching secret"""

@abstractmethod
def remove_secrets(self, channel: Channel, settings: dict[str, str | None]) -> None:
"""Implementations should include routine for removing secrets"""
def remove_secret(self, channel: Channel, settings: dict[str, str | None]) -> None:
"""Implementations should include routine for removing secret"""

@abstractmethod
def get_auth_type(self) -> str:
Expand All @@ -82,77 +120,22 @@ def get_keyring_id(self, channel_name: str) -> str:
"""


class CacheChannelAuthBase(ChannelAuthBase):
"""
Adds a class instance cache object for storage of authentication information.
"""

def __init__(self, channel_name: str):
"""
Makes sure we have initialized the cache object.
"""
super().__init__(channel_name)

if not hasattr(self, "_cache"):
raise CondaAuthError(
"Cache not initialized on class; please run `BasicAuthHandler.set_cache`"
" before using"
)

@classmethod
def set_cache(cls, cache: dict[str, Any]) -> None:
cls._cache = cache


def test_credentials(func):
def verify_credentials(channel: Channel) -> None:
"""
Decorator function used to test whether the collected credentials can successfully make a
request.
Verify the credentials that have been currently set for the channel.
This decorator could be applied to any function which updates the ``AuthManager.cache``
property.
Raises exception if unable to make a successful request.
"""

@wraps(func)
def wrapper(self, channel: Channel, *args, **kwargs):
func(self, channel, *args, **kwargs)

for url in channel.base_urls:
session = get_session(url)
resp = session.head(url)

try:
resp.raise_for_status()
except requests.exceptions.HTTPError as exc:
if exc.response.status_code == requests.codes["unauthorized"]:
error_message = INVALID_CREDENTIALS_ERROR_MESSAGE
else:
error_message = str(exc)

raise CondaAuthError(error_message)

return wrapper


def save_credentials(func):
"""
Decorator function used to save credentials to the keyring storage system.
This decorator could be applied to any function which updates the ``AuthManager.cache``
property.
"""

@wraps(func)
def wrapper(self, channel: Channel, *args, **kwargs):
func(self, channel, *args, **kwargs)

username, secret = self.cache.get(channel.canonical_name, (None, None))

if username is None and secret is None:
raise CondaAuthError(USERNAME_AND_PASSWORD_NOT_SET_ERROR_MESSAGE)

keyring.set_password(
self.get_keyring_id(channel.canonical_name), username, secret
)

return wrapper
for url in channel.base_urls:
session = get_session(url)
resp = session.head(url)

try:
resp.raise_for_status()
except requests.exceptions.HTTPError as exc:
if exc.response.status_code == requests.codes["unauthorized"]:
error_message = INVALID_CREDENTIALS_ERROR_MESSAGE
else:
error_message = str(exc)

raise CondaAuthError(error_message)
33 changes: 17 additions & 16 deletions conda_auth/handlers/basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@
import keyring
from keyring.errors import PasswordDeleteError
from requests.auth import _basic_auth_str # type: ignore
from conda.base.context import context
from conda.exceptions import CondaError
from conda.models.channel import Channel
from conda.plugins.types import ChannelAuthBase

from ..constants import HTTP_BASIC_AUTH_NAME, LOGOUT_ERROR_MESSAGE
from ..exceptions import CondaAuthError
from .base import (
AuthManager,
CacheChannelAuthBase,
test_credentials,
save_credentials,
)
from .base import AuthManager

USERNAME_PARAM_NAME = "username"
"""
Expand All @@ -31,12 +28,13 @@ class BasicAuthManager(AuthManager):
def get_keyring_id(self, channel_name: str):
return f"{HTTP_BASIC_AUTH_NAME}::{channel_name}"

@save_credentials
@test_credentials
def set_secrets(self, channel: Channel, settings: dict[str, str]) -> None:
if self.cache.get(channel.canonical_name) is not None:
return

def _fetch_secret(
self, channel: Channel, settings: dict[str, str | None]
) -> tuple[str, str]:
"""
Gets the secrets by checking the keyring and then falling back to interrupting
the program and asking the user for the credentials.
"""
username = settings.get(USERNAME_PARAM_NAME)
keyring_id = self.get_keyring_id(channel.canonical_name)

Expand All @@ -49,9 +47,9 @@ def set_secrets(self, channel: Channel, settings: dict[str, str]) -> None:
if password is None:
password = getpass()

self.cache[channel.canonical_name] = (username, password)
return username, password

def remove_secrets(self, channel: Channel, settings: dict[str, str | None]) -> None:
def remove_secret(self, channel: Channel, settings: dict[str, str | None]) -> None:
keyring_id = self.get_keyring_id(channel.canonical_name)
username = settings.get(USERNAME_PARAM_NAME)

Expand All @@ -71,7 +69,10 @@ def get_config_parameters(self) -> tuple[str, ...]:
return (USERNAME_PARAM_NAME,)


class BasicAuthHandler(CacheChannelAuthBase):
manager = BasicAuthManager(context)


class BasicAuthHandler(ChannelAuthBase):
"""
Implementation of HTTPBasicAuth that relies on a cache location for
retrieving login credentials on object instantiation.
Expand All @@ -81,7 +82,7 @@ class BasicAuthHandler(CacheChannelAuthBase):

def __init__(self, channel_name: str):
super().__init__(channel_name)
self.username, self.password = self._cache.get(channel_name, (None, None))
self.username, self.password = manager.get_secret(channel_name)

if self.username is None and self.password is None:
raise CondaError(
Expand Down
38 changes: 17 additions & 21 deletions conda_auth/handlers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@

import keyring
from keyring.errors import PasswordDeleteError
from conda.base.context import context
from conda.exceptions import CondaError
from conda.models.channel import Channel
from conda.plugins.types import ChannelAuthBase

from ..constants import OAUTH2_NAME, LOGOUT_ERROR_MESSAGE
from ..exceptions import CondaAuthError
from .base import (
AuthManager,
CacheChannelAuthBase,
test_credentials,
save_credentials,
)
from .base import AuthManager

LOGIN_URL_PARAM_NAME = "login_url"
"""
Expand All @@ -30,12 +27,13 @@ class OAuth2Manager(AuthManager):
def get_keyring_id(self, channel_name: str) -> str:
return f"{OAUTH2_NAME}::{channel_name}"

@save_credentials
@test_credentials
def set_secrets(self, channel: Channel, settings: dict[str, str | None]) -> None:
if self.cache.get(channel.canonical_name) is not None:
return

def _fetch_secret(
self, channel: Channel, settings: dict[str, str | None]
) -> tuple[str, str]:
"""
Gets the secrets by checking the keyring and then falling back to interrupting
the program and asking the user for secret.
"""
login_url = settings.get(LOGIN_URL_PARAM_NAME)

if login_url is None:
Expand All @@ -54,9 +52,9 @@ def set_secrets(self, channel: Channel, settings: dict[str, str | None]) -> None
print(f"Follow link to login: {login_url}")
token = input("Copy and paste login token here: ")

self.cache[channel.canonical_name] = (USERNAME, token)
return USERNAME, token

def remove_secrets(self, channel: Channel, settings: dict[str, str | None]) -> None:
def remove_secret(self, channel: Channel, settings: dict[str, str | None]) -> None:
keyring_id = self.get_keyring_id(channel.canonical_name)

try:
Expand All @@ -71,19 +69,17 @@ def get_config_parameters(self) -> tuple[str, ...]:
return (LOGIN_URL_PARAM_NAME,)


class OAuth2Handler(CacheChannelAuthBase):
manager = OAuth2Manager(context)


class OAuth2Handler(ChannelAuthBase):
"""
Implementation of HTTPBasicAuth that relies on a cache location for
retrieving login credentials on object instantiation.
"""

def __init__(self, channel_name: str):
if not hasattr(self, "_cache"):
raise CondaAuthError(
"Cache not initialized on class; please run `OAuth2Hanlder.set_cache` before using"
)

self.token = self._cache.get(channel_name)
self.token, _ = manager.get_secret(channel_name)

if self.token is None:
raise CondaError(
Expand Down
Loading

0 comments on commit 98763b4

Please sign in to comment.