Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SIGSEGV fatal error when re-assigning a tensor that has been previously split #39

Open
davoclavo opened this issue Jul 16, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@davoclavo
Copy link
Contributor

davoclavo commented Jul 16, 2023

Hello!

I stumbled upon a fatal error while using torch.split + reassignment of a tensor - not sure how to even start debugging this, but I am documenting it here in case someone knows how to investigate this further.

Here is a way to replicate the error.

@ val data = torch.arange(0L, 1_000_000L)
data: Tensor[Int64] = tensor dtype=int64, shape=[1000000], device=CPU
[0, 1, 2, ..., 999997, 999998, 999999]

@ val Seq(a,b) = torch.split(data, 600_000)

@ a
res2: Tensor[Int64] = tensor dtype=int64, shape=[600000], device=CPU
[0, 1, 2, ..., 599997, 599998, 599999]

@ b
res3: Tensor[Int64] = tensor dtype=int64, shape=[400000], device=CPU
[600000, 600001, 600002, ..., 999997, 999998, 999999]

@ val x = a
#
# A fatal error has been detected by the Java Runtime Environment:
#
#  SIGSEGV (0xb) at pc=0x00000001b984535a, pid=19939, tid=9731
#
# JRE version: OpenJDK Runtime Environment Zulu19.30+11-CA (19.0.1+10) (build 19.0.1+10)
# Java VM: OpenJDK 64-Bit Server VM Zulu19.30+11-CA (19.0.1+10, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, bsd-amd64)
# Problematic frame:
# C  [libjnitorch.dylib+0x4ff35a]  Java_org_bytedeco_pytorch_TensorBase_sizes+0x5a
#
# No core dump will be written. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
#
# An error report file with more information is saved as:
# /experiments/hs_err_pid19939.log
#
# If you would like to submit a bug report, please visit:
#   http://www.azul.com/support/
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#

Trying with other variations.. any operation done after any of the portions of tensor.split causes this panic, even a + 1

@davoclavo
Copy link
Contributor Author

Quick update:

Adding .clone() seems to fix the issue, I wonder what other operations might require the same fix. Will do a PR soon.

I also wonder what are the implications on calling .clone() in terms of memory usage or any other computing factor.

--- a/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
+++ b/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
@@ -1014,7 +1014,7 @@ private[torch] trait IndexingSlicingJoiningOps {
         case i: Int      => torchNative.split(input.native, i.toLong, dim.toLong)
         case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
       }
-    (0L until result.size()).map(i => Tensor(result.get(i)))
+    (0L until result.size()).map(i => Tensor(result.get(i)).clone())
   }

   /** Returns a tensor with all specified dimensions of `input` of size 1 removed.

@sbrunk sbrunk added the bug Something isn't working label Jul 16, 2023
@sbrunk
Copy link
Owner

sbrunk commented Jul 16, 2023

It might have to do something with the fact that split returns a view.

https://pytorch.org/docs/stable/generated/torch.split.html:

Splits the tensor into chunks. Each chunk is a view of the original tensor.

It's just a guess for now, but it would explain why clone() makes a difference.

I ran into this while implementing tensor printing, which needs to convert tensor values to buffers, and crashed on non-contiguous values, as the memory layout of views can sometimes be non-contiguous.

val buf = tensor.native.contiguous.createBuffer[B]

In this case the view should be contiguous, so it's not exactly the same issue, but it could still be related to being a view.

Interestingly, your example works on my machine (I tried in ammonite too):

object Split extends App {
  val data = torch.arange(0L, 1_000_000L)
  val Seq(a, b) = torch.split(data, 600_000)
  println(a)
  println(b)
  val x = a
  println(x)
}

@sbrunk
Copy link
Owner

sbrunk commented Jul 16, 2023

Perhaps we also need to understand why the Python impl calls split_with_sizes in certain cases. We might need to do something similar.

https://github.com/pytorch/pytorch/blob/9adfaf880784ec0cf5f085fc3f282cf53650050f/torch/_tensor.py#L770

    def split(self, split_size, dim=0):
        r"""See :func:`torch.split`"""
        if has_torch_function_unary(self):
            return handle_torch_function(
                Tensor.split, (self,), self, split_size, dim=dim
            )
        if isinstance(split_size, Tensor):
            try:
                split_size = int(split_size)
            except ValueError:
                pass


        if isinstance(split_size, (int, torch.SymInt)):
            return torch._VF.split(self, split_size, dim)  # type: ignore[attr-defined]
        else:
            return torch._VF.split_with_sizes(self, split_size, dim)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants