Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] allow specifying custom reqwest client #53

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,20 @@ fn main() {

```rust
use std::collections::HashMap;
use std::collections::HashMap;
use reqwest::header::HeaderMap;

fn main() {
let mut reader = oneio::get_remote_reader(
let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())]))
.try_into().expect("invalid headers");

let client = reqwest::blocking::Client::builder()
.default_headers(headers)
.danger_accept_invalid_certs(true)
.build().unwrap();
let mut reader = oneio::get_http_reader(
"https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN",
HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])
Some(client),
).unwrap();
let mut text = "".to_string();
reader.read_to_string(&mut text).unwrap();
Expand Down
9 changes: 7 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@
//! Read remote content with custom headers
//! ```no_run
//! use std::collections::HashMap;
//! let mut reader = oneio::get_remote_reader(
//! use reqwest::header::HeaderMap;
//! let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])).try_into().expect("invalid headers");
//! let client = reqwest::blocking::Client::builder()
//! .default_headers(headers)
//! .build().unwrap();
//! let mut reader = oneio::get_http_reader(
//! "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN",
//! HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])
//! Some(client),
//! ).unwrap();
//! let mut text = "".to_string();
//! reader.read_to_string(&mut text).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/oneio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::fs::File;
use std::io::{BufWriter, Read, Write};
use std::path::Path;

fn get_writer_raw(path: &str) -> Result<BufWriter<File>, OneIoError> {
pub fn get_writer_raw(path: &str) -> Result<BufWriter<File>, OneIoError> {
let path = Path::new(path);
if let Some(prefix) = path.parent() {
std::fs::create_dir_all(prefix)?;
Expand All @@ -27,7 +27,7 @@ fn get_writer_raw(path: &str) -> Result<BufWriter<File>, OneIoError> {
Ok(output_file)
}

fn get_reader_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
pub fn get_reader_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
#[cfg(feature = "remote")]
let raw_reader: Box<dyn Read + Send> = remote::get_reader_raw_remote(path)?;
#[cfg(not(feature = "remote"))]
Expand Down
97 changes: 58 additions & 39 deletions src/oneio/remote.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::oneio::compressions::OneIOCompression;
use crate::oneio::{compressions, get_writer_raw};
use crate::OneIoError;
use std::collections::HashMap;
use reqwest::blocking::Client;
use std::io::Read;

fn get_protocol(path: &str) -> Option<String> {
Expand All @@ -12,7 +12,7 @@ fn get_protocol(path: &str) -> Option<String> {
Some(parts[0].to_string())
}

fn get_remote_ftp_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
fn get_ftp_reader_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
if !path.starts_with("ftp://") {
return Err(OneIoError::NotSupported(path.to_string()));
}
Expand All @@ -31,51 +31,72 @@ fn get_remote_ftp_raw(path: &str) -> Result<Box<dyn Read + Send>, OneIoError> {
Ok(reader)
}

fn get_remote_http_raw(
fn get_http_reader_raw(
path: &str,
header: HashMap<String, String>,
opt_client: Option<Client>,
) -> Result<reqwest::blocking::Response, OneIoError> {
let mut headers: reqwest::header::HeaderMap = (&header).try_into().expect("invalid headers");
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("oneio"),
);
headers.insert(
reqwest::header::CONTENT_LENGTH,
reqwest::header::HeaderValue::from_static("0"),
);
#[cfg(feature = "cli")]
headers.insert(
reqwest::header::CACHE_CONTROL,
reqwest::header::HeaderValue::from_static("no-cache"),
);
let client = reqwest::blocking::Client::builder()
.default_headers(headers)
.build()?;
let client = match opt_client {
Some(c) => c,
None => {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("oneio"),
);
headers.insert(
reqwest::header::CONTENT_LENGTH,
reqwest::header::HeaderValue::from_static("0"),
);
#[cfg(feature = "cli")]
headers.insert(
reqwest::header::CACHE_CONTROL,
reqwest::header::HeaderValue::from_static("no-cache"),
);
Client::builder().default_headers(headers).build()?
}
};
let res = client
.execute(client.get(path).build()?)?
.error_for_status()?;
Ok(res)
}

/// Get a reader for remote content with the capability to specify headers.
/// Get a reader for remote content with the capability to specify headers, and customer reqwest options.
///
/// Example usage:
/// Example usage with custom header fields:
/// ```no_run
/// use std::collections::HashMap;
/// let mut reader = oneio::get_remote_reader(
/// use reqwest::header::HeaderMap;
/// let headers: HeaderMap = (&HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])).try_into().expect("invalid headers");
/// let client = reqwest::blocking::Client::builder()
/// .default_headers(headers)
/// .build().unwrap();
/// let mut reader = oneio::get_http_reader(
/// "https://SOME_REMOTE_RESOURCE_PROTECTED_BY_ACCESS_TOKEN",
/// HashMap::from([("X-Custom-Auth-Key".to_string(), "TOKEN".to_string())])
/// Some(client),
/// ).unwrap();
/// let mut text = "".to_string();
/// reader.read_to_string(&mut text).unwrap();
/// println!("{}", text);
/// ```
///
/// Example with customer builder that allows invalid certificates (bad practice):
/// ```no_run
/// use std::collections::HashMap;
/// let client = reqwest::blocking::ClientBuilder::new().danger_accept_invalid_certs(true).build().unwrap();
/// let mut reader = oneio::get_http_reader(
/// "https://example.com",
/// Some(client)
/// ).unwrap();
/// let mut text = "".to_string();
/// reader.read_to_string(&mut text).unwrap();
/// println!("{}", text);
/// ```
pub fn get_remote_reader(
pub fn get_http_reader(
path: &str,
header: HashMap<String, String>,
opt_client: Option<Client>,
) -> Result<Box<dyn Read + Send>, OneIoError> {
let raw_reader: Box<dyn Read + Send> = Box::new(get_remote_http_raw(path, header)?);
let raw_reader: Box<dyn Read + Send> = Box::new(get_http_reader_raw(path, opt_client)?);
let file_type = *path.split('.').collect::<Vec<&str>>().last().unwrap();
match file_type {
#[cfg(feature = "gz")]
Expand Down Expand Up @@ -117,17 +138,15 @@ pub fn get_remote_reader(
/// fn main() -> Result<(), OneIoError> {
/// let remote_path = "https://example.com/file.txt";
/// let local_path = "path/to/save/file.txt";
/// let header: Option<HashMap<String, String>> = None;
///
/// download(remote_path, local_path, header)?;
/// download(remote_path, local_path, None)?;
///
/// Ok(())
/// }
/// ```
pub fn download(
remote_path: &str,
local_path: &str,
header: Option<HashMap<String, String>>,
opt_client: Option<Client>,
) -> Result<(), OneIoError> {
match get_protocol(remote_path) {
None => {
Expand All @@ -136,12 +155,12 @@ pub fn download(
Some(protocol) => match protocol.as_str() {
"http" | "https" => {
let mut writer = get_writer_raw(local_path)?;
let mut response = get_remote_http_raw(remote_path, header.unwrap_or_default())?;
let mut response = get_http_reader_raw(remote_path, opt_client)?;
response.copy_to(&mut writer)?;
}
"ftp" => {
let mut writer = get_writer_raw(local_path)?;
let mut reader = get_remote_ftp_raw(remote_path)?;
let mut reader = get_ftp_reader_raw(remote_path)?;
std::io::copy(&mut reader, &mut writer)?;
}
#[cfg(feature = "s3")]
Expand Down Expand Up @@ -179,20 +198,20 @@ pub fn download(
/// let local_path = "/path/to/save/file.txt";
/// let retry = 3;
///
/// match download_with_retry(remote_path, local_path, None, retry) {
/// match download_with_retry(remote_path, local_path, retry, None) {
/// Ok(_) => println!("File downloaded successfully"),
/// Err(e) => eprintln!("Error downloading file: {:?}", e),
/// }
/// ```
pub fn download_with_retry(
remote_path: &str,
local_path: &str,
header: Option<HashMap<String, String>>,
retry: usize,
opt_client: Option<Client>,
) -> Result<(), OneIoError> {
let mut retry = retry;
loop {
match download(remote_path, local_path, header.clone()) {
match download(remote_path, local_path, opt_client.clone()) {
Ok(_) => {
return Ok(());
}
Expand All @@ -212,11 +231,11 @@ pub(crate) fn get_reader_raw_remote(path: &str) -> Result<Box<dyn Read + Send>,
let raw_reader: Box<dyn Read + Send> = match get_protocol(path) {
Some(protocol) => match protocol.as_str() {
"http" | "https" => {
let response = get_remote_http_raw(path, HashMap::new())?;
let response = get_http_reader_raw(path, None)?;
Box::new(response)
}
"ftp" => {
let response = get_remote_ftp_raw(path)?;
let response = get_ftp_reader_raw(path)?;
Box::new(response)
}
#[cfg(feature = "s3")]
Expand Down
Loading