Skip to content

Commit

Permalink
Fix a bug in format_float kernel (#1676)
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven authored Jan 5, 2024
1 parent 1c34077 commit e3fe415
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
44 changes: 21 additions & 23 deletions src/main/cpp/src/ftos_converter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1202,11 +1202,8 @@ __device__ inline T round_half_even(T const input, int const olength, int const
{
// "round" a integer to digits digits, with the half-even rounding mode.
if (digits > olength) {
T num = input;
for (int i = 0; i < digits - olength; i++) {
num *= 10;
}
return num;
// trailing zeros will be handled later
return input;
}
T div = POW10_TABLE[olength - digits];
T mod = input % div;
Expand All @@ -1215,10 +1212,10 @@ __device__ inline T round_half_even(T const input, int const olength, int const
return num;
}

__device__ inline int to_formated_chars(floating_decimal_64 const v,
bool const sign,
char* const result,
int digits)
__device__ inline int to_formated_double_chars(floating_decimal_64 const v,
bool const sign,
char* const result,
int digits)
{
int index = 0;
if (sign) { result[index++] = '-'; }
Expand Down Expand Up @@ -1289,9 +1286,10 @@ __device__ inline int to_formated_chars(floating_decimal_64 const v,
result[index++] = '0';
}
} else {
// 0 <= exp < olength - 1
uint32_t temp_d = digits, tailing_zero = 0;
if (exp + digits > olength) {
temp_d = olength - exp;
if (exp + digits + 1 > olength) {
temp_d = olength - exp - 1;
tailing_zero = digits - temp_d;
}
uint64_t rounded_output = round_half_even(output, olength, exp + temp_d + 1);
Expand Down Expand Up @@ -1329,7 +1327,7 @@ __device__ inline int to_formated_chars(floating_decimal_64 const v,
return index;
}

__device__ inline int format_float_size(floating_decimal_64 const v, bool const sign, int digits)
__device__ inline int format_double_size(floating_decimal_64 const v, bool const sign, int digits)
{
int index = 0;
if (sign) { index++; }
Expand All @@ -1342,7 +1340,7 @@ __device__ inline int format_float_size(floating_decimal_64 const v, bool const
index += exp + 1 + exp / 3 + 1 + digits;
} else {
uint32_t temp_d = digits;
if (exp + digits > olength) { temp_d = olength - exp; }
if (exp + digits + 1 > olength) { temp_d = olength - exp - 1; }
uint64_t rounded_output = round_half_even(output, olength, exp + temp_d + 1);
uint64_t pow10 = POW10_TABLE[temp_d];
uint64_t integer = rounded_output / pow10;
Expand All @@ -1353,10 +1351,10 @@ __device__ inline int format_float_size(floating_decimal_64 const v, bool const
return index;
}

__device__ inline int to_formated_chars(floating_decimal_32 const v,
bool const sign,
char* const result,
int digits)
__device__ inline int to_formated_float_chars(floating_decimal_32 const v,
bool const sign,
char* const result,
int digits)
{
int index = 0;
if (sign) { result[index++] = '-'; }
Expand Down Expand Up @@ -1428,8 +1426,8 @@ __device__ inline int to_formated_chars(floating_decimal_32 const v,
}
} else {
uint32_t temp_d = digits, tailing_zero = 0;
if (exp + digits > olength) {
temp_d = olength - exp;
if (exp + digits + 1 > olength) {
temp_d = olength - exp - 1;
tailing_zero = digits - temp_d;
}
uint32_t rounded_output = round_half_even(output, olength, exp + temp_d + 1);
Expand Down Expand Up @@ -1480,7 +1478,7 @@ __device__ inline int format_float_size(floating_decimal_32 const v, bool const
index += exp + 1 + exp / 3 + 1 + digits;
} else {
uint32_t temp_d = digits;
if (exp + digits > olength) { temp_d = olength - exp; }
if (exp + digits + 1 > olength) { temp_d = olength - exp - 1; }
uint64_t rounded_output = round_half_even(output, olength, exp + temp_d + 1);
uint64_t pow10 = POW10_TABLE[temp_d];
uint64_t integer = rounded_output / pow10;
Expand Down Expand Up @@ -1539,7 +1537,7 @@ __device__ inline int compute_format_float_size(double value, int digits, bool i
} else {
floating_decimal_64 v = d2d(value, sign, special);
if (special) { return special_format_str_size(sign, v.exponent, v.mantissa, digits); }
return format_float_size(v, sign, digits);
return format_double_size(v, sign, digits);
}
}

Expand All @@ -1549,11 +1547,11 @@ __device__ inline int format_float(double value, int digits, bool is_float, char
if (is_float) {
floating_decimal_32 v = f2d(value, sign, special);
if (special) { return copy_format_special_str(output, sign, v.exponent, v.mantissa, digits); }
return to_formated_chars(v, sign, output, digits);
return to_formated_float_chars(v, sign, output, digits);
} else {
floating_decimal_64 v = d2d(value, sign, special);
if (special) { return copy_format_special_str(output, sign, v.exponent, v.mantissa, digits); }
return to_formated_chars(v, sign, output, digits);
return to_formated_double_chars(v, sign, output, digits);
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/main/cpp/tests/format_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ TEST_F(FormatFloatTests, FormatFloats64)
-4.0d,
std::numeric_limits<double>::quiet_NaN(),
839542223232.794248339d,
3232.794248339d,
11234000000.0d,
-0.0d};

auto const expected = cudf::test::strings_column_wrapper{"100.00000",
Expand All @@ -80,9 +82,11 @@ TEST_F(FormatFloatTests, FormatFloats64)
"-4.00000",
"\xEF\xBF\xBD",
"839,542,223,232.79420",
"3,232.79425",
"11,234,000,000.00000",
"-0.00000"};

auto results = spark_rapids_jni::format_float(floats, 5, cudf::get_default_stream());

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected, verbosity);
}
}

0 comments on commit e3fe415

Please sign in to comment.