Skip to content

Commit

Permalink
Merge pull request #139 from jbr/validate-before-continue
Browse files Browse the repository at this point in the history
don't send 100-continue until the body has been read from
  • Loading branch information
yoshuawuyts authored Sep 6, 2020
2 parents 9738d53 + de54e57 commit 1641e22
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 52 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ log = "0.4"

[dev-dependencies]
pretty_assertions = "0.6.1"
async-std = { version = "1.4.0", features = ["unstable", "attributes"] }
async-std = { version = "1.6.2", features = ["unstable", "attributes"] }
tempfile = "3.1.0"
async-test = "1.0.0"
duplexify = "1.2.1"
async-dup = "1.2.1"
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ const MAX_HEAD_LENGTH: usize = 8 * 1024;

mod chunked;
mod date;
mod read_notifier;

pub mod client;
pub mod server;
Expand Down
66 changes: 66 additions & 0 deletions src/read_notifier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_std::io::{self, BufRead, Read};
use async_std::sync::Sender;

pin_project_lite::pin_project! {
/// ReadNotifier forwards [`async_std::io::Read`] and
/// [`async_std::io::BufRead`] to an inner reader. When the
/// ReadNotifier is read from (using `Read`, `ReadExt`, or
/// `BufRead` methods), it sends a single message containing `()`
/// on the channel.
pub(crate) struct ReadNotifier<B> {
#[pin]
reader: B,
sender: Sender<()>,
has_been_read: bool
}
}

impl<B> fmt::Debug for ReadNotifier<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadNotifier")
.field("read", &self.has_been_read)
.finish()
}
}

impl<B: BufRead> ReadNotifier<B> {
pub(crate) fn new(reader: B, sender: Sender<()>) -> Self {
Self {
reader,
sender,
has_been_read: false,
}
}
}

impl<B: BufRead> BufRead for ReadNotifier<B> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
self.project().reader.poll_fill_buf(cx)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().reader.consume(amt)
}
}

impl<B: Read> Read for ReadNotifier<B> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();

if !*this.has_been_read {
if let Ok(()) = this.sender.try_send(()) {
*this.has_been_read = true;
};
}

this.reader.poll_read(cx, buf)
}
}
80 changes: 29 additions & 51 deletions src/server/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
use std::str::FromStr;

use async_std::io::{BufReader, Read, Write};
use async_std::prelude::*;
use async_std::{prelude::*, sync, task};
use http_types::headers::{CONTENT_LENGTH, EXPECT, TRANSFER_ENCODING};
use http_types::{ensure, ensure_eq, format_err};
use http_types::{Body, Method, Request, Url};

use crate::chunked::ChunkedDecoder;
use crate::read_notifier::ReadNotifier;
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};

const LF: u8 = b'\n';

/// The number returned from httparse when the request is HTTP 1.1
const HTTP_1_1_VERSION: u8 = 1;

const CONTINUE_HEADER_VALUE: &str = "100-continue";
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";

/// Decode an HTTP request on the server.
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<Request>>
where
Expand Down Expand Up @@ -76,8 +80,6 @@ where
req.insert_header(header.name, std::str::from_utf8(header.value)?);
}

handle_100_continue(&req, &mut io).await?;

let content_length = req.header(CONTENT_LENGTH);
let transfer_encoding = req.header(TRANSFER_ENCODING);

Expand All @@ -86,11 +88,32 @@ where
"Unexpected Content-Length header"
);

// Establish a channel to wait for the body to be read. This
// allows us to avoid sending 100-continue in situations that
// respond without reading the body, saving clients from uploading
// their body.
let (body_read_sender, body_read_receiver) = sync::channel(1);

if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
task::spawn(async move {
// If the client expects a 100-continue header, spawn a
// task to wait for the first read attempt on the body.
if let Ok(()) = body_read_receiver.recv().await {
io.write_all(CONTINUE_RESPONSE).await.ok();
};
// Since the sender is moved into the Body, this task will
// finish when the client disconnects, whether or not
// 100-continue was sent.
});
}

// Check for Transfer-Encoding
if let Some(encoding) = transfer_encoding {
if encoding.last().as_str() == "chunked" {
let trailer_sender = req.send_trailers();
let reader = BufReader::new(ChunkedDecoder::new(reader, trailer_sender));
let reader = ChunkedDecoder::new(reader, trailer_sender);
let reader = BufReader::new(reader);
let reader = ReadNotifier::new(reader, body_read_sender);
req.set_body(Body::from_reader(reader, None));
return Ok(Some(req));
}
Expand All @@ -100,7 +123,8 @@ where
// Check for Content-Length.
if let Some(len) = content_length {
let len = len.last().as_str().parse::<usize>()?;
req.set_body(Body::from_reader(reader.take(len as u64), Some(len)));
let reader = ReadNotifier::new(reader.take(len as u64), body_read_sender);
req.set_body(Body::from_reader(reader, Some(len)));
}

Ok(Some(req))
Expand Down Expand Up @@ -129,20 +153,6 @@ fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<
}
}

const EXPECT_HEADER_VALUE: &str = "100-continue";
const EXPECT_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";

async fn handle_100_continue<IO>(req: &Request, io: &mut IO) -> http_types::Result<()>
where
IO: Write + Unpin,
{
if let Some(EXPECT_HEADER_VALUE) = req.header(EXPECT).map(|h| h.as_str()) {
io.write_all(EXPECT_RESPONSE).await?;
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -207,36 +217,4 @@ mod tests {
},
)
}

#[test]
fn handle_100_continue_does_nothing_with_no_expect_header() {
let request = Request::new(Method::Get, Url::parse("x:").unwrap());
let mut io = async_std::io::Cursor::new(vec![]);
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
assert!(result.is_ok());
}

#[test]
fn handle_100_continue_sends_header_if_expects_is_exactly_right() {
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
request.append_header("expect", "100-continue");
let mut io = async_std::io::Cursor::new(vec![]);
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
assert_eq!(
std::str::from_utf8(&io.into_inner()).unwrap(),
"HTTP/1.1 100 Continue\r\n\r\n"
);
assert!(result.is_ok());
}

#[test]
fn handle_100_continue_does_nothing_if_expects_header_is_wrong() {
let mut request = Request::new(Method::Get, Url::parse("x:").unwrap());
request.append_header("expect", "110-extensions-not-allowed");
let mut io = async_std::io::Cursor::new(vec![]);
let result = async_std::task::block_on(handle_100_continue(&request, &mut io));
assert_eq!(std::str::from_utf8(&io.into_inner()).unwrap(), "");
assert!(result.is_ok());
}
}
75 changes: 75 additions & 0 deletions tests/continue.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use async_dup::{Arc, Mutex};
use async_std::io::{Cursor, SeekFrom};
use async_std::{prelude::*, task};
use duplexify::Duplex;
use http_types::Result;
use std::time::Duration;

const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 10\r\n\
Expect: 100-continue\r\n\r\n";

const SLEEP_DURATION: Duration = std::time::Duration::from_millis(100);
#[async_std::test]
async fn test_with_expect_when_reading_body() -> Result<()> {
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
let server_str: Vec<u8> = vec![];

let mut client = Arc::new(Mutex::new(Cursor::new(client_str)));
let server = Arc::new(Mutex::new(Cursor::new(server_str)));

let mut request = async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written

{
let lock = server.lock();
assert_eq!("", std::str::from_utf8(lock.get_ref())?); //we haven't written yet
};

let mut buf = vec![0u8; 1];
let bytes = request.read(&mut buf).await?; //this triggers the 100-continue even though there's nothing to read yet
assert_eq!(bytes, 0); // normally we'd actually be waiting for the end of the buffer, but this lets us test this sequentially

task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel and io

{
let lock = server.lock();
assert_eq!(
"HTTP/1.1 100 Continue\r\n\r\n",
std::str::from_utf8(lock.get_ref())?
);
};

client.write_all(b"0123456789").await?;
client
.seek(SeekFrom::Start(REQUEST_WITH_EXPECT.len() as u64))
.await?;

assert_eq!("0123456789", request.body_string().await?);

Ok(())
}

#[async_std::test]
async fn test_without_expect_when_not_reading_body() -> Result<()> {
let client_str: Vec<u8> = REQUEST_WITH_EXPECT.to_vec();
let server_str: Vec<u8> = vec![];

let client = Arc::new(Mutex::new(Cursor::new(client_str)));
let server = Arc::new(Mutex::new(Cursor::new(server_str)));

async_h1::server::decode(Duplex::new(client.clone(), server.clone()))
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel

let server_lock = server.lock();
assert_eq!("", std::str::from_utf8(server_lock.get_ref())?); // we haven't written 100-continue

Ok(())
}

0 comments on commit 1641e22

Please sign in to comment.