From ca412151a1107422c38fdc7bbf21f7044ee08602 Mon Sep 17 00:00:00 2001 From: Ivan Sukhov Date: Fri, 2 Aug 2024 15:21:53 +0300 Subject: [PATCH] inferred types for csv files are now optional (#7358) --- .../inference/arrow_inferencinator.cpp | 3 +- .../inference/ut/arrow_inference_ut.cpp | 28 +++---- ydb/tests/fq/s3/test_s3_0.py | 74 +++++++++++++++++-- 3 files changed, 83 insertions(+), 22 deletions(-) diff --git a/ydb/core/external_sources/object_storage/inference/arrow_inferencinator.cpp b/ydb/core/external_sources/object_storage/inference/arrow_inferencinator.cpp index 7c143dec728f..33ae55c4a753 100644 --- a/ydb/core/external_sources/object_storage/inference/arrow_inferencinator.cpp +++ b/ydb/core/external_sources/object_storage/inference/arrow_inferencinator.cpp @@ -14,7 +14,8 @@ namespace NKikimr::NExternalSource::NObjectStorage::NInference { namespace { -bool ArrowToYdbType(Ydb::Type& resType, const arrow::DataType& type) { +bool ArrowToYdbType(Ydb::Type& optionalType, const arrow::DataType& type) { + auto& resType = *optionalType.mutable_optional_type()->mutable_item(); switch (type.id()) { case arrow::Type::NA: resType.set_type_id(Ydb::Type::UTF8); diff --git a/ydb/core/external_sources/object_storage/inference/ut/arrow_inference_ut.cpp b/ydb/core/external_sources/object_storage/inference/ut/arrow_inference_ut.cpp index 88a46386035f..9fe683f56c95 100644 --- a/ydb/core/external_sources/object_storage/inference/ut/arrow_inference_ut.cpp +++ b/ydb/core/external_sources/object_storage/inference/ut/arrow_inference_ut.cpp @@ -93,16 +93,16 @@ TEST_F(ArrowInferenceTest, csv_simple) { ASSERT_NE(response, nullptr); auto& fields = response->Fields; - ASSERT_TRUE(fields[0].type().has_type_id()); - ASSERT_EQ(response->Fields[0].type().type_id(), Ydb::Type::INT64); - ASSERT_EQ(response->Fields[0].name(), "A"); + ASSERT_TRUE(fields[0].type().optional_type().item().has_type_id()); + ASSERT_EQ(fields[0].type().optional_type().item().type_id(), Ydb::Type::INT64); + ASSERT_EQ(fields[0].name(), "A"); - ASSERT_TRUE(fields[1].type().has_type_id()); - ASSERT_EQ(fields[1].type().type_id(), Ydb::Type::UTF8); + ASSERT_TRUE(fields[1].type().optional_type().item().has_type_id()); + ASSERT_EQ(fields[1].type().optional_type().item().type_id(), Ydb::Type::UTF8); ASSERT_EQ(fields[1].name(), "B"); - ASSERT_TRUE(fields[2].type().has_type_id()); - ASSERT_EQ(fields[2].type().type_id(), Ydb::Type::DOUBLE); + ASSERT_TRUE(fields[2].type().optional_type().item().has_type_id()); + ASSERT_EQ(fields[2].type().optional_type().item().type_id(), Ydb::Type::DOUBLE); ASSERT_EQ(fields[2].name(), "C"); } @@ -129,16 +129,16 @@ TEST_F(ArrowInferenceTest, tsv_simple) { ASSERT_NE(response, nullptr); auto& fields = response->Fields; - ASSERT_TRUE(fields[0].type().has_type_id()); - ASSERT_EQ(response->Fields[0].type().type_id(), Ydb::Type::INT64); - ASSERT_EQ(response->Fields[0].name(), "A"); + ASSERT_TRUE(fields[0].type().optional_type().item().has_type_id()); + ASSERT_EQ(fields[0].type().optional_type().item().type_id(), Ydb::Type::INT64); + ASSERT_EQ(fields[0].name(), "A"); - ASSERT_TRUE(fields[1].type().has_type_id()); - ASSERT_EQ(fields[1].type().type_id(), Ydb::Type::UTF8); + ASSERT_TRUE(fields[1].type().optional_type().item().has_type_id()); + ASSERT_EQ(fields[1].type().optional_type().item().type_id(), Ydb::Type::UTF8); ASSERT_EQ(fields[1].name(), "B"); - ASSERT_TRUE(fields[2].type().has_type_id()); - ASSERT_EQ(fields[2].type().type_id(), Ydb::Type::DOUBLE); + ASSERT_TRUE(fields[2].type().optional_type().item().has_type_id()); + ASSERT_EQ(fields[2].type().optional_type().item().type_id(), Ydb::Type::DOUBLE); ASSERT_EQ(fields[2].name(), "C"); } diff --git a/ydb/tests/fq/s3/test_s3_0.py b/ydb/tests/fq/s3/test_s3_0.py index 52073aafda99..76a27858a07b 100644 --- a/ydb/tests/fq/s3/test_s3_0.py +++ b/ydb/tests/fq/s3/test_s3_0.py @@ -11,6 +11,7 @@ import ydb.tests.library.common.yatest_common as yatest_common from ydb.tests.tools.datastreams_helpers.test_yds_base import TestYdsBase from ydb.tests.tools.fq_runner.kikimr_utils import yq_v1, yq_v2, yq_all +from google.protobuf.struct_pb2 import NullValue class TestS3(TestYdsBase): @@ -115,13 +116,13 @@ def test_inference(self, kikimr, s3, client, unique_prefix): logging.debug(str(result_set)) assert len(result_set.columns) == 4 assert result_set.columns[0].name == "Date" - assert result_set.columns[0].type.type_id == ydb.Type.DATE + assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.DATE assert result_set.columns[1].name == "Fruit" - assert result_set.columns[1].type.type_id == ydb.Type.UTF8 + assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8 assert result_set.columns[2].name == "Price" - assert result_set.columns[2].type.type_id == ydb.Type.INT64 + assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64 assert result_set.columns[3].name == "Weight" - assert result_set.columns[3].type.type_id == ydb.Type.INT64 + assert result_set.columns[3].type.optional_type.item.type_id == ydb.Type.INT64 assert len(result_set.rows) == 3 assert result_set.rows[0].items[0].uint32_value == 19724 assert result_set.rows[0].items[1].text_value == "Banana" @@ -175,11 +176,11 @@ def test_inference_null_column(self, kikimr, s3, client, unique_prefix): logging.debug(str(result_set)) assert len(result_set.columns) == 3 assert result_set.columns[0].name == "Fruit" - assert result_set.columns[0].type.type_id == ydb.Type.UTF8 + assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.UTF8 assert result_set.columns[1].name == "Missing column" - assert result_set.columns[1].type.type_id == ydb.Type.UTF8 + assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8 assert result_set.columns[2].name == "Price" - assert result_set.columns[2].type.type_id == ydb.Type.INT64 + assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64 assert len(result_set.rows) == 3 assert result_set.rows[0].items[0].text_value == "Banana" assert result_set.rows[0].items[1].text_value == "" @@ -191,6 +192,65 @@ def test_inference_null_column(self, kikimr, s3, client, unique_prefix): assert result_set.rows[2].items[1].text_value == "" assert result_set.rows[2].items[2].int64_value == 15 assert sum(kikimr.control_plane.get_metering(1)) == 10 + + @yq_v2 + @pytest.mark.parametrize("client", [{"folder_id": "my_folder"}], indirect=True) + def test_inference_optional_types(self, kikimr, s3, client, unique_prefix): + resource = boto3.resource( + "s3", endpoint_url=s3.s3_url, aws_access_key_id="key", aws_secret_access_key="secret_key" + ) + + bucket = resource.Bucket("fbucket") + bucket.create(ACL='public-read') + bucket.objects.all().delete() + + s3_client = boto3.client( + "s3", endpoint_url=s3.s3_url, aws_access_key_id="key", aws_secret_access_key="secret_key" + ) + + fruits = '''Fruit,Price,Weight,Date +Banana,,,2024-01-02 +Apple,2,22, +,15,33,2024-05-06''' + s3_client.put_object(Body=fruits, Bucket='fbucket', Key='fruits.csv', ContentType='text/plain') + kikimr.control_plane.wait_bootstrap(1) + storage_connection_name = unique_prefix + "fruitbucket" + client.create_storage_connection(storage_connection_name, "fbucket") + + sql = f''' + SELECT * + FROM `{storage_connection_name}`.`fruits.csv` + WITH (format=csv_with_names, with_infer='true'); + ''' + + query_id = client.create_query("simple", sql, type=fq.QueryContent.QueryType.ANALYTICS).result.query_id + client.wait_query_status(query_id, fq.QueryMeta.COMPLETED) + + data = client.get_result_data(query_id) + result_set = data.result.result_set + logging.debug(str(result_set)) + assert len(result_set.columns) == 4 + assert result_set.columns[0].name == "Date" + assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.DATE + assert result_set.columns[1].name == "Fruit" + assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8 + assert result_set.columns[2].name == "Price" + assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64 + assert result_set.columns[3].name == "Weight" + assert result_set.columns[3].type.optional_type.item.type_id == ydb.Type.INT64 + assert len(result_set.rows) == 3 + assert result_set.rows[0].items[0].uint32_value == 19724 + assert result_set.rows[0].items[1].text_value == "Banana" + assert result_set.rows[0].items[2].null_flag_value == NullValue.NULL_VALUE + assert result_set.rows[0].items[3].null_flag_value == NullValue.NULL_VALUE + assert result_set.rows[1].items[0].null_flag_value == NullValue.NULL_VALUE + assert result_set.rows[1].items[1].text_value == "Apple" + assert result_set.rows[1].items[2].int64_value == 2 + assert result_set.rows[1].items[3].int64_value == 22 + assert result_set.rows[2].items[0].uint32_value == 19849 + assert result_set.rows[2].items[1].null_flag_value == NullValue.NULL_VALUE + assert result_set.rows[2].items[2].int64_value == 15 + assert result_set.rows[2].items[3].int64_value == 33 @yq_all @pytest.mark.parametrize("client", [{"folder_id": "my_folder"}], indirect=True)