Get one SbotConnection per stream call

This commit is contained in:
notplants 2022-01-05 13:58:48 -05:00
parent 22417e4d82
commit 26e2809c9a
6 changed files with 245 additions and 217 deletions

View File

@ -1,7 +1,7 @@
use std::process;
use golgi::messages::SsbMessageContent;
use golgi::error::GolgiError;
use golgi::messages::SsbMessageContent;
use golgi::sbot::Sbot;
async fn run() -> Result<(), GolgiError> {
@ -32,7 +32,9 @@ async fn run() -> Result<(), GolgiError> {
let post_msg_ref = sbot_client.publish(post).await?;
println!("{}", post_msg_ref);
let post_msg_ref = sbot_client.publish_description("this is a description").await?;
let post_msg_ref = sbot_client
.publish_description("this is a description")
.await?;
println!("description: {}", post_msg_ref);
Ok(())

View File

@ -1,12 +1,11 @@
use std::process;
use golgi::error::GolgiError;
use golgi::sbot::Sbot;
use async_std::stream::{StreamExt};
use async_std::stream::StreamExt;
use futures::{pin_mut, TryStreamExt};
use golgi::messages::{SsbMessageContentType, SsbMessageValue};
use golgi::error::GolgiError;
use golgi::messages::{SsbMessageContentType, SsbMessageValue};
use golgi::sbot::Sbot;
async fn run() -> Result<(), GolgiError> {
let mut sbot_client = Sbot::init(None, None).await?;
@ -17,7 +16,9 @@ async fn run() -> Result<(), GolgiError> {
let author = "@L/z54cbc8V1kL1/MiBhpEKuN3QJkSoZYNaukny3ghIs=.ed25519";
// create a history stream
let history_stream = sbot_client.create_history_stream(author.to_string()).await?;
let history_stream = sbot_client
.create_history_stream(author.to_string())
.await?;
// loop through the results until the end of the stream
pin_mut!(history_stream); // needed for iteration
@ -26,7 +27,7 @@ async fn run() -> Result<(), GolgiError> {
match res {
Ok(value) => {
println!("value: {:?}", value);
},
}
Err(err) => {
println!("err: {:?}", err);
}
@ -36,34 +37,26 @@ async fn run() -> Result<(), GolgiError> {
// create a history stream and convert it into a Vec<SsbMessageValue> using try_collect
// (if there is any error in the results, it will be raised)
let mut history_stream = sbot_client.create_history_stream(author.to_string()).await?;
let results : Vec<SsbMessageValue> = history_stream.try_collect().await?;
let history_stream = sbot_client
.create_history_stream(author.to_string())
.await?;
let results: Vec<SsbMessageValue> = history_stream.try_collect().await?;
for x in results {
println!("x: {:?}", x);
}
// example to create a history stream and use a map to convert stream of SsbMessageValue
// into a stream of KeyTypeTuple (local struct for storing message_key and message_type)
#[derive(Debug)]
struct KeyTypeTuple {
message_key: String,
message_type: SsbMessageContentType,
};
let mut history_stream = sbot_client.create_history_stream(author.to_string()).await?;
let type_stream = history_stream.map(|msg| {
match msg {
Ok(val) => {
let message_type = val.get_message_type()?;
let tuple = KeyTypeTuple {
message_key: val.signature,
message_type: message_type,
};
Ok(tuple)
}
Err(err) => {
Err(err)
}
// into a stream of tuples of (String, SsbMessageContentType)
let history_stream = sbot_client
.create_history_stream(author.to_string())
.await?;
let type_stream = history_stream.map(|msg| match msg {
Ok(val) => {
let message_type = val.get_message_type()?;
let tuple: (String, SsbMessageContentType) = (val.signature, message_type);
Ok(tuple)
}
Err(err) => Err(err),
});
pin_mut!(type_stream); // needed for iteration
println!("looping through type stream");
@ -71,7 +64,7 @@ async fn run() -> Result<(), GolgiError> {
match res {
Ok(value) => {
println!("value: {:?}", value);
},
}
Err(err) => {
println!("err: {:?}", err);
}

View File

@ -57,7 +57,7 @@ impl std::error::Error for GolgiError {
GolgiError::Sbot(_) => None,
GolgiError::SerdeJson(ref err) => Some(err),
GolgiError::ContentType(_) => None,
GolgiError::Utf8Parse{ ref source} => Some(source),
GolgiError::Utf8Parse { ref source } => Some(source),
}
}
}
@ -82,7 +82,9 @@ impl std::fmt::Display for GolgiError {
"Failed to decode typed message from ssb message content: {}",
err
),
GolgiError::Utf8Parse{ source } => write!(f, "Failed to deserialize UTF8 from bytes: {}", source),
GolgiError::Utf8Parse { source } => {
write!(f, "Failed to deserialize UTF8 from bytes: {}", source)
}
}
}
}
@ -125,6 +127,6 @@ impl From<JsonError> for GolgiError {
impl From<Utf8Error> for GolgiError {
fn from(err: Utf8Error) -> Self {
GolgiError::Utf8Parse { source: err}
GolgiError::Utf8Parse { source: err }
}
}

View File

@ -1,9 +1,9 @@
//! Message types and conversion methods for `golgi`.
use std::fmt::Debug;
use kuska_ssb::api::dto::content::TypedMessage;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt::Debug;
use crate::error::GolgiError;
@ -36,29 +36,30 @@ pub enum SsbMessageContentType {
Vote,
Post,
Contact,
Unrecognized
Unrecognized,
}
impl SsbMessageValue {
/// Gets the type field of the message content as an enum, if found.
/// if no type field is found or the type field is not a string, it returns an Err(GolgiError::ContentType)
/// if a type field is found but with an unknown string it returns an Ok(SsbMessageContentType::Unrecognized)
pub fn get_message_type(&self) -> Result<SsbMessageContentType, GolgiError> {
let msg_type = self
.content
.get("type")
.ok_or(GolgiError::ContentType("type field not found".to_string()))?;
let mtype_str: &str = msg_type.as_str().ok_or(GolgiError::ContentType("type field value is not a string as expected".to_string()))?;
let enum_type = match mtype_str {
"about" => SsbMessageContentType::About,
"post" => SsbMessageContentType::Post,
"vote" => SsbMessageContentType::Vote,
"contact" => SsbMessageContentType::Contact,
_ => SsbMessageContentType::Unrecognized,
};
Ok(enum_type)
}
/// Gets the type field of the message content as an enum, if found.
/// if no type field is found or the type field is not a string, it returns an Err(GolgiError::ContentType)
/// if a type field is found but with an unknown string it returns an Ok(SsbMessageContentType::Unrecognized)
pub fn get_message_type(&self) -> Result<SsbMessageContentType, GolgiError> {
let msg_type = self
.content
.get("type")
.ok_or(GolgiError::ContentType("type field not found".to_string()))?;
let mtype_str: &str = msg_type.as_str().ok_or(GolgiError::ContentType(
"type field value is not a string as expected".to_string(),
))?;
let enum_type = match mtype_str {
"about" => SsbMessageContentType::About,
"post" => SsbMessageContentType::Post,
"vote" => SsbMessageContentType::Vote,
"contact" => SsbMessageContentType::Contact,
_ => SsbMessageContentType::Unrecognized,
};
Ok(enum_type)
}
/// Helper function which returns true if this message is of the given type,
/// and false if the type does not match or is not found
@ -68,9 +69,7 @@ impl SsbMessageValue {
Ok(mtype) => {
matches!(mtype, _message_type)
}
Err(_err) => {
false
}
Err(_err) => false,
}
}

View File

@ -1,29 +1,27 @@
//! Sbot type and associated methods.
use std::fmt::Debug;
use async_std::net::TcpStream;
use async_std::stream::{Stream, StreamExt};
use async_stream::stream;
use async_std::{net::TcpStream, stream::Stream};
use kuska_handshake::async_std::BoxStream;
use kuska_sodiumoxide::crypto::{auth, sign::ed25519};
use kuska_ssb::{
api::{
dto::{CreateHistoryStreamIn},
ApiCaller,
},
api::{dto::CreateHistoryStreamIn, ApiCaller},
discovery, keystore,
keystore::OwnedIdentity,
rpc::{RpcReader, RpcWriter, RecvMsg},
rpc::{RpcReader, RpcWriter},
};
use crate::error::GolgiError;
use crate::messages::{SsbMessageKVT, SsbMessageContent, SsbMessageValue, SsbMessageContentType};
use crate::messages::{SsbMessageContent, SsbMessageContentType, SsbMessageKVT, SsbMessageValue};
use crate::utils;
use crate::utils::get_source_stream;
// re-export types from kuska
pub use kuska_ssb::api::dto::content::SubsetQuery;
use kuska_ssb::rpc::RequestNo;
/// A struct representing a connection with a running sbot.
/// A client and an rpc_reader can together be used to make requests to the sbot
/// and read the responses.
/// Note there can be multiple SbotConnection at the same time.
pub struct SbotConnection {
client: ApiCaller<TcpStream>,
rpc_reader: RpcReader<TcpStream>,
@ -38,9 +36,9 @@ pub struct Sbot {
address: String,
// aka caps key (scuttleverse identifier)
network_id: auth::Key,
client: ApiCaller<TcpStream>,
rpc_reader: RpcReader<TcpStream>,
sbot_connections: Vec<SbotConnection>,
// the primary connection with sbot which can be re-used for non-stream calls
// note that stream calls will each need their own SbotConnection
sbot_connection: SbotConnection,
}
impl Sbot {
@ -64,6 +62,40 @@ impl Sbot {
.await
.expect("couldn't read local secret");
let sbot_connection =
Sbot::_get_sbot_connection_helper(address.clone(), network_id.clone(), pk, sk.clone())
.await?;
Ok(Self {
id,
public_key: pk,
private_key: sk,
address,
network_id,
sbot_connection,
})
}
/// Creates a new connection with the sbot,
/// using the address, network_id, public_key and private_key supplied when Sbot was initialized.
///
/// Note that a single Sbot can have multiple SbotConnection at the same time.
pub async fn get_sbot_connection(&self) -> Result<SbotConnection, GolgiError> {
let address = self.address.clone();
let network_id = self.network_id.clone();
let public_key = self.public_key;
let private_key = self.private_key.clone();
Sbot::_get_sbot_connection_helper(address, network_id, public_key, private_key).await
}
/// Private helper function which creates a new connection with sbot,
/// but with all variables passed as arguments.
async fn _get_sbot_connection_helper(
address: String,
network_id: auth::Key,
public_key: ed25519::PublicKey,
private_key: ed25519::SecretKey,
) -> Result<SbotConnection, GolgiError> {
let socket = TcpStream::connect(&address)
.await
.map_err(|source| GolgiError::Io {
@ -74,9 +106,9 @@ impl Sbot {
let handshake = kuska_handshake::async_std::handshake_client(
&mut &socket,
network_id.clone(),
pk,
sk.clone(),
pk,
public_key,
private_key.clone(),
public_key,
)
.await
.map_err(GolgiError::Handshake)?;
@ -86,76 +118,41 @@ impl Sbot {
let rpc_reader = RpcReader::new(box_stream_read);
let client = ApiCaller::new(RpcWriter::new(box_stream_write));
let mut sbot_connections = Vec::new();
Ok(Self {
id,
public_key: pk,
private_key: sk,
address,
network_id,
client,
rpc_reader,
sbot_connections
})
}
pub async fn get_sbot_connection(&self, ip_port: Option<String>, net_id: Option<String>) -> Result<SbotConnection, GolgiError> {
let address = if ip_port.is_none() {
"127.0.0.1:8008".to_string()
} else {
ip_port.unwrap()
};
let network_id = if net_id.is_none() {
discovery::ssb_net_id()
} else {
auth::Key::from_slice(&hex::decode(net_id.unwrap()).unwrap()).unwrap()
};
let socket = TcpStream::connect(&address)
.await
.map_err(|source| GolgiError::Io {
source,
context: "socket error; failed to initiate tcp stream connection".to_string(),
})?;
let handshake = kuska_handshake::async_std::handshake_client(
&mut &socket,
network_id.clone(),
self.public_key,
self.private_key.clone(),
self.public_key,
)
.await
.map_err(GolgiError::Handshake)?;
let (box_stream_read, box_stream_write) =
BoxStream::from_handshake(socket.clone(), socket, handshake, 0x8000).split_read_write();
let rpc_reader = RpcReader::new(box_stream_read);
let client = ApiCaller::new(RpcWriter::new(box_stream_write));
let sbot_connection = SbotConnection {
rpc_reader,
client,
};
let sbot_connection = SbotConnection { rpc_reader, client };
Ok(sbot_connection)
}
/// Call the `partialReplication getSubset` RPC method and return a vector
/// of messages as KVTs (key, value, timestamp).
// TODO: add args for `descending` and `page` (max number of msgs in response)
pub async fn get_subset(&mut self, query: SubsetQuery) -> Result<Vec<SsbMessageKVT>, GolgiError> {
let req_id = self.client.getsubset_req_send(query).await?;
pub async fn get_subset(
&mut self,
query: SubsetQuery,
) -> Result<Vec<SsbMessageKVT>, GolgiError> {
let req_id = self
.sbot_connection
.client
.getsubset_req_send(query)
.await?;
utils::get_source_until_eof(&mut self.rpc_reader, req_id, utils::kvt_res_parse).await
utils::get_source_until_eof(
&mut self.sbot_connection.rpc_reader,
req_id,
utils::kvt_res_parse,
)
.await
}
/// Call the `whoami` RPC method and return an `id`.
pub async fn whoami(&mut self) -> Result<String, GolgiError> {
let req_id = self.client.whoami_req_send().await?;
let req_id = self.sbot_connection.client.whoami_req_send().await?;
utils::get_async(&mut self.rpc_reader, req_id, utils::string_res_parse).await
utils::get_async(
&mut self.sbot_connection.rpc_reader,
req_id,
utils::string_res_parse,
)
.await
}
/// Call the `publish` RPC method and return a message reference.
@ -166,9 +163,14 @@ impl Sbot {
/// `Channel` and `Vote`. See the `kuska_ssb` documentation for further details such as field
/// names and accepted values for each variant.
pub async fn publish(&mut self, msg: SsbMessageContent) -> Result<String, GolgiError> {
let req_id = self.client.publish_req_send(msg).await?;
let req_id = self.sbot_connection.client.publish_req_send(msg).await?;
utils::get_async(&mut self.rpc_reader, req_id, utils::string_res_parse).await
utils::get_async(
&mut self.sbot_connection.rpc_reader,
req_id,
utils::string_res_parse,
)
.await
}
/// Wrapper for publish which constructs and publishes a post message appropriately from a string.
@ -223,8 +225,11 @@ impl Sbot {
}
/// Get the about messages for a particular user in order of recency.
pub async fn get_about_messages(&mut self, ssb_id: &str) -> Result<Vec<SsbMessageValue>, GolgiError> {
let query = SubsetQuery::Author{
pub async fn get_about_messages(
&mut self,
ssb_id: &str,
) -> Result<Vec<SsbMessageValue>, GolgiError> {
let query = SubsetQuery::Author {
op: "author".to_string(),
feed: ssb_id.to_string(),
};
@ -234,19 +239,22 @@ impl Sbot {
// change this subset query to filter by type about in addition to author
// and remove this filter section
// filter down to about messages
let mut about_messages: Vec<SsbMessageValue> = messages.into_iter().filter(|msg| {
msg.is_message_type(SsbMessageContentType::About)
}).collect();
let mut about_messages: Vec<SsbMessageValue> = messages
.into_iter()
.filter(|msg| msg.is_message_type(SsbMessageContentType::About))
.collect();
// TODO: use subset query to order messages instead of doing it this way
about_messages.sort_by(|a, b| {
b.timestamp.partial_cmp(&a.timestamp).unwrap()
});
about_messages.sort_by(|a, b| b.timestamp.partial_cmp(&a.timestamp).unwrap());
// return about messages
Ok(about_messages)
}
/// Get value of latest about message with given key from given user
pub async fn get_latest_about_message(&mut self, ssb_id: &str, key: &str) -> Result<Option<String>, GolgiError> {
pub async fn get_latest_about_message(
&mut self,
ssb_id: &str,
key: &str,
) -> Result<Option<String>, GolgiError> {
// vector of about messages with most recent at the front of the vector
let about_messages = self.get_about_messages(ssb_id).await?;
// iterate through the vector looking for most recent about message with the given key
@ -269,10 +277,22 @@ impl Sbot {
Ok(latest_about)
}
/// Get latest about name from given user
///
/// # Arguments
///
/// * `ssb_id` - A reference to a string slice which represents the ssb user
/// to lookup the about name for.
pub async fn get_name(&mut self, ssb_id: &str) -> Result<Option<String>, GolgiError> {
self.get_latest_about_message(ssb_id, "name").await
}
/// Get lateset about description from given user
///
/// # Arguments
///
/// * `ssb_id` - A reference to a string slice which represents the ssb user
/// to lookup the about description for.
pub async fn get_description(&mut self, ssb_id: &str) -> Result<Option<String>, GolgiError> {
self.get_latest_about_message(ssb_id, "description").await
}
@ -283,64 +303,18 @@ impl Sbot {
&mut self,
id: String,
) -> Result<impl Stream<Item = Result<SsbMessageValue, GolgiError>>, GolgiError> {
let mut sbot_connection = self.get_sbot_connection(None, None).await.unwrap();
let mut sbot_connection = self.get_sbot_connection().await.unwrap();
let args = CreateHistoryStreamIn::new(id);
let req_id = sbot_connection.client.create_history_stream_req_send(&args).await?;
let history_stream = Sbot::get_source_stream(sbot_connection.rpc_reader, req_id, utils::ssb_message_res_parse).await;
let req_id = sbot_connection
.client
.create_history_stream_req_send(&args)
.await?;
let history_stream = get_source_stream(
sbot_connection.rpc_reader,
req_id,
utils::ssb_message_res_parse,
)
.await;
Ok(history_stream)
}
/// Takes in an rpc request number, and a handling function (parsing results of type T),
/// and produces an async_std::stream::Stream
/// of results of type T where the handling functions is called
/// on all rpc_reader responses which match the request number
///
/// # Arguments
///
/// * `req_no` - A `RequestNo` of the response to listen for
/// * `f` - A function which takes in an array of u8 and returns a Result<T, GolgiError>.
/// This is a function which parses the response from the RpcReader. T is a generic type,
/// so this parse function can return multiple possible types (String, json, custom struct etc.)
pub async fn get_source_stream<'a, F, T>(mut rpc_reader: RpcReader<TcpStream>, req_no: RequestNo, f: F) -> impl Stream<Item = Result<T, GolgiError>>
where
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug + serde::Deserialize<'a>,
{
// we use the async_stream::stream macro to allow for creating a stream which calls async functions
// see https://users.rust-lang.org/t/how-to-create-async-std-stream-which-calls-async-function-in-poll-next/69760
let source_stream = stream! {
loop {
// get the next message from the rpc_reader
let (id, msg) = rpc_reader.recv().await?;
let x : i32 = id.clone();
// check if the next message from rpc_reader matches the req_no we are looking for
// if it matches, then this rpc response is for the given request
// and if it doesn't match, then we ignore it
if x == req_no {
match msg {
RecvMsg::RpcResponse(_type, body) => {
// parse an item of type T from the message body using the provided
// function for parsing
let item = f(&body)?;
// return Ok(item) as the next value in the stream
yield Ok(item)
}
RecvMsg::ErrorResponse(message) => {
// if an error is received
// return an Err(err) as the next value in the stream
yield Err(GolgiError::Sbot(message.to_string()));
}
// if we find a CancelStreamResponse
// this is the end of the stream
RecvMsg::CancelStreamRespose() => break,
// if we find an unknown response, we just continue the loop
_ => {}
}
}
}
};
// finally return the stream object
source_stream
}
}

View File

@ -1,5 +1,8 @@
//! Utility methods for `golgi`.
use async_std::io::Read;
use async_std::net::TcpStream;
use async_std::stream::Stream;
use async_stream::stream;
use std::fmt::Debug;
use kuska_ssb::rpc::{RecvMsg, RequestNo, RpcReader};
@ -64,10 +67,10 @@ pub async fn get_async<'a, R, T, F>(
req_no: RequestNo,
f: F,
) -> Result<T, GolgiError>
where
R: Read + Unpin,
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug,
where
R: Read + Unpin,
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug,
{
loop {
let (id, msg) = rpc_reader.recv().await?;
@ -108,10 +111,10 @@ pub async fn get_source_until_eof<'a, R, T, F>(
req_no: RequestNo,
f: F,
) -> Result<Vec<T>, GolgiError>
where
R: Read + Unpin,
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug,
where
R: Read + Unpin,
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug,
{
let mut messages: Vec<T> = Vec::new();
loop {
@ -140,8 +143,6 @@ pub async fn get_source_until_eof<'a, R, T, F>(
Ok(messages)
}
/// Takes in an rpc request number, and a handling function,
/// and calls the handling function on all responses which match the request number,
/// and prints out the result of the handling function.
@ -160,10 +161,10 @@ pub async fn print_source_until_eof<'a, R, T, F>(
req_no: RequestNo,
f: F,
) -> Result<(), GolgiError>
where
R: Read + Unpin,
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug + serde::Deserialize<'a>,
where
R: Read + Unpin,
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug + serde::Deserialize<'a>,
{
loop {
let (id, msg) = rpc_reader.recv().await?;
@ -182,4 +183,61 @@ pub async fn print_source_until_eof<'a, R, T, F>(
}
}
Ok(())
}
}
/// Takes in an rpc request number, and a handling function (parsing results of type T),
/// and produces an async_std::stream::Stream
/// of results of type T where the handling function is called
/// on all rpc_reader responses which match the request number.
///
/// # Arguments
///
/// * `req_no` - A `RequestNo` of the response to listen for
/// * `f` - A function which takes in an array of u8 and returns a Result<T, GolgiError>.
/// This is a function which parses the response from the RpcReader. T is a generic type,
/// so this parse function can return multiple possible types (String, json, custom struct etc.)
pub async fn get_source_stream<'a, F, T>(
mut rpc_reader: RpcReader<TcpStream>,
req_no: RequestNo,
f: F,
) -> impl Stream<Item = Result<T, GolgiError>>
where
F: Fn(&[u8]) -> Result<T, GolgiError>,
T: Debug + serde::Deserialize<'a>,
{
// we use the async_stream::stream macro to allow for creating a stream which calls async functions
// see https://users.rust-lang.org/t/how-to-create-async-std-stream-which-calls-async-function-in-poll-next/69760
let source_stream = stream! {
loop {
// get the next message from the rpc_reader
let (id, msg) = rpc_reader.recv().await?;
let x : i32 = id.clone();
// check if the next message from rpc_reader matches the req_no we are looking for
// if it matches, then this rpc response is for the given request
// and if it doesn't match, then we ignore it
if x == req_no {
match msg {
RecvMsg::RpcResponse(_type, body) => {
// parse an item of type T from the message body using the provided
// function for parsing
let item = f(&body)?;
// return Ok(item) as the next value in the stream
yield Ok(item)
}
RecvMsg::ErrorResponse(message) => {
// if an error is received
// return an Err(err) as the next value in the stream
yield Err(GolgiError::Sbot(message.to_string()));
}
// if we find a CancelStreamResponse
// this is the end of the stream
RecvMsg::CancelStreamRespose() => break,
// if we find an unknown response, we just continue the loop
_ => {}
}
}
}
};
// finally return the stream object
source_stream
}