diff --git a/src/main.rs b/src/main.rs index 0faa37f..61fa989 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::{ fs, io::Write, net::{SocketAddr, TcpListener, TcpStream}, - path::Path, + path::{Path, PathBuf}, result, str::FromStr, }; @@ -23,7 +23,7 @@ fn main() -> Result<()> { let ip = args.ip.unwrap_or_default(); let port = args.port.unwrap_or_default(); - let export = args.export; + let export = fs::canonicalize(args.export).expect("failed to get absolute export path"); let allowed = args.allowed_devices; let socket = SocketAddr::from_str(&format!("{ip}:{port}")).unwrap(); @@ -41,7 +41,7 @@ fn main() -> Result<()> { fn handle_connection( mut conn: TcpStream, - export: String, + export: PathBuf, allowed_devices: Option>, ) -> Result<()> { let mut paths = vec![]; @@ -66,6 +66,11 @@ fn handle_connection( walk_dir(&export, &mut paths); + let export_path = export + .as_path() + .to_str() + .expect("invalid export path, aborting"); + let files_sent = paths.len(); log::info!("Sending {files_sent} files to {remote_ip}"); @@ -73,7 +78,7 @@ fn handle_connection( buffer.write_usize(files_sent); for file in paths { - let mut path = file.replace(&export, ""); + let mut path = file.replace(export_path, ""); if path.starts_with("/") { path.remove(0); }