diff --git a/src/client.rs b/src/client.rs index fff2cf0..ab40236 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,7 @@ use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; use std::sync::atomic::Ordering; -use std::sync::atomic::{AtomicU16, Ordering::SeqCst}; +use std::sync::atomic::{AtomicBool, AtomicU16}; use std::{io, result::Result}; use tokio::sync::RwLock; @@ -40,6 +40,8 @@ pub static OAUTH_CLIENT: Lazy> = Lazy::new(|| { pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); +pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); + /// Gets the canonical path for a resource on Reddit. This is accomplished by /// making a `HEAD` request to Reddit at the path given in `path`. /// @@ -318,11 +320,12 @@ pub async fn json(path: String, quarantine: bool) -> Result { // First, handle rolling over the OAUTH_CLIENT if need be. let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst); - if current_rate_limit < 10 { + let is_rolling_over = OAUTH_IS_ROLLING_OVER.load(Ordering::SeqCst); + if current_rate_limit < 10 && !is_rolling_over { warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()"); - OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst); tokio::spawn(force_refresh_token()); } + OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst); // Fetch the url... match reddit_get(path.clone(), quarantine).await { @@ -331,12 +334,10 @@ pub async fn json(path: String, quarantine: bool) -> Result { // Ratelimit remaining if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) { - trace!("Ratelimit remaining: {}", remaining); - if let Ok(remaining) = remaining.parse::().map(|f| f.round() as u16) { - OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst); - } else { - warn!("Failed to parse rate limit {remaining} from header."); - } + trace!( + "Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. {}", + if is_rolling_over { "Rolling over" } else { "" } + ); } // Ratelimit used @@ -358,8 +359,13 @@ pub async fn json(path: String, quarantine: bool) -> Result { let has_remaining = body.has_remaining(); if !has_remaining { + // Rate limited, so spawn a force_refresh_token() + tokio::spawn(force_refresh_token()); return match reset { - Some(val) => Err(format!("Reddit rate limit exceeded. Will reset in: {val}")), + Some(val) => Err(format!( + "Reddit rate limit exceeded. Try refreshing in a few seconds.\ + Rate limit will reset in: {val}" + )), None => Err("Reddit rate limit exceeded".to_string()), }; } diff --git a/src/oauth.rs b/src/oauth.rs index 03b56f6..a3f4dc0 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::atomic::Ordering, time::Duration}; use crate::{ - client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING}, + client::{CLIENT, OAUTH_CLIENT, OAUTH_IS_ROLLING_OVER, OAUTH_RATELIMIT_REMAINING}, oauth_resources::ANDROID_APP_VERSION_LIST, }; use base64::{engine::general_purpose, Engine as _}; @@ -131,8 +131,15 @@ pub async fn token_daemon() { } pub async fn force_refresh_token() { + if OAUTH_IS_ROLLING_OVER.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() { + trace!("Skipping refresh token roll over, already in progress"); + return; + } + trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst)); OAUTH_CLIENT.write().await.refresh().await; + OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst); + OAUTH_IS_ROLLING_OVER.store(false, Ordering::SeqCst); } #[derive(Debug, Clone, Default)]