]> OzVa Git service - ozva-cloud/commitdiff
feat: tls handshake timeout (#368)
authorsigoden <sigoden@gmail.com>
Fri, 8 Mar 2024 02:29:12 +0000 (10:29 +0800)
committerGitHub <noreply@github.com>
Fri, 8 Mar 2024 02:29:12 +0000 (10:29 +0800)
src/main.rs

index 6d669f659fb58573a12b203a7dff3945c1bb4f0a..298de0429ba174453622722e614a891e47ccf10c 100644 (file)
@@ -29,6 +29,8 @@ use std::sync::{
     atomic::{AtomicBool, Ordering},
     Arc,
 };
+use std::time::Duration;
+use tokio::time::timeout;
 use tokio::{net::TcpListener, task::JoinHandle};
 #[cfg(feature = "tls")]
 use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
@@ -91,12 +93,19 @@ fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
                         config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
                         let config = Arc::new(config);
                         let tls_accepter = TlsAcceptor::from(config);
+                        let handshake_timeout = Duration::from_secs(10);
 
                         let handle = tokio::spawn(async move {
                             loop {
-                                let (cnx, addr) = listener.accept().await.unwrap();
-                                let Ok(stream) = tls_accepter.accept(cnx).await else {
-                                    warn!("During tls handshake connection from {}", addr);
+                                let Ok((stream, addr)) = listener.accept().await else {
+                                    continue;
+                                };
+                                let Some(stream) =
+                                    timeout(handshake_timeout, tls_accepter.accept(stream))
+                                        .await
+                                        .ok()
+                                        .and_then(|v| v.ok())
+                                else {
                                     continue;
                                 };
                                 let stream = TokioIo::new(stream);
@@ -113,8 +122,10 @@ fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
                     (None, None) => {
                         let handle = tokio::spawn(async move {
                             loop {
-                                let (cnx, addr) = listener.accept().await.unwrap();
-                                let stream = TokioIo::new(cnx);
+                                let Ok((stream, addr)) = listener.accept().await else {
+                                    continue;
+                                };
+                                let stream = TokioIo::new(stream);
                                 tokio::spawn(handle_stream(
                                     server_handle.clone(),
                                     stream,
@@ -139,8 +150,10 @@ fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
                         .with_context(|| format!("Failed to bind `{}`", path.display()))?;
                     let handle = tokio::spawn(async move {
                         loop {
-                            let (cnx, _) = listener.accept().await.unwrap();
-                            let stream = TokioIo::new(cnx);
+                            let Ok((stream, _addr)) = listener.accept().await else {
+                                continue;
+                            };
+                            let stream = TokioIo::new(stream);
                             tokio::spawn(handle_stream(server_handle.clone(), stream, None));
                         }
                     });
@@ -160,18 +173,15 @@ where
     let hyper_service =
         service_fn(move |request: Request<Incoming>| handle.clone().call(request, addr));
 
-    let ret = Builder::new(TokioExecutor::new())
+    match Builder::new(TokioExecutor::new())
         .serve_connection_with_upgrades(stream, hyper_service)
-        .await;
-
-    if let Err(err) = ret {
-        let scope = match addr {
-            Some(addr) => format!(" from {}", addr),
-            None => String::new(),
-        };
-        match err.downcast_ref::<std::io::Error>() {
-            Some(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}
-            _ => warn!("Serving connection{}: {}", scope, err),
+        .await
+    {
+        Ok(()) => {}
+        Err(_err) => {
+            // This error only appears when the client doesn't send a request and terminate the connection.
+            //
+            // If client sends one request then terminate connection whenever, it doesn't appear.
         }
     }
 }