208 lines
5.2 KiB
Rust
208 lines
5.2 KiB
Rust
use std::{
|
|
thread,
|
|
io::{
|
|
stdout, stdin,
|
|
BufRead, Write,
|
|
},
|
|
};
|
|
use serde_json::{Value, from_str, to_string};
|
|
use crossbeam_channel::{Receiver, Sender, bounded};
|
|
|
|
use Result;
|
|
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
#[serde(untagged)]
|
|
pub enum RawMsg {
|
|
Request(RawRequest),
|
|
Notification(RawNotification),
|
|
Response(RawResponse),
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct RawRequest {
|
|
pub id: u64,
|
|
pub method: String,
|
|
pub params: Value,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct RawNotification {
|
|
pub method: String,
|
|
pub params: Value,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct RawResponse {
|
|
// JSON RPC allows this to be null if it was impossible
|
|
// to decode the request's id. Ignore this special case
|
|
// and just die horribly.
|
|
pub id: u64,
|
|
#[serde(default)]
|
|
pub result: Value,
|
|
#[serde(default)]
|
|
pub error: Value,
|
|
}
|
|
|
|
struct MsgReceiver {
|
|
chan: Receiver<RawMsg>,
|
|
thread: Option<thread::JoinHandle<Result<()>>>,
|
|
}
|
|
|
|
impl MsgReceiver {
|
|
fn recv(&mut self) -> Result<RawMsg> {
|
|
match self.chan.recv() {
|
|
Some(msg) => Ok(msg),
|
|
None => {
|
|
self.cleanup()?;
|
|
unreachable!()
|
|
}
|
|
}
|
|
}
|
|
|
|
fn cleanup(&mut self) -> Result<()> {
|
|
self.thread
|
|
.take()
|
|
.ok_or_else(|| format_err!("MsgReceiver thread panicked"))?
|
|
.join()
|
|
.map_err(|_| format_err!("MsgReceiver thread panicked"))??;
|
|
bail!("client disconnected")
|
|
}
|
|
|
|
fn stop(self) -> Result<()> {
|
|
// Can't really self.thread.join() here, b/c it might be
|
|
// blocking on read
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
struct MsgSender {
|
|
chan: Sender<RawMsg>,
|
|
thread: thread::JoinHandle<Result<()>>,
|
|
}
|
|
|
|
impl MsgSender {
|
|
fn send(&mut self, msg: RawMsg) {
|
|
self.chan.send(msg)
|
|
}
|
|
|
|
fn stop(self) -> Result<()> {
|
|
drop(self.chan);
|
|
self.thread.join()
|
|
.map_err(|_| format_err!("MsgSender thread panicked"))??;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub struct Io {
|
|
receiver: MsgReceiver,
|
|
sender: MsgSender,
|
|
}
|
|
|
|
impl Io {
|
|
pub fn from_stdio() -> Io {
|
|
let sender = {
|
|
let (tx, rx) = bounded(16);
|
|
MsgSender {
|
|
chan: tx,
|
|
thread: thread::spawn(move || {
|
|
let stdout = stdout();
|
|
let mut stdout = stdout.lock();
|
|
for msg in rx {
|
|
#[derive(Serialize)]
|
|
struct JsonRpc {
|
|
jsonrpc: &'static str,
|
|
#[serde(flatten)]
|
|
msg: RawMsg,
|
|
}
|
|
let text = to_string(&JsonRpc {
|
|
jsonrpc: "2.0",
|
|
msg,
|
|
})?;
|
|
write_msg_text(&mut stdout, &text)?;
|
|
}
|
|
Ok(())
|
|
}),
|
|
}
|
|
};
|
|
let receiver = {
|
|
let (tx, rx) = bounded(16);
|
|
MsgReceiver {
|
|
chan: rx,
|
|
thread: Some(thread::spawn(move || {
|
|
let stdin = stdin();
|
|
let mut stdin = stdin.lock();
|
|
while let Some(text) = read_msg_text(&mut stdin)? {
|
|
let msg: RawMsg = from_str(&text)?;
|
|
tx.send(msg);
|
|
}
|
|
Ok(())
|
|
})),
|
|
}
|
|
};
|
|
Io { receiver, sender }
|
|
}
|
|
|
|
pub fn send(&mut self, msg: RawMsg) {
|
|
self.sender.send(msg)
|
|
}
|
|
|
|
pub fn recv(&mut self) -> Result<RawMsg> {
|
|
self.receiver.recv()
|
|
}
|
|
|
|
pub fn receiver(&mut self) -> &mut Receiver<RawMsg> {
|
|
&mut self.receiver.chan
|
|
}
|
|
|
|
pub fn cleanup_receiver(&mut self) -> Result<()> {
|
|
self.receiver.cleanup()
|
|
}
|
|
|
|
pub fn stop(self) -> Result<()> {
|
|
self.receiver.stop()?;
|
|
self.sender.stop()?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
fn read_msg_text(inp: &mut impl BufRead) -> Result<Option<String>> {
|
|
let mut size = None;
|
|
let mut buf = String::new();
|
|
loop {
|
|
buf.clear();
|
|
if inp.read_line(&mut buf)? == 0 {
|
|
return Ok(None);
|
|
}
|
|
if !buf.ends_with("\r\n") {
|
|
bail!("malformed header: {:?}", buf);
|
|
}
|
|
let buf = &buf[..buf.len() - 2];
|
|
if buf.is_empty() {
|
|
break;
|
|
}
|
|
let mut parts = buf.splitn(2, ": ");
|
|
let header_name = parts.next().unwrap();
|
|
let header_value = parts.next().ok_or_else(|| format_err!("malformed header: {:?}", buf))?;
|
|
if header_name == "Content-Length" {
|
|
size = Some(header_value.parse::<usize>()?);
|
|
}
|
|
}
|
|
let size = size.ok_or_else(|| format_err!("no Content-Length"))?;
|
|
let mut buf = buf.into_bytes();
|
|
buf.resize(size, 0);
|
|
inp.read_exact(&mut buf)?;
|
|
let buf = String::from_utf8(buf)?;
|
|
debug!("< {}", buf);
|
|
Ok(Some(buf))
|
|
}
|
|
|
|
fn write_msg_text(out: &mut impl Write, msg: &str) -> Result<()> {
|
|
debug!("> {}", msg);
|
|
write!(out, "Content-Length: {}\r\n\r\n", msg.len())?;
|
|
out.write_all(msg.as_bytes())?;
|
|
out.flush()?;
|
|
Ok(())
|
|
}
|