From 3b2ad212d50d9c8122b96e9f54933c4711caa1ca Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Fri, 28 Jun 2024 18:14:47 -0400 Subject: [PATCH] fix(oauth): arc_swap --- Cargo.lock | 7 +++++++ Cargo.toml | 1 + src/client.rs | 8 ++++---- src/oauth.rs | 23 ++++++++--------------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 916093c..30bdf02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,6 +71,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "askama" version = "0.12.1" @@ -1034,6 +1040,7 @@ checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" name = "redlib" version = "0.34.0" dependencies = [ + "arc-swap", "askama", "base64", "brotli", diff --git a/Cargo.toml b/Cargo.toml index e33560e..df90167 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ fastrand = "2.0.1" log = "0.4.20" pretty_env_logger = "0.5.0" dotenvy = "0.15.7" +arc-swap = "1.7.1" [dev-dependencies] lipsum = "0.9.0" diff --git a/src/client.rs b/src/client.rs index ab40236..7281df1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,3 +1,4 @@ +use arc_swap::ArcSwap; use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; @@ -13,7 +14,6 @@ use serde_json::Value; use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, AtomicU16}; use std::{io, result::Result}; -use tokio::sync::RwLock; use crate::dbg_msg; use crate::oauth::{force_refresh_token, token_daemon, Oauth}; @@ -32,10 +32,10 @@ pub static CLIENT: Lazy>> = Lazy::new(|| { client::Client::builder().build(https) }); -pub static OAUTH_CLIENT: Lazy> = Lazy::new(|| { +pub static OAUTH_CLIENT: Lazy> = Lazy::new(|| { let client = block_on(Oauth::new()); tokio::spawn(token_daemon()); - RwLock::new(client) + ArcSwap::new(client.into()) }); pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); @@ -177,7 +177,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo let client: Client<_, Body> = CLIENT.clone(); let (token, vendor_id, device_id, user_agent, loid) = { - let client = block_on(OAUTH_CLIENT.read()); + let client = OAUTH_CLIENT.load_full(); ( client.token.clone(), client.headers_map.get("Client-Vendor-Id").cloned().unwrap_or_default(), diff --git a/src/oauth.rs b/src/oauth.rs index a3f4dc0..efdf41e 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -98,21 +98,13 @@ impl Oauth { Some(()) } - - async fn refresh(&mut self) -> Option<()> { - // Refresh is actually just a subsequent login with the same headers (without the old token - // or anything). This logic is handled in login, so we just call login again. - let refresh = self.login().await; - info!("Refreshing OAuth token... {}", if refresh.is_some() { "success" } else { "failed" }); - refresh - } } pub async fn token_daemon() { // Monitor for refreshing token loop { // Get expiry time - be sure to not hold the read lock - let expires_in = { OAUTH_CLIENT.read().await.expires_in }; + let expires_in = { OAUTH_CLIENT.load_full().expires_in }; // sleep for the expiry time minus 2 minutes let duration = Duration::from_secs(expires_in - 120); @@ -125,7 +117,7 @@ pub async fn token_daemon() { // Refresh token - in its own scope { - OAUTH_CLIENT.write().await.refresh().await; + force_refresh_token().await; } } } @@ -137,7 +129,8 @@ pub async fn force_refresh_token() { } trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst)); - OAUTH_CLIENT.write().await.refresh().await; + let new_client = Oauth::new().await; + OAUTH_CLIENT.swap(new_client.into()); OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst); OAUTH_IS_ROLLING_OVER.store(false, Ordering::SeqCst); } @@ -187,21 +180,21 @@ fn choose(list: &[T]) -> T { #[tokio::test(flavor = "multi_thread")] async fn test_oauth_client() { - assert!(!OAUTH_CLIENT.read().await.token.is_empty()); + assert!(!OAUTH_CLIENT.load_full().token.is_empty()); } #[tokio::test(flavor = "multi_thread")] async fn test_oauth_client_refresh() { - OAUTH_CLIENT.write().await.refresh().await.unwrap(); + force_refresh_token().await; } #[tokio::test(flavor = "multi_thread")] async fn test_oauth_token_exists() { - assert!(!OAUTH_CLIENT.read().await.token.is_empty()); + assert!(!OAUTH_CLIENT.load_full().token.is_empty()); } #[tokio::test(flavor = "multi_thread")] async fn test_oauth_headers_len() { - assert!(OAUTH_CLIENT.read().await.headers_map.len() >= 3); + assert!(OAUTH_CLIENT.load_full().headers_map.len() >= 3); } #[test]