Merge remote-tracking branch 'upstream/main'

This commit is contained in:
2024-06-28 12:40:15 +12:00
19 changed files with 255 additions and 120 deletions

View File

@ -5,11 +5,13 @@ use hyper::client::HttpConnector;
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
use hyper_rustls::HttpsConnector;
use libflate::gzip;
use log::error;
use log::{error, trace, warn};
use once_cell::sync::Lazy;
use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicU16, Ordering::SeqCst};
use std::{io, result::Result};
use tokio::sync::RwLock;
@ -36,6 +38,8 @@ pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
RwLock::new(client)
});
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
/// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`.
///
@ -170,7 +174,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// Construct the hyper client from the HTTPS connector.
let client: Client<_, Body> = CLIENT.clone();
let (token, vendor_id, device_id, mut user_agent, loid) = {
let (token, vendor_id, device_id, user_agent, loid) = {
let client = block_on(OAUTH_CLIENT.read());
(
client.token.clone(),
@ -181,13 +185,6 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
)
};
// Replace "Android" with a tricky word.
// Issues: #78/#115, #116
// If you include the word "Android", you will get a number of different errors
// I guess they don't expect mobile traffic on the endpoints we use
// Scrawled on wall for next poor soul: Run the test suite.
user_agent = user_agent.replace("Android", "Andr\u{200B}oid");
// Build request to Reddit. When making a GET, request gzip compression.
// (Reddit doesn't do brotli yet.)
let builder = Request::builder()
@ -314,19 +311,59 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
#[cached(size = 100, time = 30, result = true)]
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
// Closure to quickly build errors
let err = |msg: &str, e: String| -> Result<Value, String> {
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
// eprintln!("{} - {}: {}", url, msg, e);
Err(format!("{msg}: {e}"))
Err(format!("{msg}: {e} | {path}"))
};
// First, handle rolling over the OAUTH_CLIENT if need be.
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
if current_rate_limit < 10 {
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
tokio::spawn(force_refresh_token());
}
// Fetch the url...
match reddit_get(path.clone(), quarantine).await {
Ok(response) => {
let status = response.status();
// Ratelimit remaining
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
trace!("Ratelimit remaining: {}", remaining);
if let Ok(remaining) = remaining.parse::<f32>().map(|f| f.round() as u16) {
OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst);
} else {
warn!("Failed to parse rate limit {remaining} from header.");
}
}
// Ratelimit used
if let Some(Ok(used)) = response.headers().get("x-ratelimit-used").map(|val| val.to_str()) {
trace!("Ratelimit used: {}", used);
}
// Ratelimit reset
let reset = if let Some(Ok(reset)) = response.headers().get("x-ratelimit-reset").map(|val| val.to_str()) {
trace!("Ratelimit reset: {}", reset);
Some(reset.to_string())
} else {
None
};
// asynchronously aggregate the chunks of the body
match hyper::body::aggregate(response).await {
Ok(body) => {
let has_remaining = body.has_remaining();
if !has_remaining {
return match reset {
Some(val) => Err(format!("Reddit rate limit exceeded. Will reset in: {val}")),
None => Err("Reddit rate limit exceeded".to_string()),
};
}
// Parse the response from Reddit as JSON
match serde_json::from_reader(body.reader()) {
Ok(value) => {
@ -339,7 +376,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
let () = force_refresh_token().await;
return Err("OAuth token has expired. Please refresh the page!".to_string());
}
Err(format!("Reddit error {} \"{}\": {}", json["error"], json["reason"], json["message"]))
Err(format!("Reddit error {} \"{}\": {} | {path}", json["error"], json["reason"], json["message"]))
} else {
Ok(json)
}
@ -349,21 +386,24 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
if status.is_server_error() {
Err("Reddit is having issues, check if there's an outage".to_string())
} else {
err("Failed to parse page JSON data", e.to_string())
err("Failed to parse page JSON data", e.to_string(), path)
}
}
}
}
Err(e) => err("Failed receiving body from Reddit", e.to_string()),
Err(e) => err("Failed receiving body from Reddit", e.to_string(), path),
}
}
Err(e) => err("Couldn't send request to Reddit", e),
Err(e) => err("Couldn't send request to Reddit", e, path),
}
}
#[cfg(test)]
static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL";
#[tokio::test(flavor = "multi_thread")]
async fn test_localization_popular() {
let val = json("/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL".to_string(), false).await.unwrap();
let val = json(POPULAR_URL.to_string(), false).await.unwrap();
assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap());
}

View File

@ -52,6 +52,10 @@ pub struct Config {
#[serde(alias = "LIBREDDIT_DEFAULT_POST_SORT")]
pub(crate) default_post_sort: Option<String>,
#[serde(rename = "REDLIB_DEFAULT_BLUR_SPOILER")]
#[serde(alias = "LIBREDDIT_DEFAULT_BLUR_SPOILER")]
pub(crate) default_blur_spoiler: Option<String>,
#[serde(rename = "REDLIB_DEFAULT_SHOW_NSFW")]
#[serde(alias = "LIBREDDIT_DEFAULT_SHOW_NSFW")]
pub(crate) default_show_nsfw: Option<String>,
@ -88,6 +92,10 @@ pub struct Config {
#[serde(alias = "LIBREDDIT_DEFAULT_SUBSCRIPTIONS")]
pub(crate) default_subscriptions: Option<String>,
#[serde(rename = "REDLIB_DEFAULT_FILTERS")]
#[serde(alias = "LIBREDDIT_DEFAULT_FILTERS")]
pub(crate) default_filters: Option<String>,
#[serde(rename = "REDLIB_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION")]
#[serde(alias = "LIBREDDIT_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION")]
pub(crate) default_disable_visit_reddit_confirmation: Option<String>,
@ -135,6 +143,7 @@ impl Config {
default_post_sort: parse("REDLIB_DEFAULT_POST_SORT"),
default_wide: parse("REDLIB_DEFAULT_WIDE"),
default_comment_sort: parse("REDLIB_DEFAULT_COMMENT_SORT"),
default_blur_spoiler: parse("REDLIB_DEFAULT_BLUR_SPOILER"),
default_show_nsfw: parse("REDLIB_DEFAULT_SHOW_NSFW"),
default_blur_nsfw: parse("REDLIB_DEFAULT_BLUR_NSFW"),
default_use_hls: parse("REDLIB_DEFAULT_USE_HLS"),
@ -144,6 +153,7 @@ impl Config {
default_hide_sidebar_and_summary: parse("REDLIB_DEFAULT_HIDE_SIDEBAR_AND_SUMMARY"),
default_hide_score: parse("REDLIB_DEFAULT_HIDE_SCORE"),
default_subscriptions: parse("REDLIB_DEFAULT_SUBSCRIPTIONS"),
default_filters: parse("REDLIB_DEFAULT_FILTERS"),
default_disable_visit_reddit_confirmation: parse("REDLIB_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION"),
banner: parse("REDLIB_BANNER"),
robots_disable_indexing: parse("REDLIB_ROBOTS_DISABLE_INDEXING"),
@ -161,6 +171,7 @@ fn get_setting_from_config(name: &str, config: &Config) -> Option<String> {
"REDLIB_DEFAULT_LAYOUT" => config.default_layout.clone(),
"REDLIB_DEFAULT_COMMENT_SORT" => config.default_comment_sort.clone(),
"REDLIB_DEFAULT_POST_SORT" => config.default_post_sort.clone(),
"REDLIB_DEFAULT_BLUR_SPOILER" => config.default_blur_spoiler.clone(),
"REDLIB_DEFAULT_SHOW_NSFW" => config.default_show_nsfw.clone(),
"REDLIB_DEFAULT_BLUR_NSFW" => config.default_blur_nsfw.clone(),
"REDLIB_DEFAULT_USE_HLS" => config.default_use_hls.clone(),
@ -171,6 +182,7 @@ fn get_setting_from_config(name: &str, config: &Config) -> Option<String> {
"REDLIB_DEFAULT_HIDE_SIDEBAR_AND_SUMMARY" => config.default_hide_sidebar_and_summary.clone(),
"REDLIB_DEFAULT_HIDE_SCORE" => config.default_hide_score.clone(),
"REDLIB_DEFAULT_SUBSCRIPTIONS" => config.default_subscriptions.clone(),
"REDLIB_DEFAULT_FILTERS" => config.default_filters.clone(),
"REDLIB_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION" => config.default_disable_visit_reddit_confirmation.clone(),
"REDLIB_BANNER" => config.banner.clone(),
"REDLIB_ROBOTS_DISABLE_INDEXING" => config.robots_disable_indexing.clone(),
@ -243,6 +255,12 @@ fn test_default_subscriptions() {
assert_eq!(get_setting("REDLIB_DEFAULT_SUBSCRIPTIONS"), Some("news+bestof".into()));
}
#[test]
#[sealed_test(env = [("REDLIB_DEFAULT_FILTERS", "news+bestof")])]
fn test_default_filters() {
assert_eq!(get_setting("REDLIB_DEFAULT_FILTERS"), Some("news+bestof".into()));
}
#[test]
#[sealed_test]
fn test_pushshift() {

View File

@ -142,11 +142,13 @@ impl InstanceInfo {
["Wide", &convert(&self.config.default_wide)],
["Comment sort", &convert(&self.config.default_comment_sort)],
["Post sort", &convert(&self.config.default_post_sort)],
["Blur Spoiler", &convert(&self.config.default_blur_spoiler)],
["Show NSFW", &convert(&self.config.default_show_nsfw)],
["Blur NSFW", &convert(&self.config.default_blur_nsfw)],
["Use HLS", &convert(&self.config.default_use_hls)],
["Hide HLS notification", &convert(&self.config.default_hide_hls_notification)],
["Subscriptions", &convert(&self.config.default_subscriptions)],
["Filters", &convert(&self.config.default_filters)],
])
.with_header_row(["Default preferences"]),
);
@ -175,11 +177,13 @@ impl InstanceInfo {
Default wide: {:?}\n
Default comment sort: {:?}\n
Default post sort: {:?}\n
Default blur Spoiler: {:?}\n
Default show NSFW: {:?}\n
Default blur NSFW: {:?}\n
Default use HLS: {:?}\n
Default hide HLS notification: {:?}\n
Default subscriptions: {:?}\n",
Default subscriptions: {:?}\n
Default filters: {:?}\n",
self.package_name,
self.crate_version,
self.git_commit,
@ -198,11 +202,13 @@ impl InstanceInfo {
self.config.default_wide,
self.config.default_comment_sort,
self.config.default_post_sort,
self.config.default_blur_spoiler,
self.config.default_show_nsfw,
self.config.default_blur_nsfw,
self.config.default_use_hls,
self.config.default_hide_hls_notification,
self.config.default_subscriptions,
self.config.default_filters,
)
}
StringType::Html => self.to_table(),

View File

@ -160,7 +160,7 @@ async fn main() {
.long("address")
.value_name("ADDRESS")
.help("Sets address to listen on")
.default_value("0.0.0.0")
.default_value("[::]")
.num_args(1),
)
.arg(

View File

@ -1,12 +1,12 @@
use std::{collections::HashMap, time::Duration};
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
use crate::{
client::{CLIENT, OAUTH_CLIENT},
client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING},
oauth_resources::ANDROID_APP_VERSION_LIST,
};
use base64::{engine::general_purpose, Engine as _};
use hyper::{client, Body, Method, Request};
use log::info;
use log::{info, trace};
use serde_json::json;
@ -131,6 +131,7 @@ pub async fn token_daemon() {
}
pub async fn force_refresh_token() {
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
OAUTH_CLIENT.write().await.refresh().await;
}

View File

@ -19,7 +19,7 @@ struct SettingsTemplate {
// CONSTANTS
const PREFS: [&str; 18] = [
const PREFS: [&str; 19] = [
"theme",
"mascot",
"front_page",
@ -27,6 +27,7 @@ const PREFS: [&str; 18] = [
"wide",
"comment_sort",
"post_sort",
"blur_spoiler",
"show_nsfw",
"blur_nsfw",
"use_hls",

View File

@ -64,7 +64,7 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
let post_sort = req.cookie("post_sort").map_or_else(|| "hot".to_string(), |c| c.value().to_string());
let sort = req.param("sort").unwrap_or_else(|| req.param("id").unwrap_or(post_sort));
let sub_name = req.param("sub").unwrap_or(if front_page == "default" || front_page.is_empty() {
let mut sub_name = req.param("sub").unwrap_or(if front_page == "default" || front_page.is_empty() {
if subscribed.is_empty() {
"popular".to_string()
} else {
@ -84,6 +84,11 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
return Ok(redirect(&["/user/", &sub_name[2..]].concat()));
}
// If multi-sub, replace + with url encoded +
if sub_name.contains('+') {
sub_name = sub_name.replace('+', "%2B");
}
// Request subreddit metadata
let sub = if !sub_name.contains('+') && sub_name != subscribed && sub_name != "popular" && sub_name != "all" {
// Regular subreddit

View File

@ -1,3 +1,4 @@
#![allow(dead_code)]
use crate::config::get_setting;
//
// CRATES
@ -156,6 +157,7 @@ impl PollOption {
// Post flags with nsfw and stickied
pub struct Flags {
pub spoiler: bool,
pub nsfw: bool,
pub stickied: bool,
}
@ -402,6 +404,7 @@ impl Post {
},
},
flags: Flags {
spoiler: data["spoiler"].as_bool().unwrap_or_default(),
nsfw: data["over_18"].as_bool().unwrap_or_default(),
stickied: data["stickied"].as_bool().unwrap_or_default() || data["pinned"].as_bool().unwrap_or_default(),
},
@ -576,6 +579,7 @@ pub struct Preferences {
pub front_page: String,
pub layout: String,
pub wide: String,
pub blur_spoiler: String,
pub show_nsfw: String,
pub blur_nsfw: String,
pub hide_hls_notification: String,
@ -628,6 +632,7 @@ impl Preferences {
front_page: setting(req, "front_page"),
layout: setting(req, "layout"),
wide: setting(req, "wide"),
blur_spoiler: setting(req, "blur_spoiler"),
show_nsfw: setting(req, "show_nsfw"),
hide_sidebar_and_summary: setting(req, "hide_sidebar_and_summary"),
blur_nsfw: setting(req, "blur_nsfw"),
@ -749,6 +754,7 @@ pub async fn parse_post(post: &Value) -> Post {
},
},
flags: Flags {
spoiler: post["data"]["spoiler"].as_bool().unwrap_or_default(),
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
stickied: post["data"]["stickied"].as_bool().unwrap_or_default() || post["data"]["pinned"].as_bool().unwrap_or(false),
},
@ -1046,7 +1052,7 @@ pub fn redirect(path: &str) -> Response<Body> {
/// Renders a generic error landing page.
pub async fn error(req: Request<Body>, msg: &str) -> Result<Response<Body>, String> {
error!("Error page rendered: {msg}");
error!("Error page rendered: {}", msg.split('|').next().unwrap_or_default());
let url = req.uri().to_string();
let body = ErrorTemplate {
msg: msg.to_string(),