diff --git a/src/gpuhunt/providers/oci.py b/src/gpuhunt/providers/oci.py index 795ba9c..0e83fa3 100644 --- a/src/gpuhunt/providers/oci.py +++ b/src/gpuhunt/providers/oci.py @@ -1,5 +1,6 @@ import logging import re +from dataclasses import asdict, dataclass from typing import Iterable, List, Optional, Type import oci @@ -48,54 +49,29 @@ def get( if ( shape.hidden or shape.status != "ACTIVE" - or shape.shape_type.value != "vm" - or shape.sub_type.value == "flexible" + or shape.shape_type.value not in ("vm", "bm") + or shape.sub_type.value not in ("standard", "gpu", "optimized") + or ".A1." in shape.name ): continue - # extra validation, failing here would mean we are not handling some - # case that was not present in the data at the time of writing - if ( - len(shape.products) != 1 - or (ocpu_product := shape.products[0]).type.value != "ocpu" - or (product_details := products.find(ocpu_product.part_number)) is None - or product_details.billing_model != "UCM" - or product_details.price_type != "HOUR" - or (price_l10n := product_details.find_price_l10n("USD")) is None - or len(price_l10n.prices) != 1 - or (product_price := price_l10n.prices[0]).model != "PAY_AS_YOU_GO" - ): - logger.warning( - "Skipping shape %s due to unexpected cost estimator data", - shape.name, - ) - continue - - if shape.sub_type.value == "gpu" and shape.gpu_qty is not None: - shape_price = product_price.value * shape.gpu_qty - else: - shape_price = product_price.value * ocpu_product.qty - - vcpu = ocpu_product.qty if shape.is_arm_cpu() else ocpu_product.qty * 2 - - gpu = dict( - gpu_count=shape.gpu_qty or 0, - gpu_name=get_gpu_name(shape.name), - gpu_memory=shape.get_gpu_unit_memory_gb(), - ) - if any(gpu.values()) and not all(gpu.values()): + try: + resources = shape_to_resources(shape, products) + except CostEstimatorDataError as e: logger.warning( - "Skipping shape %s due to incomplete GPU parameters: %s", shape.name, gpu + "Skipping shape %s due to unexpected Cost Estimator data: %s", shape.name, e ) continue catalog_item = RawCatalogItem( instance_name=shape.name, location=None, - price=shape_price, - cpu=vcpu, - memory=shape.bundle_memory_qty, - **gpu, + price=resources.total_price(), + cpu=resources.cpu.vcpus, + memory=resources.memory.gbs, + gpu_count=resources.gpu.units_count, + gpu_name=resources.gpu.name, + gpu_memory=resources.gpu.unit_memory_gb, spot=False, disk_size=None, ) @@ -216,6 +192,107 @@ def _get(self, resource: str, ResponseModel: Type[BaseModel]): return ResponseModel.parse_raw(resp.content) +class CostEstimatorDataError(Exception): + pass + + +@dataclass +class CPUConfiguration: + vcpus: int + price: float + + +@dataclass +class MemoryConfiguration: + gbs: int + price: float + + +@dataclass +class GPUConfiguration: + units_count: int + unit_memory_gb: Optional[float] + name: Optional[str] + price: float + + def __post_init__(self): + d = asdict(self) + if any(d.values()) and not all(d.values()): + raise CostEstimatorDataError(f"Incomplete GPU parameters: {self}") + + +@dataclass +class ResourcesConfiguration: + cpu: CPUConfiguration + memory: MemoryConfiguration + gpu: GPUConfiguration + + def total_price(self) -> float: + return self.cpu.price + self.memory.price + self.gpu.price + + +def shape_to_resources( + shape: CostEstimatorShape, products: CostEstimatorProductList +) -> ResourcesConfiguration: + cpu = None + gpu = GPUConfiguration(units_count=0, unit_memory_gb=None, name=None, price=0.0) + memory = MemoryConfiguration(gbs=shape.bundle_memory_qty, price=0.0) + + for product in shape.products: + product_details = products.find(product.part_number) + if product_details is None: + raise CostEstimatorDataError(f"Could not find product {product.part_number!r}") + product_price = get_product_price_usd_per_hour(product_details) + + if product.type.value == "ocpu": + vcpus = product.qty if shape.is_arm_cpu() else product.qty * 2 + if shape.gpu_qty: + gpu = GPUConfiguration( + units_count=shape.gpu_qty, + unit_memory_gb=shape.get_gpu_unit_memory_gb(), + name=get_gpu_name(shape.name), + price=product_price * shape.gpu_qty, + ) + cpu = CPUConfiguration(vcpus=vcpus, price=0.0) + else: + cpu = CPUConfiguration(vcpus=vcpus, price=product_price * product.qty) + + elif product.type.value == "memory": + memory = MemoryConfiguration(gbs=product.qty, price=product_price * product.qty) + + else: + raise CostEstimatorDataError(f"Unknown product type {product.type.value!r}") + + if cpu is None: + raise CostEstimatorDataError(f"No ocpu product") + + return ResourcesConfiguration(cpu, memory, gpu) + + +def get_product_price_usd_per_hour(product: CostEstimatorProduct) -> float: + if product.billing_model != "UCM": + raise CostEstimatorDataError( + f"Billing model for product {product.part_number!r} is {product.billing_model!r}" + ) + if product.price_type != "HOUR": + raise CostEstimatorDataError( + f"Price type for product {product.part_number!r} is {product.price_type!r}" + ) + price_l10n = product.find_price_l10n("USD") + if price_l10n is None: + raise CostEstimatorDataError(f"No USD price for product {product.part_number!r}") + if len(price_l10n.prices) != 1: + raise CostEstimatorDataError( + f"Product {product.part_number!r} has {len(price_l10n.prices)} USD prices" + ) + price = price_l10n.prices[0] + if price.model != "PAY_AS_YOU_GO": + raise CostEstimatorDataError( + f"Pricing model for product {product.part_number!r} is {price.model!r}" + ) + return price.value + + def get_gpu_name(shape_name: str) -> Optional[str]: parts = re.split(r"[\.-]", shape_name.upper()) diff --git a/src/integrity_tests/test_oci.py b/src/integrity_tests/test_oci.py index ffe25b7..24b2a13 100644 --- a/src/integrity_tests/test_oci.py +++ b/src/integrity_tests/test_oci.py @@ -12,7 +12,7 @@ def data_rows(catalog_dir: Path) -> List[dict]: return list(csv.DictReader(f)) -@pytest.mark.parametrize("gpu", ["P100", "V100", "A10", ""]) +@pytest.mark.parametrize("gpu", ["P100", "V100", "A10", "A100", "H100", ""]) def test_gpu_present(gpu: str, data_rows: List[dict]): assert gpu in map(itemgetter("gpu_name"), data_rows) @@ -21,8 +21,9 @@ def test_on_demand_present(data_rows: List[dict]): assert "False" in map(itemgetter("spot"), data_rows) -def test_vm_present(data_rows: List[dict]): - assert any(name.startswith("VM") for name in map(itemgetter("instance_name"), data_rows)) +@pytest.mark.parametrize("prefix", ["VM.Standard", "BM.Standard", "VM.GPU", "BM.GPU"]) +def test_family_present(prefix: str, data_rows: List[dict]): + assert any(name.startswith(prefix) for name in map(itemgetter("instance_name"), data_rows)) def test_quantity_decreases_as_query_complexity_increases(data_rows: List[dict]):