Skip to content

Commit

Permalink
add timezone to type definition
Browse files Browse the repository at this point in the history
Signed-off-by: pyalex <[email protected]>
  • Loading branch information
pyalex committed Apr 21, 2022
1 parent 12a8459 commit ce45062
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions sdk/python/feast/embedded_go/type_map.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List

import pyarrow as pa
import pytz

from feast.protos.feast.types import Value_pb2
from feast.types import Array, PrimitiveFeastType

PA_TIMESTAMP_TYPE = pa.timestamp("s", tz=pytz.UTC)

ARROW_TYPE_TO_PROTO_FIELD = {
pa.int32(): "int32_val",
pa.int64(): "int64_val",
Expand All @@ -13,7 +16,7 @@
pa.bool_(): "bool_val",
pa.string(): "string_val",
pa.binary(): "bytes_val",
pa.timestamp("s"): "unix_timestamp_val",
PA_TIMESTAMP_TYPE: "unix_timestamp_val",
}

ARROW_LIST_TYPE_TO_PROTO_FIELD = {
Expand All @@ -24,7 +27,7 @@
pa.bool_(): "bool_list_val",
pa.string(): "string_list_val",
pa.binary(): "bytes_list_val",
pa.timestamp("s"): "unix_timestamp_list_val",
PA_TIMESTAMP_TYPE: "unix_timestamp_list_val",
}

ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS = {
Expand All @@ -35,7 +38,7 @@
pa.bool_(): Value_pb2.BoolList,
pa.string(): Value_pb2.StringList,
pa.binary(): Value_pb2.BytesList,
pa.timestamp("s"): Value_pb2.Int64List,
PA_TIMESTAMP_TYPE: Value_pb2.Int64List,
}

FEAST_TYPE_TO_ARROW_TYPE = {
Expand Down Expand Up @@ -66,7 +69,7 @@ def arrow_array_to_array_of_proto(
proto_list_class = ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS[arrow_type.value_type]
proto_field_name = ARROW_LIST_TYPE_TO_PROTO_FIELD[arrow_type.value_type]

if arrow_type.value_type == pa.timestamp("s"):
if arrow_type.value_type == PA_TIMESTAMP_TYPE:
arrow_array = arrow_array.cast(pa.list_(pa.int64()))

for v in arrow_array.tolist():
Expand All @@ -76,7 +79,7 @@ def arrow_array_to_array_of_proto(
else:
proto_field_name = ARROW_TYPE_TO_PROTO_FIELD[arrow_type]

if arrow_type == pa.timestamp("s"):
if arrow_type == PA_TIMESTAMP_TYPE:
arrow_array = arrow_array.cast(pa.int64())

for v in arrow_array.tolist():
Expand Down

0 comments on commit ce45062

Please sign in to comment.