]> OzVa Git service - ozva-cloud/commitdiff
fix: range request (#44)
authorsigoden <sigoden@gmail.com>
Thu, 16 Jun 2022 02:24:32 +0000 (10:24 +0800)
committerGitHub <noreply@github.com>
Thu, 16 Jun 2022 02:24:32 +0000 (10:24 +0800)
close #43

Cargo.lock
Cargo.toml
src/main.rs
src/server.rs
src/streamer.rs [new file with mode: 0644]
tests/range.rs [new file with mode: 0644]

index 7738f2372cba70a837719aa22f094dbaf1d5167f..b4acd97f5756840eb04a2999beb6a764210b98b5 100644 (file)
@@ -188,6 +188,27 @@ dependencies = [
  "wasm-bindgen-futures",
 ]
 
+[[package]]
+name = "async-stream"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
+dependencies = [
+ "async-stream-impl",
+ "futures-core",
+]
+
+[[package]]
+name = "async-stream-impl"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
 [[package]]
 name = "async-task"
 version = "4.2.0"
@@ -554,6 +575,7 @@ version = "0.17.0"
 dependencies = [
  "assert_cmd",
  "assert_fs",
+ "async-stream",
  "async-walkdir",
  "async_zip",
  "base64",
index 6d1665149ab88420ba1f25813f7ec4af216fa934..86708941d6a726c9fe98005f123adc975033325f 100644 (file)
@@ -15,7 +15,7 @@ clap = { version = "3", default-features = false, features = ["std"] }
 chrono = "0.4"
 tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "signal"]}
 tokio-rustls = "0.23"
-tokio-util = { version = "0.7",  features = ["codec", "io-util"] }
+tokio-util = { version = "0.7",  features = ["io-util"] }
 hyper = { version = "0.14", features = ["http1", "server", "tcp", "stream"] }
 percent-encoding = "2.1"
 serde = { version = "1", features = ["derive"] }
@@ -37,6 +37,7 @@ xml-rs = "0.8"
 env_logger = { version = "0.9", default-features = false, features = ["humantime"] }
 log = "0.4"
 socket2 = "0.4"
+async-stream = "0.3"
 
 [dev-dependencies]
 assert_cmd = "2"
index a44acd9fb3807e9c2024176c1a2a2ac4df35c497..30d4ac8f4bfefc4f2f2d552751e844aed114bcfd 100644 (file)
@@ -1,6 +1,7 @@
 mod args;
 mod auth;
 mod server;
+mod streamer;
 mod tls;
 
 #[macro_use]
index 2aee8435115257e6e8d325ac24cc6aceee4dc911..23b899b770de5a8c3cbc08fdf87103aa54ded230 100644 (file)
@@ -1,4 +1,5 @@
 use crate::auth::{generate_www_auth, valid_digest};
+use crate::streamer::Streamer;
 use crate::{Args, BoxResult};
 use xml::escape::escape_str_pcdata;
 
@@ -10,26 +11,26 @@ use futures::stream::StreamExt;
 use futures::TryStreamExt;
 use headers::{
     AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders,
-    AccessControlAllowOrigin, Connection, ContentLength, ContentRange, ContentType, ETag,
-    HeaderMap, HeaderMapExt, IfModifiedSince, IfNoneMatch, IfRange, LastModified, Range,
+    AccessControlAllowOrigin, Connection, ContentLength, ContentType, ETag, HeaderMap,
+    HeaderMapExt, IfModifiedSince, IfNoneMatch, IfRange, LastModified, Range,
 };
 use hyper::header::{
-    HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE, ORIGIN, RANGE,
-    WWW_AUTHENTICATE,
+    HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_RANGE,
+    CONTENT_TYPE, ORIGIN, RANGE, WWW_AUTHENTICATE,
 };
 use hyper::{Body, Method, StatusCode, Uri};
 use percent_encoding::percent_decode;
 use serde::Serialize;
 use std::fs::Metadata;
+use std::io::SeekFrom;
 use std::net::SocketAddr;
 use std::path::{Path, PathBuf};
 use std::sync::Arc;
 use std::time::SystemTime;
 use tokio::fs::File;
-use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWrite};
+use tokio::io::{AsyncSeekExt, AsyncWrite};
 use tokio::{fs, io};
-use tokio_util::codec::{BytesCodec, FramedRead};
-use tokio_util::io::{ReaderStream, StreamReader};
+use tokio_util::io::StreamReader;
 use uuid::Uuid;
 
 pub type Request = hyper::Request<Body>;
@@ -40,7 +41,7 @@ const INDEX_CSS: &str = include_str!("../assets/index.css");
 const INDEX_JS: &str = include_str!("../assets/index.js");
 const FAVICON_ICO: &[u8] = include_bytes!("../assets/favicon.ico");
 const INDEX_NAME: &str = "index.html";
-const BUF_SIZE: usize = 1024 * 16;
+const BUF_SIZE: usize = 65536;
 
 pub struct Server {
     args: Arc<Args>,
@@ -353,8 +354,8 @@ impl Server {
                 error!("Failed to zip {}, {}", path.display(), e);
             }
         });
-        let stream = ReaderStream::new(reader);
-        *res.body_mut() = Body::wrap_stream(stream);
+        let reader = Streamer::new(reader, BUF_SIZE);
+        *res.body_mut() = Body::wrap_stream(reader.into_stream());
         Ok(())
     }
 
@@ -425,7 +426,7 @@ impl Server {
     ) -> BoxResult<()> {
         let (file, meta) = tokio::join!(fs::File::open(path), fs::metadata(path),);
         let (mut file, meta) = (file?, meta?);
-        let mut maybe_range = true;
+        let mut use_range = true;
         if let Some((etag, last_modified)) = extract_cache_headers(&meta) {
             let cached = {
                 if let Some(if_none_match) = headers.typed_get::<IfNoneMatch>() {
@@ -436,55 +437,77 @@ impl Server {
                     false
                 }
             };
-            res.headers_mut().typed_insert(last_modified);
-            res.headers_mut().typed_insert(etag.clone());
             if cached {
                 *res.status_mut() = StatusCode::NOT_MODIFIED;
                 return Ok(());
             }
+
+            res.headers_mut().typed_insert(last_modified);
+            res.headers_mut().typed_insert(etag.clone());
+
             if headers.typed_get::<Range>().is_some() {
-                maybe_range = headers
+                use_range = headers
                     .typed_get::<IfRange>()
                     .map(|if_range| !if_range.is_modified(Some(&etag), Some(&last_modified)))
                     // Always be fresh if there is no validators
                     .unwrap_or(true);
             } else {
-                maybe_range = false;
+                use_range = false;
             }
         }
-        let file_range = if maybe_range {
-            if let Some(content_range) = headers
-                .typed_get::<Range>()
-                .and_then(|range| to_content_range(&range, meta.len()))
-            {
-                res.headers_mut().typed_insert(content_range.clone());
-                *res.status_mut() = StatusCode::PARTIAL_CONTENT;
-                content_range.bytes_range()
-            } else {
-                None
-            }
+
+        let range = if use_range {
+            parse_range(headers)
         } else {
             None
         };
+
         if let Some(mime) = mime_guess::from_path(&path).first() {
             res.headers_mut().typed_insert(ContentType::from(mime));
+        } else {
+            res.headers_mut().insert(
+                CONTENT_TYPE,
+                HeaderValue::from_static("application/octet-stream"),
+            );
         }
+
         res.headers_mut().typed_insert(AcceptRanges::bytes());
-        res.headers_mut()
-            .typed_insert(ContentLength(meta.len() as u64));
-        if head_only {
-            return Ok(());
-        }
 
-        let body = if let Some((begin, end)) = file_range {
-            file.seek(io::SeekFrom::Start(begin)).await?;
-            let stream = FramedRead::new(file.take(end - begin + 1), BytesCodec::new());
-            Body::wrap_stream(stream)
+        let size = meta.len();
+
+        if let Some(range) = range {
+            if range
+                .end
+                .map_or_else(|| range.start < size, |v| v >= range.start)
+                && file.seek(SeekFrom::Start(range.start)).await.is_ok()
+            {
+                let end = range.end.unwrap_or(size - 1).min(size - 1);
+                let part_size = end - range.start + 1;
+                let reader = Streamer::new(file, BUF_SIZE);
+                *res.status_mut() = StatusCode::PARTIAL_CONTENT;
+                let content_range = format!("bytes {}-{}/{}", range.start, end, size);
+                res.headers_mut()
+                    .insert(CONTENT_RANGE, content_range.parse().unwrap());
+                res.headers_mut()
+                    .insert(CONTENT_LENGTH, format!("{}", part_size).parse().unwrap());
+                if head_only {
+                    return Ok(());
+                }
+                *res.body_mut() = Body::wrap_stream(reader.into_stream_sized(part_size));
+            } else {
+                *res.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE;
+                res.headers_mut()
+                    .insert(CONTENT_RANGE, format!("bytes */{}", size).parse().unwrap());
+            }
         } else {
-            let stream = FramedRead::new(file, BytesCodec::new());
-            Body::wrap_stream(stream)
-        };
-        *res.body_mut() = body;
+            res.headers_mut()
+                .insert(CONTENT_LENGTH, format!("{}", size).parse().unwrap());
+            if head_only {
+                return Ok(());
+            }
+            let reader = Streamer::new(file, BUF_SIZE);
+            *res.body_mut() = Body::wrap_stream(reader.into_stream());
+        }
         Ok(())
     }
 
@@ -965,32 +988,34 @@ fn extract_cache_headers(meta: &Metadata) -> Option<(ETag, LastModified)> {
     Some((etag, last_modified))
 }
 
-fn to_content_range(range: &Range, complete_length: u64) -> Option<ContentRange> {
-    use core::ops::Bound::{Included, Unbounded};
-    let mut iter = range.iter();
-    let bounds = iter.next();
+#[derive(Debug)]
+struct RangeValue {
+    start: u64,
+    end: Option<u64>,
+}
 
-    if iter.next().is_some() {
-        // Found multiple byte-range-spec. Drop.
-        return None;
+fn parse_range(headers: &HeaderMap<HeaderValue>) -> Option<RangeValue> {
+    let range_hdr = headers.get(RANGE)?;
+    let hdr = range_hdr.to_str().ok()?;
+    let mut sp = hdr.splitn(2, '=');
+    let units = sp.next().unwrap();
+    if units == "bytes" {
+        let range = sp.next()?;
+        let mut sp_range = range.splitn(2, '-');
+        let start: u64 = sp_range.next().unwrap().parse().ok()?;
+        let end: Option<u64> = if let Some(end) = sp_range.next() {
+            if end.is_empty() {
+                None
+            } else {
+                Some(end.parse().ok()?)
+            }
+        } else {
+            None
+        };
+        Some(RangeValue { start, end })
+    } else {
+        None
     }
-
-    bounds.and_then(|b| match b {
-        (Included(start), Included(end)) if start <= end && start < complete_length => {
-            ContentRange::bytes(
-                start..=end.min(complete_length.saturating_sub(1)),
-                complete_length,
-            )
-            .ok()
-        }
-        (Included(start), Unbounded) if start < complete_length => {
-            ContentRange::bytes(start.., complete_length).ok()
-        }
-        (Unbounded, Included(end)) if end > 0 => {
-            ContentRange::bytes(complete_length.saturating_sub(end).., complete_length).ok()
-        }
-        _ => None,
-    })
 }
 
 fn encode_uri(v: &str) -> String {
diff --git a/src/streamer.rs b/src/streamer.rs
new file mode 100644 (file)
index 0000000..163b36f
--- /dev/null
@@ -0,0 +1,68 @@
+use async_stream::stream;
+use futures::{Stream, StreamExt};
+use std::io::Error;
+use std::pin::Pin;
+use tokio::io::{AsyncRead, AsyncReadExt};
+
+pub struct Streamer<R>
+where
+    R: AsyncRead + Unpin + Send + 'static,
+{
+    reader: R,
+    buf_size: usize,
+}
+
+impl<R> Streamer<R>
+where
+    R: AsyncRead + Unpin + Send + 'static,
+{
+    #[inline]
+    pub fn new(reader: R, buf_size: usize) -> Self {
+        Self { reader, buf_size }
+    }
+    pub fn into_stream(
+        mut self,
+    ) -> Pin<Box<impl ?Sized + Stream<Item = Result<Vec<u8>, Error>> + 'static>> {
+        let stream = stream! {
+            loop {
+                let mut buf = vec![0; self.buf_size];
+                let r = self.reader.read(&mut buf).await?;
+                if r == 0 {
+                    break
+                }
+                buf.truncate(r);
+                yield Ok(buf);
+            }
+        };
+        stream.boxed()
+    }
+    // allow truncation as truncated remaining is always less than buf_size: usize
+    pub fn into_stream_sized(
+        mut self,
+        max_length: u64,
+    ) -> Pin<Box<impl ?Sized + Stream<Item = Result<Vec<u8>, Error>> + 'static>> {
+        let stream = stream! {
+        let mut remaining = max_length;
+            loop {
+                if remaining == 0 {
+                    break;
+                }
+                let bs = if remaining >= self.buf_size as u64 {
+                    self.buf_size
+                } else {
+                    remaining as usize
+                };
+                let mut buf = vec![0; bs];
+                let r = self.reader.read(&mut buf).await?;
+                if r == 0 {
+                    break;
+                } else {
+                    buf.truncate(r);
+                    yield Ok(buf);
+                }
+                remaining -= r as u64;
+            }
+        };
+        stream.boxed()
+    }
+}
diff --git a/tests/range.rs b/tests/range.rs
new file mode 100644 (file)
index 0000000..a2c9c50
--- /dev/null
@@ -0,0 +1,45 @@
+mod fixtures;
+mod utils;
+
+use fixtures::{server, Error, TestServer};
+use headers::HeaderValue;
+use rstest::rstest;
+
+#[rstest]
+fn get_file_range(server: TestServer) -> Result<(), Error> {
+    let resp = fetch!(b"GET", format!("{}index.html", server.url()))
+        .header("range", HeaderValue::from_static("bytes=0-6"))
+        .send()?;
+    assert_eq!(resp.status(), 206);
+    assert_eq!(resp.headers().get("content-range").unwrap(), "bytes 0-6/18");
+    assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes");
+    assert_eq!(resp.headers().get("content-length").unwrap(), "7");
+    assert_eq!(resp.text()?, "This is");
+    Ok(())
+}
+
+#[rstest]
+fn get_file_range_beyond(server: TestServer) -> Result<(), Error> {
+    let resp = fetch!(b"GET", format!("{}index.html", server.url()))
+        .header("range", HeaderValue::from_static("bytes=12-20"))
+        .send()?;
+    assert_eq!(resp.status(), 206);
+    assert_eq!(
+        resp.headers().get("content-range").unwrap(),
+        "bytes 12-17/18"
+    );
+    assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes");
+    assert_eq!(resp.headers().get("content-length").unwrap(), "6");
+    assert_eq!(resp.text()?, "x.html");
+    Ok(())
+}
+
+#[rstest]
+fn get_file_range_invalid(server: TestServer) -> Result<(), Error> {
+    let resp = fetch!(b"GET", format!("{}index.html", server.url()))
+        .header("range", HeaderValue::from_static("bytes=20-"))
+        .send()?;
+    assert_eq!(resp.status(), 416);
+    assert_eq!(resp.headers().get("content-range").unwrap(), "bytes */18");
+    Ok(())
+}