Skip to content

Commit

Permalink
Introduce the Merge trait
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerollmops committed Aug 8, 2021
1 parent cf659ce commit 6e52e5a
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 58 deletions.
70 changes: 44 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,31 @@
//! use std::convert::TryInto;
//! use std::io::Cursor;
//!
//! use grenad::{MergerBuilder, Reader, Writer};
//! use grenad::{Merge, MergerBuilder, Reader, Writer};
//!
//! // This merge function:
//! // This merger:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! #[derive(Clone, Copy)]
//! struct WrappingSumU32s;
//!
//! impl Merge for WrappingSumU32s {
//! type Error = TryFromSliceError;
//! type Output = [u8; 4];
//!
//! fn merge<I, A>(&self, key: &[u8], values: I) -> Result<Self::Output, Self::Error>
//! where
//! I: IntoIterator<Item = A>,
//! A: AsRef<[u8]>
//! {
//! let mut output: u32 = 0;
//! for value in values {
//! let num = value.as_ref().try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(output.to_ne_bytes())
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -82,7 +90,7 @@
//!
//! // We create a merger that will sum our u32s when necessary,
//! // and we add our readers to the list of readers to merge.
//! let merger_builder = MergerBuilder::new(wrapping_sum_u32s);
//! let merger_builder = MergerBuilder::new(WrappingSumU32s);
//! let merger = merger_builder.add(readera).add(readerb).add(readerc).build();
//!
//! // We can iterate over the entries in key-order.
Expand All @@ -106,28 +114,36 @@
//! use std::borrow::Cow;
//! use std::convert::TryInto;
//!
//! use grenad::{CursorVec, SorterBuilder};
//! use grenad::{Merge, CursorVec, SorterBuilder};
//!
//! // This merge function:
//! // This merger:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! #[derive(Clone, Copy)]
//! struct WrappingSumU32s;
//!
//! impl Merge for WrappingSumU32s {
//! type Error = TryFromSliceError;
//! type Output = [u8; 4];
//!
//! fn merge<I, A>(&self, key: &[u8], values: I) -> Result<Self::Output, Self::Error>
//! where
//! I: IntoIterator<Item = A>,
//! A: AsRef<[u8]>
//! {
//! let mut output: u32 = 0;
//! for value in values {
//! let num = value.as_ref().try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(output.to_ne_bytes())
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // We create a sorter that will sum our u32s when necessary.
//! let mut sorter = SorterBuilder::new(wrapping_sum_u32s).chunk_creator(CursorVec).build();
//! let mut sorter = SorterBuilder::new(WrappingSumU32s).chunk_creator(CursorVec).build();
//!
//! // We insert multiple entries with the same key but different values
//! // in arbitrary order, the sorter will take care of merging them for us.
Expand Down Expand Up @@ -158,6 +174,7 @@ extern crate quickcheck;
mod block_builder;
mod compression;
mod error;
mod merge;
mod merger;
mod reader;
mod sorter;
Expand All @@ -166,6 +183,7 @@ mod writer;

pub use self::compression::CompressionType;
pub use self::error::Error;
pub use self::merge::Merge;
pub use self::merger::{Merger, MergerBuilder, MergerIter};
pub use self::reader::Reader;
#[cfg(feature = "tempfile")]
Expand Down
22 changes: 22 additions & 0 deletions src/merge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
pub trait Merge {
type Error;
type Output: AsRef<[u8]>;

fn merge<I, A>(&self, key: &[u8], values: I) -> Result<Self::Output, Self::Error>
where
I: IntoIterator<Item = A>,
A: AsRef<[u8]>;
}

impl<M: Merge> Merge for &M {
type Error = <M as Merge>::Error;
type Output = <M as Merge>::Output;

fn merge<I, A>(&self, key: &[u8], values: I) -> Result<Self::Output, Self::Error>
where
I: IntoIterator<Item = A>,
A: AsRef<[u8]>,
{
(**self).merge(key, values)
}
}
34 changes: 15 additions & 19 deletions src/merger.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::io;
use std::iter::once;

use crate::{Error, Reader, Writer};
use crate::{Error, Merge, Reader, Writer};

/// A struct that is used to configure a [`Merger`] with the sources to merge.
pub struct MergerBuilder<R, MF> {
Expand Down Expand Up @@ -83,7 +82,7 @@ impl<R, MF> Merger<R, MF> {
}
}

impl<R: io::Read, MF> Merger<R, MF> {
impl<R: io::Read, MF: Merge> Merger<R, MF> {
/// Consumes this [`Merger`] and outputs a stream of the merged entries in key-order.
pub fn into_merger_iter(self) -> Result<MergerIter<R, MF>, Error> {
let mut heap = BinaryHeap::new();
Expand All @@ -97,7 +96,7 @@ impl<R: io::Read, MF> Merger<R, MF> {
merge: self.merge,
heap,
current_key: Vec::new(),
merged_value: Vec::new(),
merged_value: None,
tmp_entries: Vec::new(),
})
}
Expand All @@ -106,7 +105,7 @@ impl<R: io::Read, MF> Merger<R, MF> {
impl<R, MF, U> Merger<R, MF>
where
R: io::Read,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: Merge<Error = U>,
{
/// Consumes this [`Merger`] and streams the entries to the [`Writer`] given in parameter.
pub fn write_into<W: io::Write>(self, writer: &mut Writer<W>) -> Result<(), Error<U>> {
Expand All @@ -119,19 +118,19 @@ where
}

/// An iterator that yield the merged entries in key-order.
pub struct MergerIter<R, MF> {
pub struct MergerIter<R, MF: Merge> {
merge: MF,
heap: BinaryHeap<Entry<R>>,
current_key: Vec<u8>,
merged_value: Vec<u8>,
merged_value: Option<MF::Output>,
/// We keep this buffer to avoid allocating a vec every time.
tmp_entries: Vec<Entry<R>>,
}

impl<R, MF, U> MergerIter<R, MF>
where
R: io::Read,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: Merge<Error = U>,
{
/// Yield the entries in key-order where values have been merged when needed.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error<U>> {
Expand All @@ -158,21 +157,15 @@ where
}
}

/// Extract the currently pointed values from the entries.
// Extract the currently pointed values from the entries.
let other_values = self.tmp_entries.iter().filter_map(|e| e.iter.current().map(|(_, v)| v));
let values: Vec<_> = once(first_value).chain(other_values).map(Cow::Borrowed).collect();
let values = once(first_value).chain(other_values);

match (self.merge)(&first_key, &values) {
match self.merge.merge(&first_key, values) {
Ok(value) => {
self.current_key.clear();
self.current_key.extend_from_slice(first_key);
match value {
Cow::Owned(value) => self.merged_value = value,
Cow::Borrowed(value) => {
self.merged_value.clear();
self.merged_value.extend_from_slice(value);
}
}
self.merged_value = Some(value);
}
Err(e) => return Err(Error::Merge(e)),
}
Expand All @@ -184,6 +177,9 @@ where
}
}

Ok(Some((&self.current_key, &self.merged_value)))
match self.merged_value.as_ref() {
Some(value) => Ok(Some((&self.current_key, value.as_ref()))),
None => Ok(None),
}
}
}
42 changes: 29 additions & 13 deletions src/sorter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::alloc::{alloc, dealloc, Layout};
use std::borrow::Cow;
use std::convert::Infallible;
use std::fs::File;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
Expand All @@ -15,7 +14,7 @@ const MIN_SORTER_MEMORY: usize = 10_485_760; // 10MB
const DEFAULT_NB_CHUNKS: usize = 25;
const MIN_NB_CHUNKS: usize = 1;

use crate::{CompressionType, Error, Merger, MergerIter, Reader, Writer, WriterBuilder};
use crate::{CompressionType, Error, Merge, Merger, MergerIter, Reader, Writer, WriterBuilder};

/// A struct that is used to configure a [`Sorter`] to better fit your needs.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -323,7 +322,7 @@ impl<MF> Sorter<MF, DefaultChunkCreator> {

impl<MF, CC, U> Sorter<MF, CC>
where
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: Merge<Error = U>,
CC: ChunkCreator,
{
/// Insert an entry into the [`Sorter`] making sure that conflicts
Expand Down Expand Up @@ -367,22 +366,23 @@ where
let mut current = None;
for (key, value) in self.entries.iter() {
match current.as_mut() {
None => current = Some((key, vec![Cow::Borrowed(value)])),
None => current = Some((key, vec![value])),
Some((current_key, vals)) => {
if current_key != &key {
let merged_val = (self.merge)(current_key, vals).map_err(Error::Merge)?;
writer.insert(&current_key, &merged_val)?;
let merged_val =
self.merge.merge(current_key, vals.iter()).map_err(Error::Merge)?;
writer.insert(&current_key, merged_val)?;
vals.clear();
*current_key = key;
}
vals.push(Cow::Borrowed(value));
vals.push(value);
}
}
}

if let Some((key, vals)) = current.take() {
let merged_val = (self.merge)(key, &vals).map_err(Error::Merge)?;
writer.insert(&key, &merged_val)?;
let merged_val = self.merge.merge(key, vals).map_err(Error::Merge)?;
writer.insert(&key, merged_val)?;
}

let chunk = writer.into_inner()?;
Expand Down Expand Up @@ -518,13 +518,29 @@ mod tests {

use super::*;

fn merge<'a>(_key: &[u8], vals: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Infallible> {
Ok(vals.iter().map(AsRef::as_ref).flatten().cloned().collect())
#[derive(Clone, Copy)]
struct ConcatBytes;

impl Merge for ConcatBytes {
type Error = Infallible;
type Output = Vec<u8>;

fn merge<I, A>(&self, _key: &[u8], values: I) -> Result<Self::Output, Self::Error>
where
I: IntoIterator<Item = A>,
A: AsRef<[u8]>,
{
let mut output = Vec::new();
for value in values {
output.extend_from_slice(value.as_ref());
}
Ok(output)
}
}

#[test]
fn simple_cursorvec() {
let mut sorter = SorterBuilder::new(merge)
let mut sorter = SorterBuilder::new(ConcatBytes)
.chunk_compression_type(CompressionType::Snappy)
.chunk_creator(CursorVec)
.build();
Expand All @@ -551,7 +567,7 @@ mod tests {

#[test]
fn hard_cursorvec() {
let mut sorter = SorterBuilder::new(merge)
let mut sorter = SorterBuilder::new(ConcatBytes)
.dump_threshold(1024) // 1KiB
.allow_realloc(false)
.chunk_compression_type(CompressionType::Snappy)
Expand Down

0 comments on commit 6e52e5a

Please sign in to comment.