diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4018d1bb3..4e80cdc37 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,7 +21,7 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo -from typing_extensions import get_args, get_origin +from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION @@ -177,16 +177,17 @@ def is_field_noneable(field: "FieldInfo") -> bool: return False return False - def get_type_from_field(field: Any) -> Any: - type_: Any = field.annotation + def get_sa_type_from_type_annotation(annotation: Any) -> Any: # Resolve Optional fields - if type_ is None: + if annotation is None: raise ValueError("Missing field type") - origin = get_origin(type_) + origin = get_origin(annotation) if origin is None: - return type_ + return annotation + elif origin is Annotated: + return get_sa_type_from_type_annotation(get_args(annotation)[0]) if _is_union_type(origin): - bases = get_args(type_) + bases = get_args(annotation) if len(bases) > 2: raise ValueError( "Cannot have a (non-optional) union as a SQLAlchemy field" @@ -197,9 +198,14 @@ def get_type_from_field(field: Any) -> Any: "Cannot have a (non-optional) union as a SQLAlchemy field" ) # Optional unions are allowed - return bases[0] if bases[0] is not NoneType else bases[1] + use_type = bases[0] if bases[0] is not NoneType else bases[1] + return get_sa_type_from_type_annotation(use_type) return origin + def get_sa_type_from_field(field: Any) -> Any: + type_: Any = field.annotation + return get_sa_type_from_type_annotation(type_) + def get_field_metadata(field: Any) -> Any: for meta in field.metadata: if isinstance(meta, (PydanticMetadata, MaxLen)): @@ -444,7 +450,7 @@ def is_field_noneable(field: "FieldInfo") -> bool: ) return field.allow_none # type: ignore[no-any-return, attr-defined] - def get_type_from_field(field: Any) -> Any: + def get_sa_type_from_field(field: Any) -> Any: if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d8fced51f..1597e4e04 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -71,7 +71,7 @@ get_field_metadata, get_model_fields, get_relationship_to, - get_type_from_field, + get_sa_type_from_field, init_pydantic_private_attrs, is_field_noneable, is_table_model_class, @@ -649,7 +649,7 @@ def get_sqlalchemy_type(field: Any) -> Any: if sa_type is not Undefined: return sa_type - type_ = get_type_from_field(field) + type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI diff --git a/tests/test_annotated_uuid.py b/tests/test_annotated_uuid.py new file mode 100644 index 000000000..b0e25ab09 --- /dev/null +++ b/tests/test_annotated_uuid.py @@ -0,0 +1,26 @@ +import uuid +from typing import Optional + +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from tests.conftest import needs_pydanticv2 + + +@needs_pydanticv2 +def test_annotated_optional_types(clear_sqlmodel) -> None: + from pydantic import UUID4 + + class Hero(SQLModel, table=True): + # Pydantic UUID4 is: Annotated[UUID, UuidVersion(4)] + id: Optional[UUID4] = Field(default_factory=uuid.uuid4, primary_key=True) + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + db.commit() + statement = select(Hero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(hero.id, uuid.UUID)