Skip to content

Commit

Permalink
ARROW-622: [Python] Add coerce_timestamps option to parquet.write_tab…
Browse files Browse the repository at this point in the history
…le, deprecate timestamps_to_ms argument

Requires PARQUET-1078 apache/parquet-cpp#380

cc @xhochy @fjetter @cpcloud, could you have a look. This needs to go into 0.6.0

Author: Wes McKinney <[email protected]>

Closes apache#944 from wesm/ARROW-622 and squashes the following commits:

3a21dfe [Wes McKinney] Add test to exhaust more paths of coerce_timestamps, error handling
45bbf5b [Wes McKinney] Add coerce_timestamps to write_metadata
172a9e1 [Wes McKinney] Implement coerce_timestamps option
  • Loading branch information
wesm committed Aug 7, 2017
1 parent 7a4026a commit 0b91cad
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 28 deletions.
4 changes: 3 additions & 1 deletion python/pyarrow/_parquet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport (CArray, CSchema, CStatus,
CTable, CMemoryPool,
CKeyValueMetadata,
RandomAccessFile, OutputStream)
RandomAccessFile, OutputStream,
TimeUnit)


cdef extern from "parquet/api/schema.h" namespace "parquet::schema" nogil:
Expand Down Expand Up @@ -266,5 +267,6 @@ cdef extern from "parquet/arrow/writer.h" namespace "parquet::arrow" nogil:
Builder()
Builder* disable_deprecated_int96_timestamps()
Builder* enable_deprecated_int96_timestamps()
Builder* coerce_timestamps(TimeUnit unit)
shared_ptr[ArrowWriterProperties] build()
c_bool support_deprecated_int96_timestamps()
16 changes: 15 additions & 1 deletion python/pyarrow/_parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,16 @@ cdef class ParquetWriter:
cdef readonly:
object use_dictionary
object use_deprecated_int96_timestamps
object coerce_timestamps
object compression
object version
int row_group_size

def __cinit__(self, where, Schema schema, use_dictionary=None,
compression=None, version=None,
MemoryPool memory_pool=None,
use_deprecated_int96_timestamps=False):
use_deprecated_int96_timestamps=False,
coerce_timestamps=None):
cdef:
shared_ptr[FileOutputStream] filestream
shared_ptr[WriterProperties] properties
Expand All @@ -574,6 +576,7 @@ cdef class ParquetWriter:
self.compression = compression
self.version = version
self.use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
self.coerce_timestamps = coerce_timestamps

cdef WriterProperties.Builder properties_builder
self._set_version(&properties_builder)
Expand All @@ -583,6 +586,7 @@ cdef class ParquetWriter:

cdef ArrowWriterProperties.Builder arrow_properties_builder
self._set_int96_support(&arrow_properties_builder)
self._set_coerce_timestamps(&arrow_properties_builder)
arrow_properties = arrow_properties_builder.build()

pool = maybe_unbox_memory_pool(memory_pool)
Expand All @@ -598,6 +602,16 @@ cdef class ParquetWriter:
else:
props.disable_deprecated_int96_timestamps()

cdef int _set_coerce_timestamps(
self, ArrowWriterProperties.Builder* props) except -1:
if self.coerce_timestamps == 'ms':
props.coerce_timestamps(TimeUnit_MILLI)
elif self.coerce_timestamps == 'us':
props.coerce_timestamps(TimeUnit_MICRO)
elif self.coerce_timestamps is not None:
raise ValueError('Invalid value for coerce_timestamps: {0}'
.format(self.coerce_timestamps))

cdef void _set_version(self, WriterProperties.Builder* props):
if self.version is not None:
if self.version == "1.0":
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def dataframe_to_arrays(df, timestamps_to_ms, schema, preserve_index):


def maybe_coerce_datetime64(values, dtype, type_, timestamps_to_ms=False):
if timestamps_to_ms:
import warnings
warnings.warn('timestamps_to_ms=True is deprecated', FutureWarning)

from pyarrow.compat import DatetimeTZDtype

if values.dtype.type != np.datetime64:
Expand Down
22 changes: 18 additions & 4 deletions python/pyarrow/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,8 @@ def read_pandas(source, columns=None, nthreads=1, metadata=None):

def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary=True, compression='snappy',
use_deprecated_int96_timestamps=False, **kwargs):
use_deprecated_int96_timestamps=False,
coerce_timestamps=None, **kwargs):
"""
Write a Table to Parquet format
Expand All @@ -773,6 +774,11 @@ def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary : bool or list
Specify if we should use dictionary encoding in general or only for
some columns.
use_deprecated_int96_timestamps : boolean, default False
Write nanosecond resolution timestamps to INT96 Parquet format
coerce_timestamps : string, default None
Cast timestamps a particular resolution.
Valid values: {None, 'ms', 'us'}
compression : str or dict
Specify the compression codec, either on a general basis or per-column.
"""
Expand All @@ -781,7 +787,8 @@ def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary=use_dictionary,
compression=compression,
version=version,
use_deprecated_int96_timestamps=use_deprecated_int96_timestamps)
use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
coerce_timestamps=coerce_timestamps)

writer = None
try:
Expand All @@ -801,7 +808,8 @@ def write_table(table, where, row_group_size=None, version='1.0',


def write_metadata(schema, where, version='1.0',
use_deprecated_int96_timestamps=False):
use_deprecated_int96_timestamps=False,
coerce_timestamps=None):
"""
Write metadata-only Parquet file from schema
Expand All @@ -811,10 +819,16 @@ def write_metadata(schema, where, version='1.0',
where: string or pyarrow.io.NativeFile
version : {"1.0", "2.0"}, default "1.0"
The Parquet format version, defaults to 1.0
use_deprecated_int96_timestamps : boolean, default False
Write nanosecond resolution timestamps to INT96 Parquet format
coerce_timestamps : string, default None
Cast timestamps a particular resolution.
Valid values: {None, 'ms', 'us'}
"""
options = dict(
version=version,
use_deprecated_int96_timestamps=use_deprecated_int96_timestamps
use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
coerce_timestamps=coerce_timestamps
)
writer = ParquetWriter(where, schema, **options)
writer.close()
Expand Down
73 changes: 51 additions & 22 deletions python/pyarrow/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ def test_pandas_parquet_2_0_rountrip(tmpdir):
df = alltypes_sample(size=10000)

filename = tmpdir.join('pandas_rountrip.parquet')
arrow_table = pa.Table.from_pandas(df, timestamps_to_ms=True)
arrow_table = pa.Table.from_pandas(df)
assert b'pandas' in arrow_table.schema.metadata

_write_table(arrow_table, filename.strpath, version="2.0")
_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='ms')
table_read = pq.read_pandas(filename.strpath)
assert b'pandas' in table_read.schema.metadata

Expand All @@ -120,10 +121,11 @@ def test_pandas_parquet_custom_metadata(tmpdir):
df = alltypes_sample(size=10000)

filename = tmpdir.join('pandas_rountrip.parquet')
arrow_table = pa.Table.from_pandas(df, timestamps_to_ms=True)
arrow_table = pa.Table.from_pandas(df)
assert b'pandas' in arrow_table.schema.metadata

_write_table(arrow_table, filename.strpath, version="2.0")
_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='ms')

md = pq.read_metadata(filename.strpath).metadata
assert b'pandas' in md
Expand All @@ -139,13 +141,12 @@ def test_pandas_parquet_2_0_rountrip_read_pandas_no_index_written(tmpdir):
df = alltypes_sample(size=10000)

filename = tmpdir.join('pandas_rountrip.parquet')
arrow_table = pa.Table.from_pandas(
df, timestamps_to_ms=True, preserve_index=False
)
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
js = json.loads(arrow_table.schema.metadata[b'pandas'].decode('utf8'))
assert not js['index_columns']

_write_table(arrow_table, filename.strpath, version="2.0")
_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='ms')
table_read = pq.read_pandas(filename.strpath)

js = json.loads(table_read.schema.metadata[b'pandas'].decode('utf8'))
Expand Down Expand Up @@ -340,10 +341,11 @@ def test_pandas_parquet_configuration_options(tmpdir):
def make_sample_file(df):
import pyarrow.parquet as pq

a_table = pa.Table.from_pandas(df, timestamps_to_ms=True)
a_table = pa.Table.from_pandas(df)

buf = io.BytesIO()
_write_table(a_table, buf, compression='SNAPPY', version='2.0')
_write_table(a_table, buf, compression='SNAPPY', version='2.0',
coerce_timestamps='ms')

buf.seek(0)
return pq.ParquetFile(buf)
Expand Down Expand Up @@ -418,22 +420,47 @@ def test_column_of_arrays(tmpdir):
df, schema = dataframe_with_arrays()

filename = tmpdir.join('pandas_rountrip.parquet')
arrow_table = pa.Table.from_pandas(df, timestamps_to_ms=True,
schema=schema)
_write_table(arrow_table, filename.strpath, version="2.0")
arrow_table = pa.Table.from_pandas(df, schema=schema)
_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='ms')
table_read = _read_table(filename.strpath)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)


@parquet
def test_coerce_timestamps(tmpdir):
# ARROW-622
df, schema = dataframe_with_arrays()

filename = tmpdir.join('pandas_rountrip.parquet')
arrow_table = pa.Table.from_pandas(df, schema=schema)

_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='us')
table_read = _read_table(filename.strpath)
df_read = table_read.to_pandas()

df_expected = df.copy()
for i, x in enumerate(df_expected['datetime64']):
if isinstance(x, np.ndarray):
df_expected['datetime64'][i] = x.astype('M8[us]')

tm.assert_frame_equal(df_expected, df_read)

with pytest.raises(ValueError):
_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='unknown')


@parquet
def test_column_of_lists(tmpdir):
df, schema = dataframe_with_lists()

filename = tmpdir.join('pandas_rountrip.parquet')
arrow_table = pa.Table.from_pandas(df, timestamps_to_ms=True,
schema=schema)
_write_table(arrow_table, filename.strpath, version="2.0")
arrow_table = pa.Table.from_pandas(df, schema=schema)
_write_table(arrow_table, filename.strpath, version="2.0",
coerce_timestamps='ms')
table_read = _read_table(filename.strpath)
df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)
Expand Down Expand Up @@ -469,12 +496,14 @@ def test_date_time_types():

t7 = pa.timestamp('ns')
start = pd.Timestamp('2001-01-01').value
data7 = np.array([start, start + 1, start + 2], dtype='int64')
data7 = np.array([start, start + 1000, start + 2000],
dtype='int64')
a7 = pa.Array.from_pandas(data7, type=t7)

t7_us = pa.timestamp('us')
start = pd.Timestamp('2001-01-01').value
data7_us = np.array([start, start + 1, start + 2], dtype='int64') // 1000
data7_us = np.array([start, start + 1000, start + 2000],
dtype='int64') // 1000
a7_us = pa.Array.from_pandas(data7_us, type=t7_us)

table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6, a7],
Expand Down Expand Up @@ -547,7 +576,7 @@ def _check_roundtrip(table, expected=None, **params):
def test_multithreaded_read():
df = alltypes_sample(size=10000)

table = pa.Table.from_pandas(df, timestamps_to_ms=True)
table = pa.Table.from_pandas(df)

buf = io.BytesIO()
_write_table(table, buf, compression='SNAPPY', version='2.0')
Expand Down Expand Up @@ -585,7 +614,7 @@ def test_pass_separate_metadata():
# ARROW-471
df = alltypes_sample(size=10000)

a_table = pa.Table.from_pandas(df, timestamps_to_ms=True)
a_table = pa.Table.from_pandas(df)

buf = io.BytesIO()
_write_table(a_table, buf, compression='snappy', version='2.0')
Expand All @@ -608,7 +637,7 @@ def test_read_single_row_group():
N, K = 10000, 4
df = alltypes_sample(size=N)

a_table = pa.Table.from_pandas(df, timestamps_to_ms=True)
a_table = pa.Table.from_pandas(df)

buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
Expand All @@ -631,7 +660,7 @@ def test_read_single_row_group_with_column_subset():

N, K = 10000, 4
df = alltypes_sample(size=N)
a_table = pa.Table.from_pandas(df, timestamps_to_ms=True)
a_table = pa.Table.from_pandas(df)

buf = io.BytesIO()
_write_table(a_table, buf, row_group_size=N / K,
Expand Down

0 comments on commit 0b91cad

Please sign in to comment.