use arc_swap::ArcSwap; use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; use hyper::client::HttpConnector; use hyper::header::HeaderValue; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; use libflate::gzip; use log::{error, trace, warn}; use once_cell::sync::Lazy; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, AtomicU16}; use std::{io, result::Result}; use crate::dbg_msg; use crate::oauth::{force_refresh_token, token_daemon, Oauth}; use crate::server::RequestExt; use crate::utils::format_url; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com"; const REDDIT_SHORT_URL_BASE: &str = "https://redd.it"; const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it"; const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com"; const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com"; pub static CLIENT: Lazy>> = Lazy::new(|| { let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build(); client::Client::builder().build(https) }); pub static OAUTH_CLIENT: Lazy> = Lazy::new(|| { let client = block_on(Oauth::new()); tokio::spawn(token_daemon()); ArcSwap::new(client.into()) }); pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); static URL_PAIRS: [(&str, &str); 2] = [ (ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST), (REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST), ]; /// Gets the canonical path for a resource on Reddit. This is accomplished by /// making a `HEAD` request to Reddit at the path given in `path`. /// /// This function returns `Ok(Some(path))`, where `path`'s value is identical /// to that of the value of the argument `path`, if Reddit responds to our /// `HEAD` request with a 2xx-family HTTP code. It will also return an /// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a /// `Location` header in the response, and the HTTP code is in the 3xx-family; /// the `String` will contain the path as reported in `Location`. The return /// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a /// `Location` header. An `Err(String)` is returned if Reddit responds with a /// 429, or if we were unable to decode the value in the `Location` header. #[cached(size = 1024, time = 600, result = true)] #[async_recursion::async_recursion] pub async fn canonical_path(path: String, tries: i8) -> Result, String> { if tries == 0 { return Ok(None); } // for each URL pair, try the HEAD request let res = { // for url base and host in URL_PAIRS, try reddit_short_head(path.clone(), true, url_base, url_base_host) and if it succeeds, set res. else, res = None let mut res = None; for (url_base, url_base_host) in URL_PAIRS { res = reddit_short_head(path.clone(), true, url_base, url_base_host).await.ok(); if let Some(res) = &res { if !res.status().is_client_error() { break; } } } res }; let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?; let status = res.status().as_u16(); let policy_error = res.headers().get(header::RETRY_AFTER).is_some(); match status { // If Reddit responds with a 2xx, then the path is already canonical. 200..=299 => Ok(Some(path)), // If Reddit responds with a 301, then the path is redirected. 301 => match res.headers().get(header::LOCATION) { Some(val) => { let Ok(original) = val.to_str() else { return Err("Unable to decode Location header.".to_string()); }; // We need to strip the .json suffix from the original path. // In addition, we want to remove share parameters. // Cut it off here instead of letting it propagate all the way // to main.rs let stripped_uri = original.strip_suffix(".json").unwrap_or(original).split('?').next().unwrap_or_default(); // The reason why we now have to format_url, is because the new OAuth // endpoints seem to return full paths, instead of relative paths. // So we need to strip the .json suffix from the original path, and // also remove all Reddit domain parts with format_url. // Otherwise, it will literally redirect to Reddit.com. let uri = format_url(stripped_uri); // Decrement tries and try again canonical_path(uri, tries - 1).await } None => Ok(None), }, // If Reddit responds with anything other than 3xx (except for the 2xx and 301 // as above), return a None. 300..=399 => Ok(None), // Rate limiting 429 => Err("Too many requests.".to_string()), // Special condition rate limiting - https://github.com/redlib-org/redlib/issues/229 403 if policy_error => Err("Too many requests.".to_string()), _ => Ok( res .headers() .get(header::LOCATION) .map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()), ), } } pub async fn proxy(req: Request, format: &str) -> Result, String> { let mut url = format!("{format}?{}", req.uri().query().unwrap_or_default()); // For each parameter in request for (name, value) in &req.params() { // Fill the parameter value in the url url = url.replace(&format!("{{{name}}}"), value); } stream(&url, &req).await } async fn stream(url: &str, req: &Request) -> Result, String> { // First parameter is target URL (mandatory). let parsed_uri = url.parse::().map_err(|_| "Couldn't parse URL".to_string())?; // Build the hyper client from the HTTPS connector. let client: Client<_, Body> = { let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build(); client::Client::builder().pool_max_idle_per_host(0).build(https) }; let mut builder = Request::get(parsed_uri); // Copy useful headers from original request for &key in &["Range", "If-Modified-Since", "Cache-Control"] { if let Some(value) = req.headers().get(key) { builder = builder.header(key, value); } } let stream_request = builder.body(Body::empty()).map_err(|_| "Couldn't build empty body in stream".to_string())?; client .request(stream_request) .await .map(|mut res| { let mut rm = |key: &str| res.headers_mut().remove(key); rm("access-control-expose-headers"); rm("server"); rm("vary"); rm("etag"); rm("x-cdn"); rm("x-cdn-client-region"); rm("x-cdn-name"); rm("x-cdn-server-region"); rm("x-reddit-cdn"); rm("x-reddit-video-features"); rm("Nel"); rm("Report-To"); res }) .map_err(|e| e.to_string()) } /// Makes a GET request to Reddit at `path`. By default, this will honor HTTP /// 3xx codes Reddit returns and will automatically redirect. fn reddit_get(path: String, quarantine: bool) -> Boxed, String>> { request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST) } /// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects. fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> { request(&Method::HEAD, path, false, quarantine, base_path, host) } // /// Makes a HEAD request to Reddit at `path`. This will not follow redirects. // fn reddit_head(path: String, quarantine: bool) -> Boxed, String>> { // request(&Method::HEAD, path, false, quarantine, false) // } // Unused - reddit_head is only ever called in the context of a short URL /// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect` /// will recurse on the URL that Reddit provides in the Location HTTP header /// in its response. fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> { // Build Reddit URL from path. let url = format!("{base_path}{path}"); // Construct the hyper client from the HTTPS connector. let client: Client<_, Body> = { let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build(); client::Client::builder().pool_max_idle_per_host(0).build(https) }; let (token, vendor_id, device_id, user_agent, loid) = { let client = OAUTH_CLIENT.load_full(); ( client.token.clone(), client.headers_map.get("Client-Vendor-Id").cloned().unwrap_or_default(), client.headers_map.get("X-Reddit-Device-Id").cloned().unwrap_or_default(), client.headers_map.get("User-Agent").cloned().unwrap_or_default(), client.headers_map.get("x-reddit-loid").cloned().unwrap_or_default(), ) }; // Build request to Reddit. When making a GET, request gzip compression. // (Reddit doesn't do brotli yet.) let builder = Request::builder() .method(method) .uri(&url) .header("User-Agent", user_agent) .header("Client-Vendor-Id", vendor_id) .header("X-Reddit-Device-Id", device_id) .header("x-reddit-loid", loid) .header("Host", host) .header("Authorization", &format!("Bearer {token}")) .header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" }) .header("Accept-Language", "en-US,en;q=0.5") .header("Connection", "keep-alive") .header( "Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%2C%20%22pref_gated_sr_optin%22%3A%20true%7D" } else { "" }, ) .body(Body::empty()); async move { match builder { Ok(req) => match client.request(req).await { Ok(mut response) => { // Reddit may respond with a 3xx. Decide whether or not to // redirect based on caller params. if response.status().is_redirection() { if !redirect { return Ok(response); }; let location_header = response.headers().get(header::LOCATION); if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) { return Err("Reddit response was invalid".to_string()); } return request( method, location_header .map(|val| { // We need to make adjustments to the URI // we get back from Reddit. Namely, we // must: // // 1. Remove the authority (e.g. // https://www.reddit.com) that may be // present, so that we recurse on the // path (and query parameters) as // required. // // 2. Percent-encode the path. let new_path = percent_encode(val.as_bytes(), CONTROLS) .to_string() .trim_start_matches(REDDIT_URL_BASE) .trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE) .to_string(); format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" }) }) .unwrap_or_default() .to_string(), true, quarantine, base_path, host, ) .await; }; match response.headers().get(header::CONTENT_ENCODING) { // Content not compressed. None => Ok(response), // Content encoded (hopefully with gzip). Some(hdr) => { match hdr.to_str() { Ok(val) => match val { "gzip" => {} "identity" => return Ok(response), _ => return Err("Reddit response was encoded with an unsupported compressor".to_string()), }, Err(_) => return Err("Reddit response was invalid".to_string()), } // We get here if the body is gzip-compressed. // The body must be something that implements // std::io::Read, hence the conversion to // bytes::buf::Buf and then transformation into a // Reader. let mut decompressed: Vec; { let mut aggregated_body = match body::aggregate(response.body_mut()).await { Ok(b) => b.reader(), Err(e) => return Err(e.to_string()), }; let mut decoder = match gzip::Decoder::new(&mut aggregated_body) { Ok(decoder) => decoder, Err(e) => return Err(e.to_string()), }; decompressed = Vec::::new(); if let Err(e) = io::copy(&mut decoder, &mut decompressed) { return Err(e.to_string()); }; } response.headers_mut().remove(header::CONTENT_ENCODING); response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into()); *(response.body_mut()) = Body::from(decompressed); Ok(response) } } } Err(e) => { dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e); Err(e.to_string()) } }, Err(_) => Err("Post url contains non-ASCII characters".to_string()), } } .boxed() } // Make a request to a Reddit API and parse the JSON response #[cached(size = 100, time = 30, result = true)] pub async fn json(path: String, quarantine: bool) -> Result { // Closure to quickly build errors let err = |msg: &str, e: String, path: String| -> Result { // eprintln!("{} - {}: {}", url, msg, e); Err(format!("{msg}: {e} | {path}")) }; // First, handle rolling over the OAUTH_CLIENT if need be. let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst); let is_rolling_over = OAUTH_IS_ROLLING_OVER.load(Ordering::SeqCst); if current_rate_limit < 10 && !is_rolling_over { warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()"); tokio::spawn(force_refresh_token()); } OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst); // Fetch the url... match reddit_get(path.clone(), quarantine).await { Ok(response) => { let status = response.status(); let reset: Option = if let (Some(remaining), Some(reset), Some(used)) = ( response.headers().get("x-ratelimit-remaining").and_then(|val| val.to_str().ok().map(|s| s.to_string())), response.headers().get("x-ratelimit-reset").and_then(|val| val.to_str().ok().map(|s| s.to_string())), response.headers().get("x-ratelimit-used").and_then(|val| val.to_str().ok().map(|s| s.to_string())), ) { trace!( "Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. Resets in {reset}. Rollover: {}. Ratelimit used: {used}", if is_rolling_over { "yes" } else { "no" }, ); Some(reset) } else { None }; // asynchronously aggregate the chunks of the body match hyper::body::aggregate(response).await { Ok(body) => { let has_remaining = body.has_remaining(); if !has_remaining { // Rate limited, so spawn a force_refresh_token() tokio::spawn(force_refresh_token()); return match reset { Some(val) => Err(format!( "Reddit rate limit exceeded. Try refreshing in a few seconds.\ Rate limit will reset in: {val}" )), None => Err("Reddit rate limit exceeded".to_string()), }; } // Parse the response from Reddit as JSON match serde_json::from_reader(body.reader()) { Ok(value) => { let json: Value = value; // If user is suspended if let Some(data) = json.get("data") { if let Some(is_suspended) = data.get("is_suspended").and_then(Value::as_bool) { if is_suspended { return Err("suspended".into()); } } } // If Reddit returned an error if json["error"].is_i64() { // OAuth token has expired; http status 401 if json["message"] == "Unauthorized" { error!("Forcing a token refresh"); let () = force_refresh_token().await; return Err("OAuth token has expired. Please refresh the page!".to_string()); } // Handle quarantined if json["reason"] == "quarantined" { return Err("quarantined".into()); } // Handle gated if json["reason"] == "gated" { return Err("gated".into()); } // Handle private subs if json["reason"] == "private" { return Err("private".into()); } // Handle banned subs if json["reason"] == "banned" { return Err("banned".into()); } Err(format!("Reddit error {} \"{}\": {} | {path}", json["error"], json["reason"], json["message"])) } else { Ok(json) } } Err(e) => { error!("Got an invalid response from reddit {e}. Status code: {status}"); if status.is_server_error() { Err("Reddit is having issues, check if there's an outage".to_string()) } else { err("Failed to parse page JSON data", e.to_string(), path) } } } } Err(e) => err("Failed receiving body from Reddit", e.to_string(), path), } } Err(e) => err("Couldn't send request to Reddit", e, path), } } #[cfg(test)] static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL"; #[tokio::test(flavor = "multi_thread")] async fn test_localization_popular() { let val = json(POPULAR_URL.to_string(), false).await.unwrap(); assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap()); } #[tokio::test(flavor = "multi_thread")] async fn test_obfuscated_share_link() { let share_link = "/r/rust/s/kPgq8WNHRK".into(); // Correct link without share parameters let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/".into(); assert_eq!(canonical_path(share_link, 3).await, Ok(Some(canonical_link))); } #[tokio::test(flavor = "multi_thread")] async fn test_share_link_strip_json() { let link = "/17krzvz".into(); let canonical_link = "/comments/17krzvz".into(); assert_eq!(canonical_path(link, 3).await, Ok(Some(canonical_link))); } #[tokio::test(flavor = "multi_thread")] async fn test_private_sub() { let link = json("/r/suicide/about.json?raw_json=1".into(), true).await; assert!(link.is_err()); assert_eq!(link, Err("private".into())); } #[tokio::test(flavor = "multi_thread")] async fn test_banned_sub() { let link = json("/r/aaa/about.json?raw_json=1".into(), true).await; assert!(link.is_err()); assert_eq!(link, Err("banned".into())); } #[tokio::test(flavor = "multi_thread")] async fn test_gated_sub() { // quarantine to false to specifically catch when we _don't_ catch it let link = json("/r/drugs/about.json?raw_json=1".into(), false).await; assert!(link.is_err()); assert_eq!(link, Err("gated".into())); }