Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OCI Bare Metal instances in catalog #71

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 115 additions & 38 deletions src/gpuhunt/providers/oci.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
from dataclasses import asdict, dataclass
from typing import Iterable, List, Optional, Type

import oci
Expand Down Expand Up @@ -48,54 +49,29 @@ def get(
if (
shape.hidden
or shape.status != "ACTIVE"
or shape.shape_type.value != "vm"
or shape.sub_type.value == "flexible"
or shape.shape_type.value not in ("vm", "bm")
or shape.sub_type.value not in ("standard", "gpu", "optimized")
or ".A1." in shape.name
):
continue

# extra validation, failing here would mean we are not handling some
# case that was not present in the data at the time of writing
if (
len(shape.products) != 1
or (ocpu_product := shape.products[0]).type.value != "ocpu"
or (product_details := products.find(ocpu_product.part_number)) is None
or product_details.billing_model != "UCM"
or product_details.price_type != "HOUR"
or (price_l10n := product_details.find_price_l10n("USD")) is None
or len(price_l10n.prices) != 1
or (product_price := price_l10n.prices[0]).model != "PAY_AS_YOU_GO"
):
logger.warning(
"Skipping shape %s due to unexpected cost estimator data",
shape.name,
)
continue

if shape.sub_type.value == "gpu" and shape.gpu_qty is not None:
shape_price = product_price.value * shape.gpu_qty
else:
shape_price = product_price.value * ocpu_product.qty

vcpu = ocpu_product.qty if shape.is_arm_cpu() else ocpu_product.qty * 2

gpu = dict(
gpu_count=shape.gpu_qty or 0,
gpu_name=get_gpu_name(shape.name),
gpu_memory=shape.get_gpu_unit_memory_gb(),
)
if any(gpu.values()) and not all(gpu.values()):
try:
resources = shape_to_resources(shape, products)
except CostEstimatorDataError as e:
logger.warning(
"Skipping shape %s due to incomplete GPU parameters: %s", shape.name, gpu
"Skipping shape %s due to unexpected Cost Estimator data: %s", shape.name, e
)
continue

catalog_item = RawCatalogItem(
instance_name=shape.name,
location=None,
price=shape_price,
cpu=vcpu,
memory=shape.bundle_memory_qty,
**gpu,
price=resources.total_price(),
cpu=resources.cpu.vcpus,
memory=resources.memory.gbs,
gpu_count=resources.gpu.units_count,
gpu_name=resources.gpu.name,
gpu_memory=resources.gpu.unit_memory_gb,
spot=False,
disk_size=None,
)
Expand Down Expand Up @@ -216,6 +192,107 @@ def _get(self, resource: str, ResponseModel: Type[BaseModel]):
return ResponseModel.parse_raw(resp.content)


class CostEstimatorDataError(Exception):
pass


@dataclass
class CPUConfiguration:
vcpus: int
price: float


@dataclass
class MemoryConfiguration:
gbs: int
price: float


@dataclass
class GPUConfiguration:
units_count: int
unit_memory_gb: Optional[float]
name: Optional[str]
price: float

def __post_init__(self):
d = asdict(self)
if any(d.values()) and not all(d.values()):
raise CostEstimatorDataError(f"Incomplete GPU parameters: {self}")


@dataclass
class ResourcesConfiguration:
cpu: CPUConfiguration
memory: MemoryConfiguration
gpu: GPUConfiguration

def total_price(self) -> float:
return self.cpu.price + self.memory.price + self.gpu.price


def shape_to_resources(
shape: CostEstimatorShape, products: CostEstimatorProductList
) -> ResourcesConfiguration:
cpu = None
gpu = GPUConfiguration(units_count=0, unit_memory_gb=None, name=None, price=0.0)
memory = MemoryConfiguration(gbs=shape.bundle_memory_qty, price=0.0)

for product in shape.products:
product_details = products.find(product.part_number)
if product_details is None:
raise CostEstimatorDataError(f"Could not find product {product.part_number!r}")
product_price = get_product_price_usd_per_hour(product_details)

if product.type.value == "ocpu":
vcpus = product.qty if shape.is_arm_cpu() else product.qty * 2
if shape.gpu_qty:
gpu = GPUConfiguration(
units_count=shape.gpu_qty,
unit_memory_gb=shape.get_gpu_unit_memory_gb(),
name=get_gpu_name(shape.name),
price=product_price * shape.gpu_qty,
)
cpu = CPUConfiguration(vcpus=vcpus, price=0.0)
else:
cpu = CPUConfiguration(vcpus=vcpus, price=product_price * product.qty)

elif product.type.value == "memory":
memory = MemoryConfiguration(gbs=product.qty, price=product_price * product.qty)

else:
raise CostEstimatorDataError(f"Unknown product type {product.type.value!r}")

if cpu is None:
raise CostEstimatorDataError(f"No ocpu product")

return ResourcesConfiguration(cpu, memory, gpu)


def get_product_price_usd_per_hour(product: CostEstimatorProduct) -> float:
if product.billing_model != "UCM":
raise CostEstimatorDataError(
f"Billing model for product {product.part_number!r} is {product.billing_model!r}"
)
if product.price_type != "HOUR":
raise CostEstimatorDataError(
f"Price type for product {product.part_number!r} is {product.price_type!r}"
)
price_l10n = product.find_price_l10n("USD")
if price_l10n is None:
raise CostEstimatorDataError(f"No USD price for product {product.part_number!r}")
if len(price_l10n.prices) != 1:
raise CostEstimatorDataError(
f"Product {product.part_number!r} has {len(price_l10n.prices)} USD prices"
)
price = price_l10n.prices[0]
if price.model != "PAY_AS_YOU_GO":
raise CostEstimatorDataError(
f"Pricing model for product {product.part_number!r} is {price.model!r}"
)
return price.value


def get_gpu_name(shape_name: str) -> Optional[str]:
parts = re.split(r"[\.-]", shape_name.upper())

Expand Down
7 changes: 4 additions & 3 deletions src/integrity_tests/test_oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def data_rows(catalog_dir: Path) -> List[dict]:
return list(csv.DictReader(f))


@pytest.mark.parametrize("gpu", ["P100", "V100", "A10", ""])
@pytest.mark.parametrize("gpu", ["P100", "V100", "A10", "A100", "H100", ""])
def test_gpu_present(gpu: str, data_rows: List[dict]):
assert gpu in map(itemgetter("gpu_name"), data_rows)

Expand All @@ -21,8 +21,9 @@ def test_on_demand_present(data_rows: List[dict]):
assert "False" in map(itemgetter("spot"), data_rows)


def test_vm_present(data_rows: List[dict]):
assert any(name.startswith("VM") for name in map(itemgetter("instance_name"), data_rows))
@pytest.mark.parametrize("prefix", ["VM.Standard", "BM.Standard", "VM.GPU", "BM.GPU"])
def test_family_present(prefix: str, data_rows: List[dict]):
assert any(name.startswith(prefix) for name in map(itemgetter("instance_name"), data_rows))


def test_quantity_decreases_as_query_complexity_increases(data_rows: List[dict]):
Expand Down
Loading