From 2f49e3084f49e5818a7849c8ef56ec011e09421e Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Mon, 3 Jul 2023 10:52:24 -0700 Subject: [PATCH] Fix the FP16 clamp issue. (#457) * Fix the FP16 clamp issue. * Fix clamp (cached graph nodes were previously replaced with the cast version) --------- Co-authored-by: Denis Vieriu --- .../native/mps/operations/TensorCompare.mm | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 4f8def1cbb777..70532577e1869 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -14,24 +14,40 @@ MPSGraphTensor *minTensor = nil, *maxTensor = nil; }; -void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor) +void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor, + const Tensor& min_tensor, const Tensor& max_tensor) { + auto input_dtype = input_tensor.scalar_type(); + auto min_dtype = input_dtype; + auto max_dtype = input_dtype; + if (cachedGraph->minTensor) + min_dtype = min_tensor.scalar_type(); + if (cachedGraph->maxTensor) + max_dtype = max_tensor.scalar_type(); MPSGraph *mpsGraph = cachedGraph->graph(); cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); + MPSGraphTensor* minTensor = cachedGraph->minTensor; + MPSGraphTensor* maxTensor = cachedGraph->maxTensor; + if (input_dtype != min_dtype) { + minTensor = castMPSTensor(mpsGraph, cachedGraph->minTensor, input_dtype); + } + if (input_dtype != max_dtype) { + maxTensor = castMPSTensor(mpsGraph, cachedGraph->maxTensor, input_dtype); + } if (cachedGraph->minTensor && cachedGraph->maxTensor) { cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor - minValueTensor:cachedGraph->minTensor - maxValueTensor:cachedGraph->maxTensor + minValueTensor:minTensor + maxValueTensor:maxTensor name:nil]; } else if (cachedGraph->maxTensor) { cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor - secondaryTensor:cachedGraph->maxTensor + secondaryTensor:maxTensor name:nil]; } else if (cachedGraph->minTensor) { cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor - secondaryTensor:cachedGraph->minTensor + secondaryTensor:minTensor name:nil]; } } @@ -147,18 +163,19 @@ void clamp_tensor_out_mps(const Tensor& input_t, MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - if (has_min) + if (has_min) { newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor); - if (has_max) - newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor); + } + if (has_max) { + newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor);; + } - clamp_mps_graph(newCachedGraph, input_t); + clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor, max_opt_tensor); } return newCachedGraph; }); cachedGraph = static_cast(tmpCachedGraph); } - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t); @@ -228,7 +245,7 @@ void clamp_scalar_out_mps(const Tensor& input_t, shape:(mps::getMPSShape(input_t)) dataType:(mps::getMPSScalarType(input_t.scalar_type())) ]; - clamp_mps_graph(newCachedGraph, input_t); + clamp_mps_graph(newCachedGraph, input_t, input_t, input_t); } return newCachedGraph; });