fix(client): fix failing tests, retries for canonical_path
This commit is contained in:
parent
793047f63f
commit
7156be6ad0
@ -23,7 +23,13 @@ use crate::server::RequestExt;
|
||||
use crate::utils::format_url;
|
||||
|
||||
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
|
||||
const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com";
|
||||
|
||||
const REDDIT_SHORT_URL_BASE: &str = "https://redd.it";
|
||||
const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it";
|
||||
|
||||
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
||||
const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com";
|
||||
|
||||
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
|
||||
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http1().build();
|
||||
@ -40,6 +46,11 @@ pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
||||
|
||||
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
static URL_PAIRS: [(&str, &str); 2] = [
|
||||
(ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST),
|
||||
(REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST),
|
||||
];
|
||||
|
||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||
///
|
||||
@ -53,8 +64,27 @@ pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
|
||||
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
|
||||
/// 429, or if we were unable to decode the value in the `Location` header.
|
||||
#[cached(size = 1024, time = 600, result = true)]
|
||||
pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
||||
let res = reddit_head(path.clone(), true).await?;
|
||||
pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, String> {
|
||||
if tries == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// for each URL pair, try the HEAD request
|
||||
let res = {
|
||||
// for url base and host in URL_PAIRS, try reddit_short_head(path.clone(), true, url_base, url_base_host) and if it succeeds, set res. else, res = None
|
||||
let mut res = None;
|
||||
for (url_base, url_base_host) in URL_PAIRS {
|
||||
res = reddit_short_head(path.clone(), true, url_base, url_base_host).await.ok();
|
||||
if let Some(res) = &res {
|
||||
if !res.status().is_client_error() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
};
|
||||
|
||||
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
|
||||
let status = res.status().as_u16();
|
||||
let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
|
||||
|
||||
@ -68,6 +98,7 @@ pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
||||
let Ok(original) = val.to_str() else {
|
||||
return Err("Unable to decode Location header.".to_string());
|
||||
};
|
||||
|
||||
// We need to strip the .json suffix from the original path.
|
||||
// In addition, we want to remove share parameters.
|
||||
// Cut it off here instead of letting it propagate all the way
|
||||
@ -80,7 +111,9 @@ pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
||||
// also remove all Reddit domain parts with format_url.
|
||||
// Otherwise, it will literally redirect to Reddit.com.
|
||||
let uri = format_url(stripped_uri);
|
||||
Ok(Some(uri))
|
||||
|
||||
// Decrement tries and try again
|
||||
Box::pin(canonical_path(uri, tries - 1)).await
|
||||
}
|
||||
None => Ok(None),
|
||||
},
|
||||
@ -161,20 +194,26 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
||||
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
||||
/// 3xx codes Reddit returns and will automatically redirect.
|
||||
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::GET, path, true, quarantine)
|
||||
request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
|
||||
}
|
||||
|
||||
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
||||
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::HEAD, path, false, quarantine)
|
||||
/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
|
||||
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::HEAD, path, false, quarantine, base_path, host)
|
||||
}
|
||||
|
||||
// /// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
||||
// fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
// request(&Method::HEAD, path, false, quarantine, false)
|
||||
// }
|
||||
// Unused - reddit_head is only ever called in the context of a short URL
|
||||
|
||||
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
|
||||
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
||||
/// in its response.
|
||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
|
||||
// Build Reddit URL from path.
|
||||
let url = format!("{REDDIT_URL_BASE}{path}");
|
||||
let url = format!("{base_path}{path}");
|
||||
|
||||
// Construct the hyper client from the HTTPS connector.
|
||||
let client: Client<_, Body> = CLIENT.clone();
|
||||
@ -199,7 +238,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
.header("Client-Vendor-Id", vendor_id)
|
||||
.header("X-Reddit-Device-Id", device_id)
|
||||
.header("x-reddit-loid", loid)
|
||||
.header("Host", "oauth.reddit.com")
|
||||
.header("Host", host)
|
||||
.header("Authorization", &format!("Bearer {token}"))
|
||||
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
|
||||
.header("Accept-Language", "en-US,en;q=0.5")
|
||||
@ -254,6 +293,8 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
.to_string(),
|
||||
true,
|
||||
quarantine,
|
||||
base_path,
|
||||
host,
|
||||
)
|
||||
.await;
|
||||
};
|
||||
@ -432,13 +473,13 @@ async fn test_localization_popular() {
|
||||
async fn test_obfuscated_share_link() {
|
||||
let share_link = "/r/rust/s/kPgq8WNHRK".into();
|
||||
// Correct link without share parameters
|
||||
let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc".into();
|
||||
assert_eq!(canonical_path(share_link).await, Ok(Some(canonical_link)));
|
||||
let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/".into();
|
||||
assert_eq!(canonical_path(share_link, 3).await, Ok(Some(canonical_link)));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_share_link_strip_json() {
|
||||
let link = "/17krzvz".into();
|
||||
let canonical_link = "/r/nfl/comments/17krzvz/rapoport_sources_former_no_2_overall_pick/".into();
|
||||
assert_eq!(canonical_path(link).await, Ok(Some(canonical_link)));
|
||||
let canonical_link = "/comments/17krzvz".into();
|
||||
assert_eq!(canonical_path(link, 3).await, Ok(Some(canonical_link)));
|
||||
}
|
||||
|
@ -341,7 +341,7 @@ async fn main() {
|
||||
let sub = req.param("sub").unwrap_or_default();
|
||||
match req.param("id").as_deref() {
|
||||
// Share link
|
||||
Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}")).await {
|
||||
Some(id) if (8..12).contains(&id.len()) => match canonical_path(format!("/r/{sub}/s/{id}"), 3).await {
|
||||
Ok(Some(path)) => Ok(redirect(&path)),
|
||||
Ok(None) => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||
Err(e) => error(req, &e).await,
|
||||
@ -360,7 +360,7 @@ async fn main() {
|
||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
|
||||
|
||||
// Short link for post
|
||||
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}")).await {
|
||||
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{id}"), 3).await {
|
||||
Ok(path_opt) => match path_opt {
|
||||
Some(path) => Ok(redirect(&path)),
|
||||
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||
|
Loading…
x
Reference in New Issue
Block a user