diff --git a/Cargo.lock b/Cargo.lock index 5a42e8f1249..548c1814c7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1480,6 +1480,7 @@ version = "0.0.0" dependencies = [ "anyhow", "crossbeam", + "dashmap", "hdrhistogram", "insta", "lazy_static", diff --git a/crates/rome_cli/Cargo.toml b/crates/rome_cli/Cargo.toml index dee5dcfdb90..c5fd27067cc 100644 --- a/crates/rome_cli/Cargo.toml +++ b/crates/rome_cli/Cargo.toml @@ -33,6 +33,7 @@ serde = { version = "1.0.133", features = ["derive"] } serde_json = { version = "1.0.74" } tokio = { version = "1.15.0", features = ["io-std", "io-util", "net", "time", "rt", "rt-multi-thread", "macros"] } anyhow = "1.0.52" +dashmap = "5.2.0" [target.'cfg(unix)'.dependencies] libc = "0.2.127" diff --git a/crates/rome_cli/src/execute.rs b/crates/rome_cli/src/execute.rs index 1f1c01518d3..85135c857ed 100644 --- a/crates/rome_cli/src/execute.rs +++ b/crates/rome_cli/src/execute.rs @@ -131,7 +131,7 @@ pub(crate) fn execute_mode(mode: Execution, mut session: CliSession) -> Result<( let can_format = workspace.supports_feature(SupportsFeatureParams { path: rome_path.clone(), feature: FeatureName::Format, - }); + })?; if can_format { workspace.open_file(OpenFileParams { path: rome_path.clone(), diff --git a/crates/rome_cli/src/service/mod.rs b/crates/rome_cli/src/service/mod.rs index 4420d3cbc4d..339e8e40f0c 100644 --- a/crates/rome_cli/src/service/mod.rs +++ b/crates/rome_cli/src/service/mod.rs @@ -4,14 +4,39 @@ //! is based on the [Language Server Protocol](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#baseProtocol), //! a simplified derivative of the HTTP protocol -use std::{io, panic::RefUnwindSafe, str::FromStr}; +use std::{ + any::type_name, + borrow::Cow, + io, + panic::RefUnwindSafe, + str::{from_utf8, FromStr}, + sync::Arc, + time::Duration, +}; use anyhow::{bail, ensure, Context, Error}; -use rome_service::{workspace::WorkspaceTransport, TransportError}; +use dashmap::DashMap; +use rome_service::{ + workspace::{TransportRequest, WorkspaceTransport}, + TransportError, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::{ + from_slice, from_str, to_vec, + value::{to_raw_value, RawValue}, + Value, +}; use tokio::{ - io::{split, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}, + io::{ + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, + BufReader, BufWriter, + }, runtime::Runtime, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + sync::{ + mpsc::{channel, Receiver, Sender}, + oneshot, + }, + time::sleep, }; #[cfg(windows)] @@ -32,155 +57,333 @@ pub(crate) use self::unix::{ensure_daemon, print_socket, run_daemon}; /// [WorkspaceTransport] instance if the socket is currently active pub fn open_transport(runtime: Runtime) -> io::Result> { match runtime.block_on(open_socket()) { - Ok(Some(socket)) => Ok(Some(SocketTransport::open(runtime, socket))), + Ok(Some((read, write))) => Ok(Some(SocketTransport::open(runtime, read, write))), Ok(None) => Ok(None), Err(err) => Err(err), } } +type JsonRpcResult = Result, TransportError>; + /// Implementation of [WorkspaceTransport] for types implementing [AsyncRead] /// and [AsyncWrite] +/// +/// This structs holds an instance of the `tokio` runtime, as well as the +/// following fields: +/// - `write_send` is a sender handle to the "write channel", an MPSC channel +/// that's used to queue up requests to be sent to the server (for simplicity +/// the requests are pushed to the channel as serialized byte buffers) +/// - `pending_requests` is handle to a shared hashmap where the keys are `u64` +/// corresponding to request IDs, and the values are sender handles to oneshot +/// channel instances that can be consumed to fullfill the associated request +/// +/// Creating a new `SocketTransport` instance requires providing a `tokio` +/// runtime instance as well as the "read half" and "write half" of the socket +/// object to be used by this transport instance. These two objects implement +/// [AsyncRead] and [AsyncWrite] respectively, and should generally map to the +/// same underlying I/O object but are represented as separate so they can be +/// used concurrently +/// +/// This concurrent handling of I/O is implemented useing two "background tasks": +/// - the [write_task] pulls outgoing messages from the "write channel" and +/// writes them to the "write half" of the socket +/// - the [read_task] reads incoming messages from the "read half" of the +/// socket, then looks up a request with an ID corresponding to the received +/// message in the "pending requests" map. If a pending request is found, it's +/// fullfilled with the content of the message that was just received +/// +/// In addition to these, a new "foreground task" is created for each request. +/// Each foreground task creates a oneshot channel and stores it in the pending +/// requests map using the request ID as a key, then serialize the content of +/// the request and send it over the write channel. Finally, the task blocks +/// the current thread until a response is received over the oneshot channel +/// from the read task, or the request times out pub struct SocketTransport { runtime: Runtime, - read_recv: UnboundedReceiver>, - write_send: UnboundedSender>, + write_send: Sender>, + pending_requests: PendingRequests, } +type PendingRequests = Arc>>; + impl SocketTransport { - pub fn open(runtime: Runtime, socket: T) -> Self + pub fn open(runtime: Runtime, socket_read: R, socket_write: W) -> Self where - T: AsyncRead + AsyncWrite + Send + 'static, + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, { - let (socket_read, mut socket_write) = split(socket); - let mut socket_read = BufReader::new(socket_read); - - let (read_send, read_recv) = mpsc::unbounded_channel(); - let (write_send, mut write_recv) = mpsc::unbounded_channel::>(); - - let read_task = async move { - loop { - let mut length = None; - let mut line = String::new(); - - loop { - match socket_read - .read_line(&mut line) - .await - .context("failed to read header line from the socket")? - { - // A read of 0 bytes means the connection was closed - 0 => { - bail!("the connection to the remote workspace was unexpectedly closed"); - } - // A read of two bytes corresponds to the "\r\n" sequence - // that indicates the end of the header section - 2 => { - if line != "\r\n" { - bail!("unexpected byte sequence received from the remote workspace, got {line:?} expected \"\\r\\n\""); - } - - break; - } - _ => { - let header: TransportHeader = line - .parse() - .context("failed to parse header from the remote workspace")?; - - match header { - TransportHeader::ContentLength(value) => { - length = Some(value); - } - TransportHeader::ContentType => {} - TransportHeader::Unknown(name) => { - eprintln!("ignoring unknown header {name:?}"); - } - } - - line.clear(); - } - } - } + /// Capacity of the "write channel", once this many requests have been + /// queued up, calls to `write_send.send` will block the sending task + /// until enough capacity is available again + /// + /// Note that this does not limit how many requests can be in flight at + /// a given time, it only serves as a loose rate-limit on how many new + /// requests can be sent to the server within a given time frame + const WRITE_CHANNEL_CAPACITY: usize = 16; - let length = length.context("incoming response from the remote workspace is missing the Content-Length header")?; + let (write_send, write_recv) = channel(WRITE_CHANNEL_CAPACITY); - let mut result = vec![0u8; length]; - socket_read.read_exact(&mut result).await.with_context(|| { - format!("failed to read message of {length} bytes from the socket") - })?; + let pending_requests: Arc>> = Arc::default(); + let pending_requests_2 = Arc::clone(&pending_requests); - // Send the received message over the transport channel, or - // exit the task if the channel was closed - if read_send.send(result).is_err() { - break; - } - } + let socket_read = BufReader::new(socket_read); + let socket_write = BufWriter::new(socket_write); - Ok(()) - }; + runtime.spawn(write_task(write_recv, socket_write)); + runtime.spawn(read_task(socket_read, pending_requests)); - let write_task = async move { - while let Some(message) = write_recv.recv().await { - socket_write.write_all(b"Content-Length: ").await?; + Self { + runtime, + write_send, + pending_requests: pending_requests_2, + } + } +} - let length = message.len().to_string(); - socket_write.write_all(length.as_bytes()).await?; - socket_write.write_all(b"\r\n").await?; +// Allow the socket to be recovered across panic boundaries +impl RefUnwindSafe for SocketTransport {} - socket_write - .write_all(b"Content-Type: application/vscode-jsonrpc; charset=utf-8\r\n") - .await?; +impl WorkspaceTransport for SocketTransport { + fn request(&self, request: TransportRequest

) -> Result + where + P: Serialize, + R: DeserializeOwned, + { + let (send, recv) = oneshot::channel(); - socket_write.write_all(b"\r\n").await?; + self.pending_requests.insert(request.id, send); + + let request = JsonRpcRequest { + jsonrpc: Cow::Borrowed("2.0"), + id: request.id, + method: Cow::Borrowed(request.method), + params: request.params, + }; - socket_write.write_all(&message).await?; + let request = to_vec(&request).map_err(|err| { + TransportError::SerdeError(format!( + "failed to serialize {} into byte buffer: {err}", + type_name::

() + )) + })?; + + let response = self.runtime.block_on(async move { + self.write_send + .send(request) + .await + .map_err(|_| TransportError::ChannelClosed)?; + + tokio::select! { + result = recv => { + match result { + Ok(Ok(response)) => Ok(response), + Ok(Err(error)) => Err(error), + Err(_) => Err(TransportError::ChannelClosed), + } + } + _ = sleep(Duration::from_secs(15)) => { + Err(TransportError::Timeout) + } } + })?; - Ok::<(), Error>(()) - }; + let response = response.get(); + let result = from_str(response).map_err(|err| { + TransportError::SerdeError(format!( + "failed to deserialize {} from {response:?}: {err}", + type_name::() + )) + })?; - runtime.spawn(async move { - if let Err(err) = read_task.await { - eprintln!( - "{:?}", - err.context("remote connection read task exited with an error") - ); + Ok(result) + } +} + +async fn read_task(mut socket_read: BufReader, pending_requests: PendingRequests) +where + R: AsyncRead + Unpin, +{ + loop { + let message = read_message(&mut socket_read).await; + let message = match message { + Ok(message) => { + let response = from_slice(&message).with_context(|| { + if let Ok(message) = from_utf8(&message) { + format!("failed to deserialize JSON-RPC response from {message:?}") + } else { + format!("failed to deserialize JSON-RPC response from {message:?}") + } + }); + + response.map(|response| (message, response)) } - }); + Err(err) => Err(err), + }; - runtime.spawn(async move { - if let Err(err) = write_task.await { + let (message, response): (_, JsonRpcResponse) = match message { + Ok(message) => message, + Err(err) => { eprintln!( "{:?}", err.context("remote connection write task exited with an error") ); + break; } - }); + }; - Self { - runtime, - read_recv, - write_send, + if let Some((_, channel)) = pending_requests.remove(&response.id) { + let response = match (response.result, response.error) { + (Some(result), None) => Ok(result), + (None, Some(err)) => Err(TransportError::RPCError(err.message)), + + // Both result and error will be None if the request + // returns a null-ish result, in this case create a + // "null" RawValue as the result + // + // SAFETY: Calling `to_raw_value` with a static "null" + // JSON Value will always succeed + (None, None) => Ok(to_raw_value(&Value::Null).unwrap()), + + _ => { + let message = if let Ok(message) = from_utf8(&message) { + format!("invalid response {message:?}") + } else { + format!("invalid response {message:?}") + }; + + Err(TransportError::SerdeError(message)) + } + }; + + channel.send(response).ok(); } } } -// Allow the socket to be recovered across panic boundaries -impl RefUnwindSafe for SocketTransport {} +async fn read_message(mut socket_read: R) -> Result, Error> +where + R: AsyncBufRead + Unpin, +{ + let mut length = None; + let mut line = String::new(); + + loop { + match socket_read + .read_line(&mut line) + .await + .context("failed to read header line from the socket")? + { + // A read of 0 bytes means the connection was closed + 0 => { + bail!("the connection to the remote workspace was unexpectedly closed"); + } + // A read of two bytes corresponds to the "\r\n" sequence + // that indicates the end of the header section + 2 => { + if line != "\r\n" { + bail!("unexpected byte sequence received from the remote workspace, got {line:?} expected \"\\r\\n\""); + } -impl WorkspaceTransport for SocketTransport { - fn send(&mut self, request: Vec) -> Result<(), TransportError> { - self.write_send - .send(request) - .map_err(|_| TransportError::ChannelClosed) + break; + } + _ => { + let header: TransportHeader = line + .parse() + .context("failed to parse header from the remote workspace")?; + + match header { + TransportHeader::ContentLength(value) => { + length = Some(value); + } + TransportHeader::ContentType => {} + TransportHeader::Unknown(name) => { + eprintln!("ignoring unknown header {name:?}"); + } + } + + line.clear(); + } + } } - fn receive(&mut self) -> Result, TransportError> { - let read_recv = &mut self.read_recv; - self.runtime - .block_on(async move { read_recv.recv().await.ok_or(TransportError::ChannelClosed) }) + let length = length.context( + "incoming response from the remote workspace is missing the Content-Length header", + )?; + + let mut result = vec![0u8; length]; + socket_read + .read_exact(&mut result) + .await + .with_context(|| format!("failed to read message of {length} bytes from the socket"))?; + + Ok(result) +} + +async fn write_task(mut write_recv: Receiver>, mut socket_write: BufWriter) +where + W: AsyncWrite + Unpin, +{ + while let Some(message) = write_recv.recv().await { + if let Err(err) = write_message(&mut socket_write, message).await { + eprintln!( + "{:?}", + err.context("remote connection read task exited with an error") + ); + break; + } } } +async fn write_message(mut socket_write: W, message: Vec) -> Result<(), Error> +where + W: AsyncWrite + Unpin, +{ + socket_write.write_all(b"Content-Length: ").await?; + + let length = message.len().to_string(); + socket_write.write_all(length.as_bytes()).await?; + socket_write.write_all(b"\r\n").await?; + + socket_write + .write_all(b"Content-Type: application/vscode-jsonrpc; charset=utf-8\r\n") + .await?; + + socket_write.write_all(b"\r\n").await?; + + socket_write.write_all(&message).await?; + + socket_write.flush().await?; + + Ok(()) +} + +#[derive(Debug, Serialize)] +struct JsonRpcRequest

{ + jsonrpc: Cow<'static, str>, + id: u64, + method: Cow<'static, str>, + params: P, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct JsonRpcResponse { + #[allow(dead_code)] + jsonrpc: Cow<'static, str>, + id: u64, + result: Option>, + error: Option, +} + +#[derive(Debug, Deserialize)] +struct JsonRpcError { + #[allow(dead_code)] + code: i64, + message: String, + #[allow(dead_code)] + data: Option>, +} + enum TransportHeader { ContentLength(usize), ContentType, diff --git a/crates/rome_cli/src/service/unix.rs b/crates/rome_cli/src/service/unix.rs index 327af8c0b75..b9b0cfdbc4b 100644 --- a/crates/rome_cli/src/service/unix.rs +++ b/crates/rome_cli/src/service/unix.rs @@ -8,8 +8,11 @@ use std::{ use rome_lsp::{ServerConnection, ServerFactory}; use tokio::{ - io::{split, Interest}, - net::{UnixListener, UnixStream}, + io::Interest, + net::{ + unix::{OwnedReadHalf, OwnedWriteHalf}, + UnixListener, UnixStream, + }, process::{Child, Command}, time, }; @@ -61,9 +64,9 @@ fn spawn_daemon() -> io::Result { /// Open a connection to the daemon server process, returning [None] if the /// server is not running -pub(crate) async fn open_socket() -> io::Result> { +pub(crate) async fn open_socket() -> io::Result> { match try_connect().await { - Ok(socket) => Ok(Some(socket)), + Ok(socket) => Ok(Some(socket.into_split())), Err(err) // The OS will return `ConnectionRefused` if the socket file exists // but no server process is listening on it @@ -169,6 +172,6 @@ pub(crate) async fn run_daemon(factory: ServerFactory) -> io::Result /// Async task driving a single client connection async fn run_server(connection: ServerConnection, stream: UnixStream) { - let (read, write) = split(stream); + let (read, write) = stream.into_split(); connection.accept(read, write).await; } diff --git a/crates/rome_cli/src/service/windows.rs b/crates/rome_cli/src/service/windows.rs index 873e406675e..b873cd46bdb 100644 --- a/crates/rome_cli/src/service/windows.rs +++ b/crates/rome_cli/src/service/windows.rs @@ -4,13 +4,16 @@ use std::{ io::{self, ErrorKind}, mem::swap, os::windows::process::CommandExt, + pin::Pin, process::Command, + sync::Arc, + task::{Context, Poll}, time::Duration, }; use rome_lsp::{ServerConnection, ServerFactory}; use tokio::{ - io::split, + io::{AsyncRead, AsyncWrite, ReadBuf}, net::windows::named_pipe::{ClientOptions, NamedPipeClient, NamedPipeServer, ServerOptions}, time, }; @@ -58,14 +61,84 @@ fn spawn_daemon() -> io::Result<()> { /// Open a connection to the daemon server process, returning [None] if the /// server is not running -pub(crate) async fn open_socket() -> io::Result> { +pub(crate) async fn open_socket() -> io::Result> { match try_connect().await { - Ok(socket) => Ok(Some(socket)), + Ok(socket) => { + let inner = Arc::new(socket); + Ok(Some(( + ClientReadHalf { + inner: inner.clone(), + }, + ClientWriteHalf { inner }, + ))) + } Err(err) if err.kind() == ErrorKind::NotFound => Ok(None), Err(err) => Err(err), } } +pub(crate) struct ClientReadHalf { + inner: Arc, +} + +impl AsyncRead for ClientReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + match self.inner.poll_read_ready(cx) { + Poll::Ready(Ok(())) => match self.inner.try_read(buf.initialize_unfilled()) { + Ok(count) => { + buf.advance(count); + return Poll::Ready(Ok(())); + } + + Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(err) => return Poll::Ready(Err(err)), + }, + + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + } + } +} + +pub(crate) struct ClientWriteHalf { + inner: Arc, +} + +impl AsyncWrite for ClientWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match self.inner.poll_write_ready(cx) { + Poll::Ready(Ok(())) => match self.inner.try_write(buf) { + Ok(count) => return Poll::Ready(Ok(count)), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(err) => return Poll::Ready(Err(err)), + }, + + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + /// Ensure the server daemon is running and ready to receive connections /// /// Returns false if the daemon process was already running or true if it had @@ -116,6 +189,72 @@ pub(crate) async fn run_daemon(factory: ServerFactory) -> io::Result /// Async task driving a single client connection async fn run_server(connection: ServerConnection, stream: NamedPipeServer) { - let (read, write) = split(stream); + let inner = Arc::new(stream); + let read = ServerReadHalf { + inner: inner.clone(), + }; + let write = ServerWriteHalf { inner }; connection.accept(read, write).await; } + +struct ServerReadHalf { + inner: Arc, +} + +impl AsyncRead for ServerReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + match self.inner.poll_read_ready(cx) { + Poll::Ready(Ok(())) => match self.inner.try_read(buf.initialize_unfilled()) { + Ok(count) => { + buf.advance(count); + return Poll::Ready(Ok(())); + } + + Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(err) => return Poll::Ready(Err(err)), + }, + + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + } + } +} + +struct ServerWriteHalf { + inner: Arc, +} + +impl AsyncWrite for ServerWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match self.inner.poll_write_ready(cx) { + Poll::Ready(Ok(())) => match self.inner.try_write(buf) { + Ok(count) => return Poll::Ready(Ok(count)), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(err) => return Poll::Ready(Err(err)), + }, + + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} diff --git a/crates/rome_cli/src/traversal.rs b/crates/rome_cli/src/traversal.rs index 636140ceb08..d2f0422dd1c 100644 --- a/crates/rome_cli/src/traversal.rs +++ b/crates/rome_cli/src/traversal.rs @@ -18,7 +18,7 @@ use rome_service::{ workspace::{ FeatureName, FileGuard, Language, OpenFileParams, RuleCategories, SupportsFeatureParams, }, - Workspace, + RomeError, Workspace, }; use std::{ collections::HashMap, @@ -434,7 +434,7 @@ impl<'ctx, 'app> TraversalOptions<'ctx, 'app> { self.messages.send(msg.into()).ok(); } - fn can_format(&self, rome_path: &RomePath) -> bool { + fn can_format(&self, rome_path: &RomePath) -> Result { self.workspace.supports_feature(SupportsFeatureParams { path: rome_path.clone(), feature: FeatureName::Format, @@ -447,7 +447,7 @@ impl<'ctx, 'app> TraversalOptions<'ctx, 'app> { .ok(); } - fn can_lint(&self, rome_path: &RomePath) -> bool { + fn can_lint(&self, rome_path: &RomePath) -> Result { self.workspace.supports_feature(SupportsFeatureParams { path: rome_path.clone(), feature: FeatureName::Lint, @@ -470,10 +470,20 @@ impl<'ctx, 'app> TraversalContext for TraversalOptions<'ctx, 'app> { } fn can_handle(&self, rome_path: &RomePath) -> bool { - match self.execution.traversal_mode() { + let result = match self.execution.traversal_mode() { TraversalMode::Check { .. } => self.can_lint(rome_path), - TraversalMode::CI { .. } => self.can_lint(rome_path) || self.can_format(rome_path), + TraversalMode::CI { .. } => self + .can_lint(rome_path) + .and_then(|can_lint| Ok(can_lint || self.can_format(rome_path)?)), TraversalMode::Format { .. } => self.can_format(rome_path), + }; + + match result { + Ok(result) => result, + Err(err) => { + self.push_diagnostic(rome_path.file_id(), "IO", err.to_string()); + false + } } } @@ -543,10 +553,14 @@ type FileResult = Result; fn process_file(ctx: &TraversalOptions, path: &Path, file_id: FileId) -> FileResult { tracing::trace_span!("process_file", path = ?path).in_scope(move || { let rome_path = RomePath::new(path, file_id); - let can_format = ctx.can_format(&rome_path); - let can_lint = ctx.can_lint(&rome_path); + let can_format = ctx + .can_format(&rome_path) + .with_file_id_and_code(file_id, "IO")?; + let can_lint = ctx + .can_lint(&rome_path) + .with_file_id_and_code(file_id, "IO")?; let can_handle = match ctx.execution.traversal_mode() { - TraversalMode::Check { .. } => ctx.can_lint(&rome_path), + TraversalMode::Check { .. } => can_lint, TraversalMode::CI { .. } => can_lint || can_format, TraversalMode::Format { .. } => can_format, }; diff --git a/crates/rome_cli/tests/main.rs b/crates/rome_cli/tests/main.rs index d9b7b321bff..309ade47aa9 100644 --- a/crates/rome_cli/tests/main.rs +++ b/crates/rome_cli/tests/main.rs @@ -1727,7 +1727,8 @@ fn run_cli<'app>( let (stdin, stdout) = split(server); runtime.spawn(connection.accept(stdin, stdout)); - let transport = SocketTransport::open(runtime, client); + let (client_read, client_write) = split(client); + let transport = SocketTransport::open(runtime, client_read, client_write); let workspace = workspace::client(transport).unwrap(); let app = App::new(fs, console, WorkspaceRef::Owned(workspace)); diff --git a/crates/rome_js_formatter/tests/spec_test.rs b/crates/rome_js_formatter/tests/spec_test.rs index 6609a7d4780..899b529cf5d 100644 --- a/crates/rome_js_formatter/tests/spec_test.rs +++ b/crates/rome_js_formatter/tests/spec_test.rs @@ -205,10 +205,13 @@ pub fn run(spec_input_file: &str, _expected_file: &str, test_directory: &str, fi ); let mut rome_path = RomePath::new(file_path, 0); - let can_format = app.workspace.supports_feature(SupportsFeatureParams { - path: rome_path.clone(), - feature: FeatureName::Format, - }); + let can_format = app + .workspace + .supports_feature(SupportsFeatureParams { + path: rome_path.clone(), + feature: FeatureName::Format, + }) + .unwrap(); if can_format { let mut snapshot_content = SnapshotContent::default(); diff --git a/crates/rome_lsp/src/handlers/analysis.rs b/crates/rome_lsp/src/handlers/analysis.rs index e900be6853c..20e5523fa1a 100644 --- a/crates/rome_lsp/src/handlers/analysis.rs +++ b/crates/rome_lsp/src/handlers/analysis.rs @@ -31,7 +31,7 @@ pub(crate) fn code_actions( let linter_enabled = &session.workspace.supports_feature(SupportsFeatureParams { path: rome_path, feature: FeatureName::Lint, - }); + })?; if !linter_enabled { return Ok(Some(Vec::new())); } diff --git a/crates/rome_lsp/src/server.rs b/crates/rome_lsp/src/server.rs index 36bc9358429..21983ab68f4 100644 --- a/crates/rome_lsp/src/server.rs +++ b/crates/rome_lsp/src/server.rs @@ -1,4 +1,4 @@ -use std::{panic::catch_unwind, sync::Arc}; +use std::sync::Arc; use crate::capabilities::server_capabilities; use crate::requests::syntax_tree::{SyntaxTreePayload, SYNTAX_TREE_REQUEST}; @@ -6,9 +6,11 @@ use crate::session::Session; use crate::utils::{into_lsp_error, panic_to_lsp_error}; use crate::{handlers, requests}; use futures::future::ready; +use futures::FutureExt; use rome_service::{workspace, Workspace}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::Notify; +use tokio::task::spawn_blocking; use tower_lsp::jsonrpc::Result as LspResult; use tower_lsp::{lsp_types::*, ClientSocket}; use tower_lsp::{Client, LanguageServer, LspService, Server}; @@ -178,13 +180,27 @@ macro_rules! workspace_method { $builder = $builder.custom_method( concat!("rome/", stringify!($method)), |server: &LSPServer, params| { - let workspace = &server.session.workspace; - let result = catch_unwind(move || workspace.$method(params)); - - ready(match result { - Ok(Ok(result)) => Ok(result), - Ok(Err(err)) => Err(into_lsp_error(err)), - Err(err) => Err(panic_to_lsp_error(err)), + let span = tracing::trace_span!(concat!("rome/", stringify!($method)), params = ?params).or_current(); + + let workspace = server.session.workspace.clone(); + let result = spawn_blocking(move || { + let _guard = span.entered(); + workspace.$method(params) + }); + + result.map(move |result| { + // The type of `result` is `Result, JoinError>`, + // where the inner result is the return value of `$method` while the + // outer one is added by `spawn_blocking` to catch panics or + // cancellations of the task + match result { + Ok(Ok(result)) => Ok(result), + Ok(Err(err)) => Err(into_lsp_error(err)), + Err(err) => match err.try_into_panic() { + Ok(err) => Err(panic_to_lsp_error(err)), + Err(err) => Err(into_lsp_error(err)), + }, + } }) }, ); @@ -212,11 +228,7 @@ impl ServerFactory { ready(Ok(Some(()))) }); - // supports_feature is special because it returns a bool instead of a Result - builder = builder.custom_method("rome/supports_feature", |server: &LSPServer, params| { - ready(Ok(server.session.workspace.supports_feature(params))) - }); - + workspace_method!(builder, supports_feature); workspace_method!(builder, update_settings); workspace_method!(builder, open_file); workspace_method!(builder, get_syntax_tree); diff --git a/crates/rome_lsp/src/session.rs b/crates/rome_lsp/src/session.rs index 268d18ec061..4cfdbf271cf 100644 --- a/crates/rome_lsp/src/session.rs +++ b/crates/rome_lsp/src/session.rs @@ -118,7 +118,7 @@ impl Session { let lint_enabled = self.workspace.supports_feature(SupportsFeatureParams { feature: FeatureName::Lint, path: rome_path.clone(), - }); + })?; let diagnostics = if lint_enabled { let result = self.workspace.pull_diagnostics(PullDiagnosticsParams { diff --git a/crates/rome_service/src/lib.rs b/crates/rome_service/src/lib.rs index a6e49651919..010a6fe17b4 100644 --- a/crates/rome_service/src/lib.rs +++ b/crates/rome_service/src/lib.rs @@ -174,6 +174,8 @@ impl From for RomeError { pub enum TransportError { /// Error emitted by the transport layer if the connection was lost due to an I/O error ChannelClosed, + /// Error emitted by the transport layer if a request timed out + Timeout, /// Error caused by a serialization or deserialization issue SerdeError(String), /// Generic error type for RPC errors that can't be deserialized into RomeError @@ -189,17 +191,14 @@ impl Display for TransportError { TransportError::ChannelClosed => fmt.write_str( "a request to the remote workspace failed because the connection was interrupted", ), + TransportError::Timeout => { + fmt.write_str("the request to the remote workspace timed out") + } TransportError::RPCError(err) => fmt.write_str(err), } } } -impl From for TransportError { - fn from(err: serde_json::Error) -> Self { - TransportError::SerdeError(err.to_string()) - } -} - impl Default for App<'static> { fn default() -> Self { Self::with_filesystem_and_console( diff --git a/crates/rome_service/src/workspace.rs b/crates/rome_service/src/workspace.rs index 86ad59311dd..5330e94ceb8 100644 --- a/crates/rome_service/src/workspace.rs +++ b/crates/rome_service/src/workspace.rs @@ -63,7 +63,7 @@ use rome_rowan::TextRangeSchema; use rome_text_edit::Indel; use std::{borrow::Cow, panic::RefUnwindSafe, sync::Arc}; -pub use self::client::{WorkspaceClient, WorkspaceTransport}; +pub use self::client::{TransportRequest, WorkspaceClient, WorkspaceTransport}; pub use crate::file_handlers::Language; mod client; @@ -259,7 +259,7 @@ pub trait Workspace: Send + Sync + RefUnwindSafe { /// Checks whether a certain feature is supported. There are different conditions: /// - Rome doesn't recognize a file, so it can provide the feature; /// - the feature is disabled inside the configuration; - fn supports_feature(&self, params: SupportsFeatureParams) -> bool; + fn supports_feature(&self, params: SupportsFeatureParams) -> Result; /// Update the global settings for this workspace fn update_settings(&self, params: UpdateSettingsParams) -> Result<(), RomeError>; diff --git a/crates/rome_service/src/workspace/client.rs b/crates/rome_service/src/workspace/client.rs index 9eb34e3409c..cb6dd6df7a3 100644 --- a/crates/rome_service/src/workspace/client.rs +++ b/crates/rome_service/src/workspace/client.rs @@ -1,14 +1,11 @@ use std::{ panic::RefUnwindSafe, - sync::{ - atomic::{AtomicU64, Ordering}, - Mutex, - }, + sync::atomic::{AtomicU64, Ordering}, }; use rome_formatter::Printed; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::{from_slice, json, to_vec}; +use serde_json::json; use crate::{RomeError, TransportError, Workspace}; @@ -21,46 +18,22 @@ use super::{ }; pub struct WorkspaceClient { - transport: Mutex, + transport: T, request_id: AtomicU64, } pub trait WorkspaceTransport { - fn send(&mut self, request: Vec) -> Result<(), TransportError>; - fn receive(&mut self) -> Result, TransportError>; -} - -#[derive(Debug, Serialize)] -struct JsonRpcRequest

{ - jsonrpc: &'static str, - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - method: &'static str, - params: P, -} - -#[derive(Debug, Deserialize)] -struct JsonRpcResponse<'a, R> { - #[allow(dead_code)] - jsonrpc: &'a str, - id: u64, - #[serde(flatten)] - status: JsonRpcResult, -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -enum JsonRpcResult { - Ok { result: R }, - Err { error: JsonRpcError }, + fn request(&self, request: TransportRequest

) -> Result + where + P: Serialize, + R: DeserializeOwned; } -#[derive(Debug, Deserialize)] -struct JsonRpcError { - #[allow(dead_code)] - code: i64, - message: String, - data: Option, +#[derive(Debug)] +pub struct TransportRequest

{ + pub id: u64, + pub method: &'static str, + pub params: P, } #[derive(Debug, Deserialize)] @@ -72,7 +45,7 @@ where { pub fn new(transport: T) -> Result { let client = Self { - transport: Mutex::new(transport), + transport, request_id: AtomicU64::new(0), }; @@ -99,32 +72,12 @@ where P: Serialize, R: DeserializeOwned, { - let mut transport = self.transport.lock().unwrap(); - let id = self.request_id.fetch_add(1, Ordering::Relaxed); - let request = JsonRpcRequest { - jsonrpc: "2.0", - id: Some(id), - method, - params, - }; - - let request = to_vec(&request).map_err(TransportError::from)?; - transport.send(request)?; - - let response = transport.receive()?; - let response: JsonRpcResponse = from_slice(&response).map_err(TransportError::from)?; + let request = TransportRequest { id, method, params }; - // This should be true since we don't allow concurrent requests yet - assert_eq!(response.id, id); + let response = self.transport.request(request)?; - match response.status { - JsonRpcResult::Ok { result } => Ok(result), - JsonRpcResult::Err { error } => match error.data { - Some(error) => Err(error), - None => Err(RomeError::from(TransportError::RPCError(error.message))), - }, - } + Ok(response) } pub fn shutdown(self) -> Result<(), RomeError> { @@ -136,9 +89,8 @@ impl Workspace for WorkspaceClient where T: WorkspaceTransport + RefUnwindSafe + Send + Sync, { - fn supports_feature(&self, params: SupportsFeatureParams) -> bool { + fn supports_feature(&self, params: SupportsFeatureParams) -> Result { self.request("rome/supports_feature", params) - .unwrap_or(false) } fn update_settings(&self, params: UpdateSettingsParams) -> Result<(), RomeError> { diff --git a/crates/rome_service/src/workspace/server.rs b/crates/rome_service/src/workspace/server.rs index e659323a3bb..6fddcf77744 100644 --- a/crates/rome_service/src/workspace/server.rs +++ b/crates/rome_service/src/workspace/server.rs @@ -162,15 +162,16 @@ impl WorkspaceServer { } impl Workspace for WorkspaceServer { - fn supports_feature(&self, params: SupportsFeatureParams) -> bool { + fn supports_feature(&self, params: SupportsFeatureParams) -> Result { let capabilities = self.get_capabilities(¶ms.path); let settings = self.settings.read().unwrap(); - match params.feature { + let result = match params.feature { FeatureName::Format => { capabilities.formatter.format.is_some() && settings.formatter().enabled } FeatureName::Lint => capabilities.analyzer.lint.is_some() && settings.linter().enabled, - } + }; + Ok(result) } /// Update the global settings for this workspace diff --git a/crates/rome_wasm/src/lib.rs b/crates/rome_wasm/src/lib.rs index 3eb05c22fce..c2480063f73 100644 --- a/crates/rome_wasm/src/lib.rs +++ b/crates/rome_wasm/src/lib.rs @@ -38,7 +38,7 @@ impl Workspace { #[wasm_bindgen(js_name = supportsFeature)] pub fn supports_feature(&self, params: ISupportsFeatureParams) -> Result { let params: SupportsFeatureParams = params.into_serde().map_err(into_error)?; - Ok(self.inner.supports_feature(params)) + self.inner.supports_feature(params).map_err(into_error) } #[wasm_bindgen(js_name = updateSettings)]