-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #139 from jbr/validate-before-continue
don't send 100-continue until the body has been read from
- Loading branch information
Showing
5 changed files
with
174 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use std::fmt; | ||
use std::pin::Pin; | ||
use std::task::{Context, Poll}; | ||
|
||
use async_std::io::{self, BufRead, Read}; | ||
use async_std::sync::Sender; | ||
|
||
pin_project_lite::pin_project! { | ||
/// ReadNotifier forwards [`async_std::io::Read`] and | ||
/// [`async_std::io::BufRead`] to an inner reader. When the | ||
/// ReadNotifier is read from (using `Read`, `ReadExt`, or | ||
/// `BufRead` methods), it sends a single message containing `()` | ||
/// on the channel. | ||
pub(crate) struct ReadNotifier<B> { | ||
#[pin] | ||
reader: B, | ||
sender: Sender<()>, | ||
has_been_read: bool | ||
} | ||
} | ||
|
||
impl<B> fmt::Debug for ReadNotifier<B> { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.debug_struct("ReadNotifier") | ||
.field("read", &self.has_been_read) | ||
.finish() | ||
} | ||
} | ||
|
||
impl<B: BufRead> ReadNotifier<B> { | ||
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self { | ||
Self { | ||
reader, | ||
sender, | ||
has_been_read: false, | ||
} | ||
} | ||
} | ||
|
||
impl<B: BufRead> BufRead for ReadNotifier<B> { | ||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { | ||
self.project().reader.poll_fill_buf(cx) | ||
} | ||
|
||
fn consume(self: Pin<&mut Self>, amt: usize) { | ||
self.project().reader.consume(amt) | ||
} | ||
} | ||
|
||
impl<B: Read> Read for ReadNotifier<B> { | ||
fn poll_read( | ||
self: Pin<&mut Self>, | ||
cx: &mut Context<'_>, | ||
buf: &mut [u8], | ||
) -> Poll<io::Result<usize>> { | ||
let this = self.project(); | ||
|
||
if !*this.has_been_read { | ||
if let Ok(()) = this.sender.try_send(()) { | ||
*this.has_been_read = true; | ||
}; | ||
} | ||
|
||
this.reader.poll_read(cx, buf) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
use async_dup::{Arc, Mutex}; | ||
use async_std::io::{Cursor, SeekFrom}; | ||
use async_std::{prelude::*, task}; | ||
use duplexify::Duplex; | ||
use http_types::Result; | ||
use std::time::Duration; | ||
|
||
const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\ | ||
Host: example.com\r\n\ | ||
Content-Length: 10\r\n\ | ||
Expect: 100-continue\r\n\r\n"; | ||
|
||
const SLEEP_DURATION: Duration = std::time::Duration::from_millis(100); | ||
#[async_std::test] | ||
async fn test_with_expect_when_reading_body() -> Result<()> { | ||
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec(); | ||
let server_str: Vec<u8> = vec![]; | ||
|
||
let mut client = Arc::new(Mutex::new(Cursor::new(client_str))); | ||
let server = Arc::new(Mutex::new(Cursor::new(server_str))); | ||
|
||
let mut request = async_h1::server::decode(Duplex::new(client.clone(), server.clone())) | ||
.await? | ||
.unwrap(); | ||
|
||
task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written | ||
|
||
{ | ||
let lock = server.lock(); | ||
assert_eq!("", std::str::from_utf8(lock.get_ref())?); //we haven't written yet | ||
}; | ||
|
||
let mut buf = vec![0u8; 1]; | ||
let bytes = request.read(&mut buf).await?; //this triggers the 100-continue even though there's nothing to read yet | ||
assert_eq!(bytes, 0); // normally we'd actually be waiting for the end of the buffer, but this lets us test this sequentially | ||
|
||
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel and io | ||
|
||
{ | ||
let lock = server.lock(); | ||
assert_eq!( | ||
"HTTP/1.1 100 Continue\r\n\r\n", | ||
std::str::from_utf8(lock.get_ref())? | ||
); | ||
}; | ||
|
||
client.write_all(b"0123456789").await?; | ||
client | ||
.seek(SeekFrom::Start(REQUEST_WITH_EXPECT.len() as u64)) | ||
.await?; | ||
|
||
assert_eq!("0123456789", request.body_string().await?); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[async_std::test] | ||
async fn test_without_expect_when_not_reading_body() -> Result<()> { | ||
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec(); | ||
let server_str: Vec<u8> = vec![]; | ||
|
||
let client = Arc::new(Mutex::new(Cursor::new(client_str))); | ||
let server = Arc::new(Mutex::new(Cursor::new(server_str))); | ||
|
||
async_h1::server::decode(Duplex::new(client.clone(), server.clone())) | ||
.await? | ||
.unwrap(); | ||
|
||
task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel | ||
|
||
let server_lock = server.lock(); | ||
assert_eq!("", std::str::from_utf8(server_lock.get_ref())?); // we haven't written 100-continue | ||
|
||
Ok(()) | ||
} |