From 3bd8b511a7234aabe045761ddc423a914608910a Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Wed, 26 Jun 2024 23:41:26 -0400 Subject: [PATCH] fix(oauth): strengthen sync guarantees --- src/client.rs | 8 ++++---- src/oauth.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/client.rs b/src/client.rs index 1e545a1..fff2cf0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,7 @@ use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; use std::sync::atomic::Ordering; -use std::sync::atomic::{AtomicU16, Ordering::Relaxed}; +use std::sync::atomic::{AtomicU16, Ordering::SeqCst}; use std::{io, result::Result}; use tokio::sync::RwLock; @@ -317,10 +317,10 @@ pub async fn json(path: String, quarantine: bool) -> Result { }; // First, handle rolling over the OAUTH_CLIENT if need be. - let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::Relaxed); + let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst); if current_rate_limit < 10 { warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()"); - OAUTH_RATELIMIT_REMAINING.store(99, Ordering::Relaxed); + OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst); tokio::spawn(force_refresh_token()); } @@ -333,7 +333,7 @@ pub async fn json(path: String, quarantine: bool) -> Result { if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) { trace!("Ratelimit remaining: {}", remaining); if let Ok(remaining) = remaining.parse::().map(|f| f.round() as u16) { - OAUTH_RATELIMIT_REMAINING.store(remaining, Relaxed); + OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst); } else { warn!("Failed to parse rate limit {remaining} from header."); } diff --git a/src/oauth.rs b/src/oauth.rs index 161310e..03b56f6 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -131,7 +131,7 @@ pub async fn token_daemon() { } pub async fn force_refresh_token() { - trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::Relaxed)); + trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst)); OAUTH_CLIENT.write().await.refresh().await; }