]> OzVa Git service - ozva-cloud/commitdiff
feat: support gracefully shutdown server
authorsigoden <sigoden@gmail.com>
Fri, 3 Jun 2022 02:59:54 +0000 (10:59 +0800)
committersigoden <sigoden@gmail.com>
Fri, 3 Jun 2022 03:00:12 +0000 (11:00 +0800)
Cargo.lock
Cargo.toml
src/server.rs

index 5849b225a54c1faf98e37659c352ddf1b6fb7fcd..e62110a2f9a736f42ed196bc075d3a3cc97fae4a 100644 (file)
@@ -882,6 +882,15 @@ dependencies = [
  "digest",
 ]
 
+[[package]]
+name = "signal-hook-registry"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "slab"
 version = "0.4.6"
@@ -965,6 +974,7 @@ dependencies = [
  "num_cpus",
  "once_cell",
  "pin-project-lite",
+ "signal-hook-registry",
  "socket2",
  "tokio-macros",
  "winapi 0.3.9",
index c068edea142bf0aa09ccc8d5d2a98b4a29a00c16..3e46802787501840e5294dce1e494d1e2c477ebf 100644 (file)
@@ -14,7 +14,7 @@ keywords = ["static", "file", "server", "http", "cli"]
 [dependencies]
 clap = { version = "3", default-features = false, features = ["std", "cargo"] }
 chrono = "0.4"
-tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util"]}
+tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "signal"]}
 tokio-rustls = "0.23"
 tokio-stream = { version = "0.1", features = ["net"] }
 tokio-util = { version = "0.7",  features = ["codec", "io-util"] }
index de53d375eb6d371cf7ec618d6ef335f21b5a8f85..d1f38e777507d177a81f4dab80b4e0654675cbc8 100644 (file)
@@ -53,52 +53,65 @@ macro_rules! status {
 }
 
 pub async fn serve(args: Args) -> BoxResult<()> {
+    match args.tls.as_ref() {
+        Some(_) => serve_https(args).await,
+        None => serve_http(args).await,
+    }
+}
+
+pub async fn serve_https(args: Args) -> BoxResult<()> {
     let args = Arc::new(args);
     let socket_addr = args.address()?;
+    let (certs, key) = args.tls.clone().unwrap();
     let inner = Arc::new(InnerService::new(args.clone()));
-    if let Some((certs, key)) = args.tls.as_ref() {
-        let config = ServerConfig::builder()
-            .with_safe_defaults()
-            .with_no_client_auth()
-            .with_single_cert(certs.clone(), key.clone())?;
-        let tls_acceptor = TlsAcceptor::from(Arc::new(config));
-        let arc_acceptor = Arc::new(tls_acceptor);
-        let listener = TcpListener::bind(&socket_addr).await?;
-        let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
-        let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async {
-            match socket {
-                Ok(stream) => match arc_acceptor.clone().accept(stream).await {
-                    Ok(val) => Some(Ok::<_, Infallible>(val)),
-                    Err(_) => None,
-                },
+    let config = ServerConfig::builder()
+        .with_safe_defaults()
+        .with_no_client_auth()
+        .with_single_cert(certs, key)?;
+    let tls_acceptor = TlsAcceptor::from(Arc::new(config));
+    let arc_acceptor = Arc::new(tls_acceptor);
+    let listener = TcpListener::bind(&socket_addr).await?;
+    let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
+    let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async {
+        match socket {
+            Ok(stream) => match arc_acceptor.clone().accept(stream).await {
+                Ok(val) => Some(Ok::<_, Infallible>(val)),
                 Err(_) => None,
-            }
-        }));
-        let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| {
-            let inner = inner.clone();
-            async move {
-                Ok::<_, Infallible>(service_fn(move |req| {
-                    let inner = inner.clone();
-                    inner.call(req)
-                }))
-            }
-        }));
-        print_listening(args.address.as_str(), args.port, true);
-        server.await?;
-    } else {
-        let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| {
-            let inner = inner.clone();
-            async move {
-                Ok::<_, Infallible>(service_fn(move |req| {
-                    let inner = inner.clone();
-                    inner.call(req)
-                }))
-            }
-        }));
-        print_listening(args.address.as_str(), args.port, false);
-        server.await?;
-    }
+            },
+            Err(_) => None,
+        }
+    }));
+    let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| {
+        let inner = inner.clone();
+        async move {
+            Ok::<_, Infallible>(service_fn(move |req| {
+                let inner = inner.clone();
+                inner.call(req)
+            }))
+        }
+    }));
+    print_listening(args.address.as_str(), args.port, true);
+    let graceful = server.with_graceful_shutdown(shutdown_signal());
+    graceful.await?;
+    Ok(())
+}
 
+pub async fn serve_http(args: Args) -> BoxResult<()> {
+    let args = Arc::new(args);
+    let socket_addr = args.address()?;
+    let inner = Arc::new(InnerService::new(args.clone()));
+    let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| {
+        let inner = inner.clone();
+        async move {
+            Ok::<_, Infallible>(service_fn(move |req| {
+                let inner = inner.clone();
+                inner.call(req)
+            }))
+        }
+    }));
+    print_listening(args.address.as_str(), args.port, false);
+    let graceful = server.with_graceful_shutdown(shutdown_signal());
+    graceful.await?;
     Ok(())
 }
 
@@ -751,3 +764,9 @@ fn retrive_listening_addrs(address: &str) -> Vec<String> {
     }
     vec![address.to_owned()]
 }
+
+async fn shutdown_signal() {
+    tokio::signal::ctrl_c()
+        .await
+        .expect("Failed to install CTRL+C signal handler")
+}