fix(client): fix failing tests, retries for canonical_path

This commit is contained in:
Matthew Esposito 2024-09-20 23:57:18 -04:00
parent 793047f63f
commit 7156be6ad0
2 changed files with 57 additions and 16 deletions

View File

@ -23,7 +23,13 @@ use crate::server::RequestExt;
use crate::utils::format_url; use crate::utils::format_url;
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com";
const REDDIT_SHORT_URL_BASE: &str = "https://redd.it";
const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it";
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com"; const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com";
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| { pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build(); let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build();
@ -40,6 +46,11 @@ pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false); pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
static URL_PAIRS: [(&str, &str); 2] = [
(ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST),
(REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST),
];
/// Gets the canonical path for a resource on Reddit. This is accomplished by /// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`. /// making a `HEAD` request to Reddit at the path given in `path`.
/// ///
@ -53,8 +64,27 @@ pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
/// `Location` header. An `Err(String)` is returned if Reddit responds with a /// `Location` header. An `Err(String)` is returned if Reddit responds with a
/// 429, or if we were unable to decode the value in the `Location` header. /// 429, or if we were unable to decode the value in the `Location` header.
#[cached(size = 1024, time = 600, result = true)] #[cached(size = 1024, time = 600, result = true)]
pub async fn canonical_path(path: String) -> Result<Option<String>, String> { pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, String> {
let res = reddit_head(path.clone(), true).await?; if tries == 0 {
return Ok(None);
}
// for each URL pair, try the HEAD request
let res = {
// for url base and host in URL_PAIRS, try reddit_short_head(path.clone(), true, url_base, url_base_host) and if it succeeds, set res. else, res = None
let mut res = None;
for (url_base, url_base_host) in URL_PAIRS {
res = reddit_short_head(path.clone(), true, url_base, url_base_host).await.ok();
if let Some(res) = &res {
if !res.status().is_client_error() {
break;
}
}
}
res
};
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
let status = res.status().as_u16(); let status = res.status().as_u16();
let policy_error = res.headers().get(header::RETRY_AFTER).is_some(); let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
@ -68,6 +98,7 @@ pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
let Ok(original) = val.to_str() else { let Ok(original) = val.to_str() else {
return Err("Unable to decode Location header.".to_string()); return Err("Unable to decode Location header.".to_string());
}; };
// We need to strip the .json suffix from the original path. // We need to strip the .json suffix from the original path.
// In addition, we want to remove share parameters. // In addition, we want to remove share parameters.
// Cut it off here instead of letting it propagate all the way // Cut it off here instead of letting it propagate all the way
@ -80,7 +111,9 @@ pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
// also remove all Reddit domain parts with format_url. // also remove all Reddit domain parts with format_url.
// Otherwise, it will literally redirect to Reddit.com. // Otherwise, it will literally redirect to Reddit.com.
let uri = format_url(stripped_uri); let uri = format_url(stripped_uri);
Ok(Some(uri))
// Decrement tries and try again
Box::pin(canonical_path(uri, tries - 1)).await
} }
None => Ok(None), None => Ok(None),
}, },
@ -161,20 +194,26 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP /// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
/// 3xx codes Reddit returns and will automatically redirect. /// 3xx codes Reddit returns and will automatically redirect.
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> { fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
request(&Method::GET, path, true, quarantine) request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
} }
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects. /// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> { fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
request(&Method::HEAD, path, false, quarantine) request(&Method::HEAD, path, false, quarantine, base_path, host)
} }
// /// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
// fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
// request(&Method::HEAD, path, false, quarantine, false)
// }
// Unused - reddit_head is only ever called in the context of a short URL
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect` /// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
/// will recurse on the URL that Reddit provides in the Location HTTP header /// will recurse on the URL that Reddit provides in the Location HTTP header
/// in its response. /// in its response.
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> { fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
// Build Reddit URL from path. // Build Reddit URL from path.
let url = format!("{REDDIT_URL_BASE}{path}"); let url = format!("{base_path}{path}");
// Construct the hyper client from the HTTPS connector. // Construct the hyper client from the HTTPS connector.
let client: Client<_, Body> = CLIENT.clone(); let client: Client<_, Body> = CLIENT.clone();
@ -199,7 +238,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
.header("Client-Vendor-Id", vendor_id) .header("Client-Vendor-Id", vendor_id)
.header("X-Reddit-Device-Id", device_id) .header("X-Reddit-Device-Id", device_id)
.header("x-reddit-loid", loid) .header("x-reddit-loid", loid)
.header("Host", "oauth.reddit.com") .header("Host", host)
.header("Authorization", &format!("Bearer {token}")) .header("Authorization", &format!("Bearer {token}"))
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" }) .header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
.header("Accept-Language", "en-US,en;q=0.5") .header("Accept-Language", "en-US,en;q=0.5")
@ -254,6 +293,8 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
.to_string(), .to_string(),
true, true,
quarantine, quarantine,
base_path,
host,
) )
.await; .await;
}; };
@ -432,13 +473,13 @@ async fn test_localization_popular() {
async fn test_obfuscated_share_link() { async fn test_obfuscated_share_link() {
let share_link = "/r/rust/s/kPgq8WNHRK".into(); let share_link = "/r/rust/s/kPgq8WNHRK".into();
// Correct link without share parameters // Correct link without share parameters
let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc".into(); let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/".into();
assert_eq!(canonical_path(share_link).await, Ok(Some(canonical_link))); assert_eq!(canonical_path(share_link, 3).await, Ok(Some(canonical_link)));
} }
#[tokio::test(flavor = "multi_thread")] #[tokio::test(flavor = "multi_thread")]
async fn test_share_link_strip_json() { async fn test_share_link_strip_json() {
let link = "/17krzvz".into(); let link = "/17krzvz".into();
let canonical_link = "/r/nfl/comments/17krzvz/rapoport_sources_former_no_2_overall_pick/".into(); let canonical_link = "/comments/17krzvz".into();
assert_eq!(canonical_path(link).await, Ok(Some(canonical_link))); assert_eq!(canonical_path(link, 3).await, Ok(Some(canonical_link)));
} }

View File

@ -341,7 +341,7 @@ async fn main() {
let sub = req.param("sub").unwrap_or_default(); let sub = req.param("sub").unwrap_or_default();
match req.param("id").as_deref() { match req.param("id").as_deref() {
// Share link // Share link
Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}")).await { Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}"), 3).await {
Ok(Some(path)) => Ok(redirect(&path)), Ok(Some(path)) => Ok(redirect(&path)),
Ok(None) => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await, Ok(None) => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
Err(e) => error(req, &e).await, Err(e) => error(req, &e).await,
@ -360,7 +360,7 @@ async fn main() {
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await, Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
// Short link for post // Short link for post
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}")).await { Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}"), 3).await {
Ok(path_opt) => match path_opt { Ok(path_opt) => match path_opt {
Some(path) => Ok(redirect(&path)), Some(path) => Ok(redirect(&path)),
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await, None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,