From aebe8532d583c6bcf644f69e34fc45fdf4a1422e Mon Sep 17 00:00:00 2001 From: Marcelo Reis Date: Thu, 14 Mar 2024 23:59:21 -0400 Subject: [PATCH] Updated sdwan_config_builder with native support for Pydantic 2.6.x --- config/config.example.yaml | 2 +- sdwan_config_builder/README.md | 4 +- sdwan_config_builder/requirements.txt | 5 +- .../src/sdwan_config_builder/__version__.py | 4 +- .../src/sdwan_config_builder/commands.py | 7 +- .../sdwan_config_builder/loader/__init__.py | 2 +- .../src/sdwan_config_builder/loader/models.py | 166 +++++++++--------- .../sdwan_config_builder/loader/validators.py | 27 +-- 8 files changed, 110 insertions(+), 107 deletions(-) diff --git a/config/config.example.yaml b/config/config.example.yaml index 04c9d7b..bedf63b 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -1,7 +1,7 @@ --- global_config: # Different cloud have different restrictions on the characters allowed in - # the tags. AWS is the most permisive, GCP is the most restrictive, Azure is + # the tags. AWS is the most permissive, GCP is the most restrictive, Azure is # somewhere in the middle. If you use all lowercase letters and numbers, you # should be fine. common_tags: diff --git a/sdwan_config_builder/README.md b/sdwan_config_builder/README.md index 5dc8461..0e75311 100644 --- a/sdwan_config_builder/README.md +++ b/sdwan_config_builder/README.md @@ -35,7 +35,7 @@ Install config builder: Validate that config builder is installed: ``` (venv) % sdwan_config_build --version -SDWAN Config Builder Tool Version 0.7 +SDWAN Config Builder Tool Version 0.9 ``` ## Running @@ -45,7 +45,7 @@ should be saved. By default sdwan_config_build looks for a 'metadata.yaml' file The CONFIG_BUILDER_METADATA environment variable can be used to specify a custom location for the metadata file. ``` -(venv) % % sdwan_config_build --help +(venv) % sdwan_config_build --help usage: sdwan_config_build [-h] [--version] {render,export,schema} ... SDWAN Config Builder Tool diff --git a/sdwan_config_builder/requirements.txt b/sdwan_config_builder/requirements.txt index 1792b6d..473ca70 100644 --- a/sdwan_config_builder/requirements.txt +++ b/sdwan_config_builder/requirements.txt @@ -1,5 +1,6 @@ -pydantic>=2 -PyYAML>=6.0 +pydantic>=2.6 +pydantic-settings>=2.2.1 +PyYAML>=6.0.1 Jinja2>=3.1 passlib>=1.7.4 sshpubkeys>=3.3 diff --git a/sdwan_config_builder/src/sdwan_config_builder/__version__.py b/sdwan_config_builder/src/sdwan_config_builder/__version__.py index d85bf0a..6545775 100644 --- a/sdwan_config_builder/src/sdwan_config_builder/__version__.py +++ b/sdwan_config_builder/src/sdwan_config_builder/__version__.py @@ -1,2 +1,2 @@ -__copyright__ = "Copyright (c) 2022-2023 Cisco Systems, Inc. and/or its affiliates" -__version__ = "0.8.2" +__copyright__ = "Copyright (c) 2022-2024 Cisco Systems, Inc. and/or its affiliates" +__version__ = "0.9.0" diff --git a/sdwan_config_builder/src/sdwan_config_builder/commands.py b/sdwan_config_builder/src/sdwan_config_builder/commands.py index a611bc3..f4c0e28 100644 --- a/sdwan_config_builder/src/sdwan_config_builder/commands.py +++ b/sdwan_config_builder/src/sdwan_config_builder/commands.py @@ -1,5 +1,6 @@ import argparse import logging +import json from typing import Union from pathlib import Path from ipaddress import IPv4Interface, IPv4Network @@ -93,7 +94,7 @@ def render_cmd(cli_args: argparse.Namespace) -> None: 'ipv4_subnet_host': ipv4_subnet_host_filter, } jinja_env.filters.update(custom_filters) - jinja_env.globals = config_obj.dict(by_alias=True) + jinja_env.globals = config_obj.model_dump(by_alias=True) for jinja_target in app_config.targets_config.jinja_renderer.targets: try: @@ -128,7 +129,7 @@ def export_cmd(cli_args: argparse.Namespace) -> None: try: config_obj = load_yaml(ConfigModel, 'config', app_config.loader_config.top_level_config) with open(cli_args.file, 'w') as export_file: - export_file.write(config_obj.json(by_alias=True, indent=2)) + export_file.write(config_obj.model_dump_json(by_alias=True, indent=2)) logger.info(f"Exported source configuration as '{cli_args.file}'") @@ -143,6 +144,6 @@ def schema_cmd(cli_args: argparse.Namespace) -> None: :return: None """ with open(cli_args.file, 'w') as schema_file: - schema_file.write(ConfigModel.schema_json(indent=2)) + schema_file.write(json.dumps(ConfigModel.model_json_schema(), indent=2)) logger.info(f"Saved configuration schema as '{cli_args.file}'") diff --git a/sdwan_config_builder/src/sdwan_config_builder/loader/__init__.py b/sdwan_config_builder/src/sdwan_config_builder/loader/__init__.py index b68e3fe..bd010ea 100644 --- a/sdwan_config_builder/src/sdwan_config_builder/loader/__init__.py +++ b/sdwan_config_builder/src/sdwan_config_builder/loader/__init__.py @@ -2,7 +2,7 @@ import sys import yaml from typing import Any, TypeVar, Type, Union, List -from pydantic.v1 import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError from .models import ConfigModel diff --git a/sdwan_config_builder/src/sdwan_config_builder/loader/models.py b/sdwan_config_builder/src/sdwan_config_builder/loader/models.py index 88ac519..bc11ecd 100644 --- a/sdwan_config_builder/src/sdwan_config_builder/loader/models.py +++ b/sdwan_config_builder/src/sdwan_config_builder/loader/models.py @@ -1,10 +1,12 @@ from functools import partial from secrets import token_urlsafe -from typing import List, Dict, Any, Optional, Iterable, Union +from typing import List, Dict, Optional, Iterable, Union +from typing_extensions import Annotated from enum import Enum from pathlib import Path from ipaddress import IPv4Network, IPv6Network, IPv4Interface, IPv4Address -from pydantic.v1 import BaseModel, BaseSettings, Field, validator, root_validator, constr, conint +from pydantic import field_validator, field_serializer, model_validator, Field, ConfigDict, BaseModel, ValidationInfo +from pydantic_settings import SettingsConfigDict, BaseSettings from passlib.hash import sha512_crypt from sshpubkeys import SSHKey, InvalidKeyError from .validators import (formatted_string, unique_system_ip, constrained_cidr, cidr_subnet, subnet_interface, @@ -58,20 +60,23 @@ class GlobalConfigModel(BaseSettings): """ GlobalConfigModel is a special config block as field values can use environment variables as their default value """ - home_dir: str = Field(..., env='HOME') - project_root: str = Field(..., env='PROJ_ROOT') + model_config = SettingsConfigDict(case_sensitive=True) + + home_dir: Annotated[str, Field(validation_alias='HOME')] + project_root: Annotated[str, Field(validation_alias='PROJ_ROOT')] common_tags: Dict[str, str] = None ubuntu_image: str - ssh_public_key_file: str = Field(None, description='Can use python format string syntax to reference other ' - 'previous fields in this model') - ssh_public_key: Optional[str] = None - ssh_public_key_fp: Optional[str] = None + ssh_public_key_file: Annotated[Optional[str], Field(description='Can use python format string syntax to reference ' + 'other previous fields in this model')] = None + ssh_public_key: Annotated[Optional[str], Field(validate_default=True)] = None + ssh_public_key_fp: Annotated[Optional[str], Field(validate_default=True)] = None - _validate_formatted_strings = validator('ssh_public_key_file', allow_reuse=True)(formatted_string) + _validate_formatted_strings = field_validator('ssh_public_key_file')(formatted_string) - @validator('ssh_public_key', always=True) - def resolve_ssh_public_key(cls, v, values: Dict[str, Any]): - pub_key = resolve_ssh_public_key(v, values.get('ssh_public_key_file')) + @field_validator('ssh_public_key') + @classmethod + def resolve_ssh_public_key(cls, v, info: ValidationInfo): + pub_key = resolve_ssh_public_key(v, info.data.get('ssh_public_key_file')) try: ssh_key = SSHKey(pub_key, strict=True) ssh_key.parse() @@ -82,10 +87,11 @@ def resolve_ssh_public_key(cls, v, values: Dict[str, Any]): return pub_key - @validator('ssh_public_key_fp', always=True) - def resolve_ssh_public_key_fp(cls, v: Union[str, None], values: Dict[str, Any]) -> str: + @field_validator('ssh_public_key_fp') + @classmethod + def resolve_ssh_public_key_fp(cls, v: Union[str, None], info: ValidationInfo) -> str: if v is None: - pub_key = values.get('ssh_public_key') + pub_key = info.data.get('ssh_public_key') if pub_key is None: raise ValueError("Field 'ssh_public_key' or 'ssh_public_key_file' not present") ssh_key = SSHKey(pub_key, strict=True) @@ -96,19 +102,17 @@ def resolve_ssh_public_key_fp(cls, v: Union[str, None], values: Dict[str, Any]) return fp - @validator('project_root') + @field_validator('project_root') + @classmethod def resolve_project_root(cls, v: str) -> str: return str(Path(v).resolve()) - class Config: - case_sensitive = True - # # infra providers block # class InfraProviderConfigModel(BaseModel): - ntp_server: constr(regex=r'^[a-zA-Z0-9.-]+$') + ntp_server: Annotated[str, Field(pattern=r'^[a-zA-Z0-9.-]+$')] class GCPConfigModel(InfraProviderConfigModel): @@ -133,38 +137,37 @@ class InfraProvidersModel(BaseModel): # class ControllerCommonInfraModel(BaseModel): + model_config = ConfigDict(use_enum_values=True) + provider: InfraProviderControllerOptionsEnum region: str - dns_domain: constr(regex=r'^[a-zA-Z0-9.-]+$') = Field( - '', description="If set, add A records for control plane element external addresses in AWS Route 53") - sw_version: constr(regex=r'^\d+(?:\.\d+)+$') + dns_domain: Annotated[str, Field(description="If set, add A records for control plane element external " + "addresses in AWS Route 53", pattern=r'^[a-zA-Z0-9.-]+$')] = '' + sw_version: Annotated[str, Field(pattern=r'^\d+(?:\.\d+)+$')] cloud_init_format: CloudInitEnum = CloudInitEnum.v1 - class Config: - use_enum_values = True - class ControllerCommonConfigModel(BaseModel): organization_name: str - site_id: conint(ge=0, le=4294967295) + site_id: Annotated[int, Field(ge=0, le=4294967295)] acl_ingress_ipv4: List[IPv4Network] acl_ingress_ipv6: List[IPv6Network] cidr: IPv4Network vpn0_gateway: IPv4Address - @validator('acl_ingress_ipv4', 'acl_ingress_ipv6') - def acl_str(cls, v: Iterable[IPv4Network]) -> str: + @field_serializer('acl_ingress_ipv4', 'acl_ingress_ipv6') + def acl_str(self, v: Iterable[IPv4Network]) -> str: return ', '.join(f'"{entry}"' for entry in v) - _validate_cidr = validator('cidr', allow_reuse=True)(constrained_cidr(max_length=23)) + _validate_cidr = field_validator('cidr')(constrained_cidr(max_length=23)) class CertAuthModel(BaseModel): passphrase: str = Field(default_factory=partial(token_urlsafe, 15)) cert_dir: str - ca_cert: str = '{cert_dir}/myCA.pem' + ca_cert: Annotated[str, Field(validate_default=True)] = '{cert_dir}/myCA.pem' - _validate_formatted_strings = validator('ca_cert', always=True, allow_reuse=True)(formatted_string) + _validate_formatted_strings = field_validator('ca_cert')(formatted_string) class ControllerConfigModel(BaseModel): @@ -172,18 +175,19 @@ class ControllerConfigModel(BaseModel): vpn0_interface_ipv4: IPv4Interface # Validators - _validate_system_ip = validator('system_ip', allow_reuse=True)(unique_system_ip) + _validate_system_ip = field_validator('system_ip')(unique_system_ip) class VmanageConfigModel(ControllerConfigModel): username: str = 'admin' password: str = Field(default_factory=partial(token_urlsafe, 12)) - password_hashed: Optional[str] = None + password_hashed: Annotated[Optional[str], Field(validate_default=True)] = None - @validator('password_hashed', always=True) - def hash_password(cls, v: Union[str, None], values: Dict[str, Any]) -> str: + @field_validator('password_hashed') + @classmethod + def hash_password(cls, v: Union[str, None], info: ValidationInfo) -> str: if v is None: - clear_password = values.get('password') + clear_password = info.data.get('password') if clear_password is None: raise ValueError("Field 'password' is not present") # Using 'openssl passwd -6' recipe @@ -227,78 +231,74 @@ class InfraVmwareModel(BaseModel): class EdgeInfraModel(ComputeInstanceModel): + model_config = ConfigDict(use_enum_values=True) + provider: InfraProviderOptionsEnum - region: Optional[str] = None - zone: Optional[str] = None - sw_version: constr(regex=r'^\d+(?:\.\d+)+') + region: Annotated[Optional[str], Field(validate_default=True)] = None + zone: Annotated[Optional[str], Field(validate_default=True)] = None + sw_version: Annotated[str, Field(pattern=r'^\d+(?:\.\d+)+')] cloud_init_format: CloudInitEnum = CloudInitEnum.v1 sdwan_model: str sdwan_uuid: str - vmware: Optional[InfraVmwareModel] = None - - @validator('region', always=True) - def region_validate(cls, v, values: Dict[str, Any]): - if v is None and values['provider'] != InfraProviderOptionsEnum.vmware: - raise ValueError(f"{values['provider']} provider requires 'region' to be defined") - if v is not None and values['provider'] == InfraProviderOptionsEnum.vmware: + vmware: Annotated[Optional[InfraVmwareModel], Field(validate_default=True)] = None + + @field_validator('region') + @classmethod + def region_validate(cls, v, info: ValidationInfo): + if v is None and info.data['provider'] != InfraProviderOptionsEnum.vmware: + raise ValueError(f"{info.data['provider']} provider requires 'region' to be defined") + if v is not None and info.data['provider'] == InfraProviderOptionsEnum.vmware: raise ValueError(f"'region' is not allowed when provider is {InfraProviderOptionsEnum.vmware}") return v - @validator('zone', always=True) - def zone_validate(cls, v, values: Dict[str, Any]): - if v is None and values['provider'] == InfraProviderOptionsEnum.gcp: + @field_validator('zone') + @classmethod + def zone_validate(cls, v, info: ValidationInfo): + if v is None and info.data['provider'] == InfraProviderOptionsEnum.gcp: raise ValueError("GCP requires zone to be defined") return v - @validator('vmware', always=True) - def vmware_section(cls, v, values: Dict[str, Any]): - if v is None and values['provider'] == InfraProviderOptionsEnum.vmware: + @field_validator('vmware') + @classmethod + def vmware_section(cls, v, info: ValidationInfo): + if v is None and info.data['provider'] == InfraProviderOptionsEnum.vmware: raise ValueError(f"{InfraProviderOptionsEnum.vmware} provider requires 'vmware' section to be defined") - if v is not None and values['provider'] != InfraProviderOptionsEnum.vmware: + if v is not None and info.data['provider'] != InfraProviderOptionsEnum.vmware: raise ValueError(f"'vmware' section is only allowed when provider is {InfraProviderOptionsEnum.vmware}") return v - @root_validator - def instance_type_validate(cls, values: Dict[str, Any]): - if values['instance_type'] is None and values['provider'] != InfraProviderOptionsEnum.vmware: - raise ValueError(f"{values['provider']} provider requires 'instance_type' to be defined") - if values['instance_type'] is not None and values['provider'] == InfraProviderOptionsEnum.vmware: + @model_validator(mode='after') + def instance_type_validate(self) -> 'EdgeInfraModel': + if self.instance_type is None and self.provider != InfraProviderOptionsEnum.vmware: + raise ValueError(f"{self.provider} provider requires 'instance_type' to be defined") + if self.instance_type is not None and self.provider == InfraProviderOptionsEnum.vmware: raise ValueError(f"'instance_type' is not allowed when provider is {InfraProviderOptionsEnum.vmware}") - return values - - class Config: - use_enum_values = True + return self class EdgeConfigModel(BaseModel): - site_id: conint(ge=0, le=4294967295) + site_id: Annotated[int, Field(ge=0, le=4294967295)] system_ip: IPv4Address cidr: Optional[IPv4Network] = None - vpn0_range: Optional[IPv4Network] = None - vpn0_interface_ipv4: Optional[IPv4Interface] = None - vpn0_gateway: Optional[IPv4Address] = None - vpn1_range: Optional[IPv4Network] = None - vpn1_interface_ipv4: Optional[IPv4Interface] = None + vpn0_range: Annotated[Optional[IPv4Network], Field(validate_default=True)] = None + vpn0_interface_ipv4: Annotated[Optional[IPv4Interface], Field(validate_default=True)] = None + vpn0_gateway: Annotated[Optional[IPv4Address], Field(validate_default=True)] = None + vpn1_range: Annotated[Optional[IPv4Network], Field(validate_default=True)] = None + vpn1_interface_ipv4: Annotated[Optional[IPv4Interface], Field(validate_default=True)] = None # Validators - _validate_system_ip = validator('system_ip', allow_reuse=True)(unique_system_ip) - _validate_cidr = validator('cidr', allow_reuse=True)(constrained_cidr(max_length=23)) - _validate_vpn_range = validator('vpn0_range', 'vpn1_range', always=True, allow_reuse=True)( - cidr_subnet(cidr_field='cidr', prefix_len=24) - ) - _validate_vpn0_ipv4 = validator('vpn0_interface_ipv4', always=True, allow_reuse=True)( - subnet_interface(subnet_field='vpn0_range', host_index=10) - ) - _validate_vpn1_ipv4 = validator('vpn1_interface_ipv4', always=True, allow_reuse=True)( - subnet_interface(subnet_field='vpn1_range', host_index=10) - ) - _validate_vpn0_gw = validator('vpn0_gateway', always=True, allow_reuse=True)( - subnet_address(subnet_field='vpn0_range', host_index=0) - ) + _validate_system_ip = field_validator('system_ip')(unique_system_ip) + _validate_cidr = field_validator('cidr')(constrained_cidr(max_length=23)) + _validate_vpn_range = field_validator('vpn0_range', 'vpn1_range')(cidr_subnet(cidr_field='cidr', prefix_len=24)) + _validate_vpn0_ipv4 = field_validator('vpn0_interface_ipv4')(subnet_interface(subnet_field='vpn0_range', + host_index=10)) + _validate_vpn1_ipv4 = field_validator('vpn1_interface_ipv4')(subnet_interface(subnet_field='vpn1_range', + host_index=10)) + _validate_vpn0_gw = field_validator('vpn0_gateway')(subnet_address(subnet_field='vpn0_range', host_index=0)) class EdgeModel(BaseModel): diff --git a/sdwan_config_builder/src/sdwan_config_builder/loader/validators.py b/sdwan_config_builder/src/sdwan_config_builder/loader/validators.py index 5ddef18..686a160 100644 --- a/sdwan_config_builder/src/sdwan_config_builder/loader/validators.py +++ b/sdwan_config_builder/src/sdwan_config_builder/loader/validators.py @@ -1,19 +1,20 @@ -from typing import Set, Dict, Any, Optional, Callable, Iterator +from typing import Set, Dict, Optional, Callable, Iterator from ipaddress import IPv4Address, IPv4Network, IPv4Interface +from pydantic import ValidationInfo # # Reusable validators # -def formatted_string(v: str, values: Dict[str, Any]) -> str: +def formatted_string(v: str, info: ValidationInfo) -> str: """ Process v as a python formatted string :param v: Value to be validated - :param values: {: ...} dict of previously validated model fields + :param info: A ValidationInfo instance with previously validated model fields :return: Expanded formatted string """ try: - return v.format(**values) if v is not None else v + return v.format(**info.data) if v is not None else v except KeyError as ex: raise ValueError(f"Variable not found: {ex}") from None @@ -52,13 +53,13 @@ def cidr_subnet( *, cidr_field: str, prefix_len: int = 24 -) -> Callable[[IPv4Network, Dict[str, Any]], IPv4Network]: +) -> Callable[[IPv4Network, ValidationInfo], IPv4Network]: subnet_gen_map: Dict[IPv4Network, Iterator[IPv4Network]] = {} - def validator(subnet: IPv4Network, values: Dict[str, Any]) -> IPv4Network: + def validator(subnet: IPv4Network, info: ValidationInfo) -> IPv4Network: if subnet is None: - cidr = values.get(cidr_field, ...) + cidr = info.data.get(cidr_field, ...) if cidr is ...: raise ValueError(f"no cidr_field name {cidr_field}") if cidr is None: @@ -77,10 +78,10 @@ def subnet_interface( *, subnet_field: str, host_index: int -) -> Callable[[IPv4Interface, Dict[str, Any]], IPv4Interface]: - def validator(ipv4_interface: IPv4Interface, values: Dict[str, Any]) -> IPv4Interface: +) -> Callable[[IPv4Interface, ValidationInfo], IPv4Interface]: + def validator(ipv4_interface: IPv4Interface, info: ValidationInfo) -> IPv4Interface: if ipv4_interface is None: - subnet = values.get(subnet_field, ...) + subnet = info.data.get(subnet_field, ...) if subnet is ...: raise ValueError(f"no subnet_field name {subnet_field}") if subnet is None: @@ -99,10 +100,10 @@ def subnet_address( *, subnet_field: str, host_index: int -) -> Callable[[IPv4Address, Dict[str, Any]], IPv4Address]: - def validator(ipv4_address: IPv4Address, values: Dict[str, Any]) -> IPv4Address: +) -> Callable[[IPv4Address, ValidationInfo], IPv4Address]: + def validator(ipv4_address: IPv4Address, info: ValidationInfo) -> IPv4Address: if ipv4_address is None: - subnet = values.get(subnet_field, ...) + subnet = info.data.get(subnet_field, ...) if subnet is ...: raise ValueError(f"no subnet_field name {subnet_field}") if subnet is None: