Skip to content

Commit

Permalink
ARROW-1435: [Python] Properly handle time zone metadata in Parquet ro…
Browse files Browse the repository at this point in the history
…und trips

cc @jreback. Various bugs fixed here, but the bottom line is that this enables tz-aware pandas data to be faithfully round-tripped to Parquet format. We will need to implement compatibility tests in pandas for this, too

example DataFrame that could not be properly written before:

```python
s = pd.Series([datetime.datetime(2017, 9, 6)])
s = s.dt.tz_localize('utc')
s.index = s
# Both a column and an index to hit both use cases
df = pd.DataFrame({'tz_aware': s}, index=s)
```

Author: Wes McKinney <[email protected]>

Closes apache#1054 from wesm/ARROW-1435 and squashes the following commits:

6519945 [Wes McKinney] Add test for a non-UTC time zone too
20bb6dc [Wes McKinney] Get round trip for tz-aware index to Parquet working. Handle time zones in Column.to_pandas
f92abaa [Wes McKinney] Fix, initial test passing
6701cf0 [Wes McKinney] Initial cut at fixing tz aware columns to/from Parquet
  • Loading branch information
wesm committed Sep 7, 2017
1 parent 3033eac commit 6e5f7be
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 29 deletions.
6 changes: 4 additions & 2 deletions cpp/src/arrow/python/pandas_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,9 @@ class PandasConverter {
}

BufferVector buffers = {null_bitmap_, data};
return PushArray(
std::make_shared<ArrayData>(type_, length_, std::move(buffers), null_count, 0));
auto arr_data = std::make_shared<ArrayData>(type_, length_, std::move(buffers),
null_count, 0);
return PushArray(arr_data);
}

template <typename T>
Expand Down Expand Up @@ -1158,6 +1159,7 @@ Status PandasToArrow(MemoryPool* pool, PyObject* ao, PyObject* mo,
PandasConverter converter(pool, ao, mo, type);
RETURN_NOT_OK(converter.Convert());
*out = converter.result()[0];
DCHECK(*out);
return Status::OK();
}

Expand Down
3 changes: 2 additions & 1 deletion python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
FloatValue, DoubleValue, ListValue,
BinaryValue, StringValue, FixedSizeBinaryValue,
DecimalValue,
Date32Value, Date64Value, TimestampValue)
Date32Value, Date64Value, TimestampValue,
TimestampType)

from pyarrow.lib import (HdfsFile, NativeFile, PythonFile,
FixedSizeBufferWriter,
Expand Down
53 changes: 45 additions & 8 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def construct_metadata(df, column_names, index_levels, preserve_index, types):
dict
"""
ncolumns = len(column_names)
df_types = types[:ncolumns]
index_types = types[ncolumns:ncolumns + len(index_levels)]
df_types = types[:ncolumns - len(index_levels)]
index_types = types[ncolumns - len(index_levels):]

column_metadata = [
get_column_metadata(df[col_name], name=sanitized_name,
Expand Down Expand Up @@ -269,13 +269,15 @@ def maybe_coerce_datetime64(values, dtype, type_, timestamps_to_ms=False):
return values, type_


def make_datetimetz(tz):
from pyarrow.compat import DatetimeTZDtype
return DatetimeTZDtype('ns', tz=tz)


def table_to_blockmanager(options, table, memory_pool, nthreads=1):
import pandas.core.internals as _int
from pyarrow.compat import DatetimeTZDtype
import pyarrow.lib as lib

block_table = table

index_columns = []
index_arrays = []
index_names = []
Expand All @@ -286,20 +288,24 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1):
if metadata is not None and b'pandas' in metadata:
pandas_metadata = json.loads(metadata[b'pandas'].decode('utf8'))
index_columns = pandas_metadata['index_columns']
table = _add_any_metadata(table, pandas_metadata)

block_table = table

for name in index_columns:
i = schema.get_field_index(name)
if i != -1:
col = table.column(i)
index_name = (None if is_unnamed_index_level(name)
else name)
values = col.to_pandas().values
col_pandas = col.to_pandas()
values = col_pandas.values
if not values.flags.writeable:
# ARROW-1054: in pandas 0.19.2, factorize will reject
# non-writeable arrays when calling MultiIndex.from_arrays
values = values.copy()

index_arrays.append(values)
index_arrays.append(pd.Series(values, dtype=col_pandas.dtype))
index_names.append(index_name)
block_table = block_table.remove_column(
block_table.schema.get_field_index(name)
Expand All @@ -319,7 +325,7 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1):
klass=_int.CategoricalBlock,
fastpath=True)
elif 'timezone' in item:
dtype = DatetimeTZDtype('ns', tz=item['timezone'])
dtype = make_datetimetz(item['timezone'])
block = _int.make_block(block_arr, placement=placement,
klass=_int.DatetimeTZBlock,
dtype=dtype, fastpath=True)
Expand All @@ -340,3 +346,34 @@ def table_to_blockmanager(options, table, memory_pool, nthreads=1):
]

return _int.BlockManager(blocks, axes)


def _add_any_metadata(table, pandas_metadata):
modified_columns = {}

schema = table.schema

# Add time zones
for i, col_meta in enumerate(pandas_metadata['columns']):
if col_meta['pandas_type'] == 'datetimetz':
col = table[i]
converted = col.to_pandas()
tz = col_meta['metadata']['timezone']
tz_aware_type = pa.timestamp('ns', tz=tz)
with_metadata = pa.Array.from_pandas(converted.values,
type=tz_aware_type)

field = pa.field(schema[i].name, tz_aware_type)
modified_columns[i] = pa.Column.from_array(field,
with_metadata)

if len(modified_columns) > 0:
columns = []
for i in range(len(table.schema)):
if i in modified_columns:
columns.append(modified_columns[i])
else:
columns.append(table[i])
return pa.Table.from_arrays(columns)
else:
return table
15 changes: 11 additions & 4 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,18 @@ else:

cdef class TimestampValue(ArrayValue):

property value:

def __get__(self):
cdef CTimestampArray* ap = <CTimestampArray*> self.sp_array.get()
cdef CTimestampType* dtype = <CTimestampType*> ap.type().get()
return ap.Value(self.index)

def as_py(self):
cdef:
CTimestampArray* ap = <CTimestampArray*> self.sp_array.get()
CTimestampType* dtype = <CTimestampType*> ap.type().get()
int64_t value = ap.Value(self.index)
cdef CTimestampArray* ap = <CTimestampArray*> self.sp_array.get()
cdef CTimestampType* dtype = <CTimestampType*> ap.type().get()

value = self.value

if not dtype.timezone().empty():
import pytz
Expand Down
12 changes: 11 additions & 1 deletion python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ cdef class Column:

if isinstance(field_or_name, Field):
boxed_field = field_or_name
if arr.type != boxed_field.type:
raise ValueError('Passed field type does not match array')
else:
boxed_field = field(field_or_name, arr.type)

Expand All @@ -176,7 +178,15 @@ cdef class Column:
self.sp_column,
self, &out))

return pd.Series(wrap_array_output(out), name=self.name)
values = wrap_array_output(out)
result = pd.Series(values, name=self.name)

if isinstance(self.type, TimestampType):
if self.type.tz is not None:
result = (result.dt.tz_localize('utc')
.dt.tz_convert(self.type.tz))

return result

def equals(self, Column other):
"""
Expand Down
33 changes: 24 additions & 9 deletions python/pyarrow/tests/test_convert_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def _check_pandas_roundtrip(self, df, expected=None, nthreads=1,
expected = df
tm.assert_frame_equal(result, expected, check_dtype=check_dtype)

def _check_series_roundtrip(self, s, type_=None):
arr = pa.Array.from_pandas(s, type=type_)

result = pd.Series(arr.to_pandas(), name=s.name)
if isinstance(arr.type, pa.TimestampType) and arr.type.tz is not None:
result = (result.dt.tz_localize('utc')
.dt.tz_convert(arr.type.tz))

tm.assert_series_equal(s, result)

def _check_array_roundtrip(self, values, expected=None, mask=None,
timestamps_to_ms=False, type=None):
arr = pa.Array.from_pandas(values, timestamps_to_ms=timestamps_to_ms,
Expand Down Expand Up @@ -347,9 +357,7 @@ def test_timestamps_notimezone_no_nulls(self):
field = pa.field('datetime64', pa.timestamp('ns'))
schema = pa.schema([field])
self._check_pandas_roundtrip(
df,
timestamps_to_ms=False,
expected_schema=schema,
df, expected_schema=schema,
)

def test_timestamps_to_ms_explicit_schema(self):
Expand Down Expand Up @@ -389,9 +397,7 @@ def test_timestamps_notimezone_nulls(self):
field = pa.field('datetime64', pa.timestamp('ns'))
schema = pa.schema([field])
self._check_pandas_roundtrip(
df,
timestamps_to_ms=False,
expected_schema=schema,
df, expected_schema=schema,
)

def test_timestamps_with_timezone(self):
Expand All @@ -406,6 +412,8 @@ def test_timestamps_with_timezone(self):
.to_frame())
self._check_pandas_roundtrip(df, timestamps_to_ms=True)

self._check_series_roundtrip(df['datetime64'])

# drop-in a null and ns instead of ms
df = pd.DataFrame({
'datetime64': np.array([
Expand All @@ -417,7 +425,15 @@ def test_timestamps_with_timezone(self):
})
df['datetime64'] = (df['datetime64'].dt.tz_localize('US/Eastern')
.to_frame())
self._check_pandas_roundtrip(df, timestamps_to_ms=False)
self._check_pandas_roundtrip(df)

def test_timestamp_with_tz_to_pandas_type(self):
from pyarrow.compat import DatetimeTZDtype

tz = 'America/Los_Angeles'
t = pa.timestamp('ns', tz=tz)

assert t.to_pandas_dtype() == DatetimeTZDtype('ns', tz=tz)

def test_date_infer(self):
df = pd.DataFrame({
Expand Down Expand Up @@ -586,8 +602,7 @@ def test_nested_lists_all_none(self):

def test_threaded_conversion(self):
df = _alltypes_example()
self._check_pandas_roundtrip(df, nthreads=2,
timestamps_to_ms=False)
self._check_pandas_roundtrip(df, nthreads=2)

def test_category(self):
repeats = 5
Expand Down
29 changes: 28 additions & 1 deletion python/pyarrow/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
import pytest

from pyarrow.compat import guid, u
from pyarrow.compat import guid, u, BytesIO
from pyarrow.filesystem import LocalFileSystem
import pyarrow as pa
from .pandas_examples import dataframe_with_arrays, dataframe_with_lists
Expand Down Expand Up @@ -114,6 +114,33 @@ def test_pandas_parquet_2_0_rountrip(tmpdir):
tm.assert_frame_equal(df, df_read)


@parquet
def test_pandas_parquet_datetime_tz():
import pyarrow.parquet as pq

s = pd.Series([datetime.datetime(2017, 9, 6)])
s = s.dt.tz_localize('utc')

s.index = s

# Both a column and an index to hit both use cases
df = pd.DataFrame({'tz_aware': s,
'tz_eastern': s.dt.tz_convert('US/Eastern')},
index=s)

f = BytesIO()

arrow_table = pa.Table.from_pandas(df)

_write_table(arrow_table, f, coerce_timestamps='ms')
f.seek(0)

table_read = pq.read_pandas(f)

df_read = table_read.to_pandas()
tm.assert_frame_equal(df, df_read)


@parquet
def test_pandas_parquet_custom_metadata(tmpdir):
import pyarrow.parquet as pq
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def assert_equal(obj1, obj2):
if sys.version_info >= (3, 0):
PRIMITIVE_OBJECTS += [0, np.array([["hi", u"hi"], [1.3, 1]])]
else:
PRIMITIVE_OBJECTS += [long(42), long(1 << 62), long(0),
PRIMITIVE_OBJECTS += [long(42), long(1 << 62), long(0), # noqa
np.array([["hi", u"hi"],
[1.3, long(1)]])] # noqa: E501,F821
[1.3, long(1)]])] # noqa


COMPLEX_OBJECTS = [
Expand Down
18 changes: 17 additions & 1 deletion python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ cdef class TimestampType(DataType):
else:
return None

def to_pandas_dtype(self):
"""
Return the NumPy dtype that would be used for storing this
"""
if self.tz is None:
return _pandas_type_map[_Type_TIMESTAMP]
else:
# Return DatetimeTZ
return pdcompat.make_datetimetz(self.tz)


cdef class Time32Type(DataType):

Expand Down Expand Up @@ -431,7 +441,13 @@ cdef class Schema:
with nogil:
check_status(PrettyPrint(deref(self.schema), options, &result))

return frombytes(result)
printed = frombytes(result)
if self.metadata is not None:
import pprint
metadata_formatted = pprint.pformat(self.metadata)
printed += '\nmetadata\n--------\n' + metadata_formatted

return printed

def __repr__(self):
return self.__str__()
Expand Down

0 comments on commit 6e5f7be

Please sign in to comment.