Skip to content

Commit

Permalink
inferred types for csv files are now optional (#7358)
Browse files Browse the repository at this point in the history
  • Loading branch information
evanevanevanevannnn authored Aug 2, 2024
1 parent 59d098d commit ca41215
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace NKikimr::NExternalSource::NObjectStorage::NInference {

namespace {

bool ArrowToYdbType(Ydb::Type& resType, const arrow::DataType& type) {
bool ArrowToYdbType(Ydb::Type& optionalType, const arrow::DataType& type) {
auto& resType = *optionalType.mutable_optional_type()->mutable_item();
switch (type.id()) {
case arrow::Type::NA:
resType.set_type_id(Ydb::Type::UTF8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ TEST_F(ArrowInferenceTest, csv_simple) {
ASSERT_NE(response, nullptr);

auto& fields = response->Fields;
ASSERT_TRUE(fields[0].type().has_type_id());
ASSERT_EQ(response->Fields[0].type().type_id(), Ydb::Type::INT64);
ASSERT_EQ(response->Fields[0].name(), "A");
ASSERT_TRUE(fields[0].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[0].type().optional_type().item().type_id(), Ydb::Type::INT64);
ASSERT_EQ(fields[0].name(), "A");

ASSERT_TRUE(fields[1].type().has_type_id());
ASSERT_EQ(fields[1].type().type_id(), Ydb::Type::UTF8);
ASSERT_TRUE(fields[1].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[1].type().optional_type().item().type_id(), Ydb::Type::UTF8);
ASSERT_EQ(fields[1].name(), "B");

ASSERT_TRUE(fields[2].type().has_type_id());
ASSERT_EQ(fields[2].type().type_id(), Ydb::Type::DOUBLE);
ASSERT_TRUE(fields[2].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[2].type().optional_type().item().type_id(), Ydb::Type::DOUBLE);
ASSERT_EQ(fields[2].name(), "C");
}

Expand All @@ -129,16 +129,16 @@ TEST_F(ArrowInferenceTest, tsv_simple) {
ASSERT_NE(response, nullptr);

auto& fields = response->Fields;
ASSERT_TRUE(fields[0].type().has_type_id());
ASSERT_EQ(response->Fields[0].type().type_id(), Ydb::Type::INT64);
ASSERT_EQ(response->Fields[0].name(), "A");
ASSERT_TRUE(fields[0].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[0].type().optional_type().item().type_id(), Ydb::Type::INT64);
ASSERT_EQ(fields[0].name(), "A");

ASSERT_TRUE(fields[1].type().has_type_id());
ASSERT_EQ(fields[1].type().type_id(), Ydb::Type::UTF8);
ASSERT_TRUE(fields[1].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[1].type().optional_type().item().type_id(), Ydb::Type::UTF8);
ASSERT_EQ(fields[1].name(), "B");

ASSERT_TRUE(fields[2].type().has_type_id());
ASSERT_EQ(fields[2].type().type_id(), Ydb::Type::DOUBLE);
ASSERT_TRUE(fields[2].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[2].type().optional_type().item().type_id(), Ydb::Type::DOUBLE);
ASSERT_EQ(fields[2].name(), "C");
}

Expand Down
74 changes: 67 additions & 7 deletions ydb/tests/fq/s3/test_s3_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ydb.tests.library.common.yatest_common as yatest_common
from ydb.tests.tools.datastreams_helpers.test_yds_base import TestYdsBase
from ydb.tests.tools.fq_runner.kikimr_utils import yq_v1, yq_v2, yq_all
from google.protobuf.struct_pb2 import NullValue


class TestS3(TestYdsBase):
Expand Down Expand Up @@ -115,13 +116,13 @@ def test_inference(self, kikimr, s3, client, unique_prefix):
logging.debug(str(result_set))
assert len(result_set.columns) == 4
assert result_set.columns[0].name == "Date"
assert result_set.columns[0].type.type_id == ydb.Type.DATE
assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.DATE
assert result_set.columns[1].name == "Fruit"
assert result_set.columns[1].type.type_id == ydb.Type.UTF8
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[2].name == "Price"
assert result_set.columns[2].type.type_id == ydb.Type.INT64
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
assert result_set.columns[3].name == "Weight"
assert result_set.columns[3].type.type_id == ydb.Type.INT64
assert result_set.columns[3].type.optional_type.item.type_id == ydb.Type.INT64
assert len(result_set.rows) == 3
assert result_set.rows[0].items[0].uint32_value == 19724
assert result_set.rows[0].items[1].text_value == "Banana"
Expand Down Expand Up @@ -175,11 +176,11 @@ def test_inference_null_column(self, kikimr, s3, client, unique_prefix):
logging.debug(str(result_set))
assert len(result_set.columns) == 3
assert result_set.columns[0].name == "Fruit"
assert result_set.columns[0].type.type_id == ydb.Type.UTF8
assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[1].name == "Missing column"
assert result_set.columns[1].type.type_id == ydb.Type.UTF8
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[2].name == "Price"
assert result_set.columns[2].type.type_id == ydb.Type.INT64
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
assert len(result_set.rows) == 3
assert result_set.rows[0].items[0].text_value == "Banana"
assert result_set.rows[0].items[1].text_value == ""
Expand All @@ -191,6 +192,65 @@ def test_inference_null_column(self, kikimr, s3, client, unique_prefix):
assert result_set.rows[2].items[1].text_value == ""
assert result_set.rows[2].items[2].int64_value == 15
assert sum(kikimr.control_plane.get_metering(1)) == 10

@yq_v2
@pytest.mark.parametrize("client", [{"folder_id": "my_folder"}], indirect=True)
def test_inference_optional_types(self, kikimr, s3, client, unique_prefix):
resource = boto3.resource(
"s3", endpoint_url=s3.s3_url, aws_access_key_id="key", aws_secret_access_key="secret_key"
)

bucket = resource.Bucket("fbucket")
bucket.create(ACL='public-read')
bucket.objects.all().delete()

s3_client = boto3.client(
"s3", endpoint_url=s3.s3_url, aws_access_key_id="key", aws_secret_access_key="secret_key"
)

fruits = '''Fruit,Price,Weight,Date
Banana,,,2024-01-02
Apple,2,22,
,15,33,2024-05-06'''
s3_client.put_object(Body=fruits, Bucket='fbucket', Key='fruits.csv', ContentType='text/plain')
kikimr.control_plane.wait_bootstrap(1)
storage_connection_name = unique_prefix + "fruitbucket"
client.create_storage_connection(storage_connection_name, "fbucket")

sql = f'''
SELECT *
FROM `{storage_connection_name}`.`fruits.csv`
WITH (format=csv_with_names, with_infer='true');
'''

query_id = client.create_query("simple", sql, type=fq.QueryContent.QueryType.ANALYTICS).result.query_id
client.wait_query_status(query_id, fq.QueryMeta.COMPLETED)

data = client.get_result_data(query_id)
result_set = data.result.result_set
logging.debug(str(result_set))
assert len(result_set.columns) == 4
assert result_set.columns[0].name == "Date"
assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.DATE
assert result_set.columns[1].name == "Fruit"
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[2].name == "Price"
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
assert result_set.columns[3].name == "Weight"
assert result_set.columns[3].type.optional_type.item.type_id == ydb.Type.INT64
assert len(result_set.rows) == 3
assert result_set.rows[0].items[0].uint32_value == 19724
assert result_set.rows[0].items[1].text_value == "Banana"
assert result_set.rows[0].items[2].null_flag_value == NullValue.NULL_VALUE
assert result_set.rows[0].items[3].null_flag_value == NullValue.NULL_VALUE
assert result_set.rows[1].items[0].null_flag_value == NullValue.NULL_VALUE
assert result_set.rows[1].items[1].text_value == "Apple"
assert result_set.rows[1].items[2].int64_value == 2
assert result_set.rows[1].items[3].int64_value == 22
assert result_set.rows[2].items[0].uint32_value == 19849
assert result_set.rows[2].items[1].null_flag_value == NullValue.NULL_VALUE
assert result_set.rows[2].items[2].int64_value == 15
assert result_set.rows[2].items[3].int64_value == 33

@yq_all
@pytest.mark.parametrize("client", [{"folder_id": "my_folder"}], indirect=True)
Expand Down

0 comments on commit ca41215

Please sign in to comment.