fix(client): fix failing tests, retries for canonical_path
This commit is contained in:
parent
793047f63f
commit
7156be6ad0
@ -23,7 +23,13 @@ use crate::server::RequestExt;
|
|||||||
use crate::utils::format_url;
|
use crate::utils::format_url;
|
||||||
|
|
||||||
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
|
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
|
||||||
|
const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com";
|
||||||
|
|
||||||
|
const REDDIT_SHORT_URL_BASE: &str = "https://redd.it";
|
||||||
|
const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it";
|
||||||
|
|
||||||
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
||||||
|
const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com";
|
||||||
|
|
||||||
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
|
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
|
||||||
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build();
|
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build();
|
||||||
@ -40,6 +46,11 @@ pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
|||||||
|
|
||||||
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
|
static URL_PAIRS: [(&str, &str); 2] = [
|
||||||
|
(ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST),
|
||||||
|
(REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST),
|
||||||
|
];
|
||||||
|
|
||||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||||
///
|
///
|
||||||
@ -53,8 +64,27 @@ pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
|||||||
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
|
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
|
||||||
/// 429, or if we were unable to decode the value in the `Location` header.
|
/// 429, or if we were unable to decode the value in the `Location` header.
|
||||||
#[cached(size = 1024, time = 600, result = true)]
|
#[cached(size = 1024, time = 600, result = true)]
|
||||||
pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, String> {
|
||||||
let res = reddit_head(path.clone(), true).await?;
|
if tries == 0 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for each URL pair, try the HEAD request
|
||||||
|
let res = {
|
||||||
|
// for url base and host in URL_PAIRS, try reddit_short_head(path.clone(), true, url_base, url_base_host) and if it succeeds, set res. else, res = None
|
||||||
|
let mut res = None;
|
||||||
|
for (url_base, url_base_host) in URL_PAIRS {
|
||||||
|
res = reddit_short_head(path.clone(), true, url_base, url_base_host).await.ok();
|
||||||
|
if let Some(res) = &res {
|
||||||
|
if !res.status().is_client_error() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res
|
||||||
|
};
|
||||||
|
|
||||||
|
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
|
||||||
let status = res.status().as_u16();
|
let status = res.status().as_u16();
|
||||||
let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
|
let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
|
||||||
|
|
||||||
@ -68,6 +98,7 @@ pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
|||||||
let Ok(original) = val.to_str() else {
|
let Ok(original) = val.to_str() else {
|
||||||
return Err("Unable to decode Location header.".to_string());
|
return Err("Unable to decode Location header.".to_string());
|
||||||
};
|
};
|
||||||
|
|
||||||
// We need to strip the .json suffix from the original path.
|
// We need to strip the .json suffix from the original path.
|
||||||
// In addition, we want to remove share parameters.
|
// In addition, we want to remove share parameters.
|
||||||
// Cut it off here instead of letting it propagate all the way
|
// Cut it off here instead of letting it propagate all the way
|
||||||
@ -80,7 +111,9 @@ pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
|||||||
// also remove all Reddit domain parts with format_url.
|
// also remove all Reddit domain parts with format_url.
|
||||||
// Otherwise, it will literally redirect to Reddit.com.
|
// Otherwise, it will literally redirect to Reddit.com.
|
||||||
let uri = format_url(stripped_uri);
|
let uri = format_url(stripped_uri);
|
||||||
Ok(Some(uri))
|
|
||||||
|
// Decrement tries and try again
|
||||||
|
Box::pin(canonical_path(uri, tries - 1)).await
|
||||||
}
|
}
|
||||||
None => Ok(None),
|
None => Ok(None),
|
||||||
},
|
},
|
||||||
@ -161,20 +194,26 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
|||||||
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
||||||
/// 3xx codes Reddit returns and will automatically redirect.
|
/// 3xx codes Reddit returns and will automatically redirect.
|
||||||
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||||
request(&Method::GET, path, true, quarantine)
|
request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
|
||||||
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
|
||||||
request(&Method::HEAD, path, false, quarantine)
|
request(&Method::HEAD, path, false, quarantine, base_path, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// /// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
||||||
|
// fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||||
|
// request(&Method::HEAD, path, false, quarantine, false)
|
||||||
|
// }
|
||||||
|
// Unused - reddit_head is only ever called in the context of a short URL
|
||||||
|
|
||||||
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
|
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
|
||||||
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
||||||
/// in its response.
|
/// in its response.
|
||||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
|
||||||
// Build Reddit URL from path.
|
// Build Reddit URL from path.
|
||||||
let url = format!("{REDDIT_URL_BASE}{path}");
|
let url = format!("{base_path}{path}");
|
||||||
|
|
||||||
// Construct the hyper client from the HTTPS connector.
|
// Construct the hyper client from the HTTPS connector.
|
||||||
let client: Client<_, Body> = CLIENT.clone();
|
let client: Client<_, Body> = CLIENT.clone();
|
||||||
@ -199,7 +238,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
|||||||
.header("Client-Vendor-Id", vendor_id)
|
.header("Client-Vendor-Id", vendor_id)
|
||||||
.header("X-Reddit-Device-Id", device_id)
|
.header("X-Reddit-Device-Id", device_id)
|
||||||
.header("x-reddit-loid", loid)
|
.header("x-reddit-loid", loid)
|
||||||
.header("Host", "oauth.reddit.com")
|
.header("Host", host)
|
||||||
.header("Authorization", &format!("Bearer {token}"))
|
.header("Authorization", &format!("Bearer {token}"))
|
||||||
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
|
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
|
||||||
.header("Accept-Language", "en-US,en;q=0.5")
|
.header("Accept-Language", "en-US,en;q=0.5")
|
||||||
@ -254,6 +293,8 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
|||||||
.to_string(),
|
.to_string(),
|
||||||
true,
|
true,
|
||||||
quarantine,
|
quarantine,
|
||||||
|
base_path,
|
||||||
|
host,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
};
|
};
|
||||||
@ -432,13 +473,13 @@ async fn test_localization_popular() {
|
|||||||
async fn test_obfuscated_share_link() {
|
async fn test_obfuscated_share_link() {
|
||||||
let share_link = "/r/rust/s/kPgq8WNHRK".into();
|
let share_link = "/r/rust/s/kPgq8WNHRK".into();
|
||||||
// Correct link without share parameters
|
// Correct link without share parameters
|
||||||
let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc".into();
|
let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/".into();
|
||||||
assert_eq!(canonical_path(share_link).await, Ok(Some(canonical_link)));
|
assert_eq!(canonical_path(share_link, 3).await, Ok(Some(canonical_link)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test(flavor = "multi_thread")]
|
#[tokio::test(flavor = "multi_thread")]
|
||||||
async fn test_share_link_strip_json() {
|
async fn test_share_link_strip_json() {
|
||||||
let link = "/17krzvz".into();
|
let link = "/17krzvz".into();
|
||||||
let canonical_link = "/r/nfl/comments/17krzvz/rapoport_sources_former_no_2_overall_pick/".into();
|
let canonical_link = "/comments/17krzvz".into();
|
||||||
assert_eq!(canonical_path(link).await, Ok(Some(canonical_link)));
|
assert_eq!(canonical_path(link, 3).await, Ok(Some(canonical_link)));
|
||||||
}
|
}
|
||||||
|
@ -341,7 +341,7 @@ async fn main() {
|
|||||||
let sub = req.param("sub").unwrap_or_default();
|
let sub = req.param("sub").unwrap_or_default();
|
||||||
match req.param("id").as_deref() {
|
match req.param("id").as_deref() {
|
||||||
// Share link
|
// Share link
|
||||||
Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}")).await {
|
Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}"), 3).await {
|
||||||
Ok(Some(path)) => Ok(redirect(&path)),
|
Ok(Some(path)) => Ok(redirect(&path)),
|
||||||
Ok(None) => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
Ok(None) => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||||
Err(e) => error(req, &e).await,
|
Err(e) => error(req, &e).await,
|
||||||
@ -360,7 +360,7 @@ async fn main() {
|
|||||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
|
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
|
||||||
|
|
||||||
// Short link for post
|
// Short link for post
|
||||||
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}")).await {
|
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}"), 3).await {
|
||||||
Ok(path_opt) => match path_opt {
|
Ok(path_opt) => match path_opt {
|
||||||
Some(path) => Ok(redirect(&path)),
|
Some(path) => Ok(redirect(&path)),
|
||||||
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||||
|
Loading…
Reference in New Issue
Block a user