fix(oauth): atomics to avoid simultaneous token rollover

This commit is contained in:
Matthew Esposito 2024-06-27 23:26:31 -04:00
parent 3bd8b511a7
commit 89313f73e6
2 changed files with 14 additions and 10 deletions

View File

@ -11,7 +11,7 @@ use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value; use serde_json::Value;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicU16, Ordering::SeqCst}; use std::sync::atomic::{AtomicBool, AtomicU16};
use std::{io, result::Result}; use std::{io, result::Result};
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -40,6 +40,8 @@ pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
/// Gets the canonical path for a resource on Reddit. This is accomplished by /// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`. /// making a `HEAD` request to Reddit at the path given in `path`.
/// ///
@ -318,11 +320,12 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
// First, handle rolling over the OAUTH_CLIENT if need be. // First, handle rolling over the OAUTH_CLIENT if need be.
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst); let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
if current_rate_limit < 10 { let is_rolling_over = OAUTH_IS_ROLLING_OVER.load(Ordering::SeqCst);
if current_rate_limit < 10 && !is_rolling_over {
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()"); warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
tokio::spawn(force_refresh_token()); tokio::spawn(force_refresh_token());
} }
OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst);
// Fetch the url... // Fetch the url...
match reddit_get(path.clone(), quarantine).await { match reddit_get(path.clone(), quarantine).await {
@ -331,12 +334,10 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
// Ratelimit remaining // Ratelimit remaining
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) { if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
trace!("Ratelimit remaining: {}", remaining); trace!(
if let Ok(remaining) = remaining.parse::<f32>().map(|f| f.round() as u16) { "Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. {}",
OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst); if is_rolling_over { "Rolling over" } else { "" }
} else { );
warn!("Failed to parse rate limit {remaining} from header.");
}
} }
// Ratelimit used // Ratelimit used

View File

@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration}; use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
use crate::{ use crate::{
client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING}, client::{CLIENT, OAUTH_CLIENT, OAUTH_IS_ROLLING_OVER, OAUTH_RATELIMIT_REMAINING},
oauth_resources::ANDROID_APP_VERSION_LIST, oauth_resources::ANDROID_APP_VERSION_LIST,
}; };
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
@ -131,8 +131,11 @@ pub async fn token_daemon() {
} }
pub async fn force_refresh_token() { pub async fn force_refresh_token() {
OAUTH_IS_ROLLING_OVER.store(true, Ordering::SeqCst);
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst)); trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
OAUTH_CLIENT.write().await.refresh().await; OAUTH_CLIENT.write().await.refresh().await;
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
OAUTH_IS_ROLLING_OVER.store(false, Ordering::SeqCst);
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]