Skip to content

Commit

Permalink
Fix self-refreshing token
Browse files Browse the repository at this point in the history
  • Loading branch information
marioortizmanero committed Sep 25, 2021
1 parent cc3e13e commit 2b2411b
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 160 deletions.
32 changes: 3 additions & 29 deletions src/auth_code.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
use crate::{
auth_urls,
clients::{BaseClient, OAuthClient},
clients::{mutex::Mutex, BaseClient, OAuthClient},
headers,
http::{Form, HttpClient},
ClientResult, Config, Credentials, OAuth, Token,
};

use std::{collections::HashMap, sync::Arc};

#[cfg(feature = "__sync")]
use std::sync::Mutex;

#[cfg(feature = "__async")]
use futures::lock::Mutex;
use maybe_async::maybe_async;
use url::Url;

Expand Down Expand Up @@ -101,21 +96,7 @@ impl BaseClient for AuthCodeSpotify {
// NOTE: this can't use `get_token` because `get_token` itself might
// call this function when automatic reauthentication is enabled.

// The sync and async versions of Mutex have different function signatures
let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}

match tmp_locked_lock {
match self.get_token().await.lock().await.unwrap().as_ref() {
Some(Token {
refresh_token: Some(refresh_token),
..
Expand Down Expand Up @@ -160,14 +141,7 @@ impl OAuthClient for AuthCodeSpotify {

let token = self.fetch_access_token(&data).await?;

#[cfg(feature = "__async")]
{
*self.token.lock().await = Some(token);
}
#[cfg(feature = "__sync")]
{
*self.token.lock().unwrap() = Some(token);
}
*self.token.lock().await.unwrap() = Some(token);

self.write_token_cache().await
}
Expand Down
32 changes: 3 additions & 29 deletions src/auth_code_pkce.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
use crate::{
auth_urls,
clients::{BaseClient, OAuthClient},
clients::{mutex::Mutex, BaseClient, OAuthClient},
headers,
http::{Form, HttpClient},
ClientResult, Config, Credentials, OAuth, Token,
};

use std::{collections::HashMap, sync::Arc};

#[cfg(feature = "__sync")]
use std::sync::Mutex;

#[cfg(feature = "__async")]
use futures::lock::Mutex;
use maybe_async::maybe_async;
use url::Url;

Expand Down Expand Up @@ -64,21 +59,7 @@ impl BaseClient for AuthCodePkceSpotify {
// NOTE: this can't use `get_token` because `get_token` itself might
// call this function when automatic reauthentication is enabled.

// The sync and async versions of Mutex have different function signatures
let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}

match tmp_locked_lock {
match self.get_token().await.lock().await.unwrap().as_ref() {
Some(Token {
refresh_token: Some(refresh_token),
..
Expand Down Expand Up @@ -122,14 +103,7 @@ impl OAuthClient for AuthCodePkceSpotify {

let token = self.fetch_access_token(&data).await?;

#[cfg(feature = "__async")]
{
*self.token.lock().await = Some(token);
}
#[cfg(feature = "__sync")]
{
*self.token.lock().unwrap() = Some(token);
}
*self.token.lock().await.unwrap() = Some(token);

self.write_token_cache().await
}
Expand Down
40 changes: 4 additions & 36 deletions src/client_creds.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
use crate::{
clients::BaseClient,
clients::{mutex::Mutex, BaseClient},
headers,
http::{Form, HttpClient},
ClientResult, Config, Credentials, Token,
};

use std::sync::Arc;

#[cfg(feature = "__sync")]
use std::sync::Mutex;

#[cfg(feature = "__async")]
use futures::lock::Mutex;
use maybe_async::maybe_async;

/// The [Client Credentials Flow][reference] client for the Spotify API.
Expand Down Expand Up @@ -119,14 +114,7 @@ impl ClientCredsSpotify {
/// saved internally.
#[maybe_async]
pub async fn request_token(&self) -> ClientResult<()> {
#[cfg(feature = "__async")]
{
*self.token.lock().await = Some(self.fetch_token().await?);
}
#[cfg(feature = "__sync")]
{
*self.token.lock().unwrap() = Some(self.fetch_token()?);
}
*self.token.lock().await.unwrap() = Some(self.fetch_token().await?);

self.write_token_cache().await
}
Expand All @@ -152,20 +140,7 @@ impl ClientCredsSpotify {
// You could not have read lock and write lock at the same time, which
// will result in deadlock, so obtain the write lock and use it in the
// whole process.
let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}

if let Some(token) = tmp_locked_lock {
if let Some(token) = self.get_token().await.lock().await.unwrap().as_ref() {
if !token.is_expired() {
return Ok(());
}
Expand All @@ -179,14 +154,7 @@ impl ClientCredsSpotify {
async fn refresh_token(&self) -> ClientResult<()> {
let token = self.refetch_token().await?;
if let Some(token) = token {
#[cfg(feature = "__async")]
{
self.token.lock().await.replace(token);
}
#[cfg(feature = "__sync")]
{
self.token.lock().unwrap().replace(token);
}
self.token.lock().await.unwrap().replace(token);
}

self.write_token_cache().await
Expand Down
47 changes: 7 additions & 40 deletions src/clients/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
auth_urls,
clients::{
basic_auth, bearer_auth, convert_result, join_ids,
mutex::Mutex,
pagination::{paginate, Paginator},
},
http::{BaseHttpClient, Form, Headers, HttpClient, Query},
Expand All @@ -12,12 +13,6 @@ use crate::{

use std::{collections::HashMap, fmt, sync::Arc};

#[cfg(feature = "__sync")]
use std::sync::Mutex;

#[cfg(feature = "__async")]
use futures::lock::Mutex;

use chrono::Utc;
use maybe_async::maybe_async;
use serde_json::{Map, Value};
Expand Down Expand Up @@ -62,14 +57,7 @@ where
async fn refresh_token(&self) -> ClientResult<()> {
let token = self.refetch_token().await?;
if let Some(token) = token {
#[cfg(feature = "__async")]
{
self.get_token().await.lock().await.replace(token);
}
#[cfg(feature = "__sync")]
{
self.get_token().lock().unwrap().replace(token);
}
self.get_token().await.lock().await.unwrap().replace(token);
}

self.write_token_cache().await
Expand All @@ -89,19 +77,10 @@ where
/// The headers required for authenticated requests to the API
async fn auth_headers(&self) -> ClientResult<Headers> {
let mut auth = Headers::new();
let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}
let (key, val) = bearer_auth(&tmp_locked_lock.expect("Rspotify not authenticated"));
let token = self.get_token().await;
let token = token.lock().await.unwrap();
let token = token.as_ref().expect("Rspotify not authenticated");
let (key, val) = bearer_auth(token);
auth.insert(key, val);

Ok(auth)
Expand Down Expand Up @@ -212,19 +191,7 @@ where
return Ok(());
}

let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}
if let Some(token) = tmp_locked_lock {
if let Some(token) = self.get_token().await.lock().await.unwrap().as_ref() {
token.write_cache(&self.get_config().cache_path)?;
}

Expand Down
1 change: 1 addition & 0 deletions src/clients/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod base;
pub mod mutex;
pub mod oauth;
pub mod pagination;

Expand Down
18 changes: 18 additions & 0 deletions src/clients/mutex/futures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use futures::lock::{Mutex, MutexGuard};

#[derive(Debug, Default)]
pub struct FuturesMutex<T: ?Sized>(Mutex<T>);

#[derive(Debug)]
pub struct LockError;

impl<T> FuturesMutex<T> {
pub fn new(val: T) -> Self {
FuturesMutex(Mutex::new(val))
}

pub async fn lock(&self) -> Result<MutexGuard<'_, T>, LockError> {
let val = self.0.lock().await;
Ok(val)
}
}
12 changes: 12 additions & 0 deletions src/clients/mutex/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//! This mutex wraps both the synchronous and asynchronous versions under the
//! same interface.

#[cfg(feature = "__async")]
mod futures;
#[cfg(feature = "__sync")]
mod sync;

#[cfg(feature = "__async")]
pub use self::futures::FuturesMutex as Mutex;
#[cfg(feature = "__sync")]
pub use self::sync::SyncMutex as Mutex;
1 change: 1 addition & 0 deletions src/clients/mutex/sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub type SyncMutex<T> = std::sync::Mutex<T>;
34 changes: 8 additions & 26 deletions src/clients/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,7 @@ pub trait OAuthClient: BaseClient {
// You cannot have read lock and write lock at the same time, which
// would result in a deadlock, so obtain the write lock and use it in
// the whole process.
let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}
if let Some(token) = tmp_locked_lock {
if let Some(token) = self.get_token().await.lock().await.unwrap().as_ref() {
if !token.can_reauth() {
return Ok(());
}
Expand Down Expand Up @@ -127,19 +115,13 @@ pub trait OAuthClient: BaseClient {
async fn prompt_for_token(&mut self, url: &str) -> ClientResult<()> {
match self.read_token_cache().await {
Some(new_token) => {
let tok = self.get_token().await;
let locked_token = tok.lock().await;
let mut tmp_locked_lock = Option::None;
#[cfg(feature = "__async")]
{
tmp_locked_lock = locked_token.as_ref();
}

#[cfg(feature = "__sync")]
{
tmp_locked_lock = locked_token.unwrap().as_ref();
}
tmp_locked_lock.replace(&new_token);
self.get_token()
.await
.lock()
.await
.unwrap()
.as_ref()
.replace(&new_token);
}
// Otherwise following the usual procedure to get the token.
None => {
Expand Down

0 comments on commit 2b2411b

Please sign in to comment.