Merge pull request #160 from redlib-org/oauth_oppenheimer
fix(oauth): even more atomics to avoid simultaneous token rollover
This commit is contained in:
commit
4dc7ff8165
@ -11,7 +11,7 @@ use percent_encoding::{percent_encode, CONTROLS};
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use std::sync::atomic::Ordering;
|
use std::sync::atomic::Ordering;
|
||||||
use std::sync::atomic::{AtomicU16, Ordering::SeqCst};
|
use std::sync::atomic::{AtomicBool, AtomicU16};
|
||||||
use std::{io, result::Result};
|
use std::{io, result::Result};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
@ -40,6 +40,8 @@ pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
|
|||||||
|
|
||||||
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
||||||
|
|
||||||
|
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||||
///
|
///
|
||||||
@ -318,11 +320,12 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
|
|
||||||
// First, handle rolling over the OAUTH_CLIENT if need be.
|
// First, handle rolling over the OAUTH_CLIENT if need be.
|
||||||
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
|
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
|
||||||
if current_rate_limit < 10 {
|
let is_rolling_over = OAUTH_IS_ROLLING_OVER.load(Ordering::SeqCst);
|
||||||
|
if current_rate_limit < 10 && !is_rolling_over {
|
||||||
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
|
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
|
||||||
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
|
|
||||||
tokio::spawn(force_refresh_token());
|
tokio::spawn(force_refresh_token());
|
||||||
}
|
}
|
||||||
|
OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
|
||||||
// Fetch the url...
|
// Fetch the url...
|
||||||
match reddit_get(path.clone(), quarantine).await {
|
match reddit_get(path.clone(), quarantine).await {
|
||||||
@ -331,12 +334,10 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
|
|
||||||
// Ratelimit remaining
|
// Ratelimit remaining
|
||||||
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
|
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
|
||||||
trace!("Ratelimit remaining: {}", remaining);
|
trace!(
|
||||||
if let Ok(remaining) = remaining.parse::<f32>().map(|f| f.round() as u16) {
|
"Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. {}",
|
||||||
OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst);
|
if is_rolling_over { "Rolling over" } else { "" }
|
||||||
} else {
|
);
|
||||||
warn!("Failed to parse rate limit {remaining} from header.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ratelimit used
|
// Ratelimit used
|
||||||
@ -358,8 +359,13 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
let has_remaining = body.has_remaining();
|
let has_remaining = body.has_remaining();
|
||||||
|
|
||||||
if !has_remaining {
|
if !has_remaining {
|
||||||
|
// Rate limited, so spawn a force_refresh_token()
|
||||||
|
tokio::spawn(force_refresh_token());
|
||||||
return match reset {
|
return match reset {
|
||||||
Some(val) => Err(format!("Reddit rate limit exceeded. Will reset in: {val}")),
|
Some(val) => Err(format!(
|
||||||
|
"Reddit rate limit exceeded. Try refreshing in a few seconds.\
|
||||||
|
Rate limit will reset in: {val}"
|
||||||
|
)),
|
||||||
None => Err("Reddit rate limit exceeded".to_string()),
|
None => Err("Reddit rate limit exceeded".to_string()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
|
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING},
|
client::{CLIENT, OAUTH_CLIENT, OAUTH_IS_ROLLING_OVER, OAUTH_RATELIMIT_REMAINING},
|
||||||
oauth_resources::ANDROID_APP_VERSION_LIST,
|
oauth_resources::ANDROID_APP_VERSION_LIST,
|
||||||
};
|
};
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
@ -131,8 +131,15 @@ pub async fn token_daemon() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn force_refresh_token() {
|
pub async fn force_refresh_token() {
|
||||||
|
if OAUTH_IS_ROLLING_OVER.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
|
||||||
|
trace!("Skipping refresh token roll over, already in progress");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
|
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
|
||||||
OAUTH_CLIENT.write().await.refresh().await;
|
OAUTH_CLIENT.write().await.refresh().await;
|
||||||
|
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
|
||||||
|
OAUTH_IS_ROLLING_OVER.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
|
Loading…
Reference in New Issue
Block a user