diff --git a/crates/ra_proc_macro/src/msg.rs b/crates/ra_proc_macro/src/msg.rs index 2fb065d327f..aa95bcc8f7b 100644 --- a/crates/ra_proc_macro/src/msg.rs +++ b/crates/ra_proc_macro/src/msg.rs @@ -1,218 +1,93 @@ -//! A simplified version of lsp base protocol for rpc +//! Defines messages for cross-process message based on `ndjson` wire protocol use std::{ - fmt, + convert::TryFrom, io::{self, BufRead, Write}, }; +use crate::{ + rpc::{ListMacrosResult, ListMacrosTask}, + ExpansionResult, ExpansionTask, +}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(untagged)] -pub enum Message { - Request(Request), - Response(Response), +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Request { + ListMacro(ListMacrosTask), + ExpansionMacro(ExpansionTask), } -impl From for Message { - fn from(request: Request) -> Message { - Message::Request(request) - } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Response { + Error(ResponseError), + ListMacro(ListMacrosResult), + ExpansionMacro(ExpansionResult), } -impl From for Message { - fn from(response: Response) -> Message { - Message::Response(response) - } -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[serde(transparent)] -pub struct RequestId(IdRepr); - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[serde(untagged)] -enum IdRepr { - U64(u64), - String(String), -} - -impl From for RequestId { - fn from(id: u64) -> RequestId { - RequestId(IdRepr::U64(id)) - } -} - -impl From for RequestId { - fn from(id: String) -> RequestId { - RequestId(IdRepr::String(id)) - } -} - -impl fmt::Display for RequestId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - IdRepr::U64(it) => fmt::Display::fmt(it, f), - IdRepr::String(it) => fmt::Display::fmt(it, f), +macro_rules! impl_try_from_response { + ($ty:ty, $tag:ident) => { + impl TryFrom for $ty { + type Error = &'static str; + fn try_from(value: Response) -> Result { + match value { + Response::$tag(res) => Ok(res), + _ => Err("Fail to convert from response"), + } + } } - } + }; } -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Request { - pub id: RequestId, - pub method: String, - pub params: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Response { - // JSON RPC allows this to be null if it was impossible - // to decode the request's id. Ignore this special case - // and just die horribly. - pub id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} +impl_try_from_response!(ListMacrosResult, ListMacro); +impl_try_from_response!(ExpansionResult, ExpansionMacro); #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ResponseError { - pub code: i32, + pub code: ErrorCode, pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Clone, Copy, Debug)] -#[allow(unused)] -pub enum ErrorCode { - // Defined by JSON RPC - ParseError = -32700, - InvalidRequest = -32600, - MethodNotFound = -32601, - InvalidParams = -32602, - InternalError = -32603, - ServerErrorStart = -32099, - ServerErrorEnd = -32000, - ServerNotInitialized = -32002, - UnknownErrorCode = -32001, - - // Defined by protocol - ExpansionError = -32900, } #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Notification { - pub method: String, - pub params: serde_json::Value, +pub enum ErrorCode { + ServerErrorEnd, + ExpansionError, } -impl Message { - pub fn read(r: &mut impl BufRead) -> io::Result> { - let text = match read_msg_text(r)? { +pub trait Message: Sized + Serialize + DeserializeOwned { + fn read(r: &mut impl BufRead) -> io::Result> { + let text = match read_json(r)? { None => return Ok(None), Some(text) => text, }; let msg = serde_json::from_str(&text)?; Ok(Some(msg)) } - pub fn write(self, w: &mut impl Write) -> io::Result<()> { - #[derive(Serialize)] - struct JsonRpc { - jsonrpc: &'static str, - #[serde(flatten)] - msg: Message, - } - let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?; - write_msg_text(w, &text) + fn write(self, w: &mut impl Write) -> io::Result<()> { + let text = serde_json::to_string(&self)?; + write_json(w, &text) } } -impl Response { - pub fn new_ok(id: RequestId, result: R) -> Response { - Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } - } - pub fn new_err(id: RequestId, code: i32, message: String) -> Response { - let error = ResponseError { code, message, data: None }; - Response { id, result: None, error: Some(error) } - } -} +impl Message for Request {} +impl Message for Response {} -impl Request { - pub fn new(id: RequestId, method: String, params: P) -> Request { - Request { id, method, params: serde_json::to_value(params).unwrap() } - } - pub fn extract(self, method: &str) -> Result<(RequestId, P), Request> { - if self.method == method { - let params = serde_json::from_value(self.params).unwrap_or_else(|err| { - panic!("Invalid request\nMethod: {}\n error: {}", method, err) - }); - Ok((self.id, params)) - } else { - Err(self) - } - } -} - -impl Notification { - pub fn new(method: String, params: impl Serialize) -> Notification { - Notification { method, params: serde_json::to_value(params).unwrap() } - } - pub fn extract(self, method: &str) -> Result { - if self.method == method { - let params = serde_json::from_value(self.params).unwrap(); - Ok(params) - } else { - Err(self) - } - } -} - -fn read_msg_text(inp: &mut impl BufRead) -> io::Result> { - fn invalid_data(error: impl Into>) -> io::Error { - io::Error::new(io::ErrorKind::InvalidData, error) - } - macro_rules! invalid_data { - ($($tt:tt)*) => (invalid_data(format!($($tt)*))) - } - - let mut size = None; +fn read_json(inp: &mut impl BufRead) -> io::Result> { let mut buf = String::new(); - loop { - buf.clear(); - if inp.read_line(&mut buf)? == 0 { - return Ok(None); - } - if !buf.ends_with("\r\n") { - return Err(invalid_data!("malformed header: {:?}", buf)); - } - let buf = &buf[..buf.len() - 2]; - if buf.is_empty() { - break; - } - let mut parts = buf.splitn(2, ": "); - let header_name = parts.next().unwrap(); - let header_value = - parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; - if header_name == "Content-Length" { - size = Some(header_value.parse::().map_err(invalid_data)?); - } + if inp.read_line(&mut buf)? == 0 { + return Ok(None); } - let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; - let mut buf = buf.into_bytes(); - buf.resize(size, 0); - inp.read_exact(&mut buf)?; - let buf = String::from_utf8(buf).map_err(invalid_data)?; - log::debug!("< {}", buf); - Ok(Some(buf)) + // Remove ending '\n' + let buf = &buf[..buf.len() - 1]; + if buf.is_empty() { + return Ok(None); + } + Ok(Some(buf.to_string())) } -fn write_msg_text(out: &mut impl Write, msg: &str) -> io::Result<()> { +fn write_json(out: &mut impl Write, msg: &str) -> io::Result<()> { log::debug!("> {}", msg); - write!(out, "Content-Length: {}\r\n\r\n", msg.len())?; out.write_all(msg.as_bytes())?; + out.write_all(b"\n")?; out.flush()?; Ok(()) } diff --git a/crates/ra_proc_macro/src/process.rs b/crates/ra_proc_macro/src/process.rs index daae9a7e0f6..2b1f8535a14 100644 --- a/crates/ra_proc_macro/src/process.rs +++ b/crates/ra_proc_macro/src/process.rs @@ -3,11 +3,12 @@ use crossbeam_channel::{bounded, Receiver, Sender}; use ra_tt::Subtree; -use crate::msg::{ErrorCode, Message, Request, Response, ResponseError}; +use crate::msg::{ErrorCode, Request, Response, ResponseError, Message}; use crate::rpc::{ExpansionResult, ExpansionTask, ListMacrosResult, ListMacrosTask, ProcMacroKind}; use io::{BufRead, BufReader}; use std::{ + convert::{TryFrom, TryInto}, io::{self, Write}, path::{Path, PathBuf}, process::{Child, Command, Stdio}, @@ -26,7 +27,7 @@ pub(crate) struct ProcMacroProcessThread { } enum Task { - Request { req: Message, result_tx: Sender }, + Request { req: Request, result_tx: Sender }, Close, } @@ -96,7 +97,7 @@ pub fn find_proc_macros( ) -> Result, ra_tt::ExpansionError> { let task = ListMacrosTask { lib: dylib_path.to_path_buf() }; - let result: ListMacrosResult = self.send_task("list_macros", task)?; + let result: ListMacrosResult = self.send_task(Request::ListMacro(task))?; Ok(result.macros) } @@ -113,26 +114,19 @@ pub fn custom_derive( lib: dylib_path.to_path_buf(), }; - let result: ExpansionResult = self.send_task("custom_derive", task)?; + let result: ExpansionResult = self.send_task(Request::ExpansionMacro(task))?; Ok(result.expansion) } - pub fn send_task<'a, T, R>(&self, method: &str, task: T) -> Result + pub fn send_task(&self, req: Request) -> Result where - T: serde::Serialize, - R: serde::de::DeserializeOwned + Default, + R: TryFrom, { let sender = match &self.inner { None => return Err(ra_tt::ExpansionError::Unknown("No sender is found.".to_string())), Some(it) => it, }; - let msg = serde_json::to_value(task).unwrap(); - - // FIXME: use a proper request id - let id = 0; - let req = Request { id: id.into(), method: method.into(), params: msg }; - let (result_tx, result_rx) = bounded(0); sender.send(Task::Request { req: req.into(), result_tx }).map_err(|err| { @@ -141,27 +135,18 @@ pub fn send_task<'a, T, R>(&self, method: &str, task: T) -> Result { - return Err(ra_tt::ExpansionError::Unknown( - "Return request from ra_proc_srv".into(), + let res = result_rx.recv().unwrap(); + match res { + Response::Error(err) => { + return Err(ra_tt::ExpansionError::ExpansionError(err.message)); + } + _ => Ok(res.try_into().map_err(|err| { + ra_tt::ExpansionError::Unknown(format!( + "Fail to get response, reason : {:#?} ", + err )) - } - Message::Response(res) => { - if let Some(err) = res.error { - return Err(ra_tt::ExpansionError::ExpansionError(err.message)); - } - match res.result { - None => Ok(R::default()), - Some(res) => { - let result: R = serde_json::from_value(res) - .map_err(|err| ra_tt::ExpansionError::JsonError(err.to_string()))?; - Ok(result) - } - } - } + })?), } } } @@ -183,18 +168,13 @@ fn client_loop(task_rx: Receiver, mut process: Process) { Task::Close => break, }; - let res = match send_message(&mut stdin, &mut stdout, req) { + let res = match send_request(&mut stdin, &mut stdout, req) { Ok(res) => res, Err(_err) => { - let res = Response { - id: 0.into(), - result: None, - error: Some(ResponseError { - code: ErrorCode::ServerErrorEnd as i32, - message: "Server closed".into(), - data: None, - }), - }; + let res = Response::Error(ResponseError { + code: ErrorCode::ServerErrorEnd, + message: "Server closed".into(), + }); if result_tx.send(res.into()).is_err() { break; } @@ -222,11 +202,11 @@ fn client_loop(task_rx: Receiver, mut process: Process) { let _ = process.child.kill(); } -fn send_message( +fn send_request( mut writer: &mut impl Write, mut reader: &mut impl BufRead, - msg: Message, -) -> Result, io::Error> { - msg.write(&mut writer)?; - Ok(Message::read(&mut reader)?) + req: Request, +) -> Result, io::Error> { + req.write(&mut writer)?; + Ok(Response::read(&mut reader)?) } diff --git a/crates/ra_proc_macro/src/rpc.rs b/crates/ra_proc_macro/src/rpc.rs index f88d91f782c..fc8b04e2889 100644 --- a/crates/ra_proc_macro/src/rpc.rs +++ b/crates/ra_proc_macro/src/rpc.rs @@ -1,4 +1,10 @@ //! Data struture serialization related stuffs for RPC +//! +//! Define all necessary rpc serialization data structure, +//! which include ra_tt related data and some task messages. +//! Although adding Serialize and Deserialize trait to ra_tt directly seem to be much easier, +//! we deliberately duplicate the ra_tt struct with #[serde(with = "XXDef")] +//! for separation of code responsibility. use ra_tt::{ Delimiter, DelimiterKind, Ident, Leaf, Literal, Punct, SmolStr, Spacing, Subtree, TokenId,