]> OzVa Git service - ozva-cloud/commitdiff
feat: support ipv6 (#25)
authorsigoden <sigoden@gmail.com>
Mon, 6 Jun 2022 02:52:12 +0000 (10:52 +0800)
committerGitHub <noreply@github.com>
Mon, 6 Jun 2022 02:52:12 +0000 (10:52 +0800)
src/args.rs
src/server.rs

index 55e70e8e290f58c7ab6ba90eea805c1ce4c3f7d7..6466d65a71e29a663505d77cb650cea308b373c7 100644 (file)
@@ -1,7 +1,7 @@
 use clap::crate_description;
 use clap::{Arg, ArgMatches};
 use rustls::{Certificate, PrivateKey};
-use std::net::SocketAddr;
+use std::net::{IpAddr, SocketAddr};
 use std::path::{Path, PathBuf};
 use std::{env, fs, io};
 
@@ -111,8 +111,7 @@ pub fn matches() -> ArgMatches {
 
 #[derive(Debug, Clone, Eq, PartialEq)]
 pub struct Args {
-    pub address: String,
-    pub port: u16,
+    pub addr: SocketAddr,
     pub path: PathBuf,
     pub path_prefix: String,
     pub uri_prefix: String,
@@ -133,8 +132,9 @@ impl Args {
     /// If a parsing error ocurred, exit the process and print out informative
     /// error message to user.
     pub fn parse(matches: ArgMatches) -> BoxResult<Args> {
-        let address = matches.value_of("address").unwrap_or_default().to_owned();
+        let ip = matches.value_of("address").unwrap_or_default();
         let port = matches.value_of_t::<u16>("port")?;
+        let addr = to_addr(ip, port)?;
         let path = Args::parse_path(matches.value_of_os("path").unwrap_or_default())?;
         let path_prefix = matches
             .value_of("path-prefix")
@@ -166,8 +166,7 @@ impl Args {
         };
 
         Ok(Args {
-            address,
-            port,
+            addr,
             path,
             path_prefix,
             uri_prefix,
@@ -197,17 +196,15 @@ impl Args {
             })
             .map_err(|err| format!("Failed to access path `{}`: {}", path.display(), err,).into())
     }
+}
 
-    /// Construct socket address from arguments.
-    pub fn address(&self) -> BoxResult<SocketAddr> {
-        format!("{}:{}", self.address, self.port)
-            .parse()
-            .map_err(|_| format!("Invalid bind address `{}:{}`", self.address, self.port).into())
-    }
+fn to_addr(ip: &str, port: u16) -> BoxResult<SocketAddr> {
+    let ip: IpAddr = ip.parse()?;
+    Ok(SocketAddr::new(ip, port))
 }
 
 // Load public certificate from file.
-pub fn load_certs(filename: &str) -> BoxResult<Vec<Certificate>> {
+fn load_certs(filename: &str) -> BoxResult<Vec<Certificate>> {
     // Open certificate file.
     let certfile =
         fs::File::open(&filename).map_err(|e| format!("Failed to open {}: {}", &filename, e))?;
@@ -222,7 +219,7 @@ pub fn load_certs(filename: &str) -> BoxResult<Vec<Certificate>> {
 }
 
 // Load private key from file.
-pub fn load_private_key(filename: &str) -> BoxResult<PrivateKey> {
+fn load_private_key(filename: &str) -> BoxResult<PrivateKey> {
     // Open keyfile.
     let keyfile =
         fs::File::open(&filename).map_err(|e| format!("Failed to open {}: {}", &filename, e))?;
index 28dfd92c8e97e88bc006658c3517e812eadd46a5..5faff80bc710a96aca2017942641d01dae9a2f78 100644 (file)
@@ -25,7 +25,7 @@ use rustls::ServerConfig;
 use serde::Serialize;
 use std::convert::Infallible;
 use std::fs::Metadata;
-use std::net::IpAddr;
+use std::net::{IpAddr, SocketAddr};
 use std::path::{Path, PathBuf};
 use std::sync::Arc;
 use std::time::SystemTime;
@@ -56,7 +56,6 @@ macro_rules! status {
 
 pub async fn serve(args: Args) -> BoxResult<()> {
     let args = Arc::new(args);
-    let socket_addr = args.address()?;
     let inner = Arc::new(InnerService::new(args.clone()));
     match args.tls.clone() {
         Some((certs, key)) => {
@@ -66,7 +65,7 @@ pub async fn serve(args: Args) -> BoxResult<()> {
                 .with_single_cert(certs, key)?;
             let tls_acceptor = TlsAcceptor::from(Arc::new(config));
             let arc_acceptor = Arc::new(tls_acceptor);
-            let listener = TcpListener::bind(&socket_addr).await?;
+            let listener = TcpListener::bind(&args.addr).await?;
             let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
             let incoming =
                 hyper::server::accept::from_stream(incoming.filter_map(|socket| async {
@@ -87,11 +86,11 @@ pub async fn serve(args: Args) -> BoxResult<()> {
                     }))
                 }
             }));
-            print_listening(args.address.as_str(), args.port, &args.uri_prefix, true);
+            print_listening(&args.addr, &args.uri_prefix, true);
             server.await?;
         }
         None => {
-            let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| {
+            let server = hyper::Server::try_bind(&args.addr)?.serve(make_service_fn(move |_| {
                 let inner = inner.clone();
                 async move {
                     Ok::<_, Infallible>(service_fn(move |req| {
@@ -100,7 +99,7 @@ pub async fn serve(args: Args) -> BoxResult<()> {
                     }))
                 }
             }));
-            print_listening(args.address.as_str(), args.port, &args.uri_prefix, false);
+            print_listening(&args.addr, &args.uri_prefix, false);
             server.await?;
         }
     }
@@ -974,37 +973,45 @@ fn to_content_range(range: &Range, complete_length: u64) -> Option<ContentRange>
     })
 }
 
-fn print_listening(address: &str, port: u16, prefix: &str, tls: bool) {
+fn print_listening(addr: &SocketAddr, prefix: &str, tls: bool) {
     let prefix = encode_uri(prefix.trim_end_matches('/'));
-    let addrs = retrieve_listening_addrs(address);
+    let addrs = retrieve_listening_addrs(addr);
     let protocol = if tls { "https" } else { "http" };
     if addrs.len() == 1 {
-        eprintln!(
-            "Listening on {}://{}:{}{}",
-            protocol, addrs[0], port, prefix
-        );
+        eprintln!("Listening on {}://{}{}", protocol, addr, prefix);
     } else {
         eprintln!("Listening on:");
         for addr in addrs {
-            eprintln!("  {}://{}:{}{}", protocol, addr, port, prefix);
+            eprintln!("  {}://{}{}", protocol, addr, prefix);
         }
         eprintln!();
     }
 }
 
-fn retrieve_listening_addrs(address: &str) -> Vec<String> {
-    if address == "0.0.0.0" {
+fn retrieve_listening_addrs(addr: &SocketAddr) -> Vec<SocketAddr> {
+    let ip = addr.ip();
+    let port = addr.port();
+    if ip.is_unspecified() {
         if let Ok(interfaces) = get_if_addrs() {
             let mut ifaces: Vec<IpAddr> = interfaces
                 .into_iter()
                 .map(|v| v.ip())
-                .filter(|v| v.is_ipv4())
+                .filter(|v| {
+                    if ip.is_ipv4() {
+                        v.is_ipv4()
+                    } else {
+                        v.is_ipv6()
+                    }
+                })
                 .collect();
             ifaces.sort();
-            return ifaces.into_iter().map(|v| v.to_string()).collect();
+            return ifaces
+                .into_iter()
+                .map(|v| SocketAddr::new(v, port))
+                .collect();
         }
     }
-    vec![address.to_owned()]
+    vec![addr.to_owned()]
 }
 
 fn encode_uri(v: &str) -> String {