Skip to content

Segfault from Dataset.prepare() on MPS device #2504

@Ubehebe

Description

@Ubehebe

Description

I am trying to combine the MNIST tutorial with the directions in #2037 to train a model on the MPS pytorch backend. I have gotten the MPSTest in that PR to pass on my M2 machine: it successfully creates an NDArray on the MPS device.

Expected Behavior

I expected the MNIST dataset to prepare successfully on the MPS device, so that I could create a model from it. (Please let me know if this is unreasonable, or if I have to prepare the dataset on the CPU and then use some other API to load it onto MPS.)

Error Message

#
# A fatal error has been detected by the Java Runtime Environment:
#
#  SIGSEGV (0xb) at pc=0x0000000197d43f00, pid=94738, tid=8707
#
# JRE version: OpenJDK Runtime Environment Zulu17.32+13-CA (17.0.2+8) (build 17.0.2+8-LTS)
# Java VM: OpenJDK 64-Bit Server VM Zulu17.32+13-CA (17.0.2+8-LTS, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, bsd-aarch64)
# Problematic frame:
# C  [libobjc.A.dylib+0x7f00]  objc_retain+0x10
#
# No core dump will be written. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
#
# If you would like to submit a bug report, please visit:
#   http://www.azul.com/support/
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#

---------------  S U M M A R Y ------------

Command Line: ml.djl.DjlSampleKt

Host: "Mac14,5" arm64 1 MHz, 12 cores, 64G, Darwin 22.4.0, macOS 13.3 (22E252)
Time: Wed Apr  5 07:23:05 2023 PDT elapsed time: 0.326062 seconds (0d 0h 0m 0s)

---------------  T H R E A D  ---------------

Current thread (0x000000014500dc00):  JavaThread "main" [_thread_in_native, id=8707, stack(0x000000016b62c000,0x000000016b82f000)]

Stack: [0x000000016b62c000,0x000000016b82f000],  sp=0x000000016b82d7a0,  free space=2053k
Native frames: (J=compiled Java code, j=interpreted, Vv=VM code, C=native code)
C  [libobjc.A.dylib+0x7f00]  objc_retain+0x10
C  [libtorch_cpu.dylib+0x4458830]  at::native::mps::copy_cast_mps(at::Tensor&, at::Tensor const&, id<MTLBuffer>, objc_object<MTLBuffer>, bool)+0x2ec
C  [libtorch_cpu.dylib+0x445a958]  at::native::mps::mps_copy_(at::Tensor&, at::Tensor const&, bool)+0x1e10
C  [libtorch_cpu.dylib+0x49b6d8]  at::native::copy_impl(at::Tensor&, at::Tensor const&, bool)+0x5cc
C  [libtorch_cpu.dylib+0x49b04c]  at::native::copy_(at::Tensor&, at::Tensor const&, bool)+0x64
C  [libtorch_cpu.dylib+0x10d3f14]  at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool)+0x120
C  [libtorch_cpu.dylib+0x7c0f0c]  at::native::_to_copy(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbd8
C  [libtorch_cpu.dylib+0xc64120]  at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbc
C  [libtorch_cpu.dylib+0xc64120]  at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbc
C  [libtorch_cpu.dylib+0x280348c]  c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>), &torch::autograd::VariableType::(anonymous namespace)::_to_copy(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0x448
C  [libtorch_cpu.dylib+0xc63de8]  at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0x154
C  [libtorch_cpu.dylib+0xe1bed0]  at::_ops::to_device::call(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>)+0x140
C  [0.21.0-libdjl_torch.dylib+0x78974]  at::Tensor::to(c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) const+0x8c
C  [0.21.0-libdjl_torch.dylib+0x786c4]  Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTo+0xb4
j  ai.djl.pytorch.jni.PyTorchLibrary.torchTo(JI[I)J+0
j  ai.djl.pytorch.jni.JniUtils.to(Lai/djl/pytorch/engine/PtNDArray;Lai/djl/ndarray/types/DataType;Lai/djl/Device;)Lai/djl/pytorch/engine/PtNDArray;+61
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/pytorch/engine/PtNDArray;+23
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/ndarray/NDArray;+3
j  ai.djl.basicdataset.cv.classification.Mnist.readLabel(Lai/djl/repository/Artifact$Item;)Lai/djl/ndarray/NDArray;+88
j  ai.djl.basicdataset.cv.classification.Mnist.prepare(Lai/djl/util/Progress;)V+146
j  ai.djl.training.dataset.Dataset.prepare()V+2
j  ml.djl.DjlSampleKt.main([Ljava/lang/String;)V+36
v  ~StubRoutines::call_stub
V  [libjvm.dylib+0x46b270]  JavaCalls::call_helper(JavaValue*, methodHandle const&, JavaCallArguments*, JavaThread*)+0x38c
V  [libjvm.dylib+0x4cfa64]  jni_invoke_static(JNIEnv_*, JavaValue*, _jobject*, JNICallType, _jmethodID*, JNI_ArgumentPusher*, JavaThread*)+0x12c
V  [libjvm.dylib+0x4d30f8]  jni_CallStaticVoidMethod+0x130
C  [libjli.dylib+0x5378]  JavaMain+0x9d4
C  [libjli.dylib+0x76e8]  ThreadJavaMain+0xc
C  [libsystem_pthread.dylib+0x6fa8]  _pthread_start+0x94

Java frames: (J=compiled Java code, j=interpreted, Vv=VM code)
j  ai.djl.pytorch.jni.PyTorchLibrary.torchTo(JI[I)J+0
j  ai.djl.pytorch.jni.JniUtils.to(Lai/djl/pytorch/engine/PtNDArray;Lai/djl/ndarray/types/DataType;Lai/djl/Device;)Lai/djl/pytorch/engine/PtNDArray;+61
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/pytorch/engine/PtNDArray;+23
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/ndarray/NDArray;+3
j  ai.djl.basicdataset.cv.classification.Mnist.readLabel(Lai/djl/repository/Artifact$Item;)Lai/djl/ndarray/NDArray;+88
j  ai.djl.basicdataset.cv.classification.Mnist.prepare(Lai/djl/util/Progress;)V+146
j  ai.djl.training.dataset.Dataset.prepare()V+2
j  ml.djl.DjlSampleKt.main([Ljava/lang/String;)V+36
v  ~StubRoutines::call_stub

The top of the JVM stack at the moment of the segfault is this line, where the dataset is casting the NDArray to float32:

try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
    array.set(buf);
    return array.toType(DataType.FLOAT32, false);
}

The top of the native stack is objc_retain, followed by copy_cast_mps. (I can provide more details from the JVM crash report if that would be valuable.)

How to Reproduce?

Here is a minimal repro. It's in Kotlin, but I'm happy to rewrite in Java if you prefer:

import ai.djl.Device
import ai.djl.basicdataset.cv.classification.Mnist
import ai.djl.ndarray.NDManager

fun main(args: Array<String>) {
  val device = Device.fromName("mps")
  val manager = NDManager.newBaseManager(device)
  Mnist.builder()
      .optManager(manager)
      .setSampling(32 /* batchSize */, true /* random */)
      .build()
      .prepare() // segfault!
}

Steps to reproduce

Run the program above, with the pytorch-engine, pytorch-jni, and pytorch-native-cpu-osx-aarch64 jars on the runtime classpath.

What have you tried to solve it?

Nothing, beyond isolating this repro. (I'm new to ML in general, but have a lot of JVM experience.)

Environment Info

I'm using DJL v0.21.0 and pytorch-jni v1.13.1. System information:

$ uname -mrsv
Darwin 22.4.0 Darwin Kernel Version 22.4.0: Mon Mar  6 20:59:58 PST 2023; root:xnu-8796.101.5~3/RELEASE_ARM64_T6020 arm64

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions