Redirect /:id to canonical URL for post. (#617)
* Redirect /:id to canonical URL for post. This implements redirection of `/:id` (a short-form URL to a post) to the post's canonical URL. Libreddit issues a `HEAD /:id` to Reddit to get the canonical URL, and on success will send an HTTP 302 to a client with the canonical URL set in as the value of the `Location:` header. This also implements support for short IDs for non-ASCII posts, c/o spikecodes. Co-authored-by: spikecodes <19519553+spikecodes@users.noreply.github.com>
This commit is contained in:
parent
584cd4aac1
commit
c6487799ed
182
src/client.rs
182
src/client.rs
@ -1,13 +1,59 @@
|
|||||||
use cached::proc_macro::cached;
|
use cached::proc_macro::cached;
|
||||||
use futures_lite::{future::Boxed, FutureExt};
|
use futures_lite::{future::Boxed, FutureExt};
|
||||||
use hyper::{body, body::Buf, client, header, Body, Request, Response, Uri};
|
use hyper::{body, body::Buf, client, header, Body, Method, Request, Response, Uri};
|
||||||
use libflate::gzip;
|
use libflate::gzip;
|
||||||
use percent_encoding::{percent_encode, CONTROLS};
|
use percent_encoding::{percent_encode, CONTROLS};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{io, result::Result};
|
use std::{io, result::Result};
|
||||||
|
|
||||||
|
use crate::dbg_msg;
|
||||||
use crate::server::RequestExt;
|
use crate::server::RequestExt;
|
||||||
|
|
||||||
|
const REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
||||||
|
|
||||||
|
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||||
|
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||||
|
///
|
||||||
|
/// This function returns `Ok(Some(path))`, where `path`'s value is identical
|
||||||
|
/// to that of the value of the argument `path`, if Reddit responds to our
|
||||||
|
/// `HEAD` request with a 2xx-family HTTP code. It will also return an
|
||||||
|
/// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a
|
||||||
|
/// `Location` header in the response, and the HTTP code is in the 3xx-family;
|
||||||
|
/// the `String` will contain the path as reported in `Location`. The return
|
||||||
|
/// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a
|
||||||
|
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
|
||||||
|
/// 429, or if we were unable to decode the value in the `Location` header.
|
||||||
|
#[cached(size = 1024, time = 600, result = true)]
|
||||||
|
pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
||||||
|
let res = reddit_head(path.clone(), true).await?;
|
||||||
|
|
||||||
|
if res.status() == 429 {
|
||||||
|
return Err("Too many requests.".to_string());
|
||||||
|
};
|
||||||
|
|
||||||
|
// If Reddit responds with a 2xx, then the path is already canonical.
|
||||||
|
if res.status().to_string().starts_with('2') {
|
||||||
|
return Ok(Some(path));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If Reddit responds with anything other than 3xx (except for the 2xx as
|
||||||
|
// above), return a None.
|
||||||
|
if !res.status().to_string().starts_with('3') {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(
|
||||||
|
res
|
||||||
|
.headers()
|
||||||
|
.get(header::LOCATION)
|
||||||
|
.map(|val| percent_encode(val.as_bytes(), CONTROLS)
|
||||||
|
.to_string()
|
||||||
|
.trim_start_matches(REDDIT_URL_BASE)
|
||||||
|
.to_string()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
|
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
|
||||||
let mut url = format!("{}?{}", format, req.uri().query().unwrap_or_default());
|
let mut url = format!("{}?{}", format, req.uri().query().unwrap_or_default());
|
||||||
|
|
||||||
@ -63,21 +109,39 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
|||||||
.map_err(|e| e.to_string())
|
.map_err(|e| e.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
||||||
|
/// 3xx codes Reddit returns and will automatically redirect.
|
||||||
|
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||||
|
request(&Method::GET, path, true, quarantine)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
||||||
|
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||||
|
request(&Method::HEAD, path, false, quarantine)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Makes a request to Reddit. If `redirect` is `true`, request_with_redirect
|
||||||
|
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
||||||
|
/// in its response.
|
||||||
|
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||||
|
// Build Reddit URL from path.
|
||||||
|
let url = format!("{}{}", REDDIT_URL_BASE, path);
|
||||||
|
|
||||||
// Prepare the HTTPS connector.
|
// Prepare the HTTPS connector.
|
||||||
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build();
|
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build();
|
||||||
|
|
||||||
// Construct the hyper client from the HTTPS connector.
|
// Construct the hyper client from the HTTPS connector.
|
||||||
let client: client::Client<_, hyper::Body> = client::Client::builder().build(https);
|
let client: client::Client<_, hyper::Body> = client::Client::builder().build(https);
|
||||||
|
|
||||||
// Build request
|
// Build request to Reddit. When making a GET, request gzip compression.
|
||||||
|
// (Reddit doesn't do brotli yet.)
|
||||||
let builder = Request::builder()
|
let builder = Request::builder()
|
||||||
.method("GET")
|
.method(method)
|
||||||
.uri(&url)
|
.uri(&url)
|
||||||
.header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")))
|
.header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")))
|
||||||
.header("Host", "www.reddit.com")
|
.header("Host", "www.reddit.com")
|
||||||
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
|
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
|
||||||
.header("Accept-Encoding", "gzip") // Reddit doesn't do brotli yet.
|
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
|
||||||
.header("Accept-Language", "en-US,en;q=0.5")
|
.header("Accept-Language", "en-US,en;q=0.5")
|
||||||
.header("Connection", "keep-alive")
|
.header("Connection", "keep-alive")
|
||||||
.header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" })
|
.header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" })
|
||||||
@ -87,8 +151,15 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||||||
match builder {
|
match builder {
|
||||||
Ok(req) => match client.request(req).await {
|
Ok(req) => match client.request(req).await {
|
||||||
Ok(mut response) => {
|
Ok(mut response) => {
|
||||||
|
// Reddit may respond with a 3xx. Decide whether or not to
|
||||||
|
// redirect based on caller params.
|
||||||
if response.status().to_string().starts_with('3') {
|
if response.status().to_string().starts_with('3') {
|
||||||
request(
|
if !redirect {
|
||||||
|
return Ok(response);
|
||||||
|
};
|
||||||
|
|
||||||
|
return request(
|
||||||
|
method,
|
||||||
response
|
response
|
||||||
.headers()
|
.headers()
|
||||||
.get("Location")
|
.get("Location")
|
||||||
@ -98,56 +169,64 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||||||
})
|
})
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.to_string(),
|
.to_string(),
|
||||||
|
true,
|
||||||
quarantine,
|
quarantine,
|
||||||
)
|
)
|
||||||
.await
|
.await;
|
||||||
} else {
|
};
|
||||||
match response.headers().get(header::CONTENT_ENCODING) {
|
|
||||||
// Content not compressed.
|
|
||||||
None => Ok(response),
|
|
||||||
|
|
||||||
// Content gzipped.
|
match response.headers().get(header::CONTENT_ENCODING) {
|
||||||
Some(hdr) => {
|
// Content not compressed.
|
||||||
// Since we requested gzipped content, we expect
|
None => Ok(response),
|
||||||
// to get back gzipped content. If we get
|
|
||||||
// back anything else, that's a problem.
|
|
||||||
if hdr.ne("gzip") {
|
|
||||||
return Err("Reddit response was encoded with an unsupported compressor".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// The body must be something that implements
|
// Content encoded (hopefully with gzip).
|
||||||
// std::io::Read, hence the conversion to
|
Some(hdr) => {
|
||||||
// bytes::buf::Buf and then transformation into a
|
match hdr.to_str() {
|
||||||
// Reader.
|
Ok(val) => match val {
|
||||||
let mut decompressed: Vec<u8>;
|
"gzip" => {}
|
||||||
{
|
"identity" => return Ok(response),
|
||||||
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
|
_ => return Err("Reddit response was encoded with an unsupported compressor".to_string()),
|
||||||
Ok(b) => b.reader(),
|
},
|
||||||
Err(e) => return Err(e.to_string()),
|
Err(_) => return Err("Reddit response was invalid".to_string()),
|
||||||
};
|
|
||||||
|
|
||||||
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
|
|
||||||
Ok(decoder) => decoder,
|
|
||||||
Err(e) => return Err(e.to_string()),
|
|
||||||
};
|
|
||||||
|
|
||||||
decompressed = Vec::<u8>::new();
|
|
||||||
match io::copy(&mut decoder, &mut decompressed) {
|
|
||||||
Ok(_) => {}
|
|
||||||
Err(e) => return Err(e.to_string()),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
response.headers_mut().remove(header::CONTENT_ENCODING);
|
|
||||||
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
|
|
||||||
*(response.body_mut()) = Body::from(decompressed);
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We get here if the body is gzip-compressed.
|
||||||
|
|
||||||
|
// The body must be something that implements
|
||||||
|
// std::io::Read, hence the conversion to
|
||||||
|
// bytes::buf::Buf and then transformation into a
|
||||||
|
// Reader.
|
||||||
|
let mut decompressed: Vec<u8>;
|
||||||
|
{
|
||||||
|
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
|
||||||
|
Ok(b) => b.reader(),
|
||||||
|
Err(e) => return Err(e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
|
||||||
|
Ok(decoder) => decoder,
|
||||||
|
Err(e) => return Err(e.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
decompressed = Vec::<u8>::new();
|
||||||
|
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
|
||||||
|
return Err(e.to_string());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
response.headers_mut().remove(header::CONTENT_ENCODING);
|
||||||
|
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
|
||||||
|
*(response.body_mut()) = Body::from(decompressed);
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => Err(e.to_string()),
|
Err(e) => {
|
||||||
|
dbg_msg!("{} {}: {}", method, path, e);
|
||||||
|
|
||||||
|
Err(e.to_string())
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Err(_) => Err("Post url contains non-ASCII characters".to_string()),
|
Err(_) => Err("Post url contains non-ASCII characters".to_string()),
|
||||||
}
|
}
|
||||||
@ -158,9 +237,6 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||||||
// Make a request to a Reddit API and parse the JSON response
|
// Make a request to a Reddit API and parse the JSON response
|
||||||
#[cached(size = 100, time = 30, result = true)]
|
#[cached(size = 100, time = 30, result = true)]
|
||||||
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||||
// Build Reddit url from path
|
|
||||||
let url = format!("https://www.reddit.com{}", path);
|
|
||||||
|
|
||||||
// Closure to quickly build errors
|
// Closure to quickly build errors
|
||||||
let err = |msg: &str, e: String| -> Result<Value, String> {
|
let err = |msg: &str, e: String| -> Result<Value, String> {
|
||||||
// eprintln!("{} - {}: {}", url, msg, e);
|
// eprintln!("{} - {}: {}", url, msg, e);
|
||||||
@ -168,7 +244,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Fetch the url...
|
// Fetch the url...
|
||||||
match request(url.clone(), quarantine).await {
|
match reddit_get(path.clone(), quarantine).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
||||||
@ -186,7 +262,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||||||
.as_str()
|
.as_str()
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
json["message"].as_str().unwrap_or_else(|| {
|
json["message"].as_str().unwrap_or_else(|| {
|
||||||
eprintln!("{} - Error parsing reddit error", url);
|
eprintln!("{}{} - Error parsing reddit error", REDDIT_URL_BASE, path);
|
||||||
"Error parsing reddit error"
|
"Error parsing reddit error"
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
31
src/main.rs
31
src/main.rs
@ -17,7 +17,7 @@ use futures_lite::FutureExt;
|
|||||||
use hyper::{header::HeaderValue, Body, Request, Response};
|
use hyper::{header::HeaderValue, Body, Request, Response};
|
||||||
|
|
||||||
mod client;
|
mod client;
|
||||||
use client::proxy;
|
use client::{canonical_path, proxy};
|
||||||
use server::RequestExt;
|
use server::RequestExt;
|
||||||
use utils::{error, redirect, ThemeAssets};
|
use utils::{error, redirect, ThemeAssets};
|
||||||
|
|
||||||
@ -259,9 +259,6 @@ async fn main() {
|
|||||||
|
|
||||||
app.at("/r/:sub/:sort").get(|r| subreddit::community(r).boxed());
|
app.at("/r/:sub/:sort").get(|r| subreddit::community(r).boxed());
|
||||||
|
|
||||||
// Comments handler
|
|
||||||
app.at("/comments/:id").get(|r| post::item(r).boxed());
|
|
||||||
|
|
||||||
// Front page
|
// Front page
|
||||||
app.at("/").get(|r| subreddit::community(r).boxed());
|
app.at("/").get(|r| subreddit::community(r).boxed());
|
||||||
|
|
||||||
@ -279,13 +276,25 @@ async fn main() {
|
|||||||
// Handle about pages
|
// Handle about pages
|
||||||
app.at("/about").get(|req| error(req, "About pages aren't added yet".to_string()).boxed());
|
app.at("/about").get(|req| error(req, "About pages aren't added yet".to_string()).boxed());
|
||||||
|
|
||||||
app.at("/:id").get(|req: Request<Body>| match req.param("id").as_deref() {
|
app.at("/:id").get(|req: Request<Body>| {
|
||||||
// Sort front page
|
Box::pin(async move {
|
||||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).boxed(),
|
match req.param("id").as_deref() {
|
||||||
// Short link for post
|
// Sort front page
|
||||||
Some(id) if id.len() > 4 && id.len() < 7 => post::item(req).boxed(),
|
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
|
||||||
// Error message for unknown pages
|
|
||||||
_ => error(req, "Nothing here".to_string()).boxed(),
|
// Short link for post
|
||||||
|
Some(id) if (5..7).contains(&id.len()) => match canonical_path(format!("/{}", id)).await {
|
||||||
|
Ok(path_opt) => match path_opt {
|
||||||
|
Some(path) => Ok(redirect(path)),
|
||||||
|
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||||
|
},
|
||||||
|
Err(e) => error(req, e).await,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Error message for unknown pages
|
||||||
|
_ => error(req, "Nothing here".to_string()).await,
|
||||||
|
}
|
||||||
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
// Default service in case no routes match
|
// Default service in case no routes match
|
||||||
|
@ -716,10 +716,11 @@ pub fn redirect(path: String) -> Response<Body> {
|
|||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn error(req: Request<Body>, msg: String) -> Result<Response<Body>, String> {
|
/// Renders a generic error landing page.
|
||||||
|
pub async fn error(req: Request<Body>, msg: impl ToString) -> Result<Response<Body>, String> {
|
||||||
let url = req.uri().to_string();
|
let url = req.uri().to_string();
|
||||||
let body = ErrorTemplate {
|
let body = ErrorTemplate {
|
||||||
msg,
|
msg: msg.to_string(),
|
||||||
prefs: Preferences::new(req),
|
prefs: Preferences::new(req),
|
||||||
url,
|
url,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user