fix(client): Handle invalid reddit response of base URL location

This commit is contained in:
Matthew Esposito 2024-06-28 22:39:42 -04:00
parent ea87ec33a1
commit 0f7eba717e

View File

@ -3,6 +3,7 @@ use cached::proc_macro::cached;
use futures_lite::future::block_on; use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt}; use futures_lite::{future::Boxed, FutureExt};
use hyper::client::HttpConnector; use hyper::client::HttpConnector;
use hyper::header::HeaderValue;
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
use hyper_rustls::HttpsConnector; use hyper_rustls::HttpsConnector;
use libflate::gzip; use libflate::gzip;
@ -21,6 +22,7 @@ use crate::server::RequestExt;
use crate::utils::format_url; use crate::utils::format_url;
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| { pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
let https = hyper_rustls::HttpsConnectorBuilder::new() let https = hyper_rustls::HttpsConnectorBuilder::new()
@ -221,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
if !redirect { if !redirect {
return Ok(response); return Ok(response);
}; };
let location_header = response.headers().get(header::LOCATION);
if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) {
return Err("Reddit response was invalid".to_string());
}
return request( return request(
method, method,
response location_header
.headers()
.get(header::LOCATION)
.map(|val| { .map(|val| {
// We need to make adjustments to the URI // We need to make adjustments to the URI
// we get back from Reddit. Namely, we // we get back from Reddit. Namely, we
@ -239,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// required. // required.
// //
// 2. Percent-encode the path. // 2. Percent-encode the path.
let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string(); let new_path = percent_encode(val.as_bytes(), CONTROLS)
.to_string()
.trim_start_matches(REDDIT_URL_BASE)
.trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE)
.to_string();
format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" }) format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" })
}) })
.unwrap_or_default() .unwrap_or_default()
@ -298,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
} }
} }
Err(e) => { Err(e) => {
dbg_msg!("{} {}: {}", method, path, e); dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
Err(e.to_string()) Err(e.to_string())
} }