2024-06-29 10:14:47 +12:00
use arc_swap ::ArcSwap ;
2021-03-18 11:30:33 +13:00
use cached ::proc_macro ::cached ;
2023-06-09 06:33:54 +12:00
use futures_lite ::future ::block_on ;
2021-03-18 11:30:33 +13:00
use futures_lite ::{ future ::Boxed , FutureExt } ;
2023-03-09 17:53:23 +13:00
use hyper ::client ::HttpConnector ;
2024-06-29 14:39:42 +12:00
use hyper ::header ::HeaderValue ;
2023-03-09 17:53:23 +13:00
use hyper ::{ body , body ::Buf , client , header , Body , Client , Method , Request , Response , Uri } ;
use hyper_rustls ::HttpsConnector ;
2022-11-04 17:04:34 +13:00
use libflate ::gzip ;
2024-06-27 11:19:30 +12:00
use log ::{ error , trace , warn } ;
2023-03-09 17:53:23 +13:00
use once_cell ::sync ::Lazy ;
2022-05-21 17:48:59 +12:00
use percent_encoding ::{ percent_encode , CONTROLS } ;
2021-03-18 11:30:33 +13:00
use serde_json ::Value ;
2024-06-27 11:19:30 +12:00
use std ::sync ::atomic ::Ordering ;
2024-06-28 15:26:31 +12:00
use std ::sync ::atomic ::{ AtomicBool , AtomicU16 } ;
2023-12-31 04:22:49 +13:00
use std ::{ io , result ::Result } ;
2021-03-18 11:30:33 +13:00
2022-11-05 21:29:04 +13:00
use crate ::dbg_msg ;
2024-01-28 17:31:21 +13:00
use crate ::oauth ::{ force_refresh_token , token_daemon , Oauth } ;
2021-03-18 11:30:33 +13:00
use crate ::server ::RequestExt ;
2023-12-30 13:28:41 +13:00
use crate ::utils ::format_url ;
2021-03-18 11:30:33 +13:00
2023-06-06 12:31:25 +12:00
const REDDIT_URL_BASE : & str = " https://oauth.reddit.com " ;
2024-09-21 15:57:18 +12:00
const REDDIT_URL_BASE_HOST : & str = " oauth.reddit.com " ;
const REDDIT_SHORT_URL_BASE : & str = " https://redd.it " ;
const REDDIT_SHORT_URL_BASE_HOST : & str = " redd.it " ;
2024-06-29 14:39:42 +12:00
const ALTERNATIVE_REDDIT_URL_BASE : & str = " https://www.reddit.com " ;
2024-09-21 15:57:18 +12:00
const ALTERNATIVE_REDDIT_URL_BASE_HOST : & str = " www.reddit.com " ;
2022-11-05 21:29:04 +13:00
2023-12-29 06:42:06 +13:00
pub static CLIENT : Lazy < Client < HttpsConnector < HttpConnector > > > = Lazy ::new ( | | {
2024-09-19 03:24:00 +12:00
let https = hyper_rustls ::HttpsConnectorBuilder ::new ( ) . with_native_roots ( ) . https_only ( ) . enable_http1 ( ) . build ( ) ;
2023-02-26 20:33:55 +13:00
client ::Client ::builder ( ) . build ( https )
} ) ;
2024-06-29 10:14:47 +12:00
pub static OAUTH_CLIENT : Lazy < ArcSwap < Oauth > > = Lazy ::new ( | | {
2023-06-09 06:33:54 +12:00
let client = block_on ( Oauth ::new ( ) ) ;
tokio ::spawn ( token_daemon ( ) ) ;
2024-06-29 10:14:47 +12:00
ArcSwap ::new ( client . into ( ) )
2023-06-09 06:33:54 +12:00
} ) ;
2023-06-06 12:31:25 +12:00
2024-06-27 11:19:30 +12:00
pub static OAUTH_RATELIMIT_REMAINING : AtomicU16 = AtomicU16 ::new ( 99 ) ;
2024-06-28 15:26:31 +12:00
pub static OAUTH_IS_ROLLING_OVER : AtomicBool = AtomicBool ::new ( false ) ;
2024-09-21 15:57:18 +12:00
static URL_PAIRS : [ ( & str , & str ) ; 2 ] = [
( ALTERNATIVE_REDDIT_URL_BASE , ALTERNATIVE_REDDIT_URL_BASE_HOST ) ,
( REDDIT_SHORT_URL_BASE , REDDIT_SHORT_URL_BASE_HOST ) ,
] ;
2022-11-05 21:29:04 +13:00
/// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`.
///
/// This function returns `Ok(Some(path))`, where `path`'s value is identical
/// to that of the value of the argument `path`, if Reddit responds to our
/// `HEAD` request with a 2xx-family HTTP code. It will also return an
/// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a
/// `Location` header in the response, and the HTTP code is in the 3xx-family;
/// the `String` will contain the path as reported in `Location`. The return
/// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
/// 429, or if we were unable to decode the value in the `Location` header.
#[ cached(size = 1024, time = 600, result = true) ]
2024-09-22 07:44:27 +12:00
#[ async_recursion::async_recursion ]
2024-09-21 15:57:18 +12:00
pub async fn canonical_path ( path : String , tries : i8 ) -> Result < Option < String > , String > {
if tries = = 0 {
return Ok ( None ) ;
}
// for each URL pair, try the HEAD request
let res = {
// for url base and host in URL_PAIRS, try reddit_short_head(path.clone(), true, url_base, url_base_host) and if it succeeds, set res. else, res = None
let mut res = None ;
for ( url_base , url_base_host ) in URL_PAIRS {
res = reddit_short_head ( path . clone ( ) , true , url_base , url_base_host ) . await . ok ( ) ;
if let Some ( res ) = & res {
if ! res . status ( ) . is_client_error ( ) {
break ;
}
}
}
res
} ;
let res = res . ok_or_else ( | | " Unable to make HEAD request to Reddit. " . to_string ( ) ) ? ;
2023-10-28 02:05:22 +13:00
let status = res . status ( ) . as_u16 ( ) ;
2024-09-17 08:16:08 +12:00
let policy_error = res . headers ( ) . get ( header ::RETRY_AFTER ) . is_some ( ) ;
2022-11-05 21:29:04 +13:00
2023-10-28 02:05:22 +13:00
match status {
// If Reddit responds with a 2xx, then the path is already canonical.
2023-12-27 12:27:25 +13:00
200 ..= 299 = > Ok ( Some ( path ) ) ,
2022-11-05 21:29:04 +13:00
2023-12-29 12:21:07 +13:00
// If Reddit responds with a 301, then the path is redirected.
301 = > match res . headers ( ) . get ( header ::LOCATION ) {
2023-12-30 13:28:41 +13:00
Some ( val ) = > {
2024-01-20 13:06:05 +13:00
let Ok ( original ) = val . to_str ( ) else {
2024-01-20 12:57:19 +13:00
return Err ( " Unable to decode Location header. " . to_string ( ) ) ;
} ;
2024-09-21 15:57:18 +12:00
2023-12-30 13:34:57 +13:00
// We need to strip the .json suffix from the original path.
// In addition, we want to remove share parameters.
// Cut it off here instead of letting it propagate all the way
// to main.rs
let stripped_uri = original . strip_suffix ( " .json " ) . unwrap_or ( original ) . split ( '?' ) . next ( ) . unwrap_or_default ( ) ;
2023-12-30 13:28:41 +13:00
// The reason why we now have to format_url, is because the new OAuth
// endpoints seem to return full paths, instead of relative paths.
// So we need to strip the .json suffix from the original path, and
// also remove all Reddit domain parts with format_url.
// Otherwise, it will literally redirect to Reddit.com.
2023-12-30 13:34:57 +13:00
let uri = format_url ( stripped_uri ) ;
2024-09-21 15:57:18 +12:00
// Decrement tries and try again
2024-09-22 07:44:27 +12:00
canonical_path ( uri , tries - 1 ) . await
2023-12-30 13:28:41 +13:00
}
2023-12-29 12:21:07 +13:00
None = > Ok ( None ) ,
} ,
// If Reddit responds with anything other than 3xx (except for the 2xx and 301
// as above), return a None.
2023-12-27 12:27:25 +13:00
300 ..= 399 = > Ok ( None ) ,
2022-11-05 21:29:04 +13:00
2024-09-17 08:16:08 +12:00
// Rate limiting
429 = > Err ( " Too many requests. " . to_string ( ) ) ,
// Special condition rate limiting - https://github.com/redlib-org/redlib/issues/229
403 if policy_error = > Err ( " Too many requests. " . to_string ( ) ) ,
2023-10-28 02:05:22 +13:00
_ = > Ok (
res
. headers ( )
. get ( header ::LOCATION )
. map ( | val | percent_encode ( val . as_bytes ( ) , CONTROLS ) . to_string ( ) . trim_start_matches ( REDDIT_URL_BASE ) . to_string ( ) ) ,
) ,
}
2022-11-05 21:29:04 +13:00
}
2021-03-18 11:30:33 +13:00
pub async fn proxy ( req : Request < Body > , format : & str ) -> Result < Response < Body > , String > {
2024-01-20 14:16:17 +13:00
let mut url = format! ( " {format} ? {} " , req . uri ( ) . query ( ) . unwrap_or_default ( ) ) ;
2021-03-18 11:30:33 +13:00
2021-05-21 07:24:06 +12:00
// For each parameter in request
2024-01-20 14:16:17 +13:00
for ( name , value ) in & req . params ( ) {
2021-05-21 07:24:06 +12:00
// Fill the parameter value in the url
2024-01-20 14:16:17 +13:00
url = url . replace ( & format! ( " {{ {name} }} " ) , value ) ;
2021-03-18 11:30:33 +13:00
}
2021-05-10 13:25:52 +12:00
stream ( & url , & req ) . await
2021-03-18 11:30:33 +13:00
}
2021-05-10 13:25:52 +12:00
async fn stream ( url : & str , req : & Request < Body > ) -> Result < Response < Body > , String > {
2021-03-18 11:30:33 +13:00
// First parameter is target URL (mandatory).
2024-01-20 14:16:17 +13:00
let parsed_uri = url . parse ::< Uri > ( ) . map_err ( | _ | " Couldn't parse URL " . to_string ( ) ) ? ;
2021-03-18 11:30:33 +13:00
// Build the hyper client from the HTTPS connector.
2024-01-20 14:16:17 +13:00
let client : Client < _ , Body > = CLIENT . clone ( ) ;
2021-03-18 11:30:33 +13:00
2024-01-20 14:16:17 +13:00
let mut builder = Request ::get ( parsed_uri ) ;
2021-05-10 13:25:52 +12:00
// Copy useful headers from original request
for & key in & [ " Range " , " If-Modified-Since " , " Cache-Control " ] {
2021-05-21 07:24:06 +12:00
if let Some ( value ) = req . headers ( ) . get ( key ) {
2021-05-10 13:25:52 +12:00
builder = builder . header ( key , value ) ;
}
}
2021-05-21 07:24:06 +12:00
let stream_request = builder . body ( Body ::empty ( ) ) . map_err ( | _ | " Couldn't build empty body in stream " . to_string ( ) ) ? ;
2021-05-10 13:25:52 +12:00
2021-03-18 11:30:33 +13:00
client
2021-05-10 13:25:52 +12:00
. request ( stream_request )
2021-03-18 11:30:33 +13:00
. await
. map ( | mut res | {
let mut rm = | key : & str | res . headers_mut ( ) . remove ( key ) ;
rm ( " access-control-expose-headers " ) ;
rm ( " server " ) ;
rm ( " vary " ) ;
rm ( " etag " ) ;
rm ( " x-cdn " ) ;
rm ( " x-cdn-client-region " ) ;
rm ( " x-cdn-name " ) ;
rm ( " x-cdn-server-region " ) ;
2021-05-10 13:25:52 +12:00
rm ( " x-reddit-cdn " ) ;
rm ( " x-reddit-video-features " ) ;
2023-07-09 14:20:58 +12:00
rm ( " Nel " ) ;
rm ( " Report-To " ) ;
2021-03-18 11:30:33 +13:00
res
} )
. map_err ( | e | e . to_string ( ) )
}
2022-11-05 21:29:04 +13:00
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
/// 3xx codes Reddit returns and will automatically redirect.
fn reddit_get ( path : String , quarantine : bool ) -> Boxed < Result < Response < Body > , String > > {
2024-09-21 15:57:18 +12:00
request ( & Method ::GET , path , true , quarantine , REDDIT_URL_BASE , REDDIT_URL_BASE_HOST )
2022-11-05 21:29:04 +13:00
}
2024-09-21 15:57:18 +12:00
/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
fn reddit_short_head ( path : String , quarantine : bool , base_path : & 'static str , host : & 'static str ) -> Boxed < Result < Response < Body > , String > > {
request ( & Method ::HEAD , path , false , quarantine , base_path , host )
2022-11-05 21:29:04 +13:00
}
2024-09-21 15:57:18 +12:00
// /// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
// fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
// request(&Method::HEAD, path, false, quarantine, false)
// }
// Unused - reddit_head is only ever called in the context of a short URL
2024-01-20 14:16:17 +13:00
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
2022-11-05 21:29:04 +13:00
/// will recurse on the URL that Reddit provides in the Location HTTP header
/// in its response.
2024-09-21 15:57:18 +12:00
fn request ( method : & 'static Method , path : String , redirect : bool , quarantine : bool , base_path : & 'static str , host : & 'static str ) -> Boxed < Result < Response < Body > , String > > {
2022-11-05 21:29:04 +13:00
// Build Reddit URL from path.
2024-09-21 15:57:18 +12:00
let url = format! ( " {base_path} {path} " ) ;
2022-11-05 21:29:04 +13:00
2021-05-21 07:24:06 +12:00
// Construct the hyper client from the HTTPS connector.
2024-01-20 14:16:17 +13:00
let client : Client < _ , Body > = CLIENT . clone ( ) ;
2021-03-18 11:30:33 +13:00
2024-06-26 11:28:41 +12:00
let ( token , vendor_id , device_id , user_agent , loid ) = {
2024-06-29 10:14:47 +12:00
let client = OAUTH_CLIENT . load_full ( ) ;
2023-06-06 12:39:56 +12:00
(
client . token . clone ( ) ,
2023-06-09 06:33:54 +12:00
client . headers_map . get ( " Client-Vendor-Id " ) . cloned ( ) . unwrap_or_default ( ) ,
client . headers_map . get ( " X-Reddit-Device-Id " ) . cloned ( ) . unwrap_or_default ( ) ,
client . headers_map . get ( " User-Agent " ) . cloned ( ) . unwrap_or_default ( ) ,
2023-06-07 07:28:36 +12:00
client . headers_map . get ( " x-reddit-loid " ) . cloned ( ) . unwrap_or_default ( ) ,
2023-06-06 12:39:56 +12:00
)
} ;
2024-05-30 10:36:56 +12:00
2022-11-05 21:29:04 +13:00
// Build request to Reddit. When making a GET, request gzip compression.
// (Reddit doesn't do brotli yet.)
2021-03-18 17:26:06 +13:00
let builder = Request ::builder ( )
2022-11-05 21:29:04 +13:00
. method ( method )
2021-03-18 17:40:55 +13:00
. uri ( & url )
2023-06-07 07:05:20 +12:00
. header ( " User-Agent " , user_agent )
2023-06-06 12:39:56 +12:00
. header ( " Client-Vendor-Id " , vendor_id )
. header ( " X-Reddit-Device-Id " , device_id )
2023-06-07 07:05:20 +12:00
. header ( " x-reddit-loid " , loid )
2024-09-21 15:57:18 +12:00
. header ( " Host " , host )
2024-01-20 14:16:17 +13:00
. header ( " Authorization " , & format! ( " Bearer {token} " ) )
2022-11-05 21:29:04 +13:00
. header ( " Accept-Encoding " , if method = = Method ::GET { " gzip " } else { " identity " } )
2021-03-18 17:40:55 +13:00
. header ( " Accept-Language " , " en-US,en;q=0.5 " )
. header ( " Connection " , " keep-alive " )
2023-03-09 17:53:23 +13:00
. header (
" Cookie " ,
if quarantine {
" _options=%7B%22pref_quarantine_optin%22%3A%20true%2C%20%22pref_gated_sr_optin%22%3A%20true%7D "
} else {
" "
} ,
)
2021-03-18 17:40:55 +13:00
. body ( Body ::empty ( ) ) ;
2021-03-18 11:30:33 +13:00
async move {
2021-03-18 17:26:06 +13:00
match builder {
Ok ( req ) = > match client . request ( req ) . await {
2022-11-04 17:04:34 +13:00
Ok ( mut response ) = > {
2022-11-05 21:29:04 +13:00
// Reddit may respond with a 3xx. Decide whether or not to
// redirect based on caller params.
2023-10-28 02:05:22 +13:00
if response . status ( ) . is_redirection ( ) {
2022-11-05 21:29:04 +13:00
if ! redirect {
return Ok ( response ) ;
} ;
2024-06-29 14:39:42 +12:00
let location_header = response . headers ( ) . get ( header ::LOCATION ) ;
if location_header = = Some ( & HeaderValue ::from_static ( " https://www.reddit.com/ " ) ) {
return Err ( " Reddit response was invalid " . to_string ( ) ) ;
}
2022-11-05 21:29:04 +13:00
return request (
method ,
2024-06-29 14:39:42 +12:00
location_header
2021-11-30 19:29:41 +13:00
. map ( | val | {
2022-11-22 04:58:40 +13:00
// We need to make adjustments to the URI
// we get back from Reddit. Namely, we
// must:
//
// 1. Remove the authority (e.g.
// https://www.reddit.com) that may be
// present, so that we recurse on the
// path (and query parameters) as
// required.
//
// 2. Percent-encode the path.
2024-06-29 14:39:42 +12:00
let new_path = percent_encode ( val . as_bytes ( ) , CONTROLS )
. to_string ( )
. trim_start_matches ( REDDIT_URL_BASE )
. trim_start_matches ( ALTERNATIVE_REDDIT_URL_BASE )
. to_string ( ) ;
2024-01-20 14:16:17 +13:00
format! ( " {new_path} {} raw_json=1 " , if new_path . contains ( '?' ) { " & " } else { " ? " } )
2021-11-30 19:29:41 +13:00
} )
2021-03-18 17:26:06 +13:00
. unwrap_or_default ( )
. to_string ( ) ,
2022-11-05 21:29:04 +13:00
true ,
2021-05-17 03:53:39 +12:00
quarantine ,
2024-09-21 15:57:18 +12:00
base_path ,
host ,
2021-03-18 17:26:06 +13:00
)
2022-11-05 21:29:04 +13:00
. await ;
} ;
match response . headers ( ) . get ( header ::CONTENT_ENCODING ) {
// Content not compressed.
None = > Ok ( response ) ,
// Content encoded (hopefully with gzip).
Some ( hdr ) = > {
match hdr . to_str ( ) {
Ok ( val ) = > match val {
" gzip " = > { }
" identity " = > return Ok ( response ) ,
_ = > return Err ( " Reddit response was encoded with an unsupported compressor " . to_string ( ) ) ,
} ,
Err ( _ ) = > return Err ( " Reddit response was invalid " . to_string ( ) ) ,
}
// We get here if the body is gzip-compressed.
// The body must be something that implements
// std::io::Read, hence the conversion to
// bytes::buf::Buf and then transformation into a
// Reader.
let mut decompressed : Vec < u8 > ;
{
let mut aggregated_body = match body ::aggregate ( response . body_mut ( ) ) . await {
Ok ( b ) = > b . reader ( ) ,
Err ( e ) = > return Err ( e . to_string ( ) ) ,
} ;
let mut decoder = match gzip ::Decoder ::new ( & mut aggregated_body ) {
Ok ( decoder ) = > decoder ,
Err ( e ) = > return Err ( e . to_string ( ) ) ,
} ;
decompressed = Vec ::< u8 > ::new ( ) ;
if let Err ( e ) = io ::copy ( & mut decoder , & mut decompressed ) {
return Err ( e . to_string ( ) ) ;
} ;
2022-11-04 17:04:34 +13:00
}
2022-11-05 21:29:04 +13:00
response . headers_mut ( ) . remove ( header ::CONTENT_ENCODING ) ;
response . headers_mut ( ) . insert ( header ::CONTENT_LENGTH , decompressed . len ( ) . into ( ) ) ;
* ( response . body_mut ( ) ) = Body ::from ( decompressed ) ;
Ok ( response )
2022-11-04 17:04:34 +13:00
}
2021-03-18 17:26:06 +13:00
}
2021-03-18 11:30:33 +13:00
}
2022-11-05 21:29:04 +13:00
Err ( e ) = > {
2024-06-29 14:39:42 +12:00
dbg_msg! ( " {method} {REDDIT_URL_BASE}{path}: {} " , e ) ;
2022-11-05 21:29:04 +13:00
Err ( e . to_string ( ) )
}
2021-03-18 17:26:06 +13:00
} ,
2021-03-18 17:40:55 +13:00
Err ( _ ) = > Err ( " Post url contains non-ASCII characters " . to_string ( ) ) ,
2021-03-18 11:30:33 +13:00
}
}
. boxed ( )
}
// Make a request to a Reddit API and parse the JSON response
2024-06-27 11:19:30 +12:00
#[ cached(size = 100, time = 30, result = true) ]
2021-05-17 03:53:39 +12:00
pub async fn json ( path : String , quarantine : bool ) -> Result < Value , String > {
2021-03-18 11:30:33 +13:00
// Closure to quickly build errors
2024-06-20 06:45:32 +12:00
let err = | msg : & str , e : String , path : String | -> Result < Value , String > {
2021-03-18 17:26:06 +13:00
// eprintln!("{} - {}: {}", url, msg, e);
2024-06-20 06:45:32 +12:00
Err ( format! ( " {msg} : {e} | {path} " ) )
2021-03-18 11:30:33 +13:00
} ;
2024-06-27 11:19:30 +12:00
// First, handle rolling over the OAUTH_CLIENT if need be.
2024-06-27 15:41:26 +12:00
let current_rate_limit = OAUTH_RATELIMIT_REMAINING . load ( Ordering ::SeqCst ) ;
2024-06-28 15:26:31 +12:00
let is_rolling_over = OAUTH_IS_ROLLING_OVER . load ( Ordering ::SeqCst ) ;
if current_rate_limit < 10 & & ! is_rolling_over {
2024-06-27 11:19:30 +12:00
warn! ( " Rate limit {current_rate_limit} is low. Spawning force_refresh_token() " ) ;
tokio ::spawn ( force_refresh_token ( ) ) ;
}
2024-06-28 15:26:31 +12:00
OAUTH_RATELIMIT_REMAINING . fetch_sub ( 1 , Ordering ::SeqCst ) ;
2024-06-27 11:19:30 +12:00
2021-03-18 11:30:33 +13:00
// Fetch the url...
2022-11-05 21:29:04 +13:00
match reddit_get ( path . clone ( ) , quarantine ) . await {
2021-03-18 11:30:33 +13:00
Ok ( response ) = > {
2021-08-12 15:49:42 +12:00
let status = response . status ( ) ;
2024-06-29 16:20:19 +12:00
let reset : Option < String > = if let ( Some ( remaining ) , Some ( reset ) , Some ( used ) ) = (
response . headers ( ) . get ( " x-ratelimit-remaining " ) . and_then ( | val | val . to_str ( ) . ok ( ) . map ( | s | s . to_string ( ) ) ) ,
response . headers ( ) . get ( " x-ratelimit-reset " ) . and_then ( | val | val . to_str ( ) . ok ( ) . map ( | s | s . to_string ( ) ) ) ,
response . headers ( ) . get ( " x-ratelimit-used " ) . and_then ( | val | val . to_str ( ) . ok ( ) . map ( | s | s . to_string ( ) ) ) ,
) {
2024-06-28 15:26:31 +12:00
trace! (
2024-06-30 02:44:33 +12:00
" Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. Resets in {reset}. Rollover: {}. Ratelimit used: {used} " ,
2024-06-29 16:20:19 +12:00
if is_rolling_over { " yes " } else { " no " } ,
2024-06-28 15:26:31 +12:00
) ;
2024-06-29 16:20:19 +12:00
Some ( reset )
2024-06-27 00:05:22 +12:00
} else {
None
} ;
2021-03-18 11:30:33 +13:00
// asynchronously aggregate the chunks of the body
match hyper ::body ::aggregate ( response ) . await {
Ok ( body ) = > {
2024-06-27 00:05:22 +12:00
let has_remaining = body . has_remaining ( ) ;
if ! has_remaining {
2024-06-28 15:29:50 +12:00
// Rate limited, so spawn a force_refresh_token()
tokio ::spawn ( force_refresh_token ( ) ) ;
2024-06-27 00:05:22 +12:00
return match reset {
2024-06-28 15:29:50 +12:00
Some ( val ) = > Err ( format! (
" Reddit rate limit exceeded. Try refreshing in a few seconds. \
Rate limit will reset in : { val } "
) ) ,
2024-06-27 00:05:22 +12:00
None = > Err ( " Reddit rate limit exceeded " . to_string ( ) ) ,
} ;
}
2021-03-18 11:30:33 +13:00
// Parse the response from Reddit as JSON
match serde_json ::from_reader ( body . reader ( ) ) {
Ok ( value ) = > {
let json : Value = value ;
2024-09-25 15:13:36 +12:00
// If user is suspended
if let Some ( data ) = json . get ( " data " ) {
if let Some ( is_suspended ) = data . get ( " is_suspended " ) . and_then ( Value ::as_bool ) {
if is_suspended {
return Err ( " suspended " . into ( ) ) ;
}
}
}
2021-03-18 11:30:33 +13:00
// If Reddit returned an error
if json [ " error " ] . is_i64 ( ) {
2024-02-03 08:53:15 +13:00
// OAuth token has expired; http status 401
if json [ " message " ] = = " Unauthorized " {
error! ( " Forcing a token refresh " ) ;
let ( ) = force_refresh_token ( ) . await ;
return Err ( " OAuth token has expired. Please refresh the page! " . to_string ( ) ) ;
}
2024-09-25 15:13:36 +12:00
2024-07-22 06:22:54 +12:00
// Handle quarantined
if json [ " reason " ] = = " quarantined " {
return Err ( " quarantined " . into ( ) ) ;
}
// Handle gated
if json [ " reason " ] = = " gated " {
return Err ( " gated " . into ( ) ) ;
}
2024-09-25 15:01:28 +12:00
// Handle private subs
if json [ " reason " ] = = " private " {
return Err ( " private " . into ( ) ) ;
}
// Handle banned subs
if json [ " reason " ] = = " banned " {
return Err ( " banned " . into ( ) ) ;
}
2024-06-20 06:45:32 +12:00
Err ( format! ( " Reddit error {} \" {} \" : {} | {path} " , json [ " error " ] , json [ " reason " ] , json [ " message " ] ) )
2021-03-18 11:30:33 +13:00
} else {
Ok ( json )
}
}
2021-08-12 15:49:42 +12:00
Err ( e ) = > {
2024-02-03 08:53:15 +13:00
error! ( " Got an invalid response from reddit {e}. Status code: {status} " ) ;
2021-08-12 15:49:42 +12:00
if status . is_server_error ( ) {
Err ( " Reddit is having issues, check if there's an outage " . to_string ( ) )
} else {
2024-06-20 06:45:32 +12:00
err ( " Failed to parse page JSON data " , e . to_string ( ) , path )
2021-08-12 15:49:42 +12:00
}
}
2021-03-18 11:30:33 +13:00
}
}
2024-06-20 06:45:32 +12:00
Err ( e ) = > err ( " Failed receiving body from Reddit " , e . to_string ( ) , path ) ,
2021-03-18 11:30:33 +13:00
}
}
2024-06-20 06:45:32 +12:00
Err ( e ) = > err ( " Couldn't send request to Reddit " , e , path ) ,
2021-03-18 11:30:33 +13:00
}
}
2023-12-29 04:40:17 +13:00
2024-06-27 11:19:30 +12:00
#[ cfg(test) ]
static POPULAR_URL : & str = " /r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL " ;
2023-12-31 15:33:27 +13:00
#[ tokio::test(flavor = " multi_thread " ) ]
2023-12-29 04:40:17 +13:00
async fn test_localization_popular ( ) {
2024-06-27 11:19:30 +12:00
let val = json ( POPULAR_URL . to_string ( ) , false ) . await . unwrap ( ) ;
2023-12-29 04:40:17 +13:00
assert_eq! ( " GLOBAL " , val [ " data " ] [ " geo_filter " ] . as_str ( ) . unwrap ( ) ) ;
}
2023-12-29 12:21:07 +13:00
2023-12-31 15:33:27 +13:00
#[ tokio::test(flavor = " multi_thread " ) ]
2023-12-29 12:21:07 +13:00
async fn test_obfuscated_share_link ( ) {
let share_link = " /r/rust/s/kPgq8WNHRK " . into ( ) ;
2023-12-30 13:34:57 +13:00
// Correct link without share parameters
2024-09-21 15:57:18 +12:00
let canonical_link = " /r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/ " . into ( ) ;
assert_eq! ( canonical_path ( share_link , 3 ) . await , Ok ( Some ( canonical_link ) ) ) ;
2023-12-29 12:21:07 +13:00
}
2023-12-30 13:28:41 +13:00
2023-12-31 15:33:27 +13:00
#[ tokio::test(flavor = " multi_thread " ) ]
2023-12-30 13:28:41 +13:00
async fn test_share_link_strip_json ( ) {
let link = " /17krzvz " . into ( ) ;
2024-09-21 15:57:18 +12:00
let canonical_link = " /comments/17krzvz " . into ( ) ;
assert_eq! ( canonical_path ( link , 3 ) . await , Ok ( Some ( canonical_link ) ) ) ;
2023-12-30 13:28:41 +13:00
}
2024-09-25 15:01:28 +12:00
#[ tokio::test(flavor = " multi_thread " ) ]
async fn test_private_sub ( ) {
let link = json ( " /r/suicide/about.json?raw_json=1 " . into ( ) , true ) . await ;
assert! ( link . is_err ( ) ) ;
assert_eq! ( link , Err ( " private " . into ( ) ) ) ;
}
#[ tokio::test(flavor = " multi_thread " ) ]
async fn test_banned_sub ( ) {
let link = json ( " /r/aaa/about.json?raw_json=1 " . into ( ) , true ) . await ;
assert! ( link . is_err ( ) ) ;
assert_eq! ( link , Err ( " banned " . into ( ) ) ) ;
}
#[ tokio::test(flavor = " multi_thread " ) ]
async fn test_gated_sub ( ) {
// quarantine to false to specifically catch when we _don't_ catch it
let link = json ( " /r/drugs/about.json?raw_json=1 " . into ( ) , false ) . await ;
assert! ( link . is_err ( ) ) ;
assert_eq! ( link , Err ( " gated " . into ( ) ) ) ;
}