Skip to content

Commit

Permalink
[lmi][neuron] Add smart defaults to LMI Neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg committed Aug 29, 2024
1 parent e55bbc1 commit 6ac4dc5
Show file tree
Hide file tree
Showing 11 changed files with 819 additions and 9 deletions.
14 changes: 8 additions & 6 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"max_dynamic_batch_size": 4,
"option.tensor_parallel_degree": 2,
"option.n_positions": 512,
"option.rolling_batch": 'disable',
"option.dtype": "fp16",
"option.model_loading_timeout": 600
},
Expand All @@ -90,6 +91,7 @@
"batch_size": 4,
"option.tensor_parallel_degree": 2,
"option.n_positions": 512,
"option.rolling_batch": 'disable',
"option.dtype": "fp16",
"option.model_loading_timeout": 600,
"option.quantize": "static_int8"
Expand All @@ -99,6 +101,7 @@
"batch_size": 4,
"option.tensor_parallel_degree": 4,
"option.n_positions": 512,
"option.rolling_batch": 'disable',
"option.dtype": "fp16",
"option.model_loading_timeout": 600
},
Expand All @@ -107,6 +110,7 @@
"batch_size": 4,
"option.tensor_parallel_degree": 8,
"option.n_positions": 512,
"option.rolling_batch": 'disable',
"option.dtype": "fp32",
"option.model_loading_timeout": 2400
},
Expand All @@ -115,6 +119,7 @@
"batch_size": 4,
"option.tensor_parallel_degree": 2,
"option.n_positions": 512,
"option.rolling_batch": 'disable',
"option.dtype": "fp16",
"option.model_loading_timeout": 900
},
Expand All @@ -123,13 +128,15 @@
"batch_size": 4,
"option.tensor_parallel_degree": 4,
"option.n_positions": 256,
"option.rolling_batch": 'disable',
"option.dtype": "fp16",
"option.model_loading_timeout": 1200
},
"mixtral-8x7b": {
"option.model_id": "s3://djl-llm/mixtral-8x7b/",
"option.tensor_parallel_degree": 8,
"option.n_positions": 1024,
"option.rolling_batch": 'disable',
"batch_size": 4,
"option.model_loading_timeout": 3600,
},
Expand All @@ -138,6 +145,7 @@
"batch_size": 2,
"option.tensor_parallel_degree": 4,
"option.n_positions": 512,
"option.rolling_batch": 'disable',
"option.dtype": "fp16",
"option.model_loading_timeout": 600,
"option.enable_streaming": True,
Expand Down Expand Up @@ -180,7 +188,6 @@
"option.tensor_parallel_degree": 4,
"option.n_positions": 512,
"option.max_rolling_batch_size": 4,
"option.rolling_batch": 'auto',
"option.model_loading_timeout": 2400,
"option.load_split_model": True,
"option.output_formatter": "jsonlines"
Expand All @@ -206,7 +213,6 @@
},
"mistral-7b-rb": {
"option.model_id": "s3://djl-llm/mistral-7b-instruct-v02/",
"option.rolling_batch": "auto",
"option.max_rolling_batch_size": 4,
"option.tensor_parallel_degree": 4,
"option.n_positions": 1024,
Expand All @@ -217,7 +223,6 @@
"option.speculative_draft_model": "s3://djl-llm/llama-2-tiny/",
"option.speculative_length": 7,
"option.tensor_parallel_degree": 12,
"option.rolling_batch": "auto",
"option.max_rolling_batch_size": 1,
"option.model_loading_timeout": 3600,
"option.output_formatter": "jsonlines"
Expand All @@ -231,7 +236,6 @@
"s3://djl-llm/inf2-compiled-graphs/llama-2-tiny/",
"option.speculative_length": 4,
"option.tensor_parallel_degree": 12,
"option.rolling_batch": "auto",
"option.max_rolling_batch_size": 1,
"option.model_loading_timeout": 3600,
"option.output_formatter": "jsonlines"
Expand All @@ -241,7 +245,6 @@
"option.tensor_parallel_degree": 2,
"option.n_positions": 1024,
"option.max_rolling_batch_size": 4,
"option.rolling_batch": 'auto',
"option.model_loading_timeout": 1200,
},
"tiny-llama-rb-aot-quant": {
Expand All @@ -250,7 +253,6 @@
"option.tensor_parallel_degree": 2,
"option.n_positions": 1024,
"option.max_rolling_batch_size": 4,
"option.rolling_batch": 'auto',
"option.model_loading_timeout": 1200,
}
}
Expand Down
1 change: 1 addition & 0 deletions wlm/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies {
testImplementation(libs.testng) {
exclude(group = "junit", module = "junit")
}
testImplementation("org.mockito:mockito-core:5.11.0")

testRuntimeOnly("ai.djl:model-zoo")
testRuntimeOnly("ai.djl.pytorch:pytorch-engine")
Expand Down
20 changes: 17 additions & 3 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.serving.wlm;

import ai.djl.util.Ec2Utils;
import ai.djl.util.NeuronUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

Expand Down Expand Up @@ -84,6 +85,7 @@ static void configure(Properties lmiProperties, LmiUtils.HuggingFaceModelConfig
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
setRollingBatch(lmiProperties, modelConfig, features);
setMpiMode(lmiProperties);
setHeuristicDefaults(lmiProperties, modelConfig);
setTensorParallelDegree(lmiProperties);
setPipelineParallelDegree(lmiProperties);
setRollingBatchSize(lmiProperties);
Expand All @@ -99,9 +101,7 @@ private static void setRollingBatch(
return;
}

String defaultRollingBatch = isTnxEnabled(features) ? "disable" : "auto";
String rollingBatch =
lmiProperties.getProperty("option.rolling_batch", defaultRollingBatch);
String rollingBatch = lmiProperties.getProperty("option.rolling_batch", "auto");
String modelType = modelConfig.getModelType();
if (!"auto".equals(rollingBatch)) {
return;
Expand Down Expand Up @@ -142,6 +142,11 @@ private static void setTensorParallelDegree(Properties lmiProperties) {
int numGpus = CudaUtils.getGpuCount();
if (numGpus > 0) {
tpDegree = String.valueOf(numGpus);
} else if (NeuronUtils.hasNeuron()) {
int numAccelerators = NeuronUtils.getNeuronCores();
if (numAccelerators > 0) {
tpDegree = String.valueOf(numAccelerators);
}
} else {
tpDegree = null;
}
Expand Down Expand Up @@ -196,6 +201,15 @@ private static boolean isTnxEnabled(String features) {
return features != null && features.contains("tnx");
}

private static void setHeuristicDefaults(
Properties lmiProperties, LmiUtils.HuggingFaceModelConfig modelConfig) {
if (NeuronUtils.hasNeuron() && isTextGenerationModel(modelConfig)) {
// Set default values for Neuron text generation models
NeuronSmartDefaultUtils smartDefaultUtils = new NeuronSmartDefaultUtils();
smartDefaultUtils.applySmartDefaults(lmiProperties, modelConfig);
}
}

private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) {
for (String arch : modelConfig.getArchitectures()) {
boolean isTextGenerationModel =
Expand Down
112 changes: 112 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,27 @@ static final class HuggingFaceModelConfig {
@SerializedName("_diffusers_version")
private String diffusersVersion;

@SerializedName("hidden_size")
private int hiddenSize;

@SerializedName("intermediate_size")
private int intermediateSize;

@SerializedName("max_position_embeddings")
private int maxPositionEmbeddings;

@SerializedName("num_attention_heads")
private int numAttentionHeads;

@SerializedName("num_hidden_layers")
private int numHiddenLayers;

@SerializedName("num_key_value_heads")
private int numKeyValueHeads;

@SerializedName("vocab_size")
private int vocabSize;

private Set<String> allArchitectures;

public String getModelType() {
Expand All @@ -523,6 +544,97 @@ public Set<String> getArchitectures() {
return allArchitectures;
}

/**
* Returns the default value for the n_positions model configuration. For models that do not
* have a pre-defined value for n_positions, this function returns the minimum of
* max_position_embeddings and 4096. If both max_position_embeddings and 4096 are not
* available, this function returns 0.
*
* @return The default value for n_positions.
*/
public Integer getDefaultNPositions() {
try {
return Math.min(maxPositionEmbeddings, 4096);
} catch (Exception e) {
return 0;
}
}

/**
* Calculates the number of parameters in a model that is similar to LLaMA. This function
* takes into account the hidden size, intermediate size, maximum position embeddings,
* number of hidden layers, vocabulary size, and number of attention heads and key-value
* heads to calculate the total parameter count.
*
* @return The total parameter count for the model.
*/
private long getLlamaLikeParameterCount() {
long headDim = (long) numAttentionHeads * numKeyValueHeads;
long embeddings = (long) vocabSize * hiddenSize;
long qkvProjection = headDim * hiddenSize * numKeyValueHeads * 3;
long oProjection = (long) hiddenSize * hiddenSize;
long gateProjection = (long) hiddenSize * intermediateSize * 3;
return embeddings
+ numHiddenLayers
* (qkvProjection
+ oProjection
+ gateProjection
+ hiddenSize
+ hiddenSize)
+ hiddenSize
+ embeddings;
}

/**
* Calculates the default parameter count for a model (GPT-2-like).
*
* <p>This function takes into account the hidden size, maximum position embeddings, number
* of hidden layers, vocabulary size, and number of attention heads to calculate the total
* parameter count.
*
* @return The total parameter count for the model.
*/
private long getDefaultParameterCount() {
long embeddingLayerTotal = (long) (vocabSize + maxPositionEmbeddings) * hiddenSize;
long attentionTotal = (long) 4 * (hiddenSize * hiddenSize);
long feedForwardTotal = (long) 8 * (hiddenSize * hiddenSize);
long layerNormTotal = (long) 4 * hiddenSize;
long transformerBlockTotal =
(attentionTotal + feedForwardTotal + layerNormTotal) * numHiddenLayers;
long finalLayerTotal = (long) hiddenSize * vocabSize;
return embeddingLayerTotal + transformerBlockTotal + finalLayerTotal;
}

/**
* Calculates the total parameter count for the model.
*
* @return The total parameter count for the model.
*/
public long getModelParameters() {
try {
if ("llama".equals(modelType) || "mistral".equals(modelType)) {
return getLlamaLikeParameterCount();
}
return getDefaultParameterCount();
} catch (Exception e) {
return 0L;
}
}

/**
* Returns the memory required to store a single batch of sequence data.
*
* <p>The memory required is calculated as the product of the sequence length, hidden size,
* number of hidden layers, and weight in bytes.
*
* @param sequenceLength The length in tokens of the sequence.
* @param weightBytes The weight in bytes.
* @return The memory required to store a single batch of sequence data.
*/
public long getApproxMemoryForSingleSequence(int sequenceLength, int weightBytes) {
return (long) sequenceLength * hiddenSize * numHiddenLayers * weightBytes;
}

private void determineAllArchitectures() {
allArchitectures = new HashSet<>();
if (configArchitectures != null) {
Expand Down
Loading

0 comments on commit 6ac4dc5

Please sign in to comment.