Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@ -5,11 +5,13 @@ use hyper::client::HttpConnector;
|
||||
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
|
||||
use hyper_rustls::HttpsConnector;
|
||||
use libflate::gzip;
|
||||
use log::error;
|
||||
use log::{error, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use percent_encoding::{percent_encode, CONTROLS};
|
||||
use serde_json::Value;
|
||||
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::atomic::{AtomicU16, Ordering::SeqCst};
|
||||
use std::{io, result::Result};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
@ -36,6 +38,8 @@ pub static OAUTH_CLIENT: Lazy<RwLock<Oauth>> = Lazy::new(|| {
|
||||
RwLock::new(client)
|
||||
});
|
||||
|
||||
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
|
||||
|
||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||
///
|
||||
@ -170,7 +174,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
// Construct the hyper client from the HTTPS connector.
|
||||
let client: Client<_, Body> = CLIENT.clone();
|
||||
|
||||
let (token, vendor_id, device_id, mut user_agent, loid) = {
|
||||
let (token, vendor_id, device_id, user_agent, loid) = {
|
||||
let client = block_on(OAUTH_CLIENT.read());
|
||||
(
|
||||
client.token.clone(),
|
||||
@ -181,13 +185,6 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
)
|
||||
};
|
||||
|
||||
// Replace "Android" with a tricky word.
|
||||
// Issues: #78/#115, #116
|
||||
// If you include the word "Android", you will get a number of different errors
|
||||
// I guess they don't expect mobile traffic on the endpoints we use
|
||||
// Scrawled on wall for next poor soul: Run the test suite.
|
||||
user_agent = user_agent.replace("Android", "Andr\u{200B}oid");
|
||||
|
||||
// Build request to Reddit. When making a GET, request gzip compression.
|
||||
// (Reddit doesn't do brotli yet.)
|
||||
let builder = Request::builder()
|
||||
@ -314,19 +311,59 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
#[cached(size = 100, time = 30, result = true)]
|
||||
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||
// Closure to quickly build errors
|
||||
let err = |msg: &str, e: String| -> Result<Value, String> {
|
||||
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
|
||||
// eprintln!("{} - {}: {}", url, msg, e);
|
||||
Err(format!("{msg}: {e}"))
|
||||
Err(format!("{msg}: {e} | {path}"))
|
||||
};
|
||||
|
||||
// First, handle rolling over the OAUTH_CLIENT if need be.
|
||||
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
|
||||
if current_rate_limit < 10 {
|
||||
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
|
||||
OAUTH_RATELIMIT_REMAINING.store(99, Ordering::SeqCst);
|
||||
tokio::spawn(force_refresh_token());
|
||||
}
|
||||
|
||||
// Fetch the url...
|
||||
match reddit_get(path.clone(), quarantine).await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
|
||||
// Ratelimit remaining
|
||||
if let Some(Ok(remaining)) = response.headers().get("x-ratelimit-remaining").map(|val| val.to_str()) {
|
||||
trace!("Ratelimit remaining: {}", remaining);
|
||||
if let Ok(remaining) = remaining.parse::<f32>().map(|f| f.round() as u16) {
|
||||
OAUTH_RATELIMIT_REMAINING.store(remaining, SeqCst);
|
||||
} else {
|
||||
warn!("Failed to parse rate limit {remaining} from header.");
|
||||
}
|
||||
}
|
||||
|
||||
// Ratelimit used
|
||||
if let Some(Ok(used)) = response.headers().get("x-ratelimit-used").map(|val| val.to_str()) {
|
||||
trace!("Ratelimit used: {}", used);
|
||||
}
|
||||
|
||||
// Ratelimit reset
|
||||
let reset = if let Some(Ok(reset)) = response.headers().get("x-ratelimit-reset").map(|val| val.to_str()) {
|
||||
trace!("Ratelimit reset: {}", reset);
|
||||
Some(reset.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// asynchronously aggregate the chunks of the body
|
||||
match hyper::body::aggregate(response).await {
|
||||
Ok(body) => {
|
||||
let has_remaining = body.has_remaining();
|
||||
|
||||
if !has_remaining {
|
||||
return match reset {
|
||||
Some(val) => Err(format!("Reddit rate limit exceeded. Will reset in: {val}")),
|
||||
None => Err("Reddit rate limit exceeded".to_string()),
|
||||
};
|
||||
}
|
||||
|
||||
// Parse the response from Reddit as JSON
|
||||
match serde_json::from_reader(body.reader()) {
|
||||
Ok(value) => {
|
||||
@ -339,7 +376,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||
let () = force_refresh_token().await;
|
||||
return Err("OAuth token has expired. Please refresh the page!".to_string());
|
||||
}
|
||||
Err(format!("Reddit error {} \"{}\": {}", json["error"], json["reason"], json["message"]))
|
||||
Err(format!("Reddit error {} \"{}\": {} | {path}", json["error"], json["reason"], json["message"]))
|
||||
} else {
|
||||
Ok(json)
|
||||
}
|
||||
@ -349,21 +386,24 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||
if status.is_server_error() {
|
||||
Err("Reddit is having issues, check if there's an outage".to_string())
|
||||
} else {
|
||||
err("Failed to parse page JSON data", e.to_string())
|
||||
err("Failed to parse page JSON data", e.to_string(), path)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => err("Failed receiving body from Reddit", e.to_string()),
|
||||
Err(e) => err("Failed receiving body from Reddit", e.to_string(), path),
|
||||
}
|
||||
}
|
||||
Err(e) => err("Couldn't send request to Reddit", e),
|
||||
Err(e) => err("Couldn't send request to Reddit", e, path),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
static POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_localization_popular() {
|
||||
let val = json("/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL".to_string(), false).await.unwrap();
|
||||
let val = json(POPULAR_URL.to_string(), false).await.unwrap();
|
||||
assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap());
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,10 @@ pub struct Config {
|
||||
#[serde(alias = "LIBREDDIT_DEFAULT_POST_SORT")]
|
||||
pub(crate) default_post_sort: Option<String>,
|
||||
|
||||
#[serde(rename = "REDLIB_DEFAULT_BLUR_SPOILER")]
|
||||
#[serde(alias = "LIBREDDIT_DEFAULT_BLUR_SPOILER")]
|
||||
pub(crate) default_blur_spoiler: Option<String>,
|
||||
|
||||
#[serde(rename = "REDLIB_DEFAULT_SHOW_NSFW")]
|
||||
#[serde(alias = "LIBREDDIT_DEFAULT_SHOW_NSFW")]
|
||||
pub(crate) default_show_nsfw: Option<String>,
|
||||
@ -88,6 +92,10 @@ pub struct Config {
|
||||
#[serde(alias = "LIBREDDIT_DEFAULT_SUBSCRIPTIONS")]
|
||||
pub(crate) default_subscriptions: Option<String>,
|
||||
|
||||
#[serde(rename = "REDLIB_DEFAULT_FILTERS")]
|
||||
#[serde(alias = "LIBREDDIT_DEFAULT_FILTERS")]
|
||||
pub(crate) default_filters: Option<String>,
|
||||
|
||||
#[serde(rename = "REDLIB_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION")]
|
||||
#[serde(alias = "LIBREDDIT_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION")]
|
||||
pub(crate) default_disable_visit_reddit_confirmation: Option<String>,
|
||||
@ -135,6 +143,7 @@ impl Config {
|
||||
default_post_sort: parse("REDLIB_DEFAULT_POST_SORT"),
|
||||
default_wide: parse("REDLIB_DEFAULT_WIDE"),
|
||||
default_comment_sort: parse("REDLIB_DEFAULT_COMMENT_SORT"),
|
||||
default_blur_spoiler: parse("REDLIB_DEFAULT_BLUR_SPOILER"),
|
||||
default_show_nsfw: parse("REDLIB_DEFAULT_SHOW_NSFW"),
|
||||
default_blur_nsfw: parse("REDLIB_DEFAULT_BLUR_NSFW"),
|
||||
default_use_hls: parse("REDLIB_DEFAULT_USE_HLS"),
|
||||
@ -144,6 +153,7 @@ impl Config {
|
||||
default_hide_sidebar_and_summary: parse("REDLIB_DEFAULT_HIDE_SIDEBAR_AND_SUMMARY"),
|
||||
default_hide_score: parse("REDLIB_DEFAULT_HIDE_SCORE"),
|
||||
default_subscriptions: parse("REDLIB_DEFAULT_SUBSCRIPTIONS"),
|
||||
default_filters: parse("REDLIB_DEFAULT_FILTERS"),
|
||||
default_disable_visit_reddit_confirmation: parse("REDLIB_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION"),
|
||||
banner: parse("REDLIB_BANNER"),
|
||||
robots_disable_indexing: parse("REDLIB_ROBOTS_DISABLE_INDEXING"),
|
||||
@ -161,6 +171,7 @@ fn get_setting_from_config(name: &str, config: &Config) -> Option<String> {
|
||||
"REDLIB_DEFAULT_LAYOUT" => config.default_layout.clone(),
|
||||
"REDLIB_DEFAULT_COMMENT_SORT" => config.default_comment_sort.clone(),
|
||||
"REDLIB_DEFAULT_POST_SORT" => config.default_post_sort.clone(),
|
||||
"REDLIB_DEFAULT_BLUR_SPOILER" => config.default_blur_spoiler.clone(),
|
||||
"REDLIB_DEFAULT_SHOW_NSFW" => config.default_show_nsfw.clone(),
|
||||
"REDLIB_DEFAULT_BLUR_NSFW" => config.default_blur_nsfw.clone(),
|
||||
"REDLIB_DEFAULT_USE_HLS" => config.default_use_hls.clone(),
|
||||
@ -171,6 +182,7 @@ fn get_setting_from_config(name: &str, config: &Config) -> Option<String> {
|
||||
"REDLIB_DEFAULT_HIDE_SIDEBAR_AND_SUMMARY" => config.default_hide_sidebar_and_summary.clone(),
|
||||
"REDLIB_DEFAULT_HIDE_SCORE" => config.default_hide_score.clone(),
|
||||
"REDLIB_DEFAULT_SUBSCRIPTIONS" => config.default_subscriptions.clone(),
|
||||
"REDLIB_DEFAULT_FILTERS" => config.default_filters.clone(),
|
||||
"REDLIB_DEFAULT_DISABLE_VISIT_REDDIT_CONFIRMATION" => config.default_disable_visit_reddit_confirmation.clone(),
|
||||
"REDLIB_BANNER" => config.banner.clone(),
|
||||
"REDLIB_ROBOTS_DISABLE_INDEXING" => config.robots_disable_indexing.clone(),
|
||||
@ -243,6 +255,12 @@ fn test_default_subscriptions() {
|
||||
assert_eq!(get_setting("REDLIB_DEFAULT_SUBSCRIPTIONS"), Some("news+bestof".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[sealed_test(env = [("REDLIB_DEFAULT_FILTERS", "news+bestof")])]
|
||||
fn test_default_filters() {
|
||||
assert_eq!(get_setting("REDLIB_DEFAULT_FILTERS"), Some("news+bestof".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[sealed_test]
|
||||
fn test_pushshift() {
|
||||
|
@ -142,11 +142,13 @@ impl InstanceInfo {
|
||||
["Wide", &convert(&self.config.default_wide)],
|
||||
["Comment sort", &convert(&self.config.default_comment_sort)],
|
||||
["Post sort", &convert(&self.config.default_post_sort)],
|
||||
["Blur Spoiler", &convert(&self.config.default_blur_spoiler)],
|
||||
["Show NSFW", &convert(&self.config.default_show_nsfw)],
|
||||
["Blur NSFW", &convert(&self.config.default_blur_nsfw)],
|
||||
["Use HLS", &convert(&self.config.default_use_hls)],
|
||||
["Hide HLS notification", &convert(&self.config.default_hide_hls_notification)],
|
||||
["Subscriptions", &convert(&self.config.default_subscriptions)],
|
||||
["Filters", &convert(&self.config.default_filters)],
|
||||
])
|
||||
.with_header_row(["Default preferences"]),
|
||||
);
|
||||
@ -175,11 +177,13 @@ impl InstanceInfo {
|
||||
Default wide: {:?}\n
|
||||
Default comment sort: {:?}\n
|
||||
Default post sort: {:?}\n
|
||||
Default blur Spoiler: {:?}\n
|
||||
Default show NSFW: {:?}\n
|
||||
Default blur NSFW: {:?}\n
|
||||
Default use HLS: {:?}\n
|
||||
Default hide HLS notification: {:?}\n
|
||||
Default subscriptions: {:?}\n",
|
||||
Default subscriptions: {:?}\n
|
||||
Default filters: {:?}\n",
|
||||
self.package_name,
|
||||
self.crate_version,
|
||||
self.git_commit,
|
||||
@ -198,11 +202,13 @@ impl InstanceInfo {
|
||||
self.config.default_wide,
|
||||
self.config.default_comment_sort,
|
||||
self.config.default_post_sort,
|
||||
self.config.default_blur_spoiler,
|
||||
self.config.default_show_nsfw,
|
||||
self.config.default_blur_nsfw,
|
||||
self.config.default_use_hls,
|
||||
self.config.default_hide_hls_notification,
|
||||
self.config.default_subscriptions,
|
||||
self.config.default_filters,
|
||||
)
|
||||
}
|
||||
StringType::Html => self.to_table(),
|
||||
|
@ -160,7 +160,7 @@ async fn main() {
|
||||
.long("address")
|
||||
.value_name("ADDRESS")
|
||||
.help("Sets address to listen on")
|
||||
.default_value("0.0.0.0")
|
||||
.default_value("[::]")
|
||||
.num_args(1),
|
||||
)
|
||||
.arg(
|
||||
|
@ -1,12 +1,12 @@
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use std::{collections::HashMap, sync::atomic::Ordering, time::Duration};
|
||||
|
||||
use crate::{
|
||||
client::{CLIENT, OAUTH_CLIENT},
|
||||
client::{CLIENT, OAUTH_CLIENT, OAUTH_RATELIMIT_REMAINING},
|
||||
oauth_resources::ANDROID_APP_VERSION_LIST,
|
||||
};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use hyper::{client, Body, Method, Request};
|
||||
use log::info;
|
||||
use log::{info, trace};
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
@ -131,6 +131,7 @@ pub async fn token_daemon() {
|
||||
}
|
||||
|
||||
pub async fn force_refresh_token() {
|
||||
trace!("Rolling over refresh token. Current rate limit: {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst));
|
||||
OAUTH_CLIENT.write().await.refresh().await;
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,7 @@ struct SettingsTemplate {
|
||||
|
||||
// CONSTANTS
|
||||
|
||||
const PREFS: [&str; 18] = [
|
||||
const PREFS: [&str; 19] = [
|
||||
"theme",
|
||||
"mascot",
|
||||
"front_page",
|
||||
@ -27,6 +27,7 @@ const PREFS: [&str; 18] = [
|
||||
"wide",
|
||||
"comment_sort",
|
||||
"post_sort",
|
||||
"blur_spoiler",
|
||||
"show_nsfw",
|
||||
"blur_nsfw",
|
||||
"use_hls",
|
||||
|
@ -64,7 +64,7 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
|
||||
let post_sort = req.cookie("post_sort").map_or_else(|| "hot".to_string(), |c| c.value().to_string());
|
||||
let sort = req.param("sort").unwrap_or_else(|| req.param("id").unwrap_or(post_sort));
|
||||
|
||||
let sub_name = req.param("sub").unwrap_or(if front_page == "default" || front_page.is_empty() {
|
||||
let mut sub_name = req.param("sub").unwrap_or(if front_page == "default" || front_page.is_empty() {
|
||||
if subscribed.is_empty() {
|
||||
"popular".to_string()
|
||||
} else {
|
||||
@ -84,6 +84,11 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
|
||||
return Ok(redirect(&["/user/", &sub_name[2..]].concat()));
|
||||
}
|
||||
|
||||
// If multi-sub, replace + with url encoded +
|
||||
if sub_name.contains('+') {
|
||||
sub_name = sub_name.replace('+', "%2B");
|
||||
}
|
||||
|
||||
// Request subreddit metadata
|
||||
let sub = if !sub_name.contains('+') && sub_name != subscribed && sub_name != "popular" && sub_name != "all" {
|
||||
// Regular subreddit
|
||||
|
@ -1,3 +1,4 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::config::get_setting;
|
||||
//
|
||||
// CRATES
|
||||
@ -156,6 +157,7 @@ impl PollOption {
|
||||
|
||||
// Post flags with nsfw and stickied
|
||||
pub struct Flags {
|
||||
pub spoiler: bool,
|
||||
pub nsfw: bool,
|
||||
pub stickied: bool,
|
||||
}
|
||||
@ -402,6 +404,7 @@ impl Post {
|
||||
},
|
||||
},
|
||||
flags: Flags {
|
||||
spoiler: data["spoiler"].as_bool().unwrap_or_default(),
|
||||
nsfw: data["over_18"].as_bool().unwrap_or_default(),
|
||||
stickied: data["stickied"].as_bool().unwrap_or_default() || data["pinned"].as_bool().unwrap_or_default(),
|
||||
},
|
||||
@ -576,6 +579,7 @@ pub struct Preferences {
|
||||
pub front_page: String,
|
||||
pub layout: String,
|
||||
pub wide: String,
|
||||
pub blur_spoiler: String,
|
||||
pub show_nsfw: String,
|
||||
pub blur_nsfw: String,
|
||||
pub hide_hls_notification: String,
|
||||
@ -628,6 +632,7 @@ impl Preferences {
|
||||
front_page: setting(req, "front_page"),
|
||||
layout: setting(req, "layout"),
|
||||
wide: setting(req, "wide"),
|
||||
blur_spoiler: setting(req, "blur_spoiler"),
|
||||
show_nsfw: setting(req, "show_nsfw"),
|
||||
hide_sidebar_and_summary: setting(req, "hide_sidebar_and_summary"),
|
||||
blur_nsfw: setting(req, "blur_nsfw"),
|
||||
@ -749,6 +754,7 @@ pub async fn parse_post(post: &Value) -> Post {
|
||||
},
|
||||
},
|
||||
flags: Flags {
|
||||
spoiler: post["data"]["spoiler"].as_bool().unwrap_or_default(),
|
||||
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
|
||||
stickied: post["data"]["stickied"].as_bool().unwrap_or_default() || post["data"]["pinned"].as_bool().unwrap_or(false),
|
||||
},
|
||||
@ -1046,7 +1052,7 @@ pub fn redirect(path: &str) -> Response<Body> {
|
||||
|
||||
/// Renders a generic error landing page.
|
||||
pub async fn error(req: Request<Body>, msg: &str) -> Result<Response<Body>, String> {
|
||||
error!("Error page rendered: {msg}");
|
||||
error!("Error page rendered: {}", msg.split('|').next().unwrap_or_default());
|
||||
let url = req.uri().to_string();
|
||||
let body = ErrorTemplate {
|
||||
msg: msg.to_string(),
|
||||
|
Reference in New Issue
Block a user