diff --git a/src/main.rs b/src/main.rs index 2bd7c66..e7f4bc2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod routes; mod sbot; +mod utils; use rocket::{launch, routes}; use rocket_dyn_templates::Template; diff --git a/src/routes.rs b/src/routes.rs index b5ad904..88afc23 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,9 +1,10 @@ -use log::info; +use log::{info, warn}; -use rocket::{form::Form, get, post, response::Redirect, uri, FromForm}; +use rocket::{form::Form, get, post, response::{Redirect, Flash}, uri, FromForm, request::FlashMessage}; use rocket_dyn_templates::{context, Template}; use crate::sbot; +use crate::utils; #[derive(FromForm)] pub struct PeerForm { @@ -11,22 +12,62 @@ pub struct PeerForm { } #[get("/")] -pub async fn home() -> Template { +pub async fn home(flash: Option>) -> Template { let whoami = match sbot::whoami().await { Ok(id) => id, Err(e) => format!("whoami call failed: {}", e), }; - Template::render("base", context! { whoami }) + Template::render("base", context! { whoami, flash }) } #[post("/subscribe", data = "")] -pub async fn subscribe_form(peer: Form) -> Redirect { +pub async fn subscribe_form(peer: Form) -> Result> { info!("Subscribing to peer {}", &peer.public_key); - Redirect::to(uri!(home)) + if let Err(e) = utils::validate_public_key(&peer.public_key) { + let validation_err_msg = format!("Public key {} is invalid: {}", &peer.public_key, e); + warn!("{}", validation_err_msg); + return Err(Flash::error(Redirect::to(uri!(home)), validation_err_msg)); + } else { + info!("Public key {} is valid.", &peer.public_key); + if let Ok(whoami) = sbot::whoami().await { + match sbot::is_following(&whoami, &peer.public_key).await { + Ok(status) if status.as_str() == "false" => { + info!("Not currently following peer {}", &peer.public_key); + }, + Ok(status) if status.as_str() == "true" => { + info!("Already following peer {}. No further action needed here.", &peer.public_key); + }, + _ => (), + } + } else { + warn!("Received an error during `whoami` RPC call. Please ensure the go-sbot is running and try again.") + } + } + Ok(Redirect::to(uri!(home))) } #[post("/unsubscribe", data = "")] -pub async fn unsubscribe_form(peer: Form) -> Redirect { +pub async fn unsubscribe_form(peer: Form) -> Result> { info!("Unsubscribing to peer {}", &peer.public_key); - Redirect::to(uri!(home)) + if let Err(e) = utils::validate_public_key(&peer.public_key) { + let validation_err_msg = format!("Public key {} is invalid: {}", &peer.public_key, e); + warn!("{}", validation_err_msg); + return Err(Flash::error(Redirect::to(uri!(home)), validation_err_msg)); + } else { + info!("Public key {} is valid.", &peer.public_key); + if let Ok(whoami) = sbot::whoami().await { + match sbot::is_following(&whoami, &peer.public_key).await { + Ok(status) if status.as_str() == "false" => { + info!("Currently following peer {}", &peer.public_key); + }, + Ok(status) if status.as_str() == "true" => { + info!("Already not following peer {}. No further action needed here.", &peer.public_key); + }, + _ => (), + } + } else { + warn!("Received an error during `whoami` RPC call. Please ensure the go-sbot is running and try again.") + } + } + Ok(Redirect::to(uri!(home))) } diff --git a/src/sbot.rs b/src/sbot.rs index 2e7939d..6683afa 100644 --- a/src/sbot.rs +++ b/src/sbot.rs @@ -1,6 +1,6 @@ use std::env; -use golgi::{sbot::Keystore, Sbot}; +use golgi::{sbot::Keystore, Sbot, api::friends::RelationshipQuery}; pub async fn init_sbot() -> Result { @@ -20,3 +20,15 @@ pub async fn whoami() -> Result { sbort.whoami().await.map_err(|e| e.to_string()) } +pub async fn is_following(public_key_a: &str, public_key_b: &str) -> Result { + let mut sbot = init_sbot().await?; + + let query = RelationshipQuery { + source: public_key_a.to_string(), + dest: public_key_b.to_string(), + }; + + sbot.friends_is_following(query) + .await + .map_err(|e| e.to_string()) +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..a155c9b --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,22 @@ +pub fn validate_public_key(public_key: &str) -> Result<(), String> { + if !public_key.starts_with('@') { + return Err("Expected '@' sigil as first character".to_string()); + } + + let dot_index = match public_key.rfind('.') { + Some(index) => index, + None => return Err("could not find '.' character".to_string()), + }; + + if !&public_key.ends_with(".ed25519") { + return Err("hashing algorithm must be ed25519".to_string()); + } + + let base64_str = &public_key[1..dot_index]; + + if base64_str.len() != 44 { + return Err("base64 data length is incorrect".to_string()); + } + + Ok(()) +} diff --git a/templates/base.html.tera b/templates/base.html.tera index e10610f..47b6487 100644 --- a/templates/base.html.tera +++ b/templates/base.html.tera @@ -14,5 +14,8 @@ + {% if flash and flash.kind == "error" %} +

[ {{ flash.message }} ]

+ {% endif %}