use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
use clap_complete::{generate, Generator, Shell};
use serde::{Deserialize, Deserializer};
+use smart_default::SmartDefault;
use std::env;
use std::net::IpAddr;
use std::path::{Path, PathBuf};
.env("DUFS_HIDDEN")
.hide_env(true)
.long("hidden")
+ .action(ArgAction::Append)
+ .value_delimiter(',')
.help("Hide paths from directory listings, e.g. tmp,*.log,*.lock")
.value_name("value"),
)
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
}
-#[derive(Debug, Deserialize, Default)]
+#[derive(Debug, Deserialize, SmartDefault, PartialEq)]
#[serde(default)]
#[serde(rename_all = "kebab-case")]
pub struct Args {
#[serde(default = "default_serve_path")]
+ #[default(default_serve_path())]
pub serve_path: PathBuf,
#[serde(deserialize_with = "deserialize_bind_addrs")]
#[serde(rename = "bind")]
#[serde(default = "default_addrs")]
+ #[default(default_addrs())]
pub addrs: Vec<BindAddr>,
#[serde(default = "default_port")]
+ #[default(default_port())]
pub port: u16,
#[serde(skip)]
pub path_is_file: bool,
pub path_prefix: String,
#[serde(skip)]
pub uri_prefix: String,
+ #[serde(deserialize_with = "deserialize_string_or_vec")]
pub hidden: Vec<String>,
#[serde(deserialize_with = "deserialize_access_control")]
pub auth: AccessControl,
/// If a parsing error occurred, exit the process and print out informative
/// error message to user.
pub fn parse(matches: ArgMatches) -> Result<Args> {
- let mut args = Self {
- serve_path: default_serve_path(),
- addrs: default_addrs(),
- port: default_port(),
- ..Default::default()
- };
+ let mut args = Self::default();
if let Some(config_path) = matches.get_one::<PathBuf>("config") {
let contents = std::fs::read_to_string(config_path)
if let Some(path) = matches.get_one::<PathBuf>("serve-path") {
args.serve_path = path.clone()
}
+
args.serve_path = Self::sanitize_path(args.serve_path)?;
if let Some(port) = matches.get_one::<u16>("port") {
format!("/{}/", &encode_uri(&args.path_prefix))
};
- if let Some(hidden) = matches
- .get_one::<String>("hidden")
- .map(|v| v.split(',').map(|x| x.to_string()).collect())
- {
- args.hidden = hidden;
+ if let Some(hidden) = matches.get_many::<String>("hidden") {
+ args.hidden = hidden.cloned().collect();
+ } else {
+ let mut hidden = vec![];
+ std::mem::swap(&mut args.hidden, &mut hidden);
+ args.hidden = hidden
+ .into_iter()
+ .flat_map(|v| v.split(',').map(|v| v.to_string()).collect::<Vec<String>>())
+ .collect();
}
if !args.enable_cors {
where
D: Deserializer<'de>,
{
- let addrs: Vec<&str> = Vec::deserialize(deserializer)?;
- BindAddr::parse_addrs(&addrs).map_err(serde::de::Error::custom)
+ struct StringOrVec;
+
+ impl<'de> serde::de::Visitor<'de> for StringOrVec {
+ type Value = Vec<BindAddr>;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+ formatter.write_str("string or list of strings")
+ }
+
+ fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
+ where
+ E: serde::de::Error,
+ {
+ BindAddr::parse_addrs(&[s]).map_err(serde::de::Error::custom)
+ }
+
+ fn visit_seq<S>(self, seq: S) -> Result<Self::Value, S::Error>
+ where
+ S: serde::de::SeqAccess<'de>,
+ {
+ let addrs: Vec<&'de str> =
+ Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?;
+ BindAddr::parse_addrs(&addrs).map_err(serde::de::Error::custom)
+ }
+ }
+
+ deserializer.deserialize_any(StringOrVec)
+}
+
+fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
+where
+ D: Deserializer<'de>,
+{
+ struct StringOrVec;
+
+ impl<'de> serde::de::Visitor<'de> for StringOrVec {
+ type Value = Vec<String>;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+ formatter.write_str("string or list of strings")
+ }
+
+ fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
+ where
+ E: serde::de::Error,
+ {
+ Ok(vec![s.to_owned()])
+ }
+
+ fn visit_seq<S>(self, seq: S) -> Result<Self::Value, S::Error>
+ where
+ S: serde::de::SeqAccess<'de>,
+ {
+ Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
+ }
+ }
+
+ deserializer.deserialize_any(StringOrVec)
}
fn deserialize_access_control<'de, D>(deserializer: D) -> Result<AccessControl, D::Error>
fn default_port() -> u16 {
5000
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use assert_fs::prelude::*;
+
+ #[test]
+ fn test_default() {
+ let cli = build_cli();
+ let matches = cli.try_get_matches_from(vec![""]).unwrap();
+ let args = Args::parse(matches).unwrap();
+ let cwd = Args::sanitize_path(std::env::current_dir().unwrap()).unwrap();
+ assert_eq!(args.serve_path, cwd);
+ assert_eq!(args.port, default_port());
+ assert_eq!(args.addrs, default_addrs());
+ }
+
+ #[test]
+ fn test_args_from_cli1() {
+ let tmpdir = assert_fs::TempDir::new().unwrap();
+ let cli = build_cli();
+ let matches = cli
+ .try_get_matches_from(vec![
+ "",
+ "--hidden",
+ "tmp,*.log,*.lock",
+ &tmpdir.to_string_lossy(),
+ ])
+ .unwrap();
+ let args = Args::parse(matches).unwrap();
+ assert_eq!(args.serve_path, Args::sanitize_path(&tmpdir).unwrap());
+ assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]);
+ }
+
+ #[test]
+ fn test_args_from_cli2() {
+ let cli = build_cli();
+ let matches = cli
+ .try_get_matches_from(vec![
+ "", "--hidden", "tmp", "--hidden", "*.log", "--hidden", "*.lock",
+ ])
+ .unwrap();
+ let args = Args::parse(matches).unwrap();
+ assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]);
+ }
+
+ #[test]
+ fn test_args_from_empty_config_file() {
+ let tmpdir = assert_fs::TempDir::new().unwrap();
+ let config_file = tmpdir.child("config.yaml");
+ config_file.write_str("").unwrap();
+
+ let cli = build_cli();
+ let matches = cli
+ .try_get_matches_from(vec!["", "-c", &config_file.to_string_lossy()])
+ .unwrap();
+ let args = Args::parse(matches).unwrap();
+ let cwd = Args::sanitize_path(std::env::current_dir().unwrap()).unwrap();
+ assert_eq!(args.serve_path, cwd);
+ assert_eq!(args.port, default_port());
+ assert_eq!(args.addrs, default_addrs());
+ }
+
+ #[test]
+ fn test_args_from_config_file1() {
+ let tmpdir = assert_fs::TempDir::new().unwrap();
+ let config_file = tmpdir.child("config.yaml");
+ let contents = format!(
+ r#"
+serve-path: {}
+bind: 0.0.0.0
+port: 3000
+allow-upload: true
+hidden: tmp,*.log,*.lock
+"#,
+ tmpdir.display()
+ );
+ config_file.write_str(&contents).unwrap();
+
+ let cli = build_cli();
+ let matches = cli
+ .try_get_matches_from(vec!["", "-c", &config_file.to_string_lossy()])
+ .unwrap();
+ let args = Args::parse(matches).unwrap();
+ assert_eq!(args.serve_path, Args::sanitize_path(&tmpdir).unwrap());
+ assert_eq!(
+ args.addrs,
+ vec![BindAddr::Address("0.0.0.0".parse().unwrap())]
+ );
+ assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]);
+ assert_eq!(args.port, 3000);
+ assert!(args.allow_upload);
+ }
+
+ #[test]
+ fn test_args_from_config_file2() {
+ let tmpdir = assert_fs::TempDir::new().unwrap();
+ let config_file = tmpdir.child("config.yaml");
+ let contents = r#"
+bind:
+ - 127.0.0.1
+ - 192.168.8.10
+hidden:
+ - tmp
+ - '*.log'
+ - '*.lock'
+"#;
+ config_file.write_str(contents).unwrap();
+
+ let cli = build_cli();
+ let matches = cli
+ .try_get_matches_from(vec!["", "-c", &config_file.to_string_lossy()])
+ .unwrap();
+ let args = Args::parse(matches).unwrap();
+ assert_eq!(
+ args.addrs,
+ vec![
+ BindAddr::Address("127.0.0.1".parse().unwrap()),
+ BindAddr::Address("192.168.8.10".parse().unwrap())
+ ]
+ );
+ assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]);
+ }
+}