Skip to content

Commit

Permalink
fix: vaex.from_arrow_dataset was not implemented/tested, fixes vaexio…
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels authored Sep 8, 2022
1 parent c26bb15 commit a0da9b9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion packages/vaex-core/vaex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def from_arrow_table(table) -> vaex.dataframe.DataFrame:


def from_arrow_dataset(arrow_dataset) -> vaex.dataframe.DataFrame:
'''Create a DataFrame from an Apache Arrow dataset'''
'''Create a DataFrame from an Apache Arrow dataset.'''
import vaex.arrow.dataset
return from_dataset(vaex.arrow.dataset.DatasetArrow(arrow_dataset))

Expand Down
33 changes: 33 additions & 0 deletions packages/vaex-core/vaex/arrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,39 @@ def __getstate__(self):
return state



class DatasetArrow(DatasetArrowBase):
snake_name = "arrow-dataset"
def __init__(self, ds, max_rows_read=1024**2*10):
self._arrow_ds = ds
super().__init__(max_rows_read=max_rows_read)

@property
def _fingerprint(self):
return self._id

def hashed(self):
raise NotImplementedError

def _create_columns(self):
super()._create_columns()
# self._ids = frozendict({name: vaex.cache.fingerprint(self._fingerprint, name) for name in self._columns})
self._ids = frozendict()

def _create_dataset(self):
self._partitions = defaultdict(list) # path -> list (which will be an arrow array later on)
self._partition_keys = defaultdict(dict) # path -> key -> int/index

for fragment in self._arrow_ds.get_fragments():
keys = pa.dataset._get_partition_keys(fragment.partition_expression)
for name, value in keys.items():
if value not in self._partitions[name]:
self._partitions[name].append(value)
self._partition_keys[fragment.path][name] = self._partitions[name].index(value)
self._partitions = {name: pa.array(values) for name, values in self._partitions.items()}



class DatasetArrowFileBase(vaex.dataset.Dataset):
def __init__(self, path, fs_options, fs=None):
super().__init__()
Expand Down
14 changes: 14 additions & 0 deletions tests/from_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import pytest
import vaex
from pathlib import Path

HERE = Path(__file__).parent


def test_from_records():
df = vaex.from_records([
Expand Down Expand Up @@ -26,3 +31,12 @@ def test_from_records():
], array_type="numpy")
assert df.a.tolist() == [[1, 1], [11, 12], [13, 14]]
assert df.a.shape == (3, 2)


def test_from_arrow_dataset():
import pyarrow.dataset
path = HERE / 'data' / 'sample_arrow_dict.parquet'
ds = pyarrow.dataset.dataset(path)
df = vaex.from_arrow_dataset(ds)
assert df.col1.sum() == 45
assert df.fingerprint() == df.fingerprint()

0 comments on commit a0da9b9

Please sign in to comment.