Merge pull request #156 from redlib-org/fix_oauth_ratelimit
feat(oauth): roll over oauth key on rate limit
This commit is contained in:
commit
d045a5760a
@ -5,11 +5,13 @@ use hyper::client::HttpConnector;
|
|||||||
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
|
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
|
||||||
use hyper_rustls::HttpsConnector;
|
use hyper_rustls::HttpsConnector;
|
||||||
use libflate::gzip;
|
use libflate::gzip;
|
||||||
use log::{error, trace};
|
use log::{error, trace, warn};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use percent_encoding::{percent_encode, CONTROLS};
|
use percent_encoding::{percent_encode, CONTROLS};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
use std::sync::atomic::{AtomicU16, Ordering::Relaxed};
|
||||||
use std::{io, result::Result};
|
use std::{io, result::Result};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
@ -36,6 +38,8 @@ pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
|
|||||||
RwLock::new(client)
|
RwLock::new(client)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
||||||
|
|
||||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||||
///
|
///
|
||||||
@ -304,7 +308,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make a request to a Reddit API and parse the JSON response
|
// Make a request to a Reddit API and parse the JSON response
|
||||||
// #[cached(size = 100, time = 30, result = true)]
|
#[cached(size = 100, time = 30, result = true)]
|
||||||
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||||
// Closure to quickly build errors
|
// Closure to quickly build errors
|
||||||
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
|
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
|
||||||
@ -312,6 +316,13 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
Err(format!("{msg}: {e} | {path}"))
|
Err(format!("{msg}: {e} | {path}"))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// First, handle rolling over the OAUTH_CLIENT if need be.
|
||||||
|
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::Relaxed);
|
||||||
|
if current_rate_limit < 10 {
|
||||||
|
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
|
||||||
|
tokio::spawn(force_refresh_token());
|
||||||
|
}
|
||||||
|
|
||||||
// Fetch the url...
|
// Fetch the url...
|
||||||
match reddit_get(path.clone(), quarantine).await {
|
match reddit_get(path.clone(), quarantine).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
@ -320,6 +331,11 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
// Ratelimit remaining
|
// Ratelimit remaining
|
||||||
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
|
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
|
||||||
trace!("Ratelimit remaining: {}", remaining);
|
trace!("Ratelimit remaining: {}", remaining);
|
||||||
|
if let Ok(remaining) = remaining.parse::<f32>().map(|f| f.round() as u16) {
|
||||||
|
OAUTH_RATELIMIT_REMAINING.store(remaining, Relaxed);
|
||||||
|
} else {
|
||||||
|
warn!("Failed to parse rate limit {remaining} from header.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ratelimit used
|
// Ratelimit used
|
||||||
@ -381,9 +397,12 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL";
|
||||||
|
|
||||||
#[tokio::test(flavor = "multi_thread")]
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
async fn test_localization_popular() {
|
async fn test_localization_popular() {
|
||||||
let val = json("/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL".to_string(), false).await.unwrap();
|
let val = json(POPULAR_URL.to_string(), false).await.unwrap();
|
||||||
assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap());
|
assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
use std::{collections::HashMap, time::Duration};
|
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
client::{CLIENT, OAUTH_CLIENT},
|
client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING},
|
||||||
oauth_resources::ANDROID_APP_VERSION_LIST,
|
oauth_resources::ANDROID_APP_VERSION_LIST,
|
||||||
};
|
};
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
use hyper::{client, Body, Method, Request};
|
use hyper::{client, Body, Method, Request};
|
||||||
use log::info;
|
use log::{info, trace};
|
||||||
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
@ -131,7 +131,9 @@ pub async fn token_daemon() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn force_refresh_token() {
|
pub async fn force_refresh_token() {
|
||||||
|
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::Relaxed));
|
||||||
OAUTH_CLIENT.write().await.refresh().await;
|
OAUTH_CLIENT.write().await.refresh().await;
|
||||||
|
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
|
Loading…
Reference in New Issue
Block a user