diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9a0bf13..01e0a83 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,11 +6,16 @@ jobs: build: strategy: matrix: + platform: [ ubuntu-latest ] toolchain: [ stable, beta ] include: - toolchain: stable check-fmt: true - runs-on: ubuntu-latest + - toolchain: 1.48.0 + platform: ubuntu-latest + msrv: true + + runs-on: ${{ matrix.platform }} steps: - name: Checkout source code uses: actions/checkout@v2 @@ -20,6 +25,19 @@ jobs: toolchain: ${{ matrix.toolchain }} override: true profile: minimal + - name: Pin tokio for MSRV + if: matrix.msrv + run: cargo update -p tokio --precise "1.14.1" --verbose + - name: Pin serde for MSRV + if: matrix.msrv + run: cargo update -p serde --precise "1.0.156" --verbose + - name: Pin log for MSRV + if: matrix.msrv + run: cargo update -p log --precise "0.4.18" --verbose + - name: Cargo check + run: cargo check --release + - name: Check documentation + run: cargo doc --release - name: Build on Rust ${{ matrix.toolchain }} run: cargo build --verbose --color always - name: Check formatting diff --git a/.gitignore b/.gitignore index 4fffb2f..6d6c90f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target /Cargo.lock +.vscode \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 238b4a1..5cacf34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,11 @@ description = "Types and primitives to integrate a spec-compliant LSP with an LD # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -lightning = { version = "0.0.114", features = ["max_level_trace", "std"] } -lightning-invoice = { version = "0.22" } -lightning-net-tokio = { version = "0.0.114" } +lightning = { git = "https://github.com/lightningdevkit/rust-lightning.git", rev = "498f2331459d8031031ef151a44c90d700aa8c7e", features = ["max_level_trace", "std"] } +lightning-invoice = { git = "https://github.com/lightningdevkit/rust-lightning.git", rev = "498f2331459d8031031ef151a44c90d700aa8c7e" } +lightning-net-tokio = { git = "https://github.com/lightningdevkit/rust-lightning.git", rev = "498f2331459d8031031ef151a44c90d700aa8c7e" } -bitcoin = "0.29.2" +bitcoin = "0.29.0" + +serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] } +serde_json = "1.0" diff --git a/src/lib.rs b/src/lib.rs index d4fd329..028f820 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,8 +15,12 @@ #![deny(private_intra_doc_links)] #![allow(bare_trait_objects)] #![allow(ellipsis_inclusive_range_patterns)] +#![allow(clippy::drop_non_drop)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] mod channel_request; mod jit_channel; mod transport; +mod utils; + +pub use transport::message_handler::{LiquidityManager, LiquidityProviderConfig}; diff --git a/src/transport/message_handler.rs b/src/transport/message_handler.rs new file mode 100644 index 0000000..6f58583 --- /dev/null +++ b/src/transport/message_handler.rs @@ -0,0 +1,165 @@ +use crate::transport::msgs::{LSPSMessage, RawLSPSMessage, LSPS_MESSAGE_TYPE}; +use crate::transport::protocol::LSPS0MessageHandler; + +use bitcoin::secp256k1::PublicKey; +use lightning::ln::features::{InitFeatures, NodeFeatures}; +use lightning::ln::msgs::{ErrorAction, LightningError}; +use lightning::ln::peer_handler::CustomMessageHandler; +use lightning::ln::wire::CustomMessageReader; +use lightning::sign::EntropySource; +use lightning::util::logger::Level; +use lightning::util::ser::Readable; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::io; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; + +const LSPS_FEATURE_BIT: usize = 729; + +/// A trait used to implement a specific LSPS protocol. +/// +/// The messages the protocol uses need to be able to be mapped +/// from and into [`LSPSMessage`]. +pub(crate) trait ProtocolMessageHandler { + type ProtocolMessage: TryFrom + Into; + const PROTOCOL_NUMBER: Option; + + fn handle_message( + &self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey, + ) -> Result<(), LightningError>; +} + +/// A configuration for [`LiquidityManager`]. +/// +/// Allows end-user to configure options when using the [`LiquidityManager`] +/// to provide liquidity services to clients. +pub struct LiquidityProviderConfig; + +/// The main interface into LSP functionality. +/// +/// Should be used as a [`CustomMessageHandler`] for your +/// [`lightning::ln::peer_handler::PeerManager`]'s [`lightning::ln::peer_handler::MessageHandler`]. +pub struct LiquidityManager +where + ES::Target: EntropySource, +{ + pending_messages: Arc>>, + request_id_to_method_map: Mutex>, + lsps0_message_handler: LSPS0MessageHandler, + provider_config: Option, +} + +impl LiquidityManager +where + ES::Target: EntropySource, +{ + /// Constructor for the LiquidityManager + /// + /// Sets up the required protocol message handlers based on the given [`LiquidityProviderConfig`]. + pub fn new(entropy_source: ES, provider_config: Option) -> Self { + let pending_messages = Arc::new(Mutex::new(vec![])); + + let lsps0_message_handler = + LSPS0MessageHandler::new(entropy_source, vec![], Arc::clone(&pending_messages)); + + Self { + pending_messages, + request_id_to_method_map: Mutex::new(HashMap::new()), + lsps0_message_handler, + provider_config, + } + } + + fn handle_lsps_message( + &self, msg: LSPSMessage, sender_node_id: &PublicKey, + ) -> Result<(), lightning::ln::msgs::LightningError> { + match msg { + LSPSMessage::Invalid => { + return Err(LightningError { err: format!("{} did not understand a message we previously sent, maybe they don't support a protocol we are trying to use?", sender_node_id), action: ErrorAction::IgnoreAndLog(Level::Error)}); + } + LSPSMessage::LSPS0(msg) => { + self.lsps0_message_handler.handle_message(msg, sender_node_id)?; + } + } + Ok(()) + } + + fn enqueue_message(&self, node_id: PublicKey, msg: LSPSMessage) { + let mut pending_msgs = self.pending_messages.lock().unwrap(); + pending_msgs.push((node_id, msg)); + } +} + +impl CustomMessageReader for LiquidityManager +where + ES::Target: EntropySource, +{ + type CustomMessage = RawLSPSMessage; + + fn read( + &self, message_type: u16, buffer: &mut R, + ) -> Result, lightning::ln::msgs::DecodeError> { + match message_type { + LSPS_MESSAGE_TYPE => Ok(Some(RawLSPSMessage::read(buffer)?)), + _ => Ok(None), + } + } +} + +impl CustomMessageHandler for LiquidityManager +where + ES::Target: EntropySource, +{ + fn handle_custom_message( + &self, msg: Self::CustomMessage, sender_node_id: &PublicKey, + ) -> Result<(), lightning::ln::msgs::LightningError> { + let mut request_id_to_method_map = self.request_id_to_method_map.lock().unwrap(); + + match LSPSMessage::from_str_with_id_map(&msg.payload, &mut request_id_to_method_map) { + Ok(msg) => self.handle_lsps_message(msg, sender_node_id), + Err(_) => { + self.enqueue_message(*sender_node_id, LSPSMessage::Invalid); + Ok(()) + } + } + } + + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { + let mut request_id_to_method_map = self.request_id_to_method_map.lock().unwrap(); + self.pending_messages + .lock() + .unwrap() + .drain(..) + .map(|(public_key, lsps_message)| { + if let Some((request_id, method_name)) = lsps_message.get_request_id_and_method() { + request_id_to_method_map.insert(request_id, method_name); + } + ( + public_key, + RawLSPSMessage { payload: serde_json::to_string(&lsps_message).unwrap() }, + ) + }) + .collect() + } + + fn provided_node_features(&self) -> NodeFeatures { + let mut features = NodeFeatures::empty(); + + if self.provider_config.is_some() { + features.set_optional_custom_bit(LSPS_FEATURE_BIT).unwrap(); + } + + features + } + + fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { + let mut features = InitFeatures::empty(); + + if self.provider_config.is_some() { + features.set_optional_custom_bit(LSPS_FEATURE_BIT).unwrap(); + } + + features + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 1e4d1d7..fef2f5f 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -8,3 +8,7 @@ // licenses. //! Types and primitives that implement the LSPS0: Transport Layer specification. + +pub mod message_handler; +pub mod msgs; +pub mod protocol; diff --git a/src/transport/msgs.rs b/src/transport/msgs.rs new file mode 100644 index 0000000..fcd166c --- /dev/null +++ b/src/transport/msgs.rs @@ -0,0 +1,394 @@ +use lightning::impl_writeable_msg; +use lightning::ln::wire; +use serde::de; +use serde::de::{MapAccess, Visitor}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Deserializer, Serialize}; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::fmt; + +const LSPS_MESSAGE_SERIALIZED_STRUCT_NAME: &str = "LSPSMessage"; +const JSONRPC_FIELD_KEY: &str = "jsonrpc"; +const JSONRPC_FIELD_VALUE: &str = "2.0"; +const JSONRPC_METHOD_FIELD_KEY: &str = "method"; +const JSONRPC_ID_FIELD_KEY: &str = "id"; +const JSONRPC_PARAMS_FIELD_KEY: &str = "params"; +const JSONRPC_RESULT_FIELD_KEY: &str = "result"; +const JSONRPC_ERROR_FIELD_KEY: &str = "error"; +const JSONRPC_INVALID_MESSAGE_ERROR_CODE: i32 = -32700; +const JSONRPC_INVALID_MESSAGE_ERROR_MESSAGE: &str = "parse error"; +const LSPS0_LISTPROTOCOLS_METHOD_NAME: &str = "lsps0.listprotocols"; + +pub const LSPS_MESSAGE_TYPE: u16 = 37913; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct RawLSPSMessage { + pub payload: String, +} + +impl_writeable_msg!(RawLSPSMessage, { payload }, {}); + +impl wire::Type for RawLSPSMessage { + fn type_id(&self) -> u16 { + LSPS_MESSAGE_TYPE + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RequestId(pub String); + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +pub struct ResponseError { + pub code: i32, + pub message: String, + pub data: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Default)] +#[serde(default)] +pub struct ListProtocolsRequest {} + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +pub struct ListProtocolsResponse { + pub protocols: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum LSPS0Request { + ListProtocols(ListProtocolsRequest), +} + +impl LSPS0Request { + pub fn method(&self) -> &str { + match self { + LSPS0Request::ListProtocols(_) => "lsps0.listprotocols", + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum LSPS0Response { + ListProtocols(ListProtocolsResponse), + ListProtocolsError(ResponseError), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum LSPS0Message { + Request(RequestId, LSPS0Request), + Response(RequestId, LSPS0Response), +} + +impl TryFrom for LSPS0Message { + type Error = (); + + fn try_from(message: LSPSMessage) -> Result { + match message { + LSPSMessage::Invalid => Err(()), + LSPSMessage::LSPS0(message) => Ok(message), + } + } +} + +impl From for LSPSMessage { + fn from(message: LSPS0Message) -> Self { + LSPSMessage::LSPS0(message) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum LSPSMessage { + Invalid, + LSPS0(LSPS0Message), +} + +impl LSPSMessage { + pub fn from_str_with_id_map( + json_str: &str, request_id_to_method: &mut HashMap, + ) -> Result { + let deserializer = &mut serde_json::Deserializer::from_str(json_str); + let visitor = LSPSMessageVisitor { request_id_to_method }; + deserializer.deserialize_any(visitor) + } + + pub fn get_request_id_and_method(&self) -> Option<(String, String)> { + match self { + LSPSMessage::LSPS0(LSPS0Message::Request(request_id, request)) => { + Some((request_id.0.clone(), request.method().to_string())) + } + _ => None, + } + } +} + +impl Serialize for LSPSMessage { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut jsonrpc_object = + serializer.serialize_struct(LSPS_MESSAGE_SERIALIZED_STRUCT_NAME, 3)?; + + jsonrpc_object.serialize_field(JSONRPC_FIELD_KEY, JSONRPC_FIELD_VALUE)?; + + match self { + LSPSMessage::LSPS0(LSPS0Message::Request(request_id, request)) => { + jsonrpc_object.serialize_field(JSONRPC_METHOD_FIELD_KEY, request.method())?; + jsonrpc_object.serialize_field(JSONRPC_ID_FIELD_KEY, &request_id.0)?; + + match request { + LSPS0Request::ListProtocols(params) => { + jsonrpc_object.serialize_field(JSONRPC_PARAMS_FIELD_KEY, params)? + } + }; + } + LSPSMessage::LSPS0(LSPS0Message::Response(request_id, response)) => { + jsonrpc_object.serialize_field(JSONRPC_ID_FIELD_KEY, &request_id.0)?; + + match response { + LSPS0Response::ListProtocols(result) => { + jsonrpc_object.serialize_field(JSONRPC_RESULT_FIELD_KEY, result)?; + } + LSPS0Response::ListProtocolsError(error) => { + jsonrpc_object.serialize_field(JSONRPC_ERROR_FIELD_KEY, error)?; + } + } + } + LSPSMessage::Invalid => { + let error = ResponseError { + code: JSONRPC_INVALID_MESSAGE_ERROR_CODE, + message: JSONRPC_INVALID_MESSAGE_ERROR_MESSAGE.to_string(), + data: None, + }; + + jsonrpc_object.serialize_field(JSONRPC_ID_FIELD_KEY, &serde_json::Value::Null)?; + jsonrpc_object.serialize_field(JSONRPC_ERROR_FIELD_KEY, &error)?; + } + } + + jsonrpc_object.end() + } +} + +struct LSPSMessageVisitor<'a> { + request_id_to_method: &'a mut HashMap, +} + +impl<'de, 'a> Visitor<'de> for LSPSMessageVisitor<'a> { + type Value = LSPSMessage; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("JSON-RPC object") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut id: Option = None; + let mut method: Option<&str> = None; + let mut params = None; + let mut result = None; + let mut error: Option = None; + + while let Some(key) = map.next_key()? { + match key { + "id" => { + id = Some(map.next_value()?); + } + "method" => { + method = Some(map.next_value()?); + } + "params" => { + params = Some(map.next_value()?); + } + "result" => { + result = Some(map.next_value()?); + } + "error" => { + error = Some(map.next_value()?); + } + _ => { + let _: serde_json::Value = map.next_value()?; + } + } + } + + match (id, method) { + (Some(id), Some(method)) => match method { + LSPS0_LISTPROTOCOLS_METHOD_NAME => { + self.request_id_to_method.insert(id.clone(), method.to_string()); + + Ok(LSPSMessage::LSPS0(LSPS0Message::Request( + RequestId(id), + LSPS0Request::ListProtocols(ListProtocolsRequest {}), + ))) + } + _ => Err(de::Error::custom(format!( + "Received request with unknown method: {}", + method + ))), + }, + (Some(id), None) => match self.request_id_to_method.get(&id) { + Some(method) => match method.as_str() { + LSPS0_LISTPROTOCOLS_METHOD_NAME => { + if let Some(error) = error { + Ok(LSPSMessage::LSPS0(LSPS0Message::Response( + RequestId(id), + LSPS0Response::ListProtocolsError(error), + ))) + } else if let Some(result) = result { + let list_protocols_response = + serde_json::from_value(result).map_err(de::Error::custom)?; + Ok(LSPSMessage::LSPS0(LSPS0Message::Response( + RequestId(id), + LSPS0Response::ListProtocols(list_protocols_response), + ))) + } else { + Err(de::Error::custom("Received invalid JSON-RPC object: one of method, result, or error required")) + } + } + _ => Err(de::Error::custom(format!( + "Received response for an unknown request method: {}", + method + ))), + }, + None => Err(de::Error::custom(format!( + "Received response for unknown request id: {}", + id + ))), + }, + (None, Some(method)) => { + Err(de::Error::custom(format!("Received unknown notification: {}", method))) + } + (None, None) => Err(de::Error::custom( + "Received invalid JSON-RPC object: one of method or id required", + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserializes_request() { + let json = r#"{ + "jsonrpc": "2.0", + "id": "request:id:xyz123", + "method": "lsps0.listprotocols" + }"#; + + let mut request_id_method_map = HashMap::new(); + + let msg = LSPSMessage::from_str_with_id_map(json, &mut request_id_method_map); + assert!(msg.is_ok()); + let msg = msg.unwrap(); + assert_eq!( + msg, + LSPSMessage::LSPS0(LSPS0Message::Request( + RequestId("request:id:xyz123".to_string()), + LSPS0Request::ListProtocols(ListProtocolsRequest {}) + )) + ); + } + + #[test] + fn serializes_request() { + let request = LSPSMessage::LSPS0(LSPS0Message::Request( + RequestId("request:id:xyz123".to_string()), + LSPS0Request::ListProtocols(ListProtocolsRequest {}), + )); + let json = serde_json::to_string(&request).unwrap(); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","method":"lsps0.listprotocols","id":"request:id:xyz123","params":{}}"# + ); + } + + #[test] + fn deserializes_success_response() { + let json = r#"{ + "jsonrpc": "2.0", + "id": "request:id:xyz123", + "result": { + "protocols": [1,2,3] + } + }"#; + let mut request_id_to_method_map = HashMap::new(); + request_id_to_method_map + .insert("request:id:xyz123".to_string(), "lsps0.listprotocols".to_string()); + + let response = + LSPSMessage::from_str_with_id_map(json, &mut request_id_to_method_map).unwrap(); + + assert_eq!( + response, + LSPSMessage::LSPS0(LSPS0Message::Response( + RequestId("request:id:xyz123".to_string()), + LSPS0Response::ListProtocols(ListProtocolsResponse { protocols: vec![1, 2, 3] }) + )) + ); + } + + #[test] + fn deserializes_error_response() { + let json = r#"{ + "jsonrpc": "2.0", + "id": "request:id:xyz123", + "error": { + "code": -32617, + "message": "Unknown Error" + } + }"#; + let mut request_id_to_method_map = HashMap::new(); + request_id_to_method_map + .insert("request:id:xyz123".to_string(), "lsps0.listprotocols".to_string()); + + let response = + LSPSMessage::from_str_with_id_map(json, &mut request_id_to_method_map).unwrap(); + + assert_eq!( + response, + LSPSMessage::LSPS0(LSPS0Message::Response( + RequestId("request:id:xyz123".to_string()), + LSPS0Response::ListProtocolsError(ResponseError { + code: -32617, + message: "Unknown Error".to_string(), + data: None + }) + )) + ); + } + + #[test] + fn deserialize_fails_with_unknown_request_id() { + let json = r#"{ + "jsonrpc": "2.0", + "id": "request:id:xyz124", + "result": { + "protocols": [1,2,3] + } + }"#; + let mut request_id_to_method_map = HashMap::new(); + request_id_to_method_map + .insert("request:id:xyz123".to_string(), "lsps0.listprotocols".to_string()); + + let response = LSPSMessage::from_str_with_id_map(json, &mut request_id_to_method_map); + assert!(response.is_err()); + } + + #[test] + fn serializes_response() { + let response = LSPSMessage::LSPS0(LSPS0Message::Response( + RequestId("request:id:xyz123".to_string()), + LSPS0Response::ListProtocols(ListProtocolsResponse { protocols: vec![1, 2, 3] }), + )); + let json = serde_json::to_string(&response).unwrap(); + assert_eq!( + json, + r#"{"jsonrpc":"2.0","id":"request:id:xyz123","result":{"protocols":[1,2,3]}}"# + ); + } +} diff --git a/src/transport/protocol.rs b/src/transport/protocol.rs new file mode 100644 index 0000000..bc9c118 --- /dev/null +++ b/src/transport/protocol.rs @@ -0,0 +1,184 @@ +use bitcoin::secp256k1::PublicKey; +use lightning::ln::msgs::{ErrorAction, LightningError}; +use lightning::sign::EntropySource; +use lightning::util::logger::Level; +use std::ops::Deref; +use std::sync::{Arc, Mutex}; + +use crate::transport::message_handler::ProtocolMessageHandler; +use crate::transport::msgs::{ + LSPS0Message, LSPS0Request, LSPS0Response, LSPSMessage, ListProtocolsRequest, + ListProtocolsResponse, RequestId, ResponseError, +}; +use crate::utils; + +pub struct LSPS0MessageHandler +where + ES::Target: EntropySource, +{ + entropy_source: ES, + pending_messages: Arc>>, + protocols: Vec, +} + +impl LSPS0MessageHandler +where + ES::Target: EntropySource, +{ + pub fn new( + entropy_source: ES, protocols: Vec, + pending_messages: Arc>>, + ) -> Self { + Self { entropy_source, protocols, pending_messages } + } + + pub fn list_protocols(&self, counterparty_node_id: PublicKey) { + let msg = LSPS0Message::Request( + utils::generate_request_id(&self.entropy_source), + LSPS0Request::ListProtocols(ListProtocolsRequest {}), + ); + + self.enqueue_message(counterparty_node_id, msg); + } + + fn enqueue_message(&self, counterparty_node_id: PublicKey, message: LSPS0Message) { + self.pending_messages.lock().unwrap().push((counterparty_node_id, message.into())); + } + + fn handle_request( + &self, request_id: RequestId, request: LSPS0Request, counterparty_node_id: &PublicKey, + ) -> Result<(), lightning::ln::msgs::LightningError> { + match request { + LSPS0Request::ListProtocols(_) => { + let msg = LSPS0Message::Response( + request_id, + LSPS0Response::ListProtocols(ListProtocolsResponse { + protocols: self.protocols.clone(), + }), + ); + self.enqueue_message(*counterparty_node_id, msg); + Ok(()) + } + } + } + + fn handle_response( + &self, response: LSPS0Response, counterparty_node_id: &PublicKey, + ) -> Result<(), LightningError> { + match response { + LSPS0Response::ListProtocols(ListProtocolsResponse { protocols }) => Ok(()), + LSPS0Response::ListProtocolsError(ResponseError { code, message, data, .. }) => { + Err(LightningError { + err: format!( + "ListProtocols error received. code = {}, message = {}, data = {:?}", + code, message, data + ), + action: ErrorAction::IgnoreAndLog(Level::Info), + }) + } + } + } +} + +impl ProtocolMessageHandler for LSPS0MessageHandler +where + ES::Target: EntropySource, +{ + type ProtocolMessage = LSPS0Message; + const PROTOCOL_NUMBER: Option = None; + + fn handle_message( + &self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey, + ) -> Result<(), LightningError> { + match message { + LSPS0Message::Request(request_id, request) => { + self.handle_request(request_id, request, counterparty_node_id) + } + LSPS0Message::Response(_, response) => { + self.handle_response(response, counterparty_node_id) + } + } + } +} + +#[cfg(test)] +mod tests { + + use std::sync::Arc; + + use super::*; + + struct TestEntropy {} + impl EntropySource for TestEntropy { + fn get_secure_random_bytes(&self) -> [u8; 32] { + [0; 32] + } + } + + #[test] + fn test_handle_list_protocols_request() { + let entropy = Arc::new(TestEntropy {}); + let protocols: Vec = vec![]; + let pending_messages = Arc::new(Mutex::new(vec![])); + + let lsps0_handler = + Arc::new(LSPS0MessageHandler::new(entropy, protocols, pending_messages.clone())); + + let list_protocols_request = LSPS0Message::Request( + RequestId("xyz123".to_string()), + LSPS0Request::ListProtocols(ListProtocolsRequest {}), + ); + let counterparty_node_id = utils::parse_pubkey( + "027100442c3b79f606f80f322d98d499eefcb060599efc5d4ecb00209c2cb54190", + ) + .unwrap(); + + lsps0_handler.handle_message(list_protocols_request, &counterparty_node_id).unwrap(); + let pending_messages = pending_messages.lock().unwrap(); + + assert_eq!(pending_messages.len(), 1); + + let (pubkey, message) = &pending_messages[0]; + + assert_eq!(*pubkey, counterparty_node_id); + assert_eq!( + *message, + LSPSMessage::LSPS0(LSPS0Message::Response( + RequestId("xyz123".to_string()), + LSPS0Response::ListProtocols(ListProtocolsResponse { protocols: vec![] }) + )) + ); + } + + #[test] + fn test_list_protocols() { + let pending_messages = Arc::new(Mutex::new(vec![])); + + let lsps0_handler = Arc::new(LSPS0MessageHandler::new( + Arc::new(TestEntropy {}), + vec![1, 2, 3], + pending_messages.clone(), + )); + + let counterparty_node_id = utils::parse_pubkey( + "027100442c3b79f606f80f322d98d499eefcb060599efc5d4ecb00209c2cb54190", + ) + .unwrap(); + + lsps0_handler.list_protocols(counterparty_node_id); + let pending_messages = pending_messages.lock().unwrap(); + + assert_eq!(pending_messages.len(), 1); + + let (pubkey, message) = &pending_messages[0]; + + assert_eq!(*pubkey, counterparty_node_id); + assert_eq!( + *message, + LSPSMessage::LSPS0(LSPS0Message::Request( + RequestId("00000000000000000000000000000000".to_string()), + LSPS0Request::ListProtocols(ListProtocolsRequest {}) + )) + ); + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..067ce0b --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,69 @@ +use bitcoin::secp256k1::PublicKey; +use lightning::sign::EntropySource; +use std::{fmt::Write, ops::Deref}; + +use crate::transport::msgs::RequestId; + +pub(crate) fn generate_request_id(entropy_source: &ES) -> RequestId +where + ES::Target: EntropySource, +{ + let bytes = entropy_source.get_secure_random_bytes(); + RequestId(hex_str(&bytes[0..16])) +} + +#[inline] +pub fn hex_str(value: &[u8]) -> String { + let mut res = String::with_capacity(2 * value.len()); + for v in value { + write!(&mut res, "{:02x}", v).expect("Unable to write"); + } + res +} + +pub fn to_vec(hex: &str) -> Option> { + let mut out = Vec::with_capacity(hex.len() / 2); + + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().enumerate() { + b <<= 4; + match *c { + b'A'..=b'F' => b |= c - b'A' + 10, + b'a'..=b'f' => b |= c - b'a' + 10, + b'0'..=b'9' => b |= c - b'0', + _ => return None, + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + + Some(out) +} + +pub fn to_compressed_pubkey(hex: &str) -> Option { + if hex.len() != 33 * 2 { + return None; + } + let data = match to_vec(&hex[0..33 * 2]) { + Some(bytes) => bytes, + None => return None, + }; + match PublicKey::from_slice(&data) { + Ok(pk) => Some(pk), + Err(_) => None, + } +} + +pub fn parse_pubkey(pubkey_str: &str) -> Result { + let pubkey = to_compressed_pubkey(pubkey_str); + if pubkey.is_none() { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "ERROR: unable to parse given pubkey for node", + )); + } + + Ok(pubkey.unwrap()) +}