diff --git a/src/ndarray.jl b/src/ndarray.jl index d62a72c39684..33b94c05e559 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -669,13 +669,21 @@ Matrix (2D NDArray) multiplication. Elementwise divide a scalar or an `NDArray` of the same shape from `dst`. Inplace updating. """ -function div_from!(dst::NDArray{T}, arg::NDArrayOrReal) where {T} +function div_from!(dst::NDArray, arg::NDArrayOrReal) @assert dst.writable if isa(arg, Real) - _div_scalar(dst, scalar = convert(T, arg), out = dst) + _div_scalar(dst, scalar = arg, out = dst) else _div(dst, arg, out = dst) end + dst +end + +function div_from!(dst::NDArray{T}, arg::Real) where {T<:Integer} + @assert dst.writable + @assert(round(T, arg) != zero(T), "Integer divided by zero") + _div_scalar(dst, scalar = arg, out = dst) + dst end """ diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 7c74536a8a18..8bd87c65ec9e 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -396,6 +396,19 @@ function test_div() t6, a6 = rand_tensors(Float16, dims) scalar_large = 1e4 @test t6 ./ scalar_large ≈ copy(a6 ./ scalar_large) + + info("NDArray::div::scalar::type convert") + let x = mx.NDArray([1, 2, 3]) + y = x ./ 1.1 + @test eltype(y) == Int + @test copy(y) == [1, 2, 3] + + y = x ./ 2 + @test eltype(y) == Int # this differs from julia + @test copy(y) == [0, 1, 1] + + @test_throws AssertionError x ./ 0.5 + end end