diff --git a/src/pty.rs b/src/pty.rs index 07dd195..69ee3bd 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -1,10 +1,20 @@ use nix::pty::{openpty, OpenptyResult}; +use nix::sys::signal::{signal, SigHandler, Signal}; +use nix::sys::wait::{waitpid, WaitStatus}; use nix::unistd::{execvp, fork, ForkResult, Pid}; use std::ffi::{CStr, CString}; use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd}; +use std::sync::atomic::{AtomicBool, Ordering}; use crate::error::{Error, Result}; +static SIGWINCH_RECEIVED: AtomicBool = AtomicBool::new(false); + +extern "C" fn sigwinch_handler(_: libc::c_int) { + // SAFETY: AtomicBool::store is async-signal-safe. + SIGWINCH_RECEIVED.store(true, Ordering::Relaxed); +} + pub struct PtySpawner { pub master: OwnedFd, pub child_pid: Pid, @@ -79,6 +89,144 @@ impl PtySpawner { } } } + + /// Forward SIGWINCH to the child PTY, relay I/O between the master fd and + /// stdin/stdout, wait for the child to exit, and return its exit code. + pub fn relay(&self) -> Result { + // Install SIGWINCH handler — sigwinch_handler only touches SIGWINCH_RECEIVED, + // which is async-signal-safe. + unsafe { + signal(Signal::SIGWINCH, SigHandler::Handler(sigwinch_handler)) + .map_err(|e| Error::Internal(anyhow::anyhow!("signal(SIGWINCH) failed: {e}")))?; + } + + let master_fd = self.master.as_raw_fd(); + let mut buf = [0u8; 4096]; + let mut stdin_open = true; + + 'relay: loop { + // Apply any pending window-size change to the master PTY. + if SIGWINCH_RECEIVED.swap(false, Ordering::Relaxed) { + let ws = get_winsize(libc::STDIN_FILENO); + // SAFETY: master_fd is a valid PTY master fd; TIOCSWINSZ is a write ioctl. + unsafe { + libc::ioctl(master_fd, libc::TIOCSWINSZ, &ws); + } + } + + let mut fds = [ + libc::pollfd { + fd: master_fd, + events: libc::POLLIN, + revents: 0, + }, + libc::pollfd { + fd: libc::STDIN_FILENO, + events: if stdin_open { libc::POLLIN } else { 0 }, + revents: 0, + }, + ]; + + // 100 ms timeout so SIGWINCH is handled promptly even if poll is not interrupted. + let ret = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, 100) }; + + if ret < 0 { + if nix::errno::Errno::last() == nix::errno::Errno::EINTR { + continue; + } + break 'relay; + } + + // Drain PTY master output → caller's stdout. + let master_rev = fds[0].revents; + if master_rev & libc::POLLIN != 0 { + let n = unsafe { + libc::read(master_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) + }; + if n < 0 { + if nix::errno::Errno::last() == nix::errno::Errno::EINTR { + continue; + } + break 'relay; // EIO when child has closed the slave side + } + if n == 0 { + break 'relay; + } + let mut off = 0usize; + let n = n as usize; + while off < n { + let w = unsafe { + libc::write( + libc::STDOUT_FILENO, + buf[off..].as_ptr() as *const libc::c_void, + n - off, + ) + }; + if w <= 0 { + break 'relay; + } + off += w as usize; + } + } + if master_rev & (libc::POLLHUP | libc::POLLERR) != 0 { + break 'relay; + } + + // Forward caller's stdin → PTY master (child input). + if stdin_open { + let stdin_rev = fds[1].revents; + if stdin_rev & libc::POLLIN != 0 { + let n = unsafe { + libc::read( + libc::STDIN_FILENO, + buf.as_mut_ptr() as *mut libc::c_void, + buf.len(), + ) + }; + if n <= 0 { + stdin_open = false; + } else { + let mut off = 0usize; + let n = n as usize; + while off < n { + let w = unsafe { + libc::write( + master_fd, + buf[off..].as_ptr() as *const libc::c_void, + n - off, + ) + }; + if w <= 0 { + break 'relay; + } + off += w as usize; + } + } + } + if stdin_rev & libc::POLLHUP != 0 { + stdin_open = false; + } + } + } + + // Restore default SIGWINCH handling. + unsafe { + let _ = signal(Signal::SIGWINCH, SigHandler::SigDfl); + } + + // Wait for child exit and surface the exit code. + loop { + match waitpid(self.child_pid, None) { + Ok(WaitStatus::Exited(_, code)) => return Ok(code), + Ok(WaitStatus::Signaled(_, sig, _)) => return Ok(128 + sig as i32), + Ok(_) => continue, + Err(nix::errno::Errno::EINTR) => continue, + Err(e) => { + return Err(Error::Internal(anyhow::anyhow!("waitpid failed: {e}"))) + } + } + } + } } #[cfg(test)] @@ -97,4 +245,55 @@ mod tests { other => panic!("unexpected wait status: {other:?}"), } } + + #[test] + fn master_fd_carries_child_stdout() { + let cmd = CString::new("echo").unwrap(); + let args = vec![CString::new("hello").unwrap()]; + let spawner = PtySpawner::spawn(&cmd, &args).expect("spawn should succeed"); + + let master_fd = spawner.master.as_raw_fd(); + let mut output = Vec::new(); + let mut buf = [0u8; 256]; + + loop { + let n = unsafe { + libc::read(master_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) + }; + if n <= 0 { + break; + } + output.extend_from_slice(&buf[..n as usize]); + } + + let _ = waitpid(spawner.child_pid, None); + + // PTY translates \n → \r\n; verify the text is present. + let text = String::from_utf8_lossy(&output); + assert!( + text.contains("hello"), + "expected 'hello' in PTY output, got: {text:?}" + ); + } + + #[test] + fn relay_echo_exits_zero_and_produces_output() { + let cmd = CString::new("echo").unwrap(); + let args = vec![CString::new("relay-test").unwrap()]; + let spawner = PtySpawner::spawn(&cmd, &args).expect("spawn should succeed"); + let code = spawner.relay().expect("relay should succeed"); + assert_eq!(code, 0, "echo should exit with code 0"); + } + + #[test] + fn relay_surfaces_nonzero_exit_code() { + let cmd = CString::new("/bin/sh").unwrap(); + let args = vec![ + CString::new("-c").unwrap(), + CString::new("exit 42").unwrap(), + ]; + let spawner = PtySpawner::spawn(&cmd, &args).expect("spawn should succeed"); + let code = spawner.relay().expect("relay should succeed"); + assert_eq!(code, 42, "exit code should be 42"); + } }