Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: upsert rows when set autoid==true fail #2286

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,89 @@ def _parse_row_request(
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
return request

@staticmethod
def _parse_upsert_row_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
fields_info: dict,
enable_dynamic: bool,
entities: List,
):
fields_data = {
field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"])
for field in fields_info
}
field_info_map = {field["name"]: field for field in fields_info}

if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

try:
for entity in entities:
if not isinstance(entity, Dict):
msg = f"expected Dict, got '{type(entity).__name__}'"
raise TypeError(msg)
for k, v in entity.items():
if k not in fields_data and not enable_dynamic:
raise DataNotMatchException(
message=ExceptionsMessage.InsertUnexpectedField % k
)

if k in fields_data:
field_info, field_data = field_info_map[k], fields_data[k]
if field_info.get("nullable", False) or field_info.get(
"default_value", None
):
field_data.valid_data.append(v is not None)
entity_helper.pack_field_value_to_field_data(v, field_data, field_info)
for field in fields_info:
key = field["name"]
if key in entity:
continue

field_info, field_data = field_info_map[key], fields_data[key]
if field_info.get("nullable", False) or field_info.get("default_value", None):
field_data.valid_data.append(False)
entity_helper.pack_field_value_to_field_data(None, field_data, field_info)
else:
raise DataNotMatchException(
message=ExceptionsMessage.InsertMissedField % key
)
json_dict = {
k: v for k, v in entity.items() if k not in fields_data and enable_dynamic
}

if enable_dynamic:
json_value = entity_helper.convert_to_json(json_dict)
d_field.scalars.json_data.data.append(json_value)

except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e

request.fields_data.extend([fields_data[field["name"]] for field in fields_info])

if enable_dynamic:
request.fields_data.append(d_field)

for _, field in enumerate(fields_info):
is_dynamic = False
field_name = field["name"]

if field.get("is_dynamic", False):
is_dynamic = True

for j, entity in enumerate(entities):
if is_dynamic and field_name in entity:
raise ParamError(
message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]"
)
if (enable_dynamic and len(fields_data) != len(fields_info) + 1) or (
not enable_dynamic and len(fields_data) != len(fields_info)
):
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
return request

@classmethod
def row_insert_param(
cls,
Expand Down Expand Up @@ -492,7 +575,7 @@ def row_upsert_param(
num_rows=len(entities),
)

return cls._parse_row_request(request, fields_info, enable_dynamic, entities)
return cls._parse_upsert_row_request(request, fields_info, enable_dynamic, entities)

@staticmethod
def _pre_insert_batch_check(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,39 @@ def test_row_insert_param_with_none(self):

Prepare.row_insert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

def test_row_upsert_param_with_auto_id(self):
import numpy as np
rng = np.random.default_rng(seed=19530)
dim = 8
schema = CollectionSchema([
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=dim),
FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True),
FieldSchema("float", DataType.DOUBLE)
])
rows = [
{"pk_field":1, "float": 1.0, "float_vector": rng.random((1, dim))[0], "a": 1},
{"pk_field":2, "float": 1.0, "float_vector": rng.random((1, dim))[0], "b": 1},
]

Prepare.row_upsert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

def test_upsert_param_with_none(self):
import numpy as np
rng = np.random.default_rng(seed=19530)
dim = 8
schema = CollectionSchema([
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=dim),
FieldSchema("nullable_field", DataType.INT64, nullable=True),
FieldSchema("default_field", DataType.FLOAT, default_value=10),
FieldSchema("pk_field", DataType.INT64, is_primary=True, auto_id=True),
FieldSchema("float", DataType.DOUBLE),
])
rows = [
{"pk_field":1, "float": 1.0,"nullable_field": None, "default_field": None,"float_vector": rng.random((1, dim))[0], "a": 1},
{"pk_field":2, "float": 1.0, "float_vector": rng.random((1, dim))[0], "b": 1},
]

Prepare.row_upsert_param("", rows, "", fields_info=schema.to_dict()["fields"], enable_dynamic=True)

class TestAlterCollectionRequest:
def test_alter_collection_request(self):
Expand Down