use std::{ fs, io::Write, net::{SocketAddr, TcpListener, TcpStream}, path::{Path, PathBuf}, result, str::FromStr, }; use clap::Parser; use crate::{buffer::ByteBuffer, cli::CliArgs}; mod buffer; mod cli; type Result = result::Result; fn main() -> Result<()> { env_logger::init(); let args = CliArgs::parse(); let ip = args.ip.unwrap_or_default(); let port = args.port.unwrap_or_default(); let export = fs::canonicalize(args.export).expect("failed to get absolute export path"); let allowed = args.allowed_devices; let socket = SocketAddr::from_str(&format!("{ip}:{port}")).unwrap(); let listener = TcpListener::bind(socket).unwrap(); for stream in listener.incoming() { match stream { Ok(conn) => handle_connection(conn, export.clone(), allowed.clone())?, Err(e) => eprintln!("Something went wrong while listening {e}"), } } Ok(()) } fn handle_connection( mut conn: TcpStream, export: PathBuf, allowed_devices: Option>, ) -> Result<()> { let mut paths = vec![]; let mut buffer = ByteBuffer::default(); let remote_ip = conn .peer_addr() .expect("Could not get remote IP address") .ip(); if let Some(allowed_devices) = allowed_devices { for allowed in allowed_devices { let allowed_net = ipnet::IpNet::from_str(&allowed).unwrap(); let is_allowed = allowed_net.contains(&remote_ip); if !is_allowed { log::error!("{remote_ip} tried to connect but is not allowed"); return Ok(()); } } } walk_dir(&export, &mut paths); let export_path = export .as_path() .to_str() .expect("invalid export path, aborting"); let files_sent = paths.len(); log::info!("Sending {files_sent} files to {remote_ip}"); buffer.write_usize(files_sent); for file in paths { let mut path = file.replace(export_path, ""); if path.starts_with("/") { path.remove(0); } log::debug!("Sending {path}"); buffer.write_string(&path); match fs::read(file) { Ok(data) => { buffer.write_usize(data.len()); buffer.write_bytes(&data); } Err(_) => { buffer.write_usize(0); eprintln!("No file found"); } }; } let _ = conn.write_all(&buffer); let _ = conn.flush(); Ok(()) } fn walk_dir>(path: P, file_paths: &mut Vec) { for entry in fs::read_dir(path).unwrap().flatten() { if entry.path().is_dir() { walk_dir(entry.path(), file_paths); continue; } let file_path = entry.path().to_str().unwrap().to_string(); file_paths.push(file_path); } } #[cfg(test)] mod tests { use std::io::Read; use super::*; const IP: &str = "0.0.0.0"; const PORT: &str = "9696"; #[test] fn test_connection() { let conn = TcpStream::connect(format!("{IP}:{PORT}")); assert!(conn.is_ok()); let mut conn = conn.unwrap(); let mut buffer = ByteBuffer::default(); let bytes_read = conn.read_to_end(&mut buffer); assert!(bytes_read.is_ok()); let file_name = buffer.read_string(); assert_eq!(file_name, "examples/test.md"); // let content = buffer.read_bytes(buffer.unread_len()); // println!("{}", String::from_utf8(content).unwrap()); } }