From 6ac4dc5b771e7ecf7f8031acc30c842aa08ca861 Mon Sep 17 00:00:00 2001 From: Tyler Osterberg Date: Thu, 22 Aug 2024 16:24:18 -0700 Subject: [PATCH] [lmi][neuron] Add smart defaults to LMI Neuron --- tests/integration/llm/prepare.py | 14 +- wlm/build.gradle.kts | 1 + .../djl/serving/wlm/LmiConfigRecommender.java | 20 +- .../java/ai/djl/serving/wlm/LmiUtils.java | 112 +++++++ .../serving/wlm/NeuronSmartDefaultUtils.java | 297 ++++++++++++++++++ .../wlm/NeuronSmartDefaultUtilsTest.java | 240 ++++++++++++++ .../smart-default-model/2b/config.json | 26 ++ .../smart-default-model/70b/config.json | 38 +++ .../smart-default-model/8b/config.json | 34 ++ .../smart-default-model/empty/config.json | 20 ++ .../smart-default-model/unit/config.json | 26 ++ 11 files changed, 819 insertions(+), 9 deletions(-) create mode 100644 wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java create mode 100644 wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java create mode 100644 wlm/src/test/resources/smart-default-model/2b/config.json create mode 100644 wlm/src/test/resources/smart-default-model/70b/config.json create mode 100644 wlm/src/test/resources/smart-default-model/8b/config.json create mode 100644 wlm/src/test/resources/smart-default-model/empty/config.json create mode 100644 wlm/src/test/resources/smart-default-model/unit/config.json diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index c97fbf22a8..ab1120202c 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -82,6 +82,7 @@ "max_dynamic_batch_size": 4, "option.tensor_parallel_degree": 2, "option.n_positions": 512, + "option.rolling_batch": 'disable', "option.dtype": "fp16", "option.model_loading_timeout": 600 }, @@ -90,6 +91,7 @@ "batch_size": 4, "option.tensor_parallel_degree": 2, "option.n_positions": 512, + "option.rolling_batch": 'disable', "option.dtype": "fp16", "option.model_loading_timeout": 600, "option.quantize": "static_int8" @@ -99,6 +101,7 @@ "batch_size": 4, "option.tensor_parallel_degree": 4, "option.n_positions": 512, + "option.rolling_batch": 'disable', "option.dtype": "fp16", "option.model_loading_timeout": 600 }, @@ -107,6 +110,7 @@ "batch_size": 4, "option.tensor_parallel_degree": 8, "option.n_positions": 512, + "option.rolling_batch": 'disable', "option.dtype": "fp32", "option.model_loading_timeout": 2400 }, @@ -115,6 +119,7 @@ "batch_size": 4, "option.tensor_parallel_degree": 2, "option.n_positions": 512, + "option.rolling_batch": 'disable', "option.dtype": "fp16", "option.model_loading_timeout": 900 }, @@ -123,6 +128,7 @@ "batch_size": 4, "option.tensor_parallel_degree": 4, "option.n_positions": 256, + "option.rolling_batch": 'disable', "option.dtype": "fp16", "option.model_loading_timeout": 1200 }, @@ -130,6 +136,7 @@ "option.model_id": "s3://djl-llm/mixtral-8x7b/", "option.tensor_parallel_degree": 8, "option.n_positions": 1024, + "option.rolling_batch": 'disable', "batch_size": 4, "option.model_loading_timeout": 3600, }, @@ -138,6 +145,7 @@ "batch_size": 2, "option.tensor_parallel_degree": 4, "option.n_positions": 512, + "option.rolling_batch": 'disable', "option.dtype": "fp16", "option.model_loading_timeout": 600, "option.enable_streaming": True, @@ -180,7 +188,6 @@ "option.tensor_parallel_degree": 4, "option.n_positions": 512, "option.max_rolling_batch_size": 4, - "option.rolling_batch": 'auto', "option.model_loading_timeout": 2400, "option.load_split_model": True, "option.output_formatter": "jsonlines" @@ -206,7 +213,6 @@ }, "mistral-7b-rb": { "option.model_id": "s3://djl-llm/mistral-7b-instruct-v02/", - "option.rolling_batch": "auto", "option.max_rolling_batch_size": 4, "option.tensor_parallel_degree": 4, "option.n_positions": 1024, @@ -217,7 +223,6 @@ "option.speculative_draft_model": "s3://djl-llm/llama-2-tiny/", "option.speculative_length": 7, "option.tensor_parallel_degree": 12, - "option.rolling_batch": "auto", "option.max_rolling_batch_size": 1, "option.model_loading_timeout": 3600, "option.output_formatter": "jsonlines" @@ -231,7 +236,6 @@ "s3://djl-llm/inf2-compiled-graphs/llama-2-tiny/", "option.speculative_length": 4, "option.tensor_parallel_degree": 12, - "option.rolling_batch": "auto", "option.max_rolling_batch_size": 1, "option.model_loading_timeout": 3600, "option.output_formatter": "jsonlines" @@ -241,7 +245,6 @@ "option.tensor_parallel_degree": 2, "option.n_positions": 1024, "option.max_rolling_batch_size": 4, - "option.rolling_batch": 'auto', "option.model_loading_timeout": 1200, }, "tiny-llama-rb-aot-quant": { @@ -250,7 +253,6 @@ "option.tensor_parallel_degree": 2, "option.n_positions": 1024, "option.max_rolling_batch_size": 4, - "option.rolling_batch": 'auto', "option.model_loading_timeout": 1200, } } diff --git a/wlm/build.gradle.kts b/wlm/build.gradle.kts index d395feaa1b..dfdc4afec9 100644 --- a/wlm/build.gradle.kts +++ b/wlm/build.gradle.kts @@ -12,6 +12,7 @@ dependencies { testImplementation(libs.testng) { exclude(group = "junit", module = "junit") } + testImplementation("org.mockito:mockito-core:5.11.0") testRuntimeOnly("ai.djl:model-zoo") testRuntimeOnly("ai.djl.pytorch:pytorch-engine") diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java index 1ec9ba0b34..42972405c1 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java @@ -13,6 +13,7 @@ package ai.djl.serving.wlm; import ai.djl.util.Ec2Utils; +import ai.djl.util.NeuronUtils; import ai.djl.util.Utils; import ai.djl.util.cuda.CudaUtils; @@ -84,6 +85,7 @@ static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES"); setRollingBatch(lmiProperties, modelConfig, features); setMpiMode(lmiProperties); + setHeuristicDefaults(lmiProperties, modelConfig); setTensorParallelDegree(lmiProperties); setPipelineParallelDegree(lmiProperties); setRollingBatchSize(lmiProperties); @@ -99,9 +101,7 @@ private static void setRollingBatch( return; } - String defaultRollingBatch = isTnxEnabled(features) ? "disable" : "auto"; - String rollingBatch = - lmiProperties.getProperty("option.rolling_batch", defaultRollingBatch); + String rollingBatch = lmiProperties.getProperty("option.rolling_batch", "auto"); String modelType = modelConfig.getModelType(); if (!"auto".equals(rollingBatch)) { return; @@ -142,6 +142,11 @@ private static void setTensorParallelDegree(Properties lmiProperties) { int numGpus = CudaUtils.getGpuCount(); if (numGpus > 0) { tpDegree = String.valueOf(numGpus); + } else if (NeuronUtils.hasNeuron()) { + int numAccelerators = NeuronUtils.getNeuronCores(); + if (numAccelerators > 0) { + tpDegree = String.valueOf(numAccelerators); + } } else { tpDegree = null; } @@ -196,6 +201,15 @@ private static boolean isTnxEnabled(String features) { return features != null && features.contains("tnx"); } + private static void setHeuristicDefaults( + Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) { + if (NeuronUtils.hasNeuron() && isTextGenerationModel(modelConfig)) { + // Set default values for Neuron text generation models + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(lmiProperties, modelConfig); + } + } + private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) { for (String arch : modelConfig.getArchitectures()) { boolean isTextGenerationModel = diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 7311b2958e..57a1e55610 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -507,6 +507,27 @@ static final class HuggingFaceModelConfig { @SerializedName("_diffusers_version") private String diffusersVersion; + @SerializedName("hidden_size") + private int hiddenSize; + + @SerializedName("intermediate_size") + private int intermediateSize; + + @SerializedName("max_position_embeddings") + private int maxPositionEmbeddings; + + @SerializedName("num_attention_heads") + private int numAttentionHeads; + + @SerializedName("num_hidden_layers") + private int numHiddenLayers; + + @SerializedName("num_key_value_heads") + private int numKeyValueHeads; + + @SerializedName("vocab_size") + private int vocabSize; + private Set allArchitectures; public String getModelType() { @@ -523,6 +544,97 @@ public Set getArchitectures() { return allArchitectures; } + /** + * Returns the default value for the n_positions model configuration. For models that do not + * have a pre-defined value for n_positions, this function returns the minimum of + * max_position_embeddings and 4096. If both max_position_embeddings and 4096 are not + * available, this function returns 0. + * + * @return The default value for n_positions. + */ + public Integer getDefaultNPositions() { + try { + return Math.min(maxPositionEmbeddings, 4096); + } catch (Exception e) { + return 0; + } + } + + /** + * Calculates the number of parameters in a model that is similar to LLaMA. This function + * takes into account the hidden size, intermediate size, maximum position embeddings, + * number of hidden layers, vocabulary size, and number of attention heads and key-value + * heads to calculate the total parameter count. + * + * @return The total parameter count for the model. + */ + private long getLlamaLikeParameterCount() { + long headDim = (long) numAttentionHeads * numKeyValueHeads; + long embeddings = (long) vocabSize * hiddenSize; + long qkvProjection = headDim * hiddenSize * numKeyValueHeads * 3; + long oProjection = (long) hiddenSize * hiddenSize; + long gateProjection = (long) hiddenSize * intermediateSize * 3; + return embeddings + + numHiddenLayers + * (qkvProjection + + oProjection + + gateProjection + + hiddenSize + + hiddenSize) + + hiddenSize + + embeddings; + } + + /** + * Calculates the default parameter count for a model (GPT-2-like). + * + *

This function takes into account the hidden size, maximum position embeddings, number + * of hidden layers, vocabulary size, and number of attention heads to calculate the total + * parameter count. + * + * @return The total parameter count for the model. + */ + private long getDefaultParameterCount() { + long embeddingLayerTotal = (long) (vocabSize + maxPositionEmbeddings) * hiddenSize; + long attentionTotal = (long) 4 * (hiddenSize * hiddenSize); + long feedForwardTotal = (long) 8 * (hiddenSize * hiddenSize); + long layerNormTotal = (long) 4 * hiddenSize; + long transformerBlockTotal = + (attentionTotal + feedForwardTotal + layerNormTotal) * numHiddenLayers; + long finalLayerTotal = (long) hiddenSize * vocabSize; + return embeddingLayerTotal + transformerBlockTotal + finalLayerTotal; + } + + /** + * Calculates the total parameter count for the model. + * + * @return The total parameter count for the model. + */ + public long getModelParameters() { + try { + if ("llama".equals(modelType) || "mistral".equals(modelType)) { + return getLlamaLikeParameterCount(); + } + return getDefaultParameterCount(); + } catch (Exception e) { + return 0L; + } + } + + /** + * Returns the memory required to store a single batch of sequence data. + * + *

The memory required is calculated as the product of the sequence length, hidden size, + * number of hidden layers, and weight in bytes. + * + * @param sequenceLength The length in tokens of the sequence. + * @param weightBytes The weight in bytes. + * @return The memory required to store a single batch of sequence data. + */ + public long getApproxMemoryForSingleSequence(int sequenceLength, int weightBytes) { + return (long) sequenceLength * hiddenSize * numHiddenLayers * weightBytes; + } + private void determineAllArchitectures() { allArchitectures = new HashSet<>(); if (configArchitectures != null) { diff --git a/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java new file mode 100644 index 0000000000..98d016e706 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/NeuronSmartDefaultUtils.java @@ -0,0 +1,297 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm; + +import ai.djl.util.NeuronUtils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +/** A utility class to auto configure LMI Neuron model properties. */ +public class NeuronSmartDefaultUtils { + + private static final float BILLION = 1_000_000_000.0F; + private static final int MAX_ROLLING_BATCH = 128; // Current cap for NeuronSDK 2.19.1 + private static final float MEMORY_PER_CORE = + 16.0F; // Currently there is only one config w/ 16 gb per core + + // Internal settings + private Integer nPositions; + private Integer availableCores; + private Float modelSizeInGb; + private Float sequenceSizeInGb; + + /** + * Applies smart defaults for Neuron models. + * + *

This method sets the following properties if not already set: + * + *

    + *
  • option.n_positions: The default n_positions for the model. + *
  • option.tensor_parallel_degree: A heuristic based on available memory. + *
  • option.max_rolling_batch_size: A heuristic based on available memory. + *
+ * + * @param prop The properties to update. + * @param modelConfig The model configuration to use. + */ + public void applySmartDefaults(Properties prop, LmiUtils.HuggingFaceModelConfig modelConfig) { + if (!prop.containsKey("option.n_positions")) { + prop.setProperty("option.n_positions", modelConfig.getDefaultNPositions().toString()); + } + setInternalSettings(prop, modelConfig); + setHeuristicNeuronTPDegree(prop); + setHeuristicNeuronMaxRollingBatch(prop); + } + + /** + * Sets the internal settings for the NeuronSmartDefaultUtils instance. + * + * @param prop The properties to retrieve settings from. + * @param modelConfig The model configuration to use for calculations. + */ + private void setInternalSettings(Properties prop, LmiUtils.HuggingFaceModelConfig modelConfig) { + clearInternalSettings(); + nPositions = Integer.parseInt(prop.getProperty("option.n_positions")); + if (NeuronUtils.hasNeuron()) { + availableCores = NeuronUtils.getNeuronCores(); + } else { + availableCores = 1; + } + int paramBytes = prop.containsKey("option.quantize") ? 1 : 2; + modelSizeInGb = (paramBytes * modelConfig.getModelParameters()) / BILLION; + sequenceSizeInGb = + modelConfig.getApproxMemoryForSingleSequence(nPositions, paramBytes) + / (1024.0F * 1024.0F * 1024.0F); + } + + /** + * Calculates the adjusted model size in GB, given a tensor parallel degree. + * + *

This method takes the model size in GB and adjusts it based on the tensor parallel degree. + * The adjustment is a linear relationship between the model size and the tensor parallel + * degree. The adjustment is based on the estimated memory increase due to the tensor parallel + * degree. + * + * @param tpDegree The tensor parallel degree. + * @return The adjusted model size in GB. + */ + private float getAdjustedModelSizeInGb(int tpDegree) { + return modelSizeInGb * (1.0F + ((tpDegree * 2 - 2) / 100.0F)); + } + + /** + * Clears the internal settings for this NeuronSmartDefaultUtils instance. + * + *

This method clears the following fields: + * + *

    + *
  • {@link #nPositions} + *
  • {@link #availableCores} + *
  • {@link #modelSizeInGb} + *
  • {@link #sequenceSizeInGb} + *
+ */ + private void clearInternalSettings() { + nPositions = null; + availableCores = null; + modelSizeInGb = null; + sequenceSizeInGb = null; + } + + /** + * Sets a heuristic value for tensor parallel degree if not already set in model properties. + * + *

This method sets the value of tensor parallel degree by iterating through the available + * core configurations and checks if the current core configuration can support the maximum + * rolling batch size that can fit in the available memory. If the current configuration can + * support the maximum rolling batch size, then the current core configuration is used as the + * tensor parallel degree. + * + *

If the maximum rolling batch size is not set, then the maximum instance concurrency is + * used as the maximum rolling batch size. + * + *

This method is called by the LMI model server when it is starting up and is used to set + * the tensor parallel degree if it is not already set in the model properties. + * + * @param prop The model properties. + */ + private void setHeuristicNeuronTPDegree(Properties prop) { + int tpDegree = availableCores; + float totalMemory = tpDegree * MEMORY_PER_CORE; + // Disambiguate "max" and available cores + if (prop.containsKey("option.tensor_parallel_degree") + && "max".equals(prop.getProperty("option.tensor_parallel_degree"))) { + prop.setProperty("option.tensor_parallel_degree", String.valueOf(availableCores)); + return; + } + + List coreConfigs = availableCoreConfigs(); + if (!prop.containsKey("option.tensor_parallel_degree") + && !prop.containsKey("option.max_rolling_batch_size")) { + // Set tensor parallel degree based off of maximizing instance concurrency with variable + // rolling batch size + int totalInstanceConcurrency = getMaxConcurrency(totalMemory, tpDegree); + for (int coreConfig : coreConfigs) { + float maxMemory = coreConfig * MEMORY_PER_CORE; + int maxConcurrency = getMaxConcurrency(maxMemory, coreConfig); + if (maxConcurrency >= totalInstanceConcurrency && coreConfig <= tpDegree) { + tpDegree = coreConfig; + totalInstanceConcurrency = maxConcurrency; + } + } + prop.setProperty("option.tensor_parallel_degree", String.valueOf(tpDegree)); + } else if (!prop.containsKey("option.tensor_parallel_degree")) { + // Set tensor parallel degree by minimizing TP degree that supports fixed batch size + int batchSize = Integer.parseInt(prop.getProperty("option.max_rolling_batch_size")); + int totalInstanceConcurrency = + getMaxConcurrencyWithBatch(totalMemory, tpDegree, batchSize); + for (int coreConfig : coreConfigs) { + float maxMemory = coreConfig * MEMORY_PER_CORE; + int maxConcurrency = getMaxConcurrencyWithBatch(maxMemory, coreConfig, batchSize); + if (maxConcurrency >= totalInstanceConcurrency && coreConfig <= tpDegree) { + tpDegree = coreConfig; + totalInstanceConcurrency = maxConcurrency; + } + } + prop.setProperty("option.tensor_parallel_degree", String.valueOf(tpDegree)); + } + } + + /** + * Finds the largest power of 2 less than or equal to n. + * + * @param n The input number. + * @return The largest power of 2 less than or equal to n. + */ + private int getMaxPowerOf2(int n) { + if (n != 0 && (n & (n - 1)) == 0) { + return n; + } + int maxPowerOf2 = 1; + while (maxPowerOf2 < n) { + maxPowerOf2 *= 2; + } + return maxPowerOf2 / 2; + } + + /** + * Calculates the maximum number of concurrent requests that can be served by a model given the + * total memory available for the model and the sequence size. + * + *

The maximum number of concurrent requests is calculated as the largest power of 2 less + * than or equal to the total memory divided by the sequence size. + * + * @param totalMemory The total memory available for the model. + * @return The maximum number of concurrent requests. + */ + private int getMaxConcurrency(float totalMemory, int tpDegree) { + int maxConcurrency = + (int) ((totalMemory - getAdjustedModelSizeInGb(tpDegree)) / sequenceSizeInGb); + maxConcurrency = getMaxPowerOf2(maxConcurrency); + return Math.min(maxConcurrency, MAX_ROLLING_BATCH); + } + + /** + * Calculates the maximum number of concurrent requests that can be served by a model given the + * total memory available for the model and the sequence size. + * + * @param totalMemory The total memory available for the model. + * @param batchSize The maximum number of requests that can be processed in a single batch. + * @return The maximum number of concurrent requests that can be served. + */ + private int getMaxConcurrencyWithBatch(float totalMemory, int tpDegree, int batchSize) { + int maxConcurrency = + (int) ((totalMemory - getAdjustedModelSizeInGb(tpDegree)) / sequenceSizeInGb); + maxConcurrency = getMaxPowerOf2(maxConcurrency); + maxConcurrency = Math.min(maxConcurrency, batchSize); + if (maxConcurrency == batchSize) { + return maxConcurrency; + } else { + return 0; + } + } + + /** + * Builds the available core configurations for a given number of cores. + * + *

The available core configurations are those that are less than or equal to the total + * number of cores. This method returns a list of available core configurations for the given + * number of cores. + * + * @return The list of available core configurations. + */ + private List availableCoreConfigs() { + List coreConfigs = new ArrayList<>(); + List availableCoreConfigs = buildCoreConfigs(availableCores); + int coresPerModel = (int) Math.ceil(modelSizeInGb / MEMORY_PER_CORE); + for (int coreConfig : availableCoreConfigs) { + if (coresPerModel >= coreConfig) { + coreConfigs.add(coreConfig); + } + } + return coreConfigs; + } + + /** + * Builds the available core configurations for a given number of cores. + * + *

The available core configurations are those that are less than or equal to the total + * number of cores. This method returns a list of available core configurations for the given + * number of cores. + * + * @param nCores The number of cores to build the configurations for. + * @return The list of available core configurations. + */ + private List buildCoreConfigs(int nCores) { + List coreConfigs = new ArrayList<>(); + // Add all powers of 2 up to the given number of cores + for (int i = 1; i <= 8; i *= 2) { + // Skip TP=4 for nCores=32 as it is not supported + if (i != 4 || nCores != 32) { + coreConfigs.add(i); + } + } + // Add the given number of cores to the list + coreConfigs.add(nCores); + return coreConfigs; + } + + /** + * Sets the max rolling batch size based on the TP degree and the model memory size. + * + *

If the max rolling batch size is not set, this method sets it to the maximum number of + * concurrent requests that can be served by a model given the total memory available for the + * model and the sequence size. + * + * @param prop The properties to set the max rolling batch size to. + */ + private void setHeuristicNeuronMaxRollingBatch(Properties prop) { + int tpDegree; + try { + tpDegree = Integer.parseInt(prop.getProperty("option.tensor_parallel_degree")); + } catch (Exception e) { + // if tensor parallel degree exists and is not an integer, it is max, use all available + // cores + tpDegree = availableCores; + } + if (!prop.containsKey("option.max_rolling_batch_size")) { + int maxRollingBatchSize = getMaxConcurrency(tpDegree * MEMORY_PER_CORE, tpDegree); + if (maxRollingBatchSize > 0) { + prop.setProperty( + "option.max_rolling_batch_size", String.valueOf(maxRollingBatchSize)); + } + } + } +} diff --git a/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java b/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java new file mode 100644 index 0000000000..4dbb2a5442 --- /dev/null +++ b/wlm/src/test/java/ai/djl/serving/wlm/NeuronSmartDefaultUtilsTest.java @@ -0,0 +1,240 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm; + +import ai.djl.util.JsonUtils; +import ai.djl.util.NeuronUtils; + +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.FileReader; +import java.io.Reader; +import java.util.Properties; + +/** + * {@link NeuronSmartDefaultUtils}. + * + * @author tyoster @ Amazon.com, Inc. + */ +public class NeuronSmartDefaultUtilsTest { + + // Known model parameter tests for LMI Utils HuggingFaceModelConfig + @Test + public void testModelConfigParametersLlama() { + LmiUtils.HuggingFaceModelConfig modelConfig = get70BLlamaHuggingFaceModelConfig(); + Assert.assertEquals(modelConfig.getDefaultNPositions(), 4096); + Assert.assertEquals(modelConfig.getModelParameters(), 71895883776L); + } + + @Test + public void testModelConfigParametersDefault() { + LmiUtils.HuggingFaceModelConfig modelConfig = getDefaultHuggingFaceModelConfig(); + Assert.assertEquals(modelConfig.getDefaultNPositions(), 1); + Assert.assertEquals(modelConfig.getModelParameters(), 19L); + } + + @Test + public void testModelConfigParametersNoParameters() { + LmiUtils.HuggingFaceModelConfig modelConfig = getNoParametersHuggingFaceModelConfig(); + Assert.assertEquals(modelConfig.getDefaultNPositions(), 1); + Assert.assertEquals(modelConfig.getModelParameters(), 0L); + } + + // Standard use tests on without Neuron device available + @Test + public void testApplySmartDefaults70BModel() { + Properties prop = new Properties(); + LmiUtils.HuggingFaceModelConfig modelConfig = get70BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "4096"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + Assert.assertEquals(prop.containsKey("option.max_rolling_batch_size"), false); + } + + @Test + public void testApplySmartDefaultsQuantize8BModel() { + Properties prop = new Properties(); + prop.setProperty("option.quantize", "static_int8"); + LmiUtils.HuggingFaceModelConfig modelConfig = get8BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "4096"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "8"); + } + + @Test + public void testApplySmartDefaults2BModel() { + Properties prop = new Properties(); + LmiUtils.HuggingFaceModelConfig modelConfig = get2BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "64"); + } + + @Test + public void testApplySmartDefaultsQuantize2BModel() { + Properties prop = new Properties(); + prop.setProperty("option.quantize", "static_int8"); + LmiUtils.HuggingFaceModelConfig modelConfig = get2BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "128"); + } + + @Test + public void testApplySmartDefaultsWithNPositions() { + Properties prop = new Properties(); + prop.setProperty("option.n_positions", "128"); + LmiUtils.HuggingFaceModelConfig modelConfig = get2BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "128"); + } + + @Test + public void testApplySmartDefaultsWithTPDegree() { + Properties prop = new Properties(); + prop.setProperty("option.tensor_parallel_degree", "1"); + LmiUtils.HuggingFaceModelConfig modelConfig = get2BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "64"); + } + + @Test + public void testApplySmartDefaultsWithMaxRollingBatch() { + Properties prop = new Properties(); + prop.setProperty("option.max_rolling_batch_size", "64"); + LmiUtils.HuggingFaceModelConfig modelConfig = get2BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + } + + @Test + public void testApplySmartDefaultsWithTPMax() { + Properties prop = new Properties(); + prop.setProperty("option.tensor_parallel_degree", "max"); + LmiUtils.HuggingFaceModelConfig modelConfig = get2BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(false); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "2048"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "1"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "64"); + } + + @Test + public void testApplySmartDefaultsWithNeuron() { + Properties prop = new Properties(); + LmiUtils.HuggingFaceModelConfig modelConfig = get70BLlamaHuggingFaceModelConfig(); + try (MockedStatic mockedStatic = Mockito.mockStatic(NeuronUtils.class)) { + mockedStatic.when(() -> NeuronUtils.hasNeuron()).thenReturn(true); + mockedStatic.when(() -> NeuronUtils.getNeuronCores()).thenReturn(32); + NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils(); + smartDefaultUtils.applySmartDefaults(prop, modelConfig); + } + Assert.assertEquals(prop.getProperty("option.n_positions"), "4096"); + Assert.assertEquals(prop.getProperty("option.tensor_parallel_degree"), "32"); + Assert.assertEquals(prop.getProperty("option.max_rolling_batch_size"), "32"); + } + + // Helper methods + public LmiUtils.HuggingFaceModelConfig get2BLlamaHuggingFaceModelConfig() { + try { + Reader reader = new FileReader("src/test/resources/smart-default-model/2b/config.json"); + return JsonUtils.GSON.fromJson(reader, LmiUtils.HuggingFaceModelConfig.class); + } catch (Exception e) { + Assert.fail(); + } + return null; + } + + public LmiUtils.HuggingFaceModelConfig get8BLlamaHuggingFaceModelConfig() { + try { + Reader reader = new FileReader("src/test/resources/smart-default-model/8b/config.json"); + return JsonUtils.GSON.fromJson(reader, LmiUtils.HuggingFaceModelConfig.class); + } catch (Exception e) { + Assert.fail(); + } + return null; + } + + public LmiUtils.HuggingFaceModelConfig get70BLlamaHuggingFaceModelConfig() { + try { + Reader reader = + new FileReader("src/test/resources/smart-default-model/70b/config.json"); + return JsonUtils.GSON.fromJson(reader, LmiUtils.HuggingFaceModelConfig.class); + } catch (Exception e) { + Assert.fail(); + } + return null; + } + + public LmiUtils.HuggingFaceModelConfig getDefaultHuggingFaceModelConfig() { + try { + Reader reader = + new FileReader("src/test/resources/smart-default-model/unit/config.json"); + return JsonUtils.GSON.fromJson(reader, LmiUtils.HuggingFaceModelConfig.class); + } catch (Exception e) { + Assert.fail(); + } + return null; + } + + public LmiUtils.HuggingFaceModelConfig getNoParametersHuggingFaceModelConfig() { + try { + Reader reader = + new FileReader("src/test/resources/smart-default-model/empty/config.json"); + return JsonUtils.GSON.fromJson(reader, LmiUtils.HuggingFaceModelConfig.class); + } catch (Exception e) { + Assert.fail(); + } + return null; + } +} diff --git a/wlm/src/test/resources/smart-default-model/2b/config.json b/wlm/src/test/resources/smart-default-model/2b/config.json new file mode 100644 index 0000000000..8b27b73df9 --- /dev/null +++ b/wlm/src/test/resources/smart-default-model/2b/config.json @@ -0,0 +1,26 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5632, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 22, + "num_key_value_heads": 4, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.35.0", + "use_cache": true, + "vocab_size": 32000 +} \ No newline at end of file diff --git a/wlm/src/test/resources/smart-default-model/70b/config.json b/wlm/src/test/resources/smart-default-model/70b/config.json new file mode 100644 index 0000000000..d4e2754e77 --- /dev/null +++ b/wlm/src/test/resources/smart-default-model/70b/config.json @@ -0,0 +1,38 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": [ + 128001, + 128008, + 128009 + ], + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.42.3", + "use_cache": true, + "vocab_size": 128256 +} \ No newline at end of file diff --git a/wlm/src/test/resources/smart-default-model/8b/config.json b/wlm/src/test/resources/smart-default-model/8b/config.json new file mode 100644 index 0000000000..9b3ff8de3f --- /dev/null +++ b/wlm/src/test/resources/smart-default-model/8b/config.json @@ -0,0 +1,34 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.0.dev0", + "use_cache": true, + "vocab_size": 128256 +} \ No newline at end of file diff --git a/wlm/src/test/resources/smart-default-model/empty/config.json b/wlm/src/test/resources/smart-default-model/empty/config.json new file mode 100644 index 0000000000..3025cade60 --- /dev/null +++ b/wlm/src/test/resources/smart-default-model/empty/config.json @@ -0,0 +1,20 @@ +{ + "architectures": [ + "DefaultForCausalLM" + ], + "attention_bias": false, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "initializer_range": 0.02, + "max_position_embeddings": 1, + "model_type": "default", + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.35.0", + "use_cache": true +} \ No newline at end of file diff --git a/wlm/src/test/resources/smart-default-model/unit/config.json b/wlm/src/test/resources/smart-default-model/unit/config.json new file mode 100644 index 0000000000..d96c7ee11d --- /dev/null +++ b/wlm/src/test/resources/smart-default-model/unit/config.json @@ -0,0 +1,26 @@ +{ + "architectures": [ + "DefaultForCausalLM" + ], + "attention_bias": false, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 1, + "initializer_range": 0.02, + "intermediate_size": 1, + "max_position_embeddings": 1, + "model_type": "default", + "num_attention_heads": 1, + "num_hidden_layers": 1, + "num_key_value_heads": 1, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.35.0", + "use_cache": true, + "vocab_size": 1 +} \ No newline at end of file