From 07bf20dbc0913065ab0326168bf0b4a623fb0534 Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Wed, 26 Jun 2024 19:19:30 -0400 Subject: [PATCH] feat(oauth): roll over oauth key on rate limit --- src/client.rs | 25 ++++++++++++++++++++++--- src/oauth.rs | 8 +++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/client.rs b/src/client.rs index 1aeb10c..fb6b56d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,11 +5,13 @@ use hyper::client::HttpConnector; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; use libflate::gzip; -use log::{error, trace}; +use log::{error, trace, warn}; use once_cell::sync::Lazy; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; +use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicU16, Ordering::Relaxed}; use std::{io, result::Result}; use tokio::sync::RwLock; @@ -36,6 +38,8 @@ pub static OAUTH_CLIENT: Lazy> = Lazy::new(|| { RwLock::new(client) }); +pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); + /// Gets the canonical path for a resource on Reddit. This is accomplished by /// making a `HEAD` request to Reddit at the path given in `path`. /// @@ -304,7 +308,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo } // Make a request to a Reddit API and parse the JSON response -// #[cached(size = 100, time = 30, result = true)] +#[cached(size = 100, time = 30, result = true)] pub async fn json(path: String, quarantine: bool) -> Result { // Closure to quickly build errors let err = |msg: &str, e: String, path: String| -> Result { @@ -312,6 +316,13 @@ pub async fn json(path: String, quarantine: bool) -> Result { Err(format!("{msg}: {e} | {path}")) }; + // First, handle rolling over the OAUTH_CLIENT if need be. + let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::Relaxed); + if current_rate_limit < 10 { + warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()"); + tokio::spawn(force_refresh_token()); + } + // Fetch the url... match reddit_get(path.clone(), quarantine).await { Ok(response) => { @@ -320,6 +331,11 @@ pub async fn json(path: String, quarantine: bool) -> Result { // Ratelimit remaining if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) { trace!("Ratelimit remaining: {}", remaining); + if let Ok(remaining) = remaining.parse::().map(|f| f.round() as u16) { + OAUTH_RATELIMIT_REMAINING.store(remaining, Relaxed); + } else { + warn!("Failed to parse rate limit {remaining} from header."); + } } // Ratelimit used @@ -381,9 +397,12 @@ pub async fn json(path: String, quarantine: bool) -> Result { } } +#[cfg(test)] +static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL"; + #[tokio::test(flavor = "multi_thread")] async fn test_localization_popular() { - let val = json("/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL".to_string(), false).await.unwrap(); + let val = json(POPULAR_URL.to_string(), false).await.unwrap(); assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap()); } diff --git a/src/oauth.rs b/src/oauth.rs index cea7693..61e8044 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -1,12 +1,12 @@ -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, sync::atomic::Ordering, time::Duration}; use crate::{ - client::{CLIENT, OAUTH_CLIENT}, + client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING}, oauth_resources::ANDROID_APP_VERSION_LIST, }; use base64::{engine::general_purpose, Engine as _}; use hyper::{client, Body, Method, Request}; -use log::info; +use log::{info, trace}; use serde_json::json; @@ -131,7 +131,9 @@ pub async fn token_daemon() { } pub async fn force_refresh_token() { + trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::Relaxed)); OAUTH_CLIENT.write().await.refresh().await; + OAUTH_RATELIMIT_REMAINING.store(99, Ordering::Relaxed); } #[derive(Debug, Clone, Default)]