diff --git a/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala b/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala index d5b8110d5dc..c6d615cda0f 100644 --- a/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala +++ b/backend/src/main/scala/cromwell/backend/standard/pollmonitoring/PollResultMonitorActor.scala @@ -1,4 +1,5 @@ package cromwell.backend.standard.pollmonitoring + import akka.actor.{Actor, ActorRef} import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} import cromwell.backend.validation.{ @@ -9,6 +10,7 @@ import cromwell.backend.validation.{ ValidatedRuntimeAttributes } import cromwell.core.logging.JobLogger +import cromwell.services.cost.InstantiatedVmInfo import cromwell.services.metadata.CallMetadataKeys import cromwell.services.metrics.bard.BardEventing.BardEventRequest import cromwell.services.metrics.bard.model.TaskSummaryEvent @@ -26,7 +28,7 @@ case class PollMonitorParameters( jobDescriptor: BackendJobDescriptor, validatedRuntimeAttributes: ValidatedRuntimeAttributes, platform: Option[Platform], - logger: Option[JobLogger] + logger: JobLogger ) /** @@ -42,6 +44,9 @@ trait PollResultMonitorActor[PollResultType] extends Actor { // Time that the user VM started spending money. def extractStartTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime] + // Used to kick off a cost calculation + def extractVmInfoFromRunState(pollStatus: PollResultType): Option[InstantiatedVmInfo] + // Time that the user VM stopped spending money. def extractEndTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime] @@ -99,6 +104,7 @@ trait PollResultMonitorActor[PollResultType] extends Actor { Option.empty private var vmStartTime: Option[OffsetDateTime] = Option.empty private var vmEndTime: Option[OffsetDateTime] = Option.empty + protected var vmCostPerHour: Option[BigDecimal] = Option.empty def processPollResult(pollStatus: PollResultType): Unit = { // Make sure jobStartTime remains the earliest event time ever seen @@ -122,8 +128,16 @@ trait PollResultMonitorActor[PollResultType] extends Actor { tellMetadata(Map(CallMetadataKeys.VmEndTime -> end)) } } + // If we don't yet have a cost per hour and we can extract VM info, send a cost request to the catalog service. + // We expect it to reply with an answer, which is handled in receive. + // NB: Due to the nature of async code, we may send a few cost requests before we get a response back. + if (vmCostPerHour.isEmpty) { + extractVmInfoFromRunState(pollStatus).foreach(handleVmCostLookup) + } } + def handleVmCostLookup(vmInfo: InstantiatedVmInfo): Unit + // When a job finishes, the bard actor needs to know about the timing in order to record metrics. // Cost related metadata should already have been handled in processPollResult. def handleAsyncJobFinish(terminalStateName: String): Unit = diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index 50c6a79e358..9da59f16b7b 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -607,9 +607,11 @@ services { } } - CostCatalogService { + // When enabled, Cromwell will store vmCostPerHour metadata for GCP tasks + GcpCostCatalogService { class = "cromwell.services.cost.GcpCostCatalogService" config { + enabled = false catalogExpirySeconds = 86400 } } diff --git a/services/src/main/scala/cromwell/services/cost/CostCatalogConfig.scala b/services/src/main/scala/cromwell/services/cost/CostCatalogConfig.scala index 47b78245920..ae342ea84e5 100644 --- a/services/src/main/scala/cromwell/services/cost/CostCatalogConfig.scala +++ b/services/src/main/scala/cromwell/services/cost/CostCatalogConfig.scala @@ -3,8 +3,9 @@ package cromwell.services.cost import com.typesafe.config.Config import net.ceedubs.ficus.Ficus._ -final case class CostCatalogConfig(catalogExpirySeconds: Int) +final case class CostCatalogConfig(enabled: Boolean, catalogExpirySeconds: Int) object CostCatalogConfig { - def apply(config: Config): CostCatalogConfig = CostCatalogConfig(config.as[Int]("catalogExpirySeconds")) + def apply(config: Config): CostCatalogConfig = + CostCatalogConfig(config.as[Boolean]("enabled"), config.as[Int]("catalogExpirySeconds")) } diff --git a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala index 35012b7e424..4c715cd3628 100644 --- a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala +++ b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogService.scala @@ -1,24 +1,76 @@ package cromwell.services.cost import akka.actor.{Actor, ActorRef} +import cats.implicits.catsSyntaxValidatedId +import com.google.`type`.Money import com.google.cloud.billing.v1._ import com.typesafe.config.Config import com.typesafe.scalalogging.LazyLogging import common.util.StringUtil.EnhancedToStringable +import common.validation.ErrorOr._ +import common.validation.ErrorOr.ErrorOr +import cromwell.services.ServiceRegistryActor.ServiceRegistryMessage import cromwell.services.cost.GcpCostCatalogService.{COMPUTE_ENGINE_SERVICE_NAME, DEFAULT_CURRENCY_CODE} import cromwell.util.GracefulShutdownHelper.ShutdownCommand import java.time.{Duration, Instant} import scala.jdk.CollectionConverters.IterableHasAsScala import java.time.temporal.ChronoUnit.SECONDS +import scala.util.Using -case class CostCatalogKey(machineType: Option[MachineType], - usageType: Option[UsageType], - machineCustomization: Option[MachineCustomization], - resourceGroup: Option[ResourceGroup] +case class CostCatalogKey(machineType: MachineType, + usageType: UsageType, + machineCustomization: MachineCustomization, + resourceType: ResourceType, + region: String ) + +object CostCatalogKey { + + // Specifically support only the SKUs that we know we can use. This is brittle and I hate it, but the more structured + // fields available in the SKU don't give us enough information without relying on the human-readable descriptions. + // + // N1: We usually use custom machines but SKUs are only available for predefined; we'll fall back to these SKUs. + // N2 and N2D: We only use custom machines. + + // Use this regex to filter down to just the SKUs we are interested in. + // NB: This should be updated if we add new machine types or the cost catalog descriptions change + final val expectedSku = + (".*?N1 Predefined Instance (Core|Ram) .*|" + + ".*?N2 Custom Instance (Core|Ram) .*|" + + ".*?N2D AMD Custom Instance (Core|Ram) .*").r + + def apply(sku: Sku): List[CostCatalogKey] = + for { + _ <- expectedSku.findFirstIn(sku.getDescription).toList + machineType <- MachineType.fromSku(sku).toList + resourceType <- ResourceType.fromSku(sku).toList + usageType <- UsageType.fromSku(sku).toList + machineCustomization <- MachineCustomization.fromSku(sku).toList + region <- sku.getServiceRegionsList.asScala.toList + } yield CostCatalogKey(machineType, usageType, machineCustomization, resourceType, region) + + def apply(instantiatedVmInfo: InstantiatedVmInfo, resourceType: ResourceType): ErrorOr[CostCatalogKey] = + MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType).map { mType => + CostCatalogKey( + mType, + UsageType.fromBoolean(instantiatedVmInfo.preemptible), + MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType), + resourceType, + instantiatedVmInfo.region + ) + } +} + +case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) extends ServiceRegistryMessage { + override def serviceName: String = "GcpCostCatalogService" +} +case class GcpCostLookupResponse(vmInfo: InstantiatedVmInfo, calculatedCost: ErrorOr[BigDecimal]) case class CostCatalogValue(catalogObject: Sku) case class ExpiringGcpCostCatalog(catalog: Map[CostCatalogKey, CostCatalogValue], fetchTime: Instant) +object ExpiringGcpCostCatalog { + def empty: ExpiringGcpCostCatalog = ExpiringGcpCostCatalog(Map.empty, Instant.MIN) +} object GcpCostCatalogService { // Can be gleaned by using googleClient.listServices @@ -26,6 +78,44 @@ object GcpCostCatalogService { // ISO 4217 https://developers.google.com/adsense/management/appendix/currencies private val DEFAULT_CURRENCY_CODE = "USD" + + def getMostRecentPricingInfo(sku: Sku): PricingInfo = { + val mostRecentPricingInfoIndex = sku.getPricingInfoCount - 1 + sku.getPricingInfo(mostRecentPricingInfoIndex) + } + + // See: https://cloud.google.com/billing/v1/how-tos/catalog-api + def calculateCpuPricePerHour(cpuSku: Sku, coreCount: Int): ErrorOr[BigDecimal] = { + val pricingInfo = getMostRecentPricingInfo(cpuSku) + val usageUnit = pricingInfo.getPricingExpression.getUsageUnit + if (usageUnit == "h") { + // Price per hour of a single core + // NB: Ignoring "TieredRates" here (the idea that stuff gets cheaper the more you use). + // Technically, we should write code that determines which tier(s) to use. + // In practice, from what I've seen, CPU cores and RAM don't have more than a single tier. + val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice + val costPerCorePerHour: BigDecimal = + costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal + val result = costPerCorePerHour * coreCount + result.validNel + } else { + s"Expected usage units of CPUs to be 'h'. Got ${usageUnit}".invalidNel + } + } + + def calculateRamPricePerHour(ramSku: Sku, ramGbCount: Double): ErrorOr[BigDecimal] = { + val pricingInfo = getMostRecentPricingInfo(ramSku) + val usageUnit = pricingInfo.getPricingExpression.getUsageUnit + if (usageUnit == "GiBy.h") { + val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice + val costPerGbHour: BigDecimal = + costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal + val result = costPerGbHour * ramGbCount + result.validNel + } else { + s"Expected usage units of RAM to be 'GiBy.h'. Got ${usageUnit}".invalidNel + } + } } /** @@ -36,37 +126,38 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service extends Actor with LazyLogging { - private val maxCatalogLifetime: Duration = - Duration.of(CostCatalogConfig(serviceConfig).catalogExpirySeconds.longValue, SECONDS) + private val costCatalogConfig = CostCatalogConfig(serviceConfig) - private var googleClient: Option[CloudCatalogClient] = Option.empty + private val maxCatalogLifetime: Duration = + Duration.of(costCatalogConfig.catalogExpirySeconds.longValue, SECONDS) // Cached catalog. Refreshed lazily when older than maxCatalogLifetime. - private var costCatalog: Option[ExpiringGcpCostCatalog] = Option.empty + private var costCatalog: ExpiringGcpCostCatalog = ExpiringGcpCostCatalog.empty /** * Returns the SKU for a given key, if it exists */ def getSku(key: CostCatalogKey): Option[CostCatalogValue] = getOrFetchCachedCatalog().get(key) - protected def fetchNewCatalog: Iterable[Sku] = { - if (googleClient.isEmpty) { - // We use option rather than lazy here so that the client isn't created when it is told to shutdown (see receive override) - googleClient = Some(CloudCatalogClient.create) + protected def fetchSkuIterable(googleClient: CloudCatalogClient): Iterable[Sku] = + makeInitialWebRequest(googleClient).iterateAll().asScala + + private def fetchNewCatalog: ExpiringGcpCostCatalog = + Using.resource(CloudCatalogClient.create) { googleClient => + val skus = fetchSkuIterable(googleClient) + ExpiringGcpCostCatalog(processCostCatalog(skus), Instant.now()) } - makeInitialWebRequest(googleClient.get).iterateAll().asScala - } - def getCatalogAge: Duration = - Duration.between(costCatalog.map(c => c.fetchTime).getOrElse(Instant.ofEpochMilli(0)), Instant.now()) - private def isCurrentCatalogExpired: Boolean = getCatalogAge.toNanos > maxCatalogLifetime.toNanos + def getCatalogAge: Duration = Duration.between(costCatalog.fetchTime, Instant.now()) + + private def isCurrentCatalogExpired: Boolean = getCatalogAge.toSeconds > maxCatalogLifetime.toSeconds private def getOrFetchCachedCatalog(): Map[CostCatalogKey, CostCatalogValue] = { - if (costCatalog.isEmpty || isCurrentCatalogExpired) { + if (isCurrentCatalogExpired) { logger.info("Fetching a new GCP public cost catalog.") - costCatalog = Some(ExpiringGcpCostCatalog(processCostCatalog(fetchNewCatalog), Instant.now())) + costCatalog = fetchNewCatalog } - costCatalog.map(expiringCatalog => expiringCatalog.catalog).getOrElse(Map.empty) + costCatalog.catalog } /** @@ -88,23 +179,63 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service * Ideally, we don't want to have an entire, unprocessed, cost catalog in memory at once since it's ~20MB. */ private def processCostCatalog(skus: Iterable[Sku]): Map[CostCatalogKey, CostCatalogValue] = - // TODO: Account for key collisions (same key can be in multiple regions) - // TODO: reduce memory footprint of returned map (don't store entire SKU object) skus.foldLeft(Map.empty[CostCatalogKey, CostCatalogValue]) { case (acc, sku) => - acc + convertSkuToKeyValuePair(sku) + val keys = CostCatalogKey(sku) + + // We expect that every cost catalog key is unique, but changes to the SKUs returned by Google may + // break this assumption. Check and log an error if we find collisions. + val collisions = keys.flatMap(acc.get(_).toList).map(_.catalogObject.getDescription) + if (collisions.nonEmpty) + logger.error( + s"Found SKU key collision when adding ${sku.getDescription}, collides with ${collisions.mkString(", ")}" + ) + + acc ++ keys.map(k => (k, CostCatalogValue(sku))) } - private def convertSkuToKeyValuePair(sku: Sku): (CostCatalogKey, CostCatalogValue) = CostCatalogKey( - machineType = MachineType.fromSku(sku), - usageType = UsageType.fromSku(sku), - machineCustomization = MachineCustomization.fromSku(sku), - resourceGroup = ResourceGroup.fromSku(sku) - ) -> CostCatalogValue(sku) + def lookUpSku(instantiatedVmInfo: InstantiatedVmInfo, resourceType: ResourceType): ErrorOr[Sku] = + CostCatalogKey(instantiatedVmInfo, resourceType).flatMap { key => + // As of Sept 2024 the cost catalog does not contain entries for custom N1 machines. If we're using N1, attempt + // to fall back to predefined. + lazy val n1PredefinedKey = + (key.machineType, key.machineCustomization) match { + case (N1, Custom) => Option(key.copy(machineCustomization = Predefined)) + case _ => None + } + val sku = getSku(key).orElse(n1PredefinedKey.flatMap(getSku)).map(_.catalogObject) + sku match { + case Some(sku) => sku.validNel + case None => s"Failed to look up ${resourceType} SKU for ${instantiatedVmInfo}".invalidNel + } + } + + // TODO consider caching this, answers won't change until we reload the SKUs + def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): ErrorOr[BigDecimal] = + for { + cpuSku <- lookUpSku(instantiatedVmInfo, Cpu) + coreCount <- MachineType.extractCoreCountFromMachineTypeString(instantiatedVmInfo.machineType) + cpuPricePerHour <- GcpCostCatalogService.calculateCpuPricePerHour(cpuSku, coreCount) + ramSku <- lookUpSku(instantiatedVmInfo, Ram) + ramMbCount <- MachineType.extractRamMbFromMachineTypeString(instantiatedVmInfo.machineType) + ramGbCount = ramMbCount / 1024d // need sub-integer resolution + ramPricePerHour <- GcpCostCatalogService.calculateRamPricePerHour(ramSku, ramGbCount) + totalCost = cpuPricePerHour + ramPricePerHour + _ = logger.info( + s"Calculated vmCostPerHour of ${totalCost} " + + s"(CPU ${cpuPricePerHour} for ${coreCount} cores [${cpuSku.getDescription}], " + + s"RAM ${ramPricePerHour} for ${ramGbCount} Gb [${ramSku.getDescription}]) " + + s"for ${instantiatedVmInfo}" + ) + } yield totalCost def serviceRegistryActor: ActorRef = serviceRegistry override def receive: Receive = { + case GcpCostLookupRequest(vmInfo, replyTo) if costCatalogConfig.enabled => + val calculatedCost = calculateVmCostPerHour(vmInfo) + val response = GcpCostLookupResponse(vmInfo, calculatedCost) + replyTo ! response + case GcpCostLookupRequest(_, _) => // do nothing if we're disabled case ShutdownCommand => - googleClient.foreach(client => client.shutdownNow()) context stop self case other => logger.error( diff --git a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala index 7507560c810..d189de43f1b 100644 --- a/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala +++ b/services/src/main/scala/cromwell/services/cost/GcpCostCatalogTypes.scala @@ -1,58 +1,124 @@ package cromwell.services.cost +import cats.implicits.catsSyntaxValidatedId import com.google.cloud.billing.v1.Sku +import common.validation.ErrorOr.ErrorOr + +import java.util.regex.{Matcher, Pattern} + +/* + * Case class that contains information retrieved from Google about a VM that cromwell has started + */ +case class InstantiatedVmInfo(region: String, machineType: String, preemptible: Boolean) /* * These types reflect hardcoded strings found in a google cost catalog. */ + +sealed trait MachineType { def machineTypeName: String } +case object N1 extends MachineType { override val machineTypeName = "n1" } +case object N2 extends MachineType { override val machineTypeName = "n2" } +case object N2d extends MachineType { override val machineTypeName = "n2d" } + object MachineType { def fromSku(sku: Sku): Option[MachineType] = { - val tokenizedDescription = sku.getDescription.split(" ") + val tokenizedDescription = sku.getDescription.toLowerCase.split(" ") if (tokenizedDescription.contains(N1.machineTypeName)) Some(N1) else if (tokenizedDescription.contains(N2.machineTypeName)) Some(N2) else if (tokenizedDescription.contains(N2d.machineTypeName)) Some(N2d) else Option.empty } + + // expects a string that looks something like "n1-standard-1" or "custom-1-4096" + def fromGoogleMachineTypeString(machineTypeString: String): ErrorOr[MachineType] = { + val mType = machineTypeString.toLowerCase + if (mType.startsWith("n1-")) N1.validNel + else if (mType.startsWith("n2d-")) N2d.validNel + else if (mType.startsWith("n2-")) N2.validNel + else if (mType.startsWith("custom-")) N1.validNel // by convention + else s"Unrecognized machine type: $machineTypeString".invalidNel + } + + def extractCoreCountFromMachineTypeString(machineTypeString: String): ErrorOr[Int] = { + // Regex to capture second-to-last hyphen-delimited token as number + val pattern: Pattern = Pattern.compile("-(\\d+)-[^-]+$") + val matcher: Matcher = pattern.matcher(machineTypeString) + if (matcher.find()) { + matcher.group(1).toInt.validNel + } else { + s"Could not extract core count from ${machineTypeString}".invalidNel + } + } + def extractRamMbFromMachineTypeString(machineTypeString: String): ErrorOr[Int] = { + // Regular expression to match the number after a hyphen at the end of the string + val pattern: Pattern = Pattern.compile("-(\\d+)$") + val matcher: Matcher = pattern.matcher(machineTypeString); + if (matcher.find()) { + matcher.group(1).toInt.validNel + } else { + s"Could not extract Ram MB count from ${machineTypeString}".invalidNel + } + } } -sealed trait MachineType { def machineTypeName: String } -case object N1 extends MachineType { override val machineTypeName = "N1" } -case object N2 extends MachineType { override val machineTypeName = "N2" } -case object N2d extends MachineType { override val machineTypeName = "N2D" } + +sealed trait UsageType { def typeName: String } +case object OnDemand extends UsageType { override val typeName = "ondemand" } +case object Preemptible extends UsageType { override val typeName = "preemptible" } object UsageType { def fromSku(sku: Sku): Option[UsageType] = - sku.getCategory.getUsageType match { + sku.getCategory.getUsageType.toLowerCase match { case OnDemand.typeName => Some(OnDemand) case Preemptible.typeName => Some(Preemptible) case _ => Option.empty } + def fromBoolean(isPreemptible: Boolean): UsageType = isPreemptible match { + case true => Preemptible + case false => OnDemand + } + } -sealed trait UsageType { def typeName: String } -case object OnDemand extends UsageType { override val typeName = "OnDemand" } -case object Preemptible extends UsageType { override val typeName = "Preemptible" } + +sealed trait MachineCustomization { def customizationName: String } +case object Custom extends MachineCustomization { override val customizationName = "custom" } +case object Predefined extends MachineCustomization { override val customizationName = "predefined" } object MachineCustomization { + def fromMachineTypeString(machineTypeString: String): MachineCustomization = + if (machineTypeString.toLowerCase.contains("custom")) Custom else Predefined + + /* + The cost catalog is annoyingly inconsistent and unstructured in this area. + - For N1 machines, only predefined SKUs are included, and they have "Predefined" in their description strings. + We will eventually fall back to using these SKUs for custom machines, but accurately represent them as Predefined here. + - For non-N1 machines, both custom and predefined SKUs are included, custom ones include "Custom" in their description + strings and predefined SKUs are only identifiable by the absence of "Custom." + */ def fromSku(sku: Sku): Option[MachineCustomization] = { - val tokenizedDescription = sku.getDescription.split(" ") + val tokenizedDescription = sku.getDescription.toLowerCase.split(" ") + + // ex. "N1 Predefined Instance Core running in Montreal" if (tokenizedDescription.contains(Predefined.customizationName)) Some(Predefined) + // ex. "N2 Custom Instance Core running in Paris" else if (tokenizedDescription.contains(Custom.customizationName)) Some(Custom) - else Option.empty + // ex. "N2 Instance Core running in Paris" + else Some(Predefined) } } -sealed trait MachineCustomization { def customizationName: String } -case object Custom extends MachineCustomization { override val customizationName = "Custom" } -case object Predefined extends MachineCustomization { override val customizationName = "Predefined" } -object ResourceGroup { - def fromSku(sku: Sku): Option[ResourceGroup] = - sku.getCategory.getResourceGroup match { +sealed trait ResourceType { def groupName: String } +case object Cpu extends ResourceType { override val groupName = "cpu" } +case object Ram extends ResourceType { override val groupName = "ram" } + +object ResourceType { + def fromSku(sku: Sku): Option[ResourceType] = { + val tokenizedDescription = sku.getDescription.toLowerCase.split(" ") + sku.getCategory.getResourceGroup.toLowerCase match { case Cpu.groupName => Some(Cpu) case Ram.groupName => Some(Ram) - case N1Standard.groupName => Some(N1Standard) + case "n1standard" if tokenizedDescription.contains("ram") => Some(Ram) + case "n1standard" if tokenizedDescription.contains("core") => Some(Cpu) case _ => Option.empty } + } } -sealed trait ResourceGroup { def groupName: String } -case object Cpu extends ResourceGroup { override val groupName = "CPU" } -case object Ram extends ResourceGroup { override val groupName = "RAM" } -case object N1Standard extends ResourceGroup { override val groupName = "N1Standard" } diff --git a/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala b/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala index 19d9e505ac6..9973aaa4082 100644 --- a/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala +++ b/services/src/test/scala/cromwell/services/cost/GcpCostCatalogServiceSpec.scala @@ -2,21 +2,25 @@ package cromwell.services.cost import akka.actor.ActorRef import akka.testkit.{ImplicitSender, TestActorRef, TestProbe} -import com.google.cloud.billing.v1.Sku +import com.google.cloud.billing.v1.{CloudCatalogClient, Sku} import com.typesafe.config.{Config, ConfigFactory} import cromwell.core.TestKitSuite import cromwell.util.GracefulShutdownHelper.ShutdownCommand import org.scalatest.concurrent.Eventually import org.scalatest.flatspec.AnyFlatSpecLike import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks import java.time.Duration import java.io.{File, FileInputStream, FileOutputStream} import scala.collection.mutable.ListBuffer +import scala.util.Using object GcpCostCatalogServiceSpec { val catalogExpirySeconds: Long = 1 // Short duration so we can do a cache expiry test - val config: Config = ConfigFactory.parseString(s"catalogExpirySeconds = $catalogExpirySeconds") + val config: Config = ConfigFactory.parseString( + s"catalogExpirySeconds = $catalogExpirySeconds, enabled = true" + ) val mockTestDataFilePath: String = "services/src/test/scala/cromwell/services/cost/serializedSkuList.testData" } class GcpCostCatalogServiceTestActor(serviceConfig: Config, globalConfig: Config, serviceRegistry: ActorRef) @@ -34,18 +38,21 @@ class GcpCostCatalogServiceTestActor(serviceConfig: Config, globalConfig: Config fis.close() skuList.toSeq } - def saveMockData(): Unit = { - val fetchedData = super.fetchNewCatalog - val fos = new FileOutputStream(new File(GcpCostCatalogServiceSpec.mockTestDataFilePath)) - fetchedData.foreach { sku => - sku.writeDelimitedTo(fos) + def saveMockData(): Unit = + Using.resources(CloudCatalogClient.create, + new FileOutputStream(new File(GcpCostCatalogServiceSpec.mockTestDataFilePath)) + ) { (googleClient, fos) => + val skus = super.fetchSkuIterable(googleClient) + skus.foreach { sku => + sku.writeDelimitedTo(fos) + } } - fos.close() - } + override def receive: Receive = { case ShutdownCommand => context stop self } - override def fetchNewCatalog: Iterable[Sku] = loadMockData + + override def fetchSkuIterable(client: CloudCatalogClient): Iterable[Sku] = loadMockData } class GcpCostCatalogServiceSpec @@ -53,7 +60,8 @@ class GcpCostCatalogServiceSpec with AnyFlatSpecLike with Matchers with Eventually - with ImplicitSender { + with ImplicitSender + with TableDrivenPropertyChecks { behavior of "CostCatalogService" def constructTestActor: GcpCostCatalogServiceTestActor = @@ -72,10 +80,11 @@ class GcpCostCatalogServiceSpec it should "cache catalogs properly" in { val testLookupKey = CostCatalogKey( - machineType = Some(N2), - usageType = Some(Preemptible), - machineCustomization = Some(Predefined), - resourceGroup = Some(Cpu) + machineType = N2, + usageType = Preemptible, + machineCustomization = Predefined, + resourceType = Cpu, + region = "europe-west9" ) val freshActor = constructTestActor @@ -87,7 +96,7 @@ class GcpCostCatalogServiceSpec freshActor.getCatalogAge.toNanos should (be < shortDuration.toNanos) // Simulate the cached catalog living longer than its lifetime - Thread.sleep(shortDuration.toMillis) + Thread.sleep(shortDuration.plus(shortDuration).toMillis) // Confirm that the catalog is old freshActor.getCatalogAge.toNanos should (be > shortDuration.toNanos) @@ -98,14 +107,198 @@ class GcpCostCatalogServiceSpec freshActor.getCatalogAge.toNanos should (be < shortDuration.toNanos) } - it should "contain an expected SKU" in { - val expectedKey = CostCatalogKey( - machineType = Some(N2d), - usageType = Some(Preemptible), - machineCustomization = None, - resourceGroup = Some(Ram) + it should "find CPU and RAM skus for all supported machine types" in { + val lookupRows = Table( + ("machineType", "usage", "customization", "resource", "region", "exists"), + (N1, Preemptible, Predefined, Cpu, "us-west1", true), + (N1, Preemptible, Predefined, Ram, "us-west1", true), + (N1, OnDemand, Predefined, Cpu, "us-west1", true), + (N1, OnDemand, Predefined, Ram, "us-west1", true), + (N1, Preemptible, Custom, Cpu, "us-west1", false), + (N1, Preemptible, Custom, Ram, "us-west1", false), + (N1, OnDemand, Custom, Cpu, "us-west1", false), + (N1, OnDemand, Custom, Ram, "us-west1", false), + (N2, Preemptible, Predefined, Cpu, "us-west1", false), + (N2, Preemptible, Predefined, Ram, "us-west1", false), + (N2, OnDemand, Predefined, Cpu, "us-west1", false), + (N2, OnDemand, Predefined, Ram, "us-west1", false), + (N2, Preemptible, Custom, Cpu, "us-west1", true), + (N2, Preemptible, Custom, Ram, "us-west1", true), + (N2, OnDemand, Custom, Cpu, "us-west1", true), + (N2, OnDemand, Custom, Ram, "us-west1", true), + (N2d, Preemptible, Predefined, Cpu, "us-west1", false), + (N2d, Preemptible, Predefined, Ram, "us-west1", false), + (N2d, OnDemand, Predefined, Cpu, "us-west1", false), + (N2d, OnDemand, Predefined, Ram, "us-west1", false), + (N2d, Preemptible, Custom, Cpu, "us-west1", true), + (N2d, Preemptible, Custom, Ram, "us-west1", true), + (N2d, OnDemand, Custom, Cpu, "us-west1", true), + (N2d, OnDemand, Custom, Ram, "us-west1", true) + ) + + forAll(lookupRows) { case (machineType, usage, customization, resource, region, exists: Boolean) => + val key = CostCatalogKey(machineType, usage, customization, resource, region) + val result = testActorRef.getSku(key) + result.nonEmpty shouldBe exists + } + } + + it should "find the skus for a VM when appropriate" in { + val lookupRows = Table( + ("instantiatedVmInfo", "resource", "skuDescription"), + (InstantiatedVmInfo("europe-west9", "custom-16-32768", false), + Cpu, + "N1 Predefined Instance Core running in Paris" + ), + (InstantiatedVmInfo("europe-west9", "custom-16-32768", false), + Ram, + "N1 Predefined Instance Ram running in Paris" + ), + (InstantiatedVmInfo("us-central1", "custom-4-4096", true), + Cpu, + "Spot Preemptible N1 Predefined Instance Core running in Americas" + ), + (InstantiatedVmInfo("us-central1", "custom-4-4096", true), + Ram, + "Spot Preemptible N1 Predefined Instance Ram running in Americas" + ), + (InstantiatedVmInfo("europe-west9", "n1-custom-16-32768", false), + Cpu, + "N1 Predefined Instance Core running in Paris" + ), + (InstantiatedVmInfo("europe-west9", "n1-custom-16-32768", false), + Ram, + "N1 Predefined Instance Ram running in Paris" + ), + (InstantiatedVmInfo("us-central1", "n1-custom-4-4096", true), + Cpu, + "Spot Preemptible N1 Predefined Instance Core running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n1-custom-4-4096", true), + Ram, + "Spot Preemptible N1 Predefined Instance Ram running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", true), + Cpu, + "Spot Preemptible N2 Custom Instance Core running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", true), + Ram, + "Spot Preemptible N2 Custom Instance Ram running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", false), + Cpu, + "N2 Custom Instance Core running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", false), Ram, "N2 Custom Instance Ram running in Americas"), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", true), + Cpu, + "Spot Preemptible N2D AMD Custom Instance Core running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", true), + Ram, + "Spot Preemptible N2D AMD Custom Instance Ram running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", false), + Cpu, + "N2D AMD Custom Instance Core running in Americas" + ), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", false), + Ram, + "N2D AMD Custom Instance Ram running in Americas" + ) + ) + + forAll(lookupRows) { case (instantiatedVmInfo: InstantiatedVmInfo, resource: ResourceType, expectedSku: String) => + val skuOr = testActorRef.lookUpSku(instantiatedVmInfo, resource) + skuOr.isValid shouldBe true + skuOr.map(sku => sku.getDescription shouldEqual expectedSku) + } + } + + it should "fail to find the skus for a VM when appropriate" in { + val lookupRows = Table( + ("instantiatedVmInfo", "resource", "errors"), + (InstantiatedVmInfo("us-central1", "custooooooom-4-4096", true), + Cpu, + List("Unrecognized machine type: custooooooom-4-4096") + ), + (InstantiatedVmInfo("us-central1", "n2custom-4-4096", true), + Cpu, + List("Unrecognized machine type: n2custom-4-4096") + ), + (InstantiatedVmInfo("us-central1", "standard-4-4096", true), + Cpu, + List("Unrecognized machine type: standard-4-4096") + ), + (InstantiatedVmInfo("planet-mars1", "custom-4-4096", true), + Cpu, + List("Failed to look up Cpu SKU for InstantiatedVmInfo(planet-mars1,custom-4-4096,true)") + ) ) - val foundValue = testActorRef.getSku(expectedKey) - foundValue.get.catalogObject.getDescription shouldBe "Spot Preemptible N2D AMD Instance Ram running in Johannesburg" + + forAll(lookupRows) { + case (instantiatedVmInfo: InstantiatedVmInfo, resource: ResourceType, expectedErrors: List[String]) => + val skuOr = testActorRef.lookUpSku(instantiatedVmInfo, resource) + skuOr.isValid shouldBe false + skuOr.leftMap(errors => errors.toList shouldEqual expectedErrors) + } + } + + it should "calculate the cost per hour for a VM" in { + // Create BigDecimals from strings to avoid inequality due to floating point shenanigans + val lookupRows = Table( + ("instantiatedVmInfo", "costPerHour"), + (InstantiatedVmInfo("us-central1", "custom-4-4096", true), BigDecimal(".361")), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", true), BigDecimal(".42544000000000004")), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", true), BigDecimal(".2371600000000000024")), + (InstantiatedVmInfo("us-central1", "custom-4-4096", false), BigDecimal("1.43392")), + (InstantiatedVmInfo("us-central1", "n2-custom-4-4096", false), BigDecimal("1.50561600000000012")), + (InstantiatedVmInfo("us-central1", "n2d-custom-4-4096", false), BigDecimal("1.309896")), + (InstantiatedVmInfo("europe-west9", "custom-4-4096", true), BigDecimal(".3501808")), + (InstantiatedVmInfo("europe-west9", "n2-custom-4-4096", true), BigDecimal("0.49532")), + (InstantiatedVmInfo("europe-west9", "n2d-custom-4-4096", true), BigDecimal("0.30608")), + (InstantiatedVmInfo("europe-west9", "custom-4-4096", false), BigDecimal("1.663347200000000016")), + (InstantiatedVmInfo("europe-west9", "n2-custom-4-4352", false), BigDecimal("1.75941630500000012")), + (InstantiatedVmInfo("europe-west9", "n2d-custom-4-4096", false), BigDecimal("1.51947952")) + ) + + forAll(lookupRows) { case (instantiatedVmInfo: InstantiatedVmInfo, expectedCostPerHour: BigDecimal) => + val costOr = testActorRef.calculateVmCostPerHour(instantiatedVmInfo) + costOr.isValid shouldBe true + costOr.map(cost => cost shouldEqual expectedCostPerHour) + } + } + + it should "fail to calculate the cost oer hour for a VM" in { + + val lookupRows = Table( + ("instantiatedVmInfo", "errors"), + (InstantiatedVmInfo("us-central1", "custooooooom-4-4096", true), + List("Unrecognized machine type: custooooooom-4-4096") + ), + (InstantiatedVmInfo("us-central1", "n2_custom_4_4096", true), + List("Unrecognized machine type: n2_custom_4_4096") + ), + (InstantiatedVmInfo("us-central1", "custom-foo-4096", true), + List("Could not extract core count from custom-foo-4096") + ), + (InstantiatedVmInfo("us-central1", "custom-16-bar", true), + List("Could not extract Ram MB count from custom-16-bar") + ), + (InstantiatedVmInfo("us-central1", "123-456-789", true), List("Unrecognized machine type: 123-456-789")), + (InstantiatedVmInfo("us-central1", "n2-16-4096", true), + List("Failed to look up Cpu SKU for InstantiatedVmInfo(us-central1,n2-16-4096,true)") + ), + (InstantiatedVmInfo("planet-mars1", "n2-custom-4-4096", true), + List("Failed to look up Cpu SKU for InstantiatedVmInfo(planet-mars1,n2-custom-4-4096,true)") + ) + ) + + forAll(lookupRows) { case (instantiatedVmInfo: InstantiatedVmInfo, expectedErrors: List[String]) => + val costOr = testActorRef.calculateVmCostPerHour(instantiatedVmInfo) + costOr.isValid shouldBe false + costOr.leftMap(errors => errors.toList shouldEqual expectedErrors) + } } } diff --git a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala index 8b05bf4057b..0f5c00fa834 100644 --- a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala +++ b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala @@ -1,6 +1,7 @@ package cromwell.backend.google.batch.actors import akka.actor.{ActorRef, Props} +import cats.data.Validated.{Invalid, Valid} import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} import cromwell.backend.google.batch.models.RunStatus import cromwell.backend.standard.pollmonitoring.{ @@ -12,6 +13,7 @@ import cromwell.backend.standard.pollmonitoring.{ } import cromwell.backend.validation.ValidatedRuntimeAttributes import cromwell.core.logging.JobLogger +import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo} import cromwell.services.metadata.CallMetadataKeys import java.time.OffsetDateTime @@ -25,13 +27,7 @@ object BatchPollResultMonitorActor { logger: JobLogger ): Props = Props( new BatchPollResultMonitorActor( - PollMonitorParameters(serviceRegistry, - workflowDescriptor, - jobDescriptor, - runtimeAttributes, - platform, - Option(logger) - ) + PollMonitorParameters(serviceRegistry, workflowDescriptor, jobDescriptor, runtimeAttributes, platform, logger) ) ) } @@ -51,31 +47,51 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters) case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime } + override def handleVmCostLookup(vmInfo: InstantiatedVmInfo) = { + val request = GcpCostLookupRequest(vmInfo, self) + params.serviceRegistry ! request + } + + def handleCostResponse(costLookupResponse: GcpCostLookupResponse): Unit = + if (vmCostPerHour.isEmpty) { // Optimization to avoid processing responses after we've received a valid one. + val cost = costLookupResponse.calculatedCost match { + case Valid(c) => + params.logger.info(s"vmCostPerHour for ${costLookupResponse.vmInfo} is ${c}") + c + case Invalid(errors) => + params.logger.error( + s"Failed to calculate VM cost per hour for ${costLookupResponse.vmInfo}. ${errors.toList.mkString(", ")}" + ) + BigDecimal(-1) + } + vmCostPerHour = Option(cost) + tellMetadata(Map(CallMetadataKeys.VmCostPerHour -> cost)) + } + override def receive: Receive = { case message: PollResultMessage => message match { case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult) case ProcessThisPollResult(result) => - params.logger.foreach(logger => - logger.error( - s"Programmer error: Received Poll Result of unknown type. Expected ${RunStatus.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." - ) + params.logger.error( + s"Programmer error: Received Poll Result of unknown type. Expected ${RunStatus.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." ) + case AsyncJobHasFinished(pollResult: RunStatus) => handleAsyncJobFinish(pollResult.getClass.getSimpleName) case AsyncJobHasFinished(result) => - params.logger.foreach(logger => - logger.error( - s"Programmer error: Received Poll Result of unknown type. Expected ${AsyncJobHasFinished.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." - ) + params.logger.error( + s"Programmer error: Received Poll Result of unknown type. Expected ${AsyncJobHasFinished.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." ) + } case _ => - params.logger.foreach(logger => - logger.error( - s"Programmer error: Cost Helper received message of type other than CostPollingMessage" - ) + params.logger.error( + s"Programmer error: Cost Helper received message of type other than CostPollingMessage" ) + } override def params: PollMonitorParameters = pollMonitorParameters + + override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty // TODO } diff --git a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala index 597c0ed8d35..3e9d55b9e1b 100644 --- a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala +++ b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PapiPollResultMonitorActor.scala @@ -1,17 +1,13 @@ package cromwell.backend.google.pipelines.common import akka.actor.{ActorRef, Props} -import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} +import cats.data.Validated.{Invalid, Valid} import cromwell.backend.google.pipelines.common.api.RunStatus -import cromwell.backend.standard.pollmonitoring.{ - AsyncJobHasFinished, - PollMonitorParameters, - PollResultMessage, - PollResultMonitorActor, - ProcessThisPollResult -} +import cromwell.backend.standard.pollmonitoring._ import cromwell.backend.validation.ValidatedRuntimeAttributes +import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform} import cromwell.core.logging.JobLogger +import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo} import cromwell.services.metadata.CallMetadataKeys import java.time.OffsetDateTime @@ -25,13 +21,7 @@ object PapiPollResultMonitorActor { logger: JobLogger ): Props = Props( new PapiPollResultMonitorActor( - PollMonitorParameters(serviceRegistry, - workflowDescriptor, - jobDescriptor, - runtimeAttributes, - platform, - Option(logger) - ) + PollMonitorParameters(serviceRegistry, workflowDescriptor, jobDescriptor, runtimeAttributes, platform, logger) ) ) } @@ -40,6 +30,7 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll override def extractEarliestEventTimeFromRunState(pollStatus: RunStatus): Option[OffsetDateTime] = pollStatus.eventList.minByOption(_.offsetDateTime).map(e => e.offsetDateTime) + override def extractStartTimeFromRunState(pollStatus: RunStatus): Option[OffsetDateTime] = pollStatus.eventList.collectFirst { case event if event.name == CallMetadataKeys.VmStartTime => event.offsetDateTime @@ -50,30 +41,51 @@ class PapiPollResultMonitorActor(parameters: PollMonitorParameters) extends Poll case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime } + override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = + pollStatus.instantiatedVmInfo + + override def handleVmCostLookup(vmInfo: InstantiatedVmInfo) = { + val request = GcpCostLookupRequest(vmInfo, self) + params.serviceRegistry ! request + } + + override def params: PollMonitorParameters = parameters + + def handleCostResponse(costLookupResponse: GcpCostLookupResponse): Unit = + if (vmCostPerHour.isEmpty) { // Optimization to avoid processing responses after we've received a valid one. + val cost = costLookupResponse.calculatedCost match { + case Valid(c) => + params.logger.info(s"vmCostPerHour for ${costLookupResponse.vmInfo} is ${c}") + c + case Invalid(errors) => + params.logger.error( + s"Failed to calculate VM cost per hour for ${costLookupResponse.vmInfo}. ${errors.toList.mkString(", ")}" + ) + BigDecimal(-1) + } + vmCostPerHour = Option(cost) + tellMetadata(Map(CallMetadataKeys.VmCostPerHour -> cost)) + } + override def receive: Receive = { + case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse) case message: PollResultMessage => message match { case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult) case ProcessThisPollResult(result) => - params.logger.foreach(logger => - logger.error( - s"Programmer error: Received Poll Result of unknown type. Expected ${RunStatus.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." - ) + params.logger.error( + s"Programmer error: Received Poll Result of unknown type. Expected ${RunStatus.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." ) case AsyncJobHasFinished(pollResult: RunStatus) => handleAsyncJobFinish(pollResult.getClass.getSimpleName) case AsyncJobHasFinished(result) => - params.logger.foreach(logger => - logger.error( - s"Programmer error: Received Poll Result of unknown type. Expected ${AsyncJobHasFinished.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." - ) + params.logger.error( + s"Programmer error: Received Poll Result of unknown type. Expected ${AsyncJobHasFinished.getClass.getSimpleName} but got ${result.getClass.getSimpleName}." ) } - case _ => - params.logger.foreach(logger => - logger.error( - s"Programmer error: Cost Helper received message of type other than CostPollingMessage" - ) + case unexpected => + params.logger.error( + s"Programmer error: Cost Helper received message of unexpected type. Was ${unexpected.getClass.getSimpleName}." ) + } - override def params: PollMonitorParameters = parameters } diff --git a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala index 402427dc81b..388ec1df89c 100644 --- a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala +++ b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala @@ -777,7 +777,7 @@ class PipelinesApiAsyncBackendJobExecutionActor(override val standardParams: Sta super[PipelinesApiStatusRequestClient].pollStatus(workflowId, handle.pendingJob) override def checkAndRecordQuotaExhaustion(runStatus: RunStatus): Unit = runStatus match { - case AwaitingCloudQuota(_) => + case AwaitingCloudQuota(_, _) => standardParams.groupMetricsActor ! RecordGroupQuotaExhaustion(googleProject(jobDescriptor.workflowDescriptor)) case _ => } diff --git a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala index 03e49e5c1c1..d5be6707161 100644 --- a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala +++ b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/api/RunStatus.scala @@ -3,20 +3,26 @@ package cromwell.backend.google.pipelines.common.api import _root_.io.grpc.Status import cromwell.backend.google.pipelines.common.PipelinesApiAsyncBackendJobExecutionActor import cromwell.core.ExecutionEvent +import cromwell.services.cost.InstantiatedVmInfo import scala.util.Try - sealed trait RunStatus { def eventList: Seq[ExecutionEvent] def toString: String + + val instantiatedVmInfo: Option[InstantiatedVmInfo] } object RunStatus { - case class Initializing(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Initializing" } - case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent]) extends RunStatus { + case class Initializing(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty) + extends RunStatus { override def toString = "Initializing" } + case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty + ) extends RunStatus { override def toString = "AwaitingCloudQuota" } - case class Running(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Running" } + case class Running(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty) + extends RunStatus { override def toString = "Running" } sealed trait TerminalRunStatus extends RunStatus { def machineType: Option[String] @@ -38,7 +44,8 @@ object RunStatus { case class Success(eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends TerminalRunStatus { override def toString = "Success" } @@ -88,7 +95,8 @@ object RunStatus { eventList, machineType, zone, - instanceName + instanceName, + Option.empty ) } } @@ -99,7 +107,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "Failed" } @@ -113,7 +122,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "Cancelled" } @@ -124,7 +134,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "Preempted" } @@ -139,7 +150,8 @@ object RunStatus { eventList: Seq[ExecutionEvent], machineType: Option[String], zone: Option[String], - instanceName: Option[String] + instanceName: Option[String], + instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty ) extends UnsuccessfulRunStatus { override def toString = "QuotaFailed" } diff --git a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala index 77e53176df3..bda6981084b 100644 --- a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala +++ b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/ErrorReporter.scala @@ -86,7 +86,15 @@ class ErrorReporter(machineType: Option[String], // Reverse the list because the first failure (likely the most relevant, will appear last otherwise) val unexpectedExitEvents: List[String] = unexpectedExitStatusErrorStrings(events, actions).reverse - builder(status, None, failed.toList ++ unexpectedExitEvents, executionEvents, machineType, zone, instanceName) + builder(status, + None, + failed.toList ++ unexpectedExitEvents, + executionEvents, + machineType, + zone, + instanceName, + Option.empty + ) } // There's maybe one FailedEvent per operation with a summary error message diff --git a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala index 9cfbf62ca4f..4a1d780685f 100644 --- a/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala +++ b/supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/request/GetRequestHandler.scala @@ -22,6 +22,7 @@ import cromwell.backend.google.pipelines.v2beta.api.Deserialization._ import cromwell.backend.google.pipelines.v2beta.api.request.ErrorReporter._ import cromwell.cloudsupport.gcp.auth.GoogleAuthMode import cromwell.core.ExecutionEvent +import cromwell.services.cost.InstantiatedVmInfo import cromwell.services.metadata.CallMetadataKeys import io.grpc.Status import org.apache.commons.lang3.exception.ExceptionUtils @@ -81,33 +82,45 @@ trait GetRequestHandler { this: RequestHandler => .toList .flatten val executionEvents = getEventList(metadata, events, actions) + val workerAssignedEvent: Option[WorkerAssignedEvent] = + events.collectFirst { + case event if event.getWorkerAssigned != null => event.getWorkerAssigned + } + val virtualMachineOption = for { + pipelineValue <- pipeline + resources <- Option(pipelineValue.getResources) + virtualMachine <- Option(resources.getVirtualMachine) + } yield virtualMachine + + // Correlate `executionEvents` to `actions` to potentially assign a grouping into the appropriate events. + val machineType = virtualMachineOption.flatMap(virtualMachine => Option(virtualMachine.getMachineType)) + + /* + preemptible is only used if the job fails, as a heuristic to guess if the VM was preempted. + If we can't get the value of preempted we still need to return something, returning false will not make the + failure count as a preemption which seems better than saying that it was preemptible when we really don't know + */ + val preemptibleOption = for { + pipelineValue <- pipeline + resources <- Option(pipelineValue.getResources) + virtualMachine <- Option(resources.getVirtualMachine) + preemptible <- Option(virtualMachine.getPreemptible) + } yield preemptible + val preemptible = preemptibleOption.exists(_.booleanValue) + val instanceName = + workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getInstance())) + val zone = workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getZone)) + val region = zone.map { zoneString => + val lastDashIndex = zoneString.lastIndexOf("-") + if (lastDashIndex != -1) zoneString.substring(0, lastDashIndex) else zoneString + } + + val instantiatedVmInfo: Option[InstantiatedVmInfo] = (region, machineType) match { + case (Some(instantiatedRegion), Some(instantiatedMachineType)) => + Option(InstantiatedVmInfo(instantiatedRegion, instantiatedMachineType, preemptible)) + case _ => Option.empty + } if (operation.getDone) { - val workerAssignedEvent: Option[WorkerAssignedEvent] = - events.collectFirst { - case event if event.getWorkerAssigned != null => event.getWorkerAssigned - } - val virtualMachineOption = for { - pipelineValue <- pipeline - resources <- Option(pipelineValue.getResources) - virtualMachine <- Option(resources.getVirtualMachine) - } yield virtualMachine - // Correlate `executionEvents` to `actions` to potentially assign a grouping into the appropriate events. - val machineType = virtualMachineOption.flatMap(virtualMachine => Option(virtualMachine.getMachineType)) - /* - preemptible is only used if the job fails, as a heuristic to guess if the VM was preempted. - If we can't get the value of preempted we still need to return something, returning false will not make the - failure count as a preemption which seems better than saying that it was preemptible when we really don't know - */ - val preemptibleOption = for { - pipelineValue <- pipeline - resources <- Option(pipelineValue.getResources) - virtualMachine <- Option(resources.getVirtualMachine) - preemptible <- Option(virtualMachine.getPreemptible) - } yield preemptible - val preemptible = preemptibleOption.exists(_.booleanValue) - val instanceName = - workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getInstance())) - val zone = workerAssignedEvent.flatMap(workerAssignedEvent => Option(workerAssignedEvent.getZone)) // If there's an error, generate an unsuccessful status. Otherwise, we were successful! Option(operation.getError) match { case Some(error) => @@ -122,14 +135,14 @@ trait GetRequestHandler { this: RequestHandler => pollingRequest.workflowId ) errorReporter.toUnsuccessfulRunStatus(error, events) - case None => Success(executionEvents, machineType, zone, instanceName) + case None => Success(executionEvents, machineType, zone, instanceName, instantiatedVmInfo) } } else if (isQuotaDelayed(events)) { - AwaitingCloudQuota(executionEvents) + AwaitingCloudQuota(executionEvents, instantiatedVmInfo) } else if (operation.hasStarted) { - Running(executionEvents) + Running(executionEvents, instantiatedVmInfo) } else { - Initializing(executionEvents) + Initializing(executionEvents, instantiatedVmInfo) } } catch { case nullPointerException: NullPointerException =>