-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SIGSEGV fatal error when re-assigning a tensor that has been previously split #39
Comments
Quick update: Adding I also wonder what are the implications on calling --- a/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
+++ b/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
@@ -1014,7 +1014,7 @@ private[torch] trait IndexingSlicingJoiningOps {
case i: Int => torchNative.split(input.native, i.toLong, dim.toLong)
case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
}
- (0L until result.size()).map(i => Tensor(result.get(i)))
+ (0L until result.size()).map(i => Tensor(result.get(i)).clone())
}
/** Returns a tensor with all specified dimensions of `input` of size 1 removed. |
It might have to do something with the fact that split returns a view. https://pytorch.org/docs/stable/generated/torch.split.html:
It's just a guess for now, but it would explain why I ran into this while implementing tensor printing, which needs to convert tensor values to buffers, and crashed on non-contiguous values, as the memory layout of views can sometimes be non-contiguous. storch/core/src/main/scala/torch/Tensor.scala Line 558 in 19abf3e
In this case the view should be contiguous, so it's not exactly the same issue, but it could still be related to being a view. Interestingly, your example works on my machine (I tried in ammonite too): object Split extends App {
val data = torch.arange(0L, 1_000_000L)
val Seq(a, b) = torch.split(data, 600_000)
println(a)
println(b)
val x = a
println(x)
} |
Perhaps we also need to understand why the Python impl calls def split(self, split_size, dim=0):
r"""See :func:`torch.split`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.split, (self,), self, split_size, dim=dim
)
if isinstance(split_size, Tensor):
try:
split_size = int(split_size)
except ValueError:
pass
if isinstance(split_size, (int, torch.SymInt)):
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
else:
return torch._VF.split_with_sizes(self, split_size, dim) |
Hello!
I stumbled upon a fatal error while using
torch.split
+ reassignment of a tensor - not sure how to even start debugging this, but I am documenting it here in case someone knows how to investigate this further.Here is a way to replicate the error.
Trying with other variations.. any operation done after any of the portions of
tensor.split
causes this panic, evena + 1
The text was updated successfully, but these errors were encountered: