diff --git a/src/main.rs b/src/main.rs index 201b1ce..cea0c48 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,8 +17,8 @@ use { Profile, providers::{Env, Format as _, Toml}, }, - http::ContentType, - response::{self, Responder, stream::ByteStream}, + http::{ContentType, uri::Origin}, + response::{self, Redirect, Responder, stream::ByteStream}, serde::Serialize, }, rocket_dyn_templates::{Template, context}, @@ -28,6 +28,7 @@ use { enum FileView { Folder(Template), + Redirect(Redirect), File(ByteStream>), } @@ -38,6 +39,7 @@ impl<'r> Responder<'r, 'r> for FileView { r.set_header(ContentType::HTML); r }), + Self::Redirect(redirect) => redirect.respond_to(req), Self::File(stream) => stream.respond_to(req), } } @@ -75,12 +77,16 @@ impl From for Error { } #[rocket::get("/")] -async fn index_root(state: &State) -> Result { - index(None, state).await +async fn index_root(uri: &Origin<'_>, state: &State) -> Result { + index(None, uri, state).await } #[rocket::get("/")] -async fn index(path: Option, state: &State) -> Result { +async fn index( + path: Option, + uri: &Origin<'_>, + state: &State, +) -> Result { let object_path = if let Some(url_path) = path.as_ref() { let s = url_path.to_str().ok_or(Error::InvalidRequest( "Path cannot be converted to UTF-8".into(), @@ -91,31 +97,42 @@ async fn index(path: Option, state: &State) -> Result Result { diff --git a/tests/integration.rs b/tests/integration.rs index 01c64cd..83f2170 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -8,10 +8,7 @@ use { scraper::{Html, Selector}, }; -#[test_log::test(tokio::test(flavor = "multi_thread"))] -async fn serves_files() -> anyhow::Result<()> { - let test = common::Test::new().await?; - +async fn create_sample_files(test: &common::Test) -> anyhow::Result<()> { test.bucket .put( &ObjectStorePath::from("file.txt"), @@ -26,6 +23,14 @@ async fn serves_files() -> anyhow::Result<()> { ) .await?; + Ok(()) +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn serves_files() -> anyhow::Result<()> { + let test = common::Test::new().await?; + create_sample_files(&test).await?; + let resp = reqwest::get(test.base_url.join("file.txt")?).await?; assert_eq!(resp.bytes().await?, "I am a file"); @@ -38,19 +43,7 @@ async fn serves_files() -> anyhow::Result<()> { #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn serves_top_level_folder() -> anyhow::Result<()> { let test = common::Test::new().await?; - - test.bucket - .put( - &ObjectStorePath::from("file.txt"), - PutPayload::from_static("I am a file".as_bytes()), - ) - .await?; - test.bucket - .put( - &ObjectStorePath::from("folder/file.txt"), - PutPayload::from_static("I am a file in a folder".as_bytes()), - ) - .await?; + create_sample_files(&test).await?; // Check that a file in the toplevel is listed: let resp = reqwest::get(test.base_url.clone()).await?; @@ -72,8 +65,11 @@ async fn serves_top_level_folder() -> anyhow::Result<()> { let selector = Selector::parse(r#"table > tbody > tr:nth-child(1) > td:first-child > a"#).unwrap(); for item in document.select(&selector) { - assert_eq!(item.attr("href"), Some("folder")); - assert_eq!(item.text().next(), Some("folder")); + // Folders should be listed ending with a slash, + // or HTTP gets confused. This is also due to the + // normalization we do on the path in the main program. + assert_eq!(item.attr("href"), Some("folder/")); + assert_eq!(item.text().next(), Some("folder/")); } let selector = @@ -89,20 +85,7 @@ async fn serves_top_level_folder() -> anyhow::Result<()> { #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn serves_second_level_folder() -> anyhow::Result<()> { let test = common::Test::new().await?; - - test.bucket - .put( - &ObjectStorePath::from("file.txt"), - PutPayload::from_static("I am a file".as_bytes()), - ) - .await?; - - test.bucket - .put( - &ObjectStorePath::from("folder/file.txt"), - PutPayload::from_static("I am a file in a folder".as_bytes()), - ) - .await?; + create_sample_files(&test).await?; // Check that a file in the second level is listed: let resp = reqwest::get(test.base_url.join("folder/")?).await?; @@ -140,19 +123,7 @@ async fn serves_second_level_folder() -> anyhow::Result<()> { #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn serves_second_level_folder_without_ending_slash() -> anyhow::Result<()> { let test = common::Test::new().await?; - - test.bucket - .put( - &ObjectStorePath::from("file.txt"), - PutPayload::from_static("I am a file".as_bytes()), - ) - .await?; - test.bucket - .put( - &ObjectStorePath::from("folder/file.txt"), - PutPayload::from_static("I am a file in a folder".as_bytes()), - ) - .await?; + create_sample_files(&test).await?; // Check that a file in the second level is listed even without an ending slash: let resp = reqwest::get(test.base_url.join("folder")?).await?; @@ -161,6 +132,10 @@ async fn serves_second_level_folder_without_ending_slash() -> anyhow::Result<()> "Request failed with {}", resp.status() ); + + // Ensure we were redirected to a URL ending with a slash + assert!(resp.url().path().ends_with("/")); + let text = resp.text().await?; println!("{}", &text); let document = Html::parse_document(&text);