From d7c0e2488afc23f771750c18626ab85a8e505b3b Mon Sep 17 00:00:00 2001 From: Wiktor Kwapisiewicz Date: Thu, 22 Feb 2024 09:17:14 +0100 Subject: [PATCH 1/3] Make AgentError implement std::error::Error Signed-off-by: Wiktor Kwapisiewicz --- src/error.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/error.rs b/src/error.rs index 4d0c313..5e2293a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,3 +19,15 @@ impl From for AgentError { AgentError::IO(e) } } + +impl std::fmt::Display for AgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AgentError::User => write!(f, "Agent: User error"), + AgentError::Proto(proto) => write!(f, "Agent: Protocol error: {}", proto), + AgentError::IO(error) => write!(f, "Agent: I/O error: {}", error), + } + } +} + +impl std::error::Error for AgentError {} From 9b755d81e446f90cd994f60c7dc786d2c32b3337 Mon Sep 17 00:00:00 2001 From: Wiktor Kwapisiewicz Date: Thu, 22 Feb 2024 09:25:17 +0100 Subject: [PATCH 2/3] Use AgentError instead of associated Error type Signed-off-by: Wiktor Kwapisiewicz --- README.md | 5 ++--- examples/key_storage.rs | 5 ++--- src/agent.rs | 9 +++------ 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 7559f6a..78f0dcd 100644 --- a/README.md +++ b/README.md @@ -15,15 +15,14 @@ use async_trait::async_trait; use tokio::net::UnixListener; use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{Message, SignRequest}; struct MyAgent; #[async_trait] impl Agent for MyAgent { - type Error = (); - - async fn handle(&self, message: Message) -> Result { + async fn handle(&self, message: Message) -> Result { match message { Message::SignRequest(request) => { // get the signature by signing `request.data` diff --git a/examples/key_storage.rs b/examples/key_storage.rs index 17b0962..4f9b974 100644 --- a/examples/key_storage.rs +++ b/examples/key_storage.rs @@ -3,6 +3,7 @@ use log::info; use tokio::net::UnixListener; use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{self, Message, SignRequest}; use ssh_agent_lib::proto::private_key::{PrivateKey, RsaPrivateKey}; use ssh_agent_lib::proto::public_key::PublicKey; @@ -147,9 +148,7 @@ impl KeyStorage { #[async_trait] impl Agent for KeyStorage { - type Error = (); - - async fn handle(&self, message: Message) -> Result { + async fn handle(&self, message: Message) -> Result { self.handle_message(message).or_else(|error| { println!("Error handling message - {:?}", error); Ok(Message::Failure) diff --git a/src/agent.rs b/src/agent.rs index 0244866..51e5567 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -7,7 +7,6 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use tokio_util::codec::{Decoder, Encoder, Framed}; -use std::error::Error; use std::fmt; use std::io; use std::marker::Unpin; @@ -111,11 +110,9 @@ impl ListeningSocket for TcpListener { #[async_trait] pub trait Agent: 'static + Sync + Send + Sized { - type Error: fmt::Debug + Send + Sync; + async fn handle(&self, message: Message) -> Result; - async fn handle(&self, message: Message) -> Result; - - async fn listen(self, socket: S) -> Result<(), Box> + async fn listen(self, socket: S) -> Result<(), AgentError> where S: ListeningSocket + fmt::Debug + Send, { @@ -136,7 +133,7 @@ pub trait Agent: 'static + Sync + Send + Sized { } Err(e) => { error!("Failed to accept socket; error = {:?}", e); - return Err(Box::new(e)); + return Err(AgentError::IO(e)); } } } From f4e068b516a5e5c7f8476fc8afe597525e10ade6 Mon Sep 17 00:00:00 2001 From: Wiktor Kwapisiewicz Date: Thu, 22 Feb 2024 10:37:00 +0100 Subject: [PATCH 3/3] Split Agent and Session for per-session use-cases Signed-off-by: Wiktor Kwapisiewicz --- README.md | 7 ++-- examples/key_storage.rs | 12 ++++-- src/agent.rs | 88 ++++++++++++++++++++--------------------- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 78f0dcd..10d1453 100644 --- a/README.md +++ b/README.md @@ -14,15 +14,16 @@ processes requests. use async_trait::async_trait; use tokio::net::UnixListener; -use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::agent::{Session, Agent}; use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{Message, SignRequest}; +#[derive(Default)] struct MyAgent; #[async_trait] -impl Agent for MyAgent { - async fn handle(&self, message: Message) -> Result { +impl Session for MyAgent { + async fn handle(&mut self, message: Message) -> Result { match message { Message::SignRequest(request) => { // get the signature by signing `request.data` diff --git a/examples/key_storage.rs b/examples/key_storage.rs index 4f9b974..88f455f 100644 --- a/examples/key_storage.rs +++ b/examples/key_storage.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use log::info; use tokio::net::UnixListener; -use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::agent::{Agent, Session}; use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{self, Message, SignRequest}; use ssh_agent_lib::proto::private_key::{PrivateKey, RsaPrivateKey}; @@ -147,8 +147,8 @@ impl KeyStorage { } #[async_trait] -impl Agent for KeyStorage { - async fn handle(&self, message: Message) -> Result { +impl Session for KeyStorage { + async fn handle(&mut self, message: Message) -> Result { self.handle_message(message).or_else(|error| { println!("Error handling message - {:?}", error); Ok(Message::Failure) @@ -156,6 +156,12 @@ impl Agent for KeyStorage { } } +impl Agent for KeyStorage { + fn new_session(&mut self) -> impl Session { + KeyStorage::new() + } +} + fn rsa_openssl_from_ssh(ssh_rsa: &RsaPrivateKey) -> Result, Box> { let n = BigNum::from_slice(&ssh_rsa.n)?; let e = BigNum::from_slice(&ssh_rsa.e)?; diff --git a/src/agent.rs b/src/agent.rs index 51e5567..7cd2a33 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -11,13 +11,13 @@ use std::fmt; use std::io; use std::marker::Unpin; use std::mem::size_of; -use std::sync::Arc; use super::error::AgentError; use super::proto::message::Message; use super::proto::{from_bytes, to_bytes}; -struct MessageCodec; +#[derive(Debug)] +pub struct MessageCodec; impl Decoder for MessageCodec { type Item = Message; @@ -52,39 +52,6 @@ impl Encoder for MessageCodec { } } -struct Session { - agent: Arc, - adapter: Framed, -} - -impl Session -where - A: Agent, - S: AsyncRead + AsyncWrite + Unpin, -{ - fn new(agent: Arc, socket: S) -> Self { - let adapter = Framed::new(socket, MessageCodec); - Self { agent, adapter } - } - - async fn handle_socket(&mut self) -> Result<(), AgentError> { - loop { - if let Some(incoming_message) = self.adapter.try_next().await? { - let response = self.agent.handle(incoming_message).await.map_err(|e| { - error!("Error handling message; error = {:?}", e); - AgentError::User - })?; - - self.adapter.send(response).await?; - } else { - // Reached EOF of the stream (client disconnected), - // we can close the socket and exit the handler. - return Ok(()); - } - } - } -} - #[async_trait] pub trait ListeningSocket { type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static; @@ -109,24 +76,48 @@ impl ListeningSocket for TcpListener { } #[async_trait] -pub trait Agent: 'static + Sync + Send + Sized { - async fn handle(&self, message: Message) -> Result; +pub trait Session: 'static + Sync + Send + Sized { + async fn handle(&mut self, message: Message) -> Result; - async fn listen(self, socket: S) -> Result<(), AgentError> + async fn handle_socket( + &mut self, + mut adapter: Framed, + ) -> Result<(), AgentError> where S: ListeningSocket + fmt::Debug + Send, { - info!("Listening; socket = {:?}", socket); - let arc_self = Arc::new(self); + loop { + if let Some(incoming_message) = adapter.try_next().await? { + let response = self.handle(incoming_message).await.map_err(|e| { + error!("Error handling message; error = {:?}", e); + AgentError::User + })?; + + adapter.send(response).await?; + } else { + // Reached EOF of the stream (client disconnected), + // we can close the socket and exit the handler. + return Ok(()); + } + } + } +} +#[async_trait] +pub trait Agent: 'static + Sync + Send + Sized { + fn new_session(&mut self) -> impl Session; + async fn listen(mut self, socket: S) -> Result<(), AgentError> + where + S: ListeningSocket + fmt::Debug + Send, + { + info!("Listening; socket = {:?}", socket); loop { match socket.accept().await { Ok(socket) => { - let agent = arc_self.clone(); - let mut session = Session::new(agent, socket); - + let mut session = self.new_session(); tokio::spawn(async move { - if let Err(e) = session.handle_socket().await { + let adapter = Framed::new(socket, MessageCodec); + if let Err(e) = session.handle_socket::(adapter).await { error!("Agent protocol error; error = {:?}", e); } }); @@ -139,3 +130,12 @@ pub trait Agent: 'static + Sync + Send + Sized { } } } + +impl Agent for T +where + T: Default + Session, +{ + fn new_session(&mut self) -> impl Session { + Self::default() + } +}