Merge remote-tracking branch 'upstream/main'

+ repository link update
This commit is contained in:
2024-07-07 17:19:28 +12:00
9 changed files with 171 additions and 77 deletions

View File

@ -1,7 +1,9 @@
use arc_swap::ArcSwap;
use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
use hyper::client::HttpConnector;
use hyper::header::HeaderValue;
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
use hyper_rustls::HttpsConnector;
use libflate::gzip;
@ -11,9 +13,8 @@ use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicU16, Ordering::SeqCst};
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::{io, result::Result};
use tokio::sync::RwLock;
use crate::dbg_msg;
use crate::oauth::{force_refresh_token, token_daemon, Oauth};
@ -21,6 +22,7 @@ use crate::server::RequestExt;
use crate::utils::format_url;
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
let https = hyper_rustls::HttpsConnectorBuilder::new()
@ -32,14 +34,16 @@ pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
client::Client::builder().build(https)
});
pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
pub static OAUTH_CLIENT: Lazy<ArcSwap<Oauth>> = Lazy::new(|| {
let client = block_on(Oauth::new());
tokio::spawn(token_daemon());
RwLock::new(client)
ArcSwap::new(client.into())
});
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
/// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`.
///
@ -175,7 +179,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
let client: Client<_, Body> = CLIENT.clone();
let (token, vendor_id, device_id, user_agent, loid) = {
let client = block_on(OAUTH_CLIENT.read());
let client = OAUTH_CLIENT.load_full();
(
client.token.clone(),
client.headers_map.get("Client-Vendor-Id").cloned().unwrap_or_default(),
@ -219,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
if !redirect {
return Ok(response);
};
let location_header = response.headers().get(header::LOCATION);
if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) {
return Err("Reddit response was invalid".to_string());
}
return request(
method,
response
.headers()
.get(header::LOCATION)
location_header
.map(|val| {
// We need to make adjustments to the URI
// we get back from Reddit. Namely, we
@ -237,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// required.
//
// 2. Percent-encode the path.
let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string();
let new_path = percent_encode(val.as_bytes(), CONTROLS)
.to_string()
.trim_start_matches(REDDIT_URL_BASE)
.trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE)
.to_string();
format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" })
})
.unwrap_or_default()
@ -296,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
}
}
Err(e) => {
dbg_msg!("{} {}: {}", method, path, e);
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
Err(e.to_string())
}
@ -318,36 +327,28 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
// First, handle rolling over the OAUTH_CLIENT if need be.
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
if current_rate_limit < 10 {
let is_rolling_over = OAUTH_IS_ROLLING_OVER.load(Ordering::SeqCst);
if current_rate_limit < 10 && !is_rolling_over {
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
tokio::spawn(force_refresh_token());
}
OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst);
// Fetch the url...
match reddit_get(path.clone(), quarantine).await {
Ok(response) => {
let status = response.status();
// Ratelimit remaining
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
trace!("Ratelimit remaining: {}", remaining);
if let Ok(remaining) = remaining.parse::<f32>().map(|f| f.round() as u16) {
OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst);
} else {
warn!("Failed to parse rate limit {remaining} from header.");
}
}
// Ratelimit used
if let Some(Ok(used)) = response.headers().get("x-ratelimit-used").map(|val| val.to_str()) {
trace!("Ratelimit used: {}", used);
}
// Ratelimit reset
let reset = if let Some(Ok(reset)) = response.headers().get("x-ratelimit-reset").map(|val| val.to_str()) {
trace!("Ratelimit reset: {}", reset);
Some(reset.to_string())
let reset: Option<String> = if let (Some(remaining), Some(reset), Some(used)) = (
response.headers().get("x-ratelimit-remaining").and_then(|val| val.to_str().ok().map(|s| s.to_string())),
response.headers().get("x-ratelimit-reset").and_then(|val| val.to_str().ok().map(|s| s.to_string())),
response.headers().get("x-ratelimit-used").and_then(|val| val.to_str().ok().map(|s| s.to_string())),
) {
trace!(
"Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. Resets in {reset}. Rollover: {}. Ratelimit used: {used}",
if is_rolling_over { "yes" } else { "no" },
);
Some(reset)
} else {
None
};
@ -358,8 +359,13 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
let has_remaining = body.has_remaining();
if !has_remaining {
// Rate limited, so spawn a force_refresh_token()
tokio::spawn(force_refresh_token());
return match reset {
Some(val) => Err(format!("Reddit rate limit exceeded. Will reset in: {val}")),
Some(val) => Err(format!(
"Reddit rate limit exceeded. Try refreshing in a few seconds.\
Rate limit will reset in: {val}"
)),
None => Err("Reddit rate limit exceeded".to_string()),
};
}

View File

@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
use crate::{
client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING},
client::{CLIENT, OAUTH_CLIENT, OAUTH_IS_ROLLING_OVER, OAUTH_RATELIMIT_REMAINING},
oauth_resources::ANDROID_APP_VERSION_LIST,
};
use base64::{engine::general_purpose, Engine as _};
@ -98,21 +98,13 @@ impl Oauth {
Some(())
}
async fn refresh(&mut self) -> Option<()> {
// Refresh is actually just a subsequent login with the same headers (without the old token
// or anything). This logic is handled in login, so we just call login again.
let refresh = self.login().await;
info!("Refreshing OAuth token... {}", if refresh.is_some() { "success" } else { "failed" });
refresh
}
}
pub async fn token_daemon() {
// Monitor for refreshing token
loop {
// Get expiry time - be sure to not hold the read lock
let expires_in = { OAUTH_CLIENT.read().await.expires_in };
let expires_in = { OAUTH_CLIENT.load_full().expires_in };
// sleep for the expiry time minus 2 minutes
let duration = Duration::from_secs(expires_in - 120);
@ -125,14 +117,22 @@ pub async fn token_daemon() {
// Refresh token - in its own scope
{
OAUTH_CLIENT.write().await.refresh().await;
force_refresh_token().await;
}
}
}
pub async fn force_refresh_token() {
if OAUTH_IS_ROLLING_OVER.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
trace!("Skipping refresh token roll over, already in progress");
return;
}
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
OAUTH_CLIENT.write().await.refresh().await;
let new_client = Oauth::new().await;
OAUTH_CLIENT.swap(new_client.into());
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
OAUTH_IS_ROLLING_OVER.store(false, Ordering::SeqCst);
}
#[derive(Debug, Clone, Default)]
@ -180,21 +180,21 @@ fn choose<T: Copy>(list: &[T]) -> T {
#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_client() {
assert!(!OAUTH_CLIENT.read().await.token.is_empty());
assert!(!OAUTH_CLIENT.load_full().token.is_empty());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_client_refresh() {
OAUTH_CLIENT.write().await.refresh().await.unwrap();
force_refresh_token().await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_token_exists() {
assert!(!OAUTH_CLIENT.read().await.token.is_empty());
assert!(!OAUTH_CLIENT.load_full().token.is_empty());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_oauth_headers_len() {
assert!(OAUTH_CLIENT.read().await.headers_map.len() >= 3);
assert!(OAUTH_CLIENT.load_full().headers_map.len() >= 3);
}
#[test]

View File

@ -64,7 +64,7 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
let post_sort = req.cookie("post_sort").map_or_else(|| "hot".to_string(), |c| c.value().to_string());
let sort = req.param("sort").unwrap_or_else(|| req.param("id").unwrap_or(post_sort));
let mut sub_name = req.param("sub").unwrap_or(if front_page == "default" || front_page.is_empty() {
let sub_name = req.param("sub").unwrap_or(if front_page == "default" || front_page.is_empty() {
if subscribed.is_empty() {
"popular".to_string()
} else {
@ -84,11 +84,6 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
return Ok(redirect(&["/user/", &sub_name[2..]].concat()));
}
// If multi-sub, replace + with url encoded +
if sub_name.contains('+') {
sub_name = sub_name.replace('+', "%2B");
}
// Request subreddit metadata
let sub = if !sub_name.contains('+') && sub_name != subscribed && sub_name != "popular" && sub_name != "all" {
// Regular subreddit
@ -124,7 +119,7 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
params.push_str(&format!("&geo_filter={geo_filter}"));
}
let path = format!("/r/{sub_name}/{sort}.json?{}{params}", req.uri().query().unwrap_or_default());
let path = format!("/r/{}/{sort}.json?{}{params}", sub_name.replace('+', "%2B"), req.uri().query().unwrap_or_default());
let url = String::from(req.uri().path_and_query().map_or("", |val| val.as_str()));
let redirect_url = url[1..].replace('?', "%3F").replace('&', "%26").replace('+', "%2B");
let filters = get_filters(&req);
@ -150,6 +145,10 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
let no_posts = posts.is_empty();
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
if sort == "new" {
posts.sort_by(|a, b| b.created_ts.cmp(&a.created_ts));
posts.sort_by(|a, b| b.flags.stickied.cmp(&a.flags.stickied));
}
Ok(template(&SubredditTemplate {
sub,
posts,

View File

@ -169,6 +169,7 @@ pub struct Media {
pub width: i64,
pub height: i64,
pub poster: String,
pub download_name: String,
}
impl Media {
@ -235,6 +236,15 @@ impl Media {
let alt_url = alt_url_val.map_or(String::new(), |val| format_url(val.as_str().unwrap_or_default()));
let download_name = if post_type == "image" || post_type == "gif" || post_type == "video" {
let permalink_base = url_path_basename(data["permalink"].as_str().unwrap_or_default());
let media_url_base = url_path_basename(url_val.as_str().unwrap_or_default());
format!("redlib_{permalink_base}_{media_url_base}")
} else {
String::new()
};
(
post_type.to_string(),
Self {
@ -245,6 +255,7 @@ impl Media {
width: source["width"].as_i64().unwrap_or_default(),
height: source["height"].as_i64().unwrap_or_default(),
poster: format_url(source["url"].as_str().unwrap_or_default()),
download_name,
},
gallery,
)
@ -298,6 +309,7 @@ pub struct Post {
pub body: String,
pub author: Author,
pub permalink: String,
pub link_title: String,
pub poll: Option<Poll>,
pub score: (String, String),
pub upvote_ratio: i64,
@ -309,6 +321,7 @@ pub struct Post {
pub domain: String,
pub rel_time: String,
pub created: String,
pub created_ts: u64,
pub num_duplicates: u64,
pub comments: (String, String),
pub gallery: Vec<GalleryMedia>,
@ -340,6 +353,7 @@ impl Post {
let data = &post["data"];
let (rel_time, created) = time(data["created_utc"].as_f64().unwrap_or_default());
let created_ts = data["created_utc"].as_f64().unwrap_or_default().round() as u64;
let score = data["score"].as_i64().unwrap_or_default();
let ratio: f64 = data["upvote_ratio"].as_f64().unwrap_or(1.0) * 100.0;
let title = val(post, "title");
@ -386,6 +400,7 @@ impl Post {
width: data["thumbnail_width"].as_i64().unwrap_or_default(),
height: data["thumbnail_height"].as_i64().unwrap_or_default(),
poster: String::new(),
download_name: String::new(),
},
media,
domain: val(post, "domain"),
@ -409,9 +424,11 @@ impl Post {
stickied: data["stickied"].as_bool().unwrap_or_default() || data["pinned"].as_bool().unwrap_or_default(),
},
permalink: val(post, "permalink"),
link_title: val(post, "link_title"),
poll: Poll::parse(&data["poll_data"]),
rel_time,
created,
created_ts,
num_duplicates: post["data"]["num_duplicates"].as_u64().unwrap_or(0),
comments: format_num(data["num_comments"].as_i64().unwrap_or_default()),
gallery,
@ -420,7 +437,6 @@ impl Post {
ws_url: val(post, "websocket_url"),
});
}
Ok((posts, res["data"]["after"].as_str().unwrap_or_default().to_string()))
}
}
@ -691,6 +707,8 @@ pub async fn parse_post(post: &Value) -> Post {
// Determine the type of media along with the media URL
let (post_type, media, gallery) = Media::parse(&post["data"]).await;
let created_ts = post["data"]["created_utc"].as_f64().unwrap_or_default().round() as u64;
let awards: Awards = Awards::parse(&post["data"]["all_awardings"]);
let permalink = val(post, "permalink");
@ -727,6 +745,7 @@ pub async fn parse_post(post: &Value) -> Post {
distinguished: val(post, "distinguished"),
},
permalink,
link_title: val(post, "link_title"),
poll,
score: format_num(score),
upvote_ratio: ratio as i64,
@ -738,6 +757,7 @@ pub async fn parse_post(post: &Value) -> Post {
width: post["data"]["thumbnail_width"].as_i64().unwrap_or_default(),
height: post["data"]["thumbnail_height"].as_i64().unwrap_or_default(),
poster: String::new(),
download_name: String::new(),
},
flair: Flair {
flair_parts: FlairPart::parse(
@ -761,6 +781,7 @@ pub async fn parse_post(post: &Value) -> Post {
domain: val(post, "domain"),
rel_time,
created,
created_ts,
num_duplicates: post["data"]["num_duplicates"].as_u64().unwrap_or(0),
comments: format_num(post["data"]["num_comments"].as_i64().unwrap_or_default()),
gallery,
@ -1120,6 +1141,20 @@ pub async fn nsfw_landing(req: Request<Body>, req_url: String) -> Result<Respons
Ok(Response::builder().status(403).header("content-type", "text/html").body(body.into()).unwrap_or_default())
}
// Returns the last (non-empty) segment of a path string
pub fn url_path_basename(path: &str) -> String {
let url_result = Url::parse(format!("https://libredd.it/{path}").as_str());
if url_result.is_err() {
path.to_string()
} else {
let mut url = url_result.unwrap();
url.path_segments_mut().unwrap().pop_if_empty();
url.path_segments().unwrap().last().unwrap().to_string()
}
}
#[cfg(test)]
mod tests {
use super::{format_num, format_url, rewrite_urls};
@ -1228,3 +1263,19 @@ fn test_rewriting_image_links() {
let output = r#"<p><figure><a href="/preview/pre/6awags382xo31.png?width=2560&amp;format=png&amp;auto=webp&amp;s=9c563aed4f07a91bdd249b5a3cea43a79710dcfc"><img loading="lazy" src="/preview/pre/6awags382xo31.png?width=2560&amp;format=png&amp;auto=webp&amp;s=9c563aed4f07a91bdd249b5a3cea43a79710dcfc"></a><figcaption>caption 1</figcaption></figure></p"#;
assert_eq!(rewrite_urls(input), output);
}
#[test]
fn test_url_path_basename() {
// without trailing slash
assert_eq!(url_path_basename("/first/last"), "last");
// with trailing slash
assert_eq!(url_path_basename("/first/last/"), "last");
// with query parameters
assert_eq!(url_path_basename("/first/last/?some=query"), "last");
// file path
assert_eq!(url_path_basename("/cdn/image.jpg"), "image.jpg");
// when a full url is passed instead of just a path
assert_eq!(url_path_basename("https://doma.in/first/last"), "last");
// empty path
assert_eq!(url_path_basename("/"), "");
}