diff --git a/src/client.rs b/src/client.rs index d4701ba..6a202ff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -23,7 +23,13 @@ use crate::server::RequestExt; use crate::utils::format_url; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; +const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com"; + +const REDDIT_SHORT_URL_BASE: &str = "https://redd.it"; +const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it"; + const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com"; +const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com"; pub static CLIENT: Lazy>> = Lazy::new(|| { let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build(); @@ -40,6 +46,11 @@ pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99); pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); +static URL_PAIRS: [(&str, &str); 2] = [ + (ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST), + (REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST), +]; + /// Gets the canonical path for a resource on Reddit. This is accomplished by /// making a `HEAD` request to Reddit at the path given in `path`. /// @@ -53,8 +64,27 @@ pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); /// `Location` header. An `Err(String)` is returned if Reddit responds with a /// 429, or if we were unable to decode the value in the `Location` header. #[cached(size = 1024, time = 600, result = true)] -pub async fn canonical_path(path: String) -> Result, String> { - let res = reddit_head(path.clone(), true).await?; +pub async fn canonical_path(path: String, tries: i8) -> Result, String> { + if tries == 0 { + return Ok(None); + } + + // for each URL pair, try the HEAD request + let res = { + // for url base and host in URL_PAIRS, try reddit_short_head(path.clone(), true, url_base, url_base_host) and if it succeeds, set res. else, res = None + let mut res = None; + for (url_base, url_base_host) in URL_PAIRS { + res = reddit_short_head(path.clone(), true, url_base, url_base_host).await.ok(); + if let Some(res) = &res { + if !res.status().is_client_error() { + break; + } + } + } + res + }; + + let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?; let status = res.status().as_u16(); let policy_error = res.headers().get(header::RETRY_AFTER).is_some(); @@ -68,6 +98,7 @@ pub async fn canonical_path(path: String) -> Result, String> { let Ok(original) = val.to_str() else { return Err("Unable to decode Location header.".to_string()); }; + // We need to strip the .json suffix from the original path. // In addition, we want to remove share parameters. // Cut it off here instead of letting it propagate all the way @@ -80,7 +111,9 @@ pub async fn canonical_path(path: String) -> Result, String> { // also remove all Reddit domain parts with format_url. // Otherwise, it will literally redirect to Reddit.com. let uri = format_url(stripped_uri); - Ok(Some(uri)) + + // Decrement tries and try again + Box::pin(canonical_path(uri, tries - 1)).await } None => Ok(None), }, @@ -161,20 +194,26 @@ async fn stream(url: &str, req: &Request) -> Result, String /// Makes a GET request to Reddit at `path`. By default, this will honor HTTP /// 3xx codes Reddit returns and will automatically redirect. fn reddit_get(path: String, quarantine: bool) -> Boxed, String>> { - request(&Method::GET, path, true, quarantine) + request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST) } -/// Makes a HEAD request to Reddit at `path`. This will not follow redirects. -fn reddit_head(path: String, quarantine: bool) -> Boxed, String>> { - request(&Method::HEAD, path, false, quarantine) +/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects. +fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> { + request(&Method::HEAD, path, false, quarantine, base_path, host) } +// /// Makes a HEAD request to Reddit at `path`. This will not follow redirects. +// fn reddit_head(path: String, quarantine: bool) -> Boxed, String>> { +// request(&Method::HEAD, path, false, quarantine, false) +// } +// Unused - reddit_head is only ever called in the context of a short URL + /// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect` /// will recurse on the URL that Reddit provides in the Location HTTP header /// in its response. -fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed, String>> { +fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> { // Build Reddit URL from path. - let url = format!("{REDDIT_URL_BASE}{path}"); + let url = format!("{base_path}{path}"); // Construct the hyper client from the HTTPS connector. let client: Client<_, Body> = CLIENT.clone(); @@ -199,7 +238,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo .header("Client-Vendor-Id", vendor_id) .header("X-Reddit-Device-Id", device_id) .header("x-reddit-loid", loid) - .header("Host", "oauth.reddit.com") + .header("Host", host) .header("Authorization", &format!("Bearer {token}")) .header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" }) .header("Accept-Language", "en-US,en;q=0.5") @@ -254,6 +293,8 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo .to_string(), true, quarantine, + base_path, + host, ) .await; }; @@ -432,13 +473,13 @@ async fn test_localization_popular() { async fn test_obfuscated_share_link() { let share_link = "/r/rust/s/kPgq8WNHRK".into(); // Correct link without share parameters - let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc".into(); - assert_eq!(canonical_path(share_link).await, Ok(Some(canonical_link))); + let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/".into(); + assert_eq!(canonical_path(share_link, 3).await, Ok(Some(canonical_link))); } #[tokio::test(flavor = "multi_thread")] async fn test_share_link_strip_json() { let link = "/17krzvz".into(); - let canonical_link = "/r/nfl/comments/17krzvz/rapoport_sources_former_no_2_overall_pick/".into(); - assert_eq!(canonical_path(link).await, Ok(Some(canonical_link))); + let canonical_link = "/comments/17krzvz".into(); + assert_eq!(canonical_path(link, 3).await, Ok(Some(canonical_link))); } diff --git a/src/main.rs b/src/main.rs index 406a0d3..515a2a8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -341,7 +341,7 @@ async fn main() { let sub = req.param("sub").unwrap_or_default(); match req.param("id").as_deref() { // Share link - Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}")).await { + Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}"), 3).await { Ok(Some(path)) => Ok(redirect(&path)), Ok(None) => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await, Err(e) => error(req, &e).await, @@ -360,7 +360,7 @@ async fn main() { Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await, // Short link for post - Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}")).await { + Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}"), 3).await { Ok(path_opt) => match path_opt { Some(path) => Ok(redirect(&path)), None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,