diff --git a/crates/ra_proc_macro/src/lib.rs b/crates/ra_proc_macro/src/lib.rs index a0a478dc8cf..51fbb046a97 100644 --- a/crates/ra_proc_macro/src/lib.rs +++ b/crates/ra_proc_macro/src/lib.rs @@ -9,7 +9,7 @@ mod rpc; mod process; pub mod msg; -use process::ProcMacroProcessSrv; +use process::{ProcMacroProcessSrv, ProcMacroProcessThread}; use ra_tt::{SmolStr, Subtree}; use rpc::ProcMacroKind; use std::{ @@ -45,21 +45,23 @@ impl ra_tt::TokenExpander for ProcMacroProcessExpander { } } -#[derive(Debug, Clone)] +#[derive(Debug)] enum ProcMacroClientKind { - Process { process: Arc }, + Process { process: Arc, thread: ProcMacroProcessThread }, Dummy, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ProcMacroClient { kind: ProcMacroClientKind, } impl ProcMacroClient { pub fn extern_process(process_path: &Path) -> Result { - let process = ProcMacroProcessSrv::run(process_path)?; - Ok(ProcMacroClient { kind: ProcMacroClientKind::Process { process: Arc::new(process) } }) + let (thread, process) = ProcMacroProcessSrv::run(process_path)?; + Ok(ProcMacroClient { + kind: ProcMacroClientKind::Process { process: Arc::new(process), thread }, + }) } pub fn dummy() -> ProcMacroClient { @@ -72,7 +74,7 @@ impl ProcMacroClient { ) -> Vec<(SmolStr, Arc)> { match &self.kind { ProcMacroClientKind::Dummy => vec![], - ProcMacroClientKind::Process { process } => { + ProcMacroClientKind::Process { process, .. } => { let macros = match process.find_proc_macros(dylib_path) { Err(err) => { eprintln!("Fail to find proc macro. Error: {:#?}", err); diff --git a/crates/ra_proc_macro/src/process.rs b/crates/ra_proc_macro/src/process.rs index 6a3fe2e2025..d028b365c55 100644 --- a/crates/ra_proc_macro/src/process.rs +++ b/crates/ra_proc_macro/src/process.rs @@ -11,7 +11,7 @@ use std::{ io::{self, Write}, path::{Path, PathBuf}, process::{Child, Command, Stdio}, - thread::spawn, + thread::{spawn, JoinHandle}, }; #[derive(Debug, Default)] @@ -19,9 +19,15 @@ pub(crate) struct ProcMacroProcessSrv { inner: Option, } -struct Task { - req: Message, - result_tx: Sender, +#[derive(Debug)] +pub(crate) struct ProcMacroProcessThread { + handle: Option>, + sender: Sender, +} + +enum Task { + Request { req: Message, result_tx: Sender }, + Close, } #[derive(Debug)] @@ -60,16 +66,33 @@ impl Process { } } +impl std::ops::Drop for ProcMacroProcessThread { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + let _ = self.sender.send(Task::Close); + + // Join the thread, it should finish shortly. We don't really care + // whether it panicked, so it is safe to ignore the result + let _ = handle.join(); + } + } +} + impl ProcMacroProcessSrv { - pub fn run(process_path: &Path) -> Result { + pub fn run( + process_path: &Path, + ) -> Result<(ProcMacroProcessThread, ProcMacroProcessSrv), io::Error> { let process = Process::run(process_path)?; let (task_tx, task_rx) = bounded(0); - - let _ = spawn(move || { + let handle = spawn(move || { client_loop(task_rx, process); }); - Ok(ProcMacroProcessSrv { inner: Some(Handle { sender: task_tx }) }) + + let srv = ProcMacroProcessSrv { inner: Some(Handle { sender: task_tx.clone() }) }; + let thread = ProcMacroProcessThread { handle: Some(handle), sender: task_tx }; + + Ok((thread, srv)) } pub fn find_proc_macros( @@ -117,7 +140,12 @@ impl ProcMacroProcessSrv { let (result_tx, result_rx) = bounded(0); - handle.sender.send(Task { req: req.into(), result_tx }).unwrap(); + handle.sender.send(Task::Request { req: req.into(), result_tx }).map_err(|err| { + ra_tt::ExpansionError::Unknown(format!( + "Fail to send task in channel, reason : {:#?} ", + err + )) + })?; let response = result_rx.recv().unwrap(); match response { @@ -155,7 +183,12 @@ fn client_loop(task_rx: Receiver, mut process: Process) { Err(_) => break, }; - let res = match send_message(&mut stdin, &mut stdout, task.req) { + let (req, result_tx) = match task { + Task::Request { req, result_tx } => (req, result_tx), + Task::Close => break, + }; + + let res = match send_message(&mut stdin, &mut stdout, req) { Ok(res) => res, Err(_err) => { let res = Response { @@ -167,7 +200,7 @@ fn client_loop(task_rx: Receiver, mut process: Process) { data: None, }), }; - if task.result_tx.send(res.into()).is_err() { + if result_tx.send(res.into()).is_err() { break; } // Restart the process @@ -185,7 +218,7 @@ fn client_loop(task_rx: Receiver, mut process: Process) { }; if let Some(res) = res { - if task.result_tx.send(res).is_err() { + if result_tx.send(res).is_err() { break; } }