]> OzVa Git service - ozva-cloud/commitdiff
feat: support unix sockets (#145)
authorsigoden <sigoden@gmail.com>
Fri, 11 Nov 2022 00:57:44 +0000 (08:57 +0800)
committerGitHub <noreply@github.com>
Fri, 11 Nov 2022 00:57:44 +0000 (08:57 +0800)
README.md
src/args.rs
src/main.rs
src/server.rs
src/unix.rs [new file with mode: 0644]

index 4aef7a4319a19f7d778610531803b120b5e26713..1dfabaf6e5c7db887b306d39421e6b84a8c77188 100644 (file)
--- a/README.md
+++ b/README.md
@@ -49,7 +49,7 @@ ARGS:
     <root>    Specific path to serve [default: .]
 
 OPTIONS:
-    -b, --bind <addr>...         Specify bind address
+    -b, --bind <addr>...         Specify bind address or unix socket
     -p, --port <port>            Specify port to listen on [default: 5000]
         --path-prefix <path>     Specify a path prefix
         --hidden <value>         Hide paths from directory listings, separated by `,`
@@ -123,10 +123,15 @@ Require username/password
 dufs -a /@admin:123
 ```
 
-Listen on a specific port
+Listen on specific host:ip 
 
 ```
-dufs -p 80
+dufs -b 127.0.0.1 -p 80
+```
+
+Listen on unix socket
+```
+dufs -b /tmp/dufs.socket
 ```
 
 Use https
index 358a23eceb30b7479a7eb1cf1836773e79ee35f1..9799d096173f349ac757a512a88a31f976914e9c 100644 (file)
@@ -28,7 +28,7 @@ pub fn build_cli() -> Command<'static> {
             Arg::new("bind")
                 .short('b')
                 .long("bind")
-                .help("Specify bind address")
+                .help("Specify bind address or unix socket")
                 .multiple_values(true)
                 .value_delimiter(',')
                 .action(ArgAction::Append)
@@ -168,7 +168,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
 
 #[derive(Debug)]
 pub struct Args {
-    pub addrs: Vec<IpAddr>,
+    pub addrs: Vec<BindAddr>,
     pub port: u16,
     pub path: PathBuf,
     pub path_is_file: bool,
@@ -204,7 +204,7 @@ impl Args {
             .values_of("bind")
             .map(|v| v.collect())
             .unwrap_or_else(|| vec!["0.0.0.0", "::"]);
-        let addrs: Vec<IpAddr> = Args::parse_addrs(&addrs)?;
+        let addrs: Vec<BindAddr> = Args::parse_addrs(&addrs)?;
         let path = Args::parse_path(matches.value_of_os("root").unwrap_or_default())?;
         let path_is_file = path.metadata()?.is_file();
         let path_prefix = matches
@@ -281,23 +281,27 @@ impl Args {
         })
     }
 
-    fn parse_addrs(addrs: &[&str]) -> BoxResult<Vec<IpAddr>> {
-        let mut ip_addrs = vec![];
+    fn parse_addrs(addrs: &[&str]) -> BoxResult<Vec<BindAddr>> {
+        let mut bind_addrs = vec![];
         let mut invalid_addrs = vec![];
         for addr in addrs {
             match addr.parse::<IpAddr>() {
                 Ok(v) => {
-                    ip_addrs.push(v);
+                    bind_addrs.push(BindAddr::Address(v));
                 }
                 Err(_) => {
-                    invalid_addrs.push(*addr);
+                    if cfg!(unix) {
+                        bind_addrs.push(BindAddr::Path(PathBuf::from(addr)));
+                    } else {
+                        invalid_addrs.push(*addr);
+                    }
                 }
             }
         }
         if !invalid_addrs.is_empty() {
             return Err(format!("Invalid bind address `{}`", invalid_addrs.join(",")).into());
         }
-        Ok(ip_addrs)
+        Ok(bind_addrs)
     }
 
     fn parse_path<P: AsRef<Path>>(path: P) -> BoxResult<PathBuf> {
@@ -322,3 +326,9 @@ impl Args {
         Ok(path)
     }
 }
+
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+pub enum BindAddr {
+    Address(IpAddr),
+    Path(PathBuf),
+}
index ed85e8a3cb87ead53be6a843bc4db712061e5910..5df341556bcb9c81e676eeb7eeea065481303d1c 100644 (file)
@@ -6,6 +6,8 @@ mod server;
 mod streamer;
 #[cfg(feature = "tls")]
 mod tls;
+#[cfg(unix)]
+mod unix;
 mod utils;
 
 #[macro_use]
@@ -20,6 +22,7 @@ use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Arc;
 
+use args::BindAddr;
 use clap_complete::Shell;
 use futures::future::join_all;
 use tokio::net::TcpListener;
@@ -75,11 +78,9 @@ fn serve(
     let inner = Arc::new(Server::new(args.clone(), running));
     let mut handles = vec![];
     let port = args.port;
-    for ip in args.addrs.iter() {
+    for bind_addr in args.addrs.iter() {
         let inner = inner.clone();
-        let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
-            .map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?;
-        let serve_func = move |remote_addr: SocketAddr| {
+        let serve_func = move |remote_addr: Option<SocketAddr>| {
             let inner = inner.clone();
             async move {
                 Ok::<_, hyper::Error>(service_fn(move |req: Request| {
@@ -88,35 +89,57 @@ fn serve(
                 }))
             }
         };
-        match args.tls.as_ref() {
-            #[cfg(feature = "tls")]
-            Some((certs, key)) => {
-                let config = ServerConfig::builder()
-                    .with_safe_defaults()
-                    .with_no_client_auth()
-                    .with_single_cert(certs.clone(), key.clone())?;
-                let config = Arc::new(config);
-                let accepter = TlsAcceptor::new(config.clone(), incoming);
-                let new_service = make_service_fn(move |socket: &TlsStream| {
-                    let remote_addr = socket.remote_addr();
-                    serve_func(remote_addr)
-                });
-                let server = tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
-                handles.push(server);
+        match bind_addr {
+            BindAddr::Address(ip) => {
+                let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
+                    .map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?;
+                match args.tls.as_ref() {
+                    #[cfg(feature = "tls")]
+                    Some((certs, key)) => {
+                        let config = ServerConfig::builder()
+                            .with_safe_defaults()
+                            .with_no_client_auth()
+                            .with_single_cert(certs.clone(), key.clone())?;
+                        let config = Arc::new(config);
+                        let accepter = TlsAcceptor::new(config.clone(), incoming);
+                        let new_service = make_service_fn(move |socket: &TlsStream| {
+                            let remote_addr = socket.remote_addr();
+                            serve_func(Some(remote_addr))
+                        });
+                        let server =
+                            tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
+                        handles.push(server);
+                    }
+                    #[cfg(not(feature = "tls"))]
+                    Some(_) => {
+                        unreachable!()
+                    }
+                    None => {
+                        let new_service = make_service_fn(move |socket: &AddrStream| {
+                            let remote_addr = socket.remote_addr();
+                            serve_func(Some(remote_addr))
+                        });
+                        let server =
+                            tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
+                        handles.push(server);
+                    }
+                };
             }
-            #[cfg(not(feature = "tls"))]
-            Some(_) => {
-                unreachable!()
-            }
-            None => {
-                let new_service = make_service_fn(move |socket: &AddrStream| {
-                    let remote_addr = socket.remote_addr();
-                    serve_func(remote_addr)
-                });
-                let server = tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
-                handles.push(server);
+            BindAddr::Path(path) => {
+                if path.exists() {
+                    std::fs::remove_file(path)?;
+                }
+                #[cfg(unix)]
+                {
+                    let listener = tokio::net::UnixListener::bind(path)
+                        .map_err(|e| format!("Failed to bind `{}`, {}", path.display(), e))?;
+                    let acceptor = unix::UnixAcceptor::from_listener(listener);
+                    let new_service = make_service_fn(move |_| serve_func(None));
+                    let server = tokio::spawn(hyper::Server::builder(acceptor).serve(new_service));
+                    handles.push(server);
+                }
             }
-        };
+        }
     }
     Ok(handles)
 }
@@ -137,17 +160,22 @@ fn create_addr_incoming(addr: SocketAddr) -> BoxResult<AddrIncoming> {
 }
 
 fn print_listening(args: Arc<Args>) -> BoxResult<()> {
-    let mut addrs = vec![];
+    let mut bind_addrs = vec![];
     let (mut ipv4, mut ipv6) = (false, false);
-    for ip in args.addrs.iter() {
-        if ip.is_unspecified() {
-            if ip.is_ipv6() {
-                ipv6 = true;
-            } else {
-                ipv4 = true;
+    for bind_addr in args.addrs.iter() {
+        match bind_addr {
+            BindAddr::Address(ip) => {
+                if ip.is_unspecified() {
+                    if ip.is_ipv6() {
+                        ipv6 = true;
+                    } else {
+                        ipv4 = true;
+                    }
+                } else {
+                    bind_addrs.push(bind_addr.clone());
+                }
             }
-        } else {
-            addrs.push(*ip);
+            _ => bind_addrs.push(bind_addr.clone()),
         }
     }
     if ipv4 || ipv6 {
@@ -156,25 +184,27 @@ fn print_listening(args: Arc<Args>) -> BoxResult<()> {
         for iface in ifaces.into_iter() {
             let local_ip = iface.ip();
             if ipv4 && local_ip.is_ipv4() {
-                addrs.push(local_ip)
+                bind_addrs.push(BindAddr::Address(local_ip))
             }
             if ipv6 && local_ip.is_ipv6() {
-                addrs.push(local_ip)
+                bind_addrs.push(BindAddr::Address(local_ip))
             }
         }
     }
-    addrs.sort_unstable();
-    let urls = addrs
+    bind_addrs.sort_unstable();
+    let urls = bind_addrs
         .into_iter()
-        .map(|addr| match addr {
-            IpAddr::V4(_) => format!("{}:{}", addr, args.port),
-            IpAddr::V6(_) => format!("[{}]:{}", addr, args.port),
-        })
-        .map(|addr| match &args.tls {
-            Some(_) => format!("https://{}", addr),
-            None => format!("http://{}", addr),
+        .map(|bind_addr| match bind_addr {
+            BindAddr::Address(addr) => {
+                let addr = match addr {
+                    IpAddr::V4(_) => format!("{}:{}", addr, args.port),
+                    IpAddr::V6(_) => format!("[{}]:{}", addr, args.port),
+                };
+                let protocol = if args.tls.is_some() { "https" } else { "http" };
+                format!("{}://{}{}", protocol, addr, args.uri_prefix)
+            }
+            BindAddr::Path(path) => path.display().to_string(),
         })
-        .map(|url| format!("{}{}", url, args.uri_prefix))
         .collect::<Vec<_>>();
 
     if urls.len() == 1 {
index a32b2195c1e737ac90c2785d2fd05ef4f5e8c139..291ec0ae5874372ce618d99e2853d56995125417 100644 (file)
@@ -84,13 +84,15 @@ impl Server {
     pub async fn call(
         self: Arc<Self>,
         req: Request,
-        addr: SocketAddr,
+        addr: Option<SocketAddr>,
     ) -> Result<Response, hyper::Error> {
         let uri = req.uri().clone();
         let assets_prefix = self.assets_prefix.clone();
         let enable_cors = self.args.enable_cors;
         let mut http_log_data = self.args.log_http.data(&req, &self.args);
-        http_log_data.insert("remote_addr".to_string(), addr.ip().to_string());
+        if let Some(addr) = addr {
+            http_log_data.insert("remote_addr".to_string(), addr.ip().to_string());
+        }
 
         let mut res = match self.clone().handle(req).await {
             Ok(res) => {
diff --git a/src/unix.rs b/src/unix.rs
new file mode 100644 (file)
index 0000000..b8b1710
--- /dev/null
@@ -0,0 +1,31 @@
+use hyper::server::accept::Accept;
+use tokio::net::UnixListener;
+
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pub struct UnixAcceptor {
+    inner: UnixListener,
+}
+
+impl UnixAcceptor {
+    pub fn from_listener(listener: UnixListener) -> Self {
+        Self { inner: listener }
+    }
+}
+
+impl Accept for UnixAcceptor {
+    type Conn = tokio::net::UnixStream;
+    type Error = std::io::Error;
+
+    fn poll_accept(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
+        match self.inner.poll_accept(cx) {
+            Poll::Pending => Poll::Pending,
+            Poll::Ready(Ok((socket, _addr))) => Poll::Ready(Some(Ok(socket))),
+            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
+        }
+    }
+}