Skip to content

Commit

Permalink
Fix the FP16 clamp issue. (#457)
Browse files Browse the repository at this point in the history
* Fix the FP16 clamp issue.

* Fix clamp (cached graph nodes were previously replaced with the cast version)

---------

Co-authored-by: Denis Vieriu <[email protected]>
  • Loading branch information
kulinseth and DenisVieriu97 authored Jul 3, 2023
1 parent 302584f commit 2f49e30
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions aten/src/ATen/native/mps/operations/TensorCompare.mm
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,40 @@
MPSGraphTensor *minTensor = nil, *maxTensor = nil;
};

void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor)
void clamp_mps_graph(CachedGraph* cachedGraph, const Tensor& input_tensor,
const Tensor& min_tensor, const Tensor& max_tensor)
{
auto input_dtype = input_tensor.scalar_type();
auto min_dtype = input_dtype;
auto max_dtype = input_dtype;
if (cachedGraph->minTensor)
min_dtype = min_tensor.scalar_type();
if (cachedGraph->maxTensor)
max_dtype = max_tensor.scalar_type();
MPSGraph *mpsGraph = cachedGraph->graph();

cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);

MPSGraphTensor* minTensor = cachedGraph->minTensor;
MPSGraphTensor* maxTensor = cachedGraph->maxTensor;
if (input_dtype != min_dtype) {
minTensor = castMPSTensor(mpsGraph, cachedGraph->minTensor, input_dtype);
}
if (input_dtype != max_dtype) {
maxTensor = castMPSTensor(mpsGraph, cachedGraph->maxTensor, input_dtype);
}
if (cachedGraph->minTensor && cachedGraph->maxTensor) {
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
minValueTensor:cachedGraph->minTensor
maxValueTensor:cachedGraph->maxTensor
minValueTensor:minTensor
maxValueTensor:maxTensor
name:nil];
} else if (cachedGraph->maxTensor) {
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:cachedGraph->maxTensor
secondaryTensor:maxTensor
name:nil];
} else if (cachedGraph->minTensor) {
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:cachedGraph->minTensor
secondaryTensor:minTensor
name:nil];
}
}
Expand Down Expand Up @@ -147,18 +163,19 @@ void clamp_tensor_out_mps(const Tensor& input_t,
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

if (has_min)
if (has_min) {
newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor);
if (has_max)
newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor);
}
if (has_max) {
newCachedGraph->maxTensor = mpsGraphRankedPlaceHolder(mpsGraph, max_opt_tensor);;
}

clamp_mps_graph(newCachedGraph, input_t);
clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor, max_opt_tensor);
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);

Expand Down Expand Up @@ -228,7 +245,7 @@ void clamp_scalar_out_mps(const Tensor& input_t,
shape:(mps::getMPSShape(input_t))
dataType:(mps::getMPSScalarType(input_t.scalar_type())) ];

clamp_mps_graph(newCachedGraph, input_t);
clamp_mps_graph(newCachedGraph, input_t, input_t, input_t);
}
return newCachedGraph;
});
Expand Down

0 comments on commit 2f49e30

Please sign in to comment.