Skip to content

Commit

Permalink
Feature/improve codegen (#19)
Browse files Browse the repository at this point in the history
* Fix naming duplication

* Fix print enum duplicates

* Fix string const

* Improve const support

* Improve class naming

* Improve Optional detection

* Improve detection for list models

* Improve multilines code
  • Loading branch information
ehooo authored Sep 30, 2024
1 parent 918f520 commit 0b6fdd8
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 93 deletions.
10 changes: 9 additions & 1 deletion src/lima_api/code_generator/cli_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,15 @@ def _get_type(self, status: str) -> str:
else:
# TODO generate model on fly
options.add("dict")
returned_type = f"typing.Union[{', '.join(options)}]" if len(options) > 1 else options.pop()

if len(options) == 1:
returned_type = options.pop()
else:
mode = "typing.Union"
if "None" in options:
options.remove("None")
mode = "typing.Optional"
returned_type = f"{mode}[{', '.join(options)}]"
elif schema.get("type") in OPENAPI_2_TYPE_MAPPING:
returned_type = OPENAPI_2_TYPE_MAPPING[schema.get("type")]
else:
Expand Down
232 changes: 143 additions & 89 deletions src/lima_api/code_generator/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ def __str__(self):
key = f"OP_{choice}"
if self.type == OpenApiType.STRING:
key = camel_to_snake(choice)
if key[0].isdigit():
key = f"{camel_to_snake(self.name).upper()}_{key}"
attributes += f' {key.upper()} = "{choice}"\n'
return BASE_CLASS.substitute(
model_class_name=self.name,
model_class_parent=f"{OPENAPI_2_TYPE_MAPPING.get(self.type)}, Enum",
class_attributes=attributes,
class_methods="",
).replace("\n\n", "\n")
return (
BASE_CLASS.substitute(
model_class_name=self.name,
model_class_parent=f"{OPENAPI_2_TYPE_MAPPING.get(self.type)}, Enum",
class_attributes=attributes,
class_methods="",
)
.replace("\n\n\n", "\n")
.replace("\n\n", "\n")
)

def __hash__(self):
return hash(str(self))

def __eq__(self, other):
return str(self) == str(other)


class SchemaObjectType(str, Enum):
Expand All @@ -38,6 +50,7 @@ class SchemaObjectType(str, Enum):
ENUM = "enum"
UNION = "union"
CONST = "const"
TYPE = "type"


class PropertyParser:
Expand All @@ -63,66 +76,15 @@ def _get_final_type(self, obj):
return def_type

def parse(self):
def_type = self.definition.get("type")
if def_type in OPENAPI_2_TYPE_MAPPING:
def_type: str = OPENAPI_2_TYPE_MAPPING.get(def_type)
if "items" in self.definition:
items = self.definition.get("items")
item_type = items.get("type")
if "$ref" in items:
ref = self.parser.get_ref(items.get("$ref"))
self.models.add(ref.name)
self.models.update(ref.models)
def_type = f"{def_type}[{ref.name}]"
else:
is_list = False
item_type = OPENAPI_2_TYPE_MAPPING.get(item_type)
if "anyOf" in items or "oneOf" in items:
key_of = "anyOf" if "anyOf" in items else "oneOf"
any_of = []
for item in items[key_of]:
if "$ref" in item:
ref = self.parser.get_ref(item.get("$ref"))
any_of.append(ref.name)
self.models.add(ref.name)
self.models.update(ref.models)
elif item.get("type") in OPENAPI_2_TYPE_MAPPING:
item_type = OPENAPI_2_TYPE_MAPPING.get(item.get("type"))
any_of.append(item_type)
else:
raise NotImplementedError(f"Unsupported type {item.get('type')}")
item_type = ", ".join(any_of)
if def_type == "list":
is_list = True
def_type = "typing.Union"
elif item_type is None:
if "properties" in items:
item_type = snake_to_camel(self.name) + "Embed"
obj = SchemaObject(self.parser, item_type)
obj.set_as_object(items.get("properties"), items.get("required", []))
self.models.add(obj.name)
self.models.update(obj.models)
self.embed_cls = obj
else:
raise NotImplementedError("Unsupported jet")

def_type = f"{def_type}[{item_type}]"
if is_list:
def_type = f"list[{def_type}]"

if "enum" in self.definition:
type_name = snake_to_camel(self.definition.get("title", self.name))
self.enum = EnumObject(
type_name=type_name,
options=self.definition.get("enum"),
enum_type=self.definition.get("type"),
)
def_type = type_name
if "$ref" in self.definition:
ref = self.parser.get_ref(self.definition.get("$ref"))
self.models.add(ref.name)
self.models.update(ref.models)
def_type = ref.name
schema = self.parser.process_schema(snake_to_camel(self.name), self.definition)
def_type = schema.attributes if schema.type != SchemaObjectType.OBJECT else schema.name
self.models.update(schema.models)
if schema.type == SchemaObjectType.ENUM:
self.enum = schema
elif schema.enums:
self.enum = schema.enums.pop()
schema.enums.add(self.enum)

field_kwargs = ""
field_name = camel_to_snake(self.name)
kwargs = []
Expand All @@ -136,6 +98,13 @@ def parse(self):
kwargs.append(f' {key}="{value}",\n')
else:
kwargs.append(f" {key}={value},\n")
if "default" in self.definition:
default_value = self.definition["default"]
if isinstance(default_value, str):
default_value = f'"{default_value}"'
kwargs.append(f" default={default_value},\n")
elif "Optional" in def_type:
kwargs.append(" default=None,\n")
if kwargs:
field_kwargs += "\n"
field_kwargs += "".join(kwargs)
Expand All @@ -153,7 +122,7 @@ def __init__(self, parser: "SchemaParser", name: str):
self.parser: SchemaParser = parser
self.type: Optional[SchemaObjectType] = None
self._str: str = ""
self.enums: list[EnumObject] = []
self.enums: set[EnumObject] = set()
self.embed_cls: list[SchemaObject] = []
self.attributes: str = ""
self.models = set()
Expand All @@ -180,21 +149,27 @@ def set_as_object(self, properties: dict, required: Optional[list] = None) -> No
parser.parse()
self.models.update(parser.models)
if parser.enum:
self.enums.append(parser.enum)
self.enums.add(parser.enum)
props.append(parser)
if parser.embed_cls is not None:
self.embed_cls.append(parser.embed_cls)
self.attributes += f"\n {parser}"
self._str += "".join([str(enum) for enum in self.enums])
if not properties and not self.attributes:
self.attributes = " ..."
# self._str += "".join([str(enum) for enum in self.enums])
self._str += "\n\n".join([str(cls) for cls in self.embed_cls])
if self.embed_cls:
self._str += "\n"
self._str += BASE_CLASS.substitute(
model_class_name=self.name,
model_class_parent="pydantic.BaseModel",
class_attributes=self.attributes,
class_methods="",
).replace("\n\n", "\n")
self._str += (
BASE_CLASS.substitute(
model_class_name=self.name,
model_class_parent="pydantic.BaseModel",
class_attributes=self.attributes,
class_methods="",
)
.replace("\n\n\n", "\n")
.replace("\n\n", "\n")
)

def set_as_alias(self, alias_type: str) -> None:
self.type = SchemaObjectType.ALIAS
Expand All @@ -208,28 +183,40 @@ def set_as_array(self, array_type: str, required=False) -> None:
self._str = f"typing.Optional[list[{array_type}]] = None"
else:
self._str = f"list[{array_type}]"
self.attributes = self._str

def set_as_enum(self, type_name, options: list[Any], enum_type: OpenApiType):
self.type = SchemaObjectType.ENUM
self.models.add(type_name)
self._str = str(EnumObject(type_name, options, enum_type))
self.attributes = type_name

def set_as_const(self, value):
self.type = SchemaObjectType.CONST
self._str = f"Literal[{value}]"
if isinstance(value, str):
value = f'"{value}"'
self._str = f"typing.Literal[{value}]"
self.attributes = self._str

def set_as_union(self, items):
self.type = SchemaObjectType.UNION
options: set[str] = set()
for any_of in items:
if "$ref" in any_of:
if any_of == {}:
options.add("dict")
elif "$ref" in any_of:
ref = self.parser.get_ref(any_of["$ref"])
self.models.add(ref.name)
options.add(ref.name)
elif "const" in any_of:
const = any_of["const"]
if isinstance(const, str):
const = f'"{const}"'
options.add(f"typing.Literal[{const}]")
elif "items" in any_of:
ops = []
if "$ref" in any_of["items"]:
ref = self.parser.get_ref(any_of["$ref"])
ref = self.parser.get_ref(any_of["items"]["$ref"])
self.models.add(ref.name)
ops.append(ref.name)
elif any_of["items"].get("type") in OPENAPI_2_TYPE_MAPPING:
Expand All @@ -243,7 +230,21 @@ def set_as_union(self, items):
else:
# Add on embed_cls
raise NotImplementedError("not supported")
self._str = f"typing.Union[{', '.join(options)}]" if len(options) > 1 else options.pop()

if len(options) == 1:
self._str = options.pop()
else:
mode = "typing.Union"
if "None" in options:
options.remove("None")
mode = "typing.Optional"
self._str = f"{mode}[{', '.join(sorted(options))}]"
self.attributes = self._str

def set_as_type(self, std_type: str):
self.type = SchemaObjectType.TYPE
self._str = OPENAPI_2_TYPE_MAPPING[std_type]
self.attributes = self._str


class SchemaParser:
Expand Down Expand Up @@ -272,10 +273,15 @@ def process_schema(self, schema_name: str, schema_data: dict) -> SchemaObject:

if "allOf" in schema_data:
new_schema = SchemaObject(self, schema_name)
for item in schema_data.get("allOf", []):
schema = self.process_schema(schema_name, item)
self.models.update(new_schema.models)
new_schema.attributes += schema.attributes
if len(schema_data["allOf"]) == 1:
new_schema = self.process_schema(schema_name, schema_data["allOf"][0])
return new_schema
else:
for item in schema_data.get("allOf", []):
schema = self.process_schema(schema_name, item)
self.models.update(schema.models)
new_schema.attributes += schema.attributes

new_schema.set_as_object({}, [])
self.models.update(new_schema.models)
return new_schema
Expand All @@ -291,7 +297,7 @@ def process_schema(self, schema_name: str, schema_data: dict) -> SchemaObject:
required=schema_data.get("required"),
)
self.models.update(new_schema.models)
case "array":
case OpenApiType.ARRAY:
items = schema_data.get("items", {})
array_type = items.get("type")
if array_type in OPENAPI_2_TYPE_MAPPING:
Expand All @@ -300,24 +306,58 @@ def process_schema(self, schema_name: str, schema_data: dict) -> SchemaObject:
obj = self.get_ref(items.get("$ref"))
self.models.add(obj.name)
self.models.update(obj.models)
new_schema.enums.update(obj.enums)
new_schema.models.update(obj.models)
new_schema.set_as_alias(f"list[{obj.name}]")
elif "anyOf" in items:
obj = self.process_schema(schema_name, items)
new_schema.enums.update(obj.enums)
new_schema.models.update(obj.models)
self.models.update(obj.models)
new_schema.set_as_alias(f"list[{obj}]")
else:
raise NotImplementedError(f"Type for list {array_type} not supported")
case "string":
raise NotImplementedError(f"Type for list '{array_type}' not supported")
case OpenApiType.BOOLEAN:
if "const" in schema_data:
new_schema.set_as_const(schema_data.get("const"))
return new_schema
new_schema.set_as_type(schema_data.get("type"))
return new_schema
case OpenApiType.NUMBER:
if "const" in schema_data:
new_schema.set_as_const(schema_data.get("const"))
return new_schema
new_schema.set_as_type(schema_data.get("type"))
return new_schema
case OpenApiType.INTEGER:
if "const" in schema_data:
new_schema.set_as_const(schema_data.get("const"))
return new_schema
new_schema.set_as_type(schema_data.get("type"))
return new_schema
case OpenApiType.STRING:
if "enum" in schema_data:
new_schema.set_as_enum(schema_name, schema_data.get("enum"), schema_data.get("type"))
enums = schema_data.get("enum")
if len(enums) == 1:
new_schema.set_as_const(enums[0])
return new_schema
new_schema.set_as_enum(schema_name, enums, schema_data.get("type"))
self.models.update(new_schema.models)
return new_schema
if "const" in schema_data:
new_schema.set_as_const(schema_data.get("const"))
return new_schema
raise NotImplementedError(f"Type {schema_data.get('type')} not supported")
new_schema.set_as_type(schema_data.get("type"))
return new_schema
case _:
if "$ref" in schema_data:
obj = self.get_ref(schema_data.get("$ref"))
self.models.add(obj.name)
self.models.update(obj.models)
return obj
if "const" in schema_data:
new_schema.set_as_const(schema_data.get("const"))
return new_schema
raise NotImplementedError(f"Type {schema_data.get('type')} not supported")
return new_schema

Expand All @@ -328,9 +368,23 @@ def get_ref(self, ref: str) -> SchemaObject:
raise ValueError(f"Schema {ref} not found")

self.schemas[ref] = self.process_schema(schema_name, self.raw_schemas[schema_name])
if self.base_name in ref and self.schemas[ref].type == SchemaObjectType.CONST:
self.schemas[ref].set_as_alias(self.schemas[ref]._str)
self.order.append(ref)
return self.schemas[ref]

def print(self, file=None):
enums = set()
for schema_name in self.order:
print(self.schemas.get(schema_name), file=file)
schema = self.schemas.get(schema_name)
enums_str = []
for enum in schema.enums:
if enum.name not in enums:
enums_str.append(str(enum))
enums.add(enum.name)

if enums_str:
print("\n".join(enums_str), file=file, end="\n")
if schema.type == SchemaObjectType.ENUM:
enums.add(schema.name)
print(schema, file=file)
Loading

0 comments on commit 0b6fdd8

Please sign in to comment.