Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invoke a DeviceLostClosure immediately if set on an invalid device. #5358

Merged
merged 8 commits into from
Mar 21, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ By @cwfitzgerald in [#5325](https://github.com/gfx-rs/wgpu/pull/5325).
- Fix behavior of integer `clamp` when `min` argument > `max` argument. By @cwfitzgerald in [#5300](https://github.com/gfx-rs/wgpu/pull/5300).
- Fix missing validation for `Device::clear_buffer` where `offset + size buffer.size` was not checked when `size` was omitted. By @ErichDonGubler in [#5282](https://github.com/gfx-rs/wgpu/pull/5282).
- Fix linking when targeting android. By @ashdnazg in [#5326](https://github.com/gfx-rs/wgpu/pull/5326).
- Failing to set the device lost closure will call the closure before returning. By @bradwerth in [#5358](https://github.com/gfx-rs/wgpu/pull/5358).

#### glsl-in

Expand Down
91 changes: 50 additions & 41 deletions tests/tests/device.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::atomic::AtomicBool;

use wgpu_test::{fail, gpu_test, FailureCase, GpuTestConfiguration, TestParameters};

#[gpu_test]
Expand Down Expand Up @@ -518,12 +520,11 @@ static DEVICE_DESTROY_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::ne
.run_async(|ctx| async move {
// This test checks that when device.destroy is called, the provided
// DeviceLostClosure is called with reason DeviceLostReason::Destroyed.
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());
static WAS_CALLED: AtomicBool = AtomicBool::new(false);

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, _m| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
let callback = Box::new(|reason, _m| {
WAS_CALLED.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::Destroyed),
"Device lost info reason should match DeviceLostReason::Destroyed."
Expand All @@ -542,7 +543,7 @@ static DEVICE_DESTROY_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::ne
.is_queue_empty());

assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
Expand All @@ -554,28 +555,48 @@ static DEVICE_DROP_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
// This test checks that when the device is dropped (such as in a GC),
// the provided DeviceLostClosure is called with reason DeviceLostReason::Unknown.
// Fails on webgl because webgl doesn't implement drop.
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());
static WAS_CALLED: std::sync::atomic::AtomicBool = AtomicBool::new(false);

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, message| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::Dropped),
"Device lost info reason should match DeviceLostReason::Dropped."
);
assert!(
message == "Device dropped.",
"Device lost info message should be \"Device dropped.\"."
);
let callback = Box::new(|reason, message| {
WAS_CALLED.store(true, std::sync::atomic::Ordering::SeqCst);
assert_eq!(reason, wgt::DeviceLostReason::Dropped);
assert_eq!(message, "Device dropped.");
});
ctx.device.set_device_lost_callback(callback);

// Drop the device.
drop(ctx.device);

assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});

#[gpu_test]
static DEVICE_INVALID_THEN_SET_LOST_CALLBACK: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().expect_fail(FailureCase::webgl2()))
.run_sync(|ctx| {
// This test checks that when the device is invalid, a subsequent call
// to set the device lost callback will immediately call the callback.
// Invalidating the device is done via a testing-only method. Fails on
// webgl because webgl doesn't implement make_invalid.

// Make the device invalid.
ctx.device.make_invalid();

static WAS_CALLED: AtomicBool = AtomicBool::new(false);

// Set a LoseDeviceCallback on the device.
let callback = Box::new(|reason, _m| {
WAS_CALLED.store(true, std::sync::atomic::Ordering::SeqCst);
assert_eq!(reason, wgt::DeviceLostReason::DeviceInvalid);
});
ctx.device.set_device_lost_callback(callback);

assert!(
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
Expand All @@ -586,16 +607,12 @@ static DEVICE_LOST_REPLACED_CALLBACK: GpuTestConfiguration = GpuTestConfiguratio
.run_sync(|ctx| {
// This test checks that a device_lost_callback is called when it is
// replaced by another callback.
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());
static WAS_CALLED: AtomicBool = AtomicBool::new(false);

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, _m| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::ReplacedCallback),
"Device lost info reason should match DeviceLostReason::ReplacedCallback."
);
let callback = Box::new(|reason, _m| {
WAS_CALLED.store(true, std::sync::atomic::Ordering::SeqCst);
assert_eq!(reason, wgt::DeviceLostReason::ReplacedCallback);
});
ctx.device.set_device_lost_callback(callback);

Expand All @@ -604,7 +621,7 @@ static DEVICE_LOST_REPLACED_CALLBACK: GpuTestConfiguration = GpuTestConfiguratio
ctx.device.set_device_lost_callback(replacement_callback);

assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
Expand All @@ -619,29 +636,21 @@ static DROPPED_GLOBAL_THEN_DEVICE_LOST: GpuTestConfiguration = GpuTestConfigurat
// wgpu without providing a more orderly shutdown. In such a case, the
// device lost callback should be invoked with the message "Device is
// dying."
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());
static WAS_CALLED: AtomicBool = AtomicBool::new(false);

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, message| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::Dropped),
"Device lost info reason should match DeviceLostReason::Dropped."
);
assert!(
message == "Device is dying.",
"Device lost info message is \"{}\" and it should be \"Device is dying.\".",
message
);
let callback = Box::new(|reason, message| {
WAS_CALLED.store(true, std::sync::atomic::Ordering::SeqCst);
assert_eq!(reason, wgt::DeviceLostReason::Dropped);
assert_eq!(message, "Device is dying.");
});
ctx.device.set_device_lost_callback(callback);

// TODO: Drop the Global, somehow.

// Confirm that the callback was invoked.
assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
17 changes: 16 additions & 1 deletion wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,15 @@ impl Global {
}
}

// This is a test-only function to force the device into an
// invalid state by inserting an error value in its place in
// the registry.
pub fn device_make_invalid<A: HalApi>(&self, device_id: DeviceId) {
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
let hub = A::hub(self);
hub.devices
.force_replace_with_error(device_id, "Made invalid.");
}

pub fn device_drop<A: HalApi>(&self, device_id: DeviceId) {
profiling::scope!("Device::drop");
api_log!("Device::drop {device_id:?}");
Expand Down Expand Up @@ -2275,7 +2284,7 @@ impl Global {
) {
let hub = A::hub(self);

if let Ok(device) = hub.devices.get(device_id) {
if let Ok(Some(device)) = hub.devices.try_get(device_id) {
let mut life_tracker = device.lock_life();
if let Some(existing_closure) = life_tracker.device_lost_closure.take() {
// It's important to not hold the lock while calling the closure.
Expand All @@ -2284,6 +2293,12 @@ impl Global {
life_tracker = device.lock_life();
}
life_tracker.device_lost_closure = Some(device_lost_closure);
} else {
// No device? Okay. Just like we have to call any existing closure
// before we drop it, we need to call this closure before we exit
// this function, because there's no device that is ever going to
// call it.
device_lost_closure.call(DeviceLostReason::DeviceInvalid, "".to_string());
}
}

Expand Down
8 changes: 7 additions & 1 deletion wgpu-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7183,7 +7183,7 @@ mod send_sync {
///
/// Corresponds to [WebGPU `GPUDeviceLostReason`](https://gpuweb.github.io/gpuweb/#enumdef-gpudevicelostreason).
#[repr(u8)]
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum DeviceLostReason {
/// Triggered by driver
Unknown = 0,
Expand All @@ -7203,4 +7203,10 @@ pub enum DeviceLostReason {
/// exactly once before it is dropped, which helps with managing the
/// memory owned by the callback.
ReplacedCallback = 3,
/// When setting the callback, but the device is already invalid
///
/// As above, when the callback is provided, wgpu guarantees that it
/// will eventually be called. If the device is already invalid, wgpu
/// will call the callback immediately, with this reason.
DeviceInvalid = 4,
}
5 changes: 5 additions & 0 deletions wgpu/src/backend/webgpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,11 @@ impl crate::context::Context for ContextWebGpu {
create_identified(device_data.0.create_render_bundle_encoder(&mapped_desc))
}

#[doc(hidden)]
fn device_make_invalid(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
// Unimplemented
}

fn device_drop(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {
// Device is dropped automatically
}
Expand Down
11 changes: 7 additions & 4 deletions wgpu/src/backend/wgpu_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1346,14 +1346,17 @@ impl crate::Context for ContextWgpuCore {
Err(e) => panic!("Error in Device::create_render_bundle_encoder: {e}"),
}
}
#[doc(hidden)]
fn device_make_invalid(&self, device: &Self::DeviceId, _device_data: &Self::DeviceData) {
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
wgc::gfx_select!(device => self.0.device_make_invalid(*device));
}
#[cfg_attr(not(any(native, Emscripten)), allow(unused))]
fn device_drop(&self, device: &Self::DeviceId, _device_data: &Self::DeviceData) {
#[cfg(any(native, Emscripten))]
{
match wgc::gfx_select!(device => self.0.device_poll(*device, wgt::Maintain::wait())) {
Ok(_) => {}
Err(err) => self.handle_error_fatal(err, "Device::drop"),
}
// Call device_poll, but don't check for errors. We have to use its
// return value, but we just drop it.
let _ = wgc::gfx_select!(device => self.0.device_poll(*device, wgt::Maintain::wait()));
wgc::gfx_select!(device => self.0.device_drop(*device));
}
}
Expand Down
11 changes: 11 additions & 0 deletions wgpu/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ pub trait Context: Debug + WasmNotSendSync + Sized {
device_data: &Self::DeviceData,
desc: &RenderBundleEncoderDescriptor<'_>,
) -> (Self::RenderBundleEncoderId, Self::RenderBundleEncoderData);
#[doc(hidden)]
fn device_make_invalid(&self, device: &Self::DeviceId, device_data: &Self::DeviceData);
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
fn device_drop(&self, device: &Self::DeviceId, device_data: &Self::DeviceData);
fn device_set_device_lost_callback(
&self,
Expand Down Expand Up @@ -1293,6 +1295,8 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync {
device_data: &crate::Data,
desc: &RenderBundleEncoderDescriptor<'_>,
) -> (ObjectId, Box<crate::Data>);
#[doc(hidden)]
fn device_make_invalid(&self, device: &ObjectId, device_data: &crate::Data);
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
fn device_drop(&self, device: &ObjectId, device_data: &crate::Data);
fn device_set_device_lost_callback(
&self,
Expand Down Expand Up @@ -2350,6 +2354,13 @@ where
(render_bundle_encoder.into(), Box::new(data) as _)
}

#[doc(hidden)]
fn device_make_invalid(&self, device: &ObjectId, device_data: &crate::Data) {
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
let device = <T::DeviceId>::from(*device);
let device_data = downcast_ref(device_data);
Context::device_make_invalid(self, &device, device_data)
}

fn device_drop(&self, device: &ObjectId, device_data: &crate::Data) {
let device = <T::DeviceId>::from(*device);
let device_data = downcast_ref(device_data);
Expand Down
6 changes: 6 additions & 0 deletions wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2703,6 +2703,12 @@ impl Device {
Box::new(callback),
)
}

/// Test-only function to make this device invalid.
#[doc(hidden)]
pub fn make_invalid(&self) {
DynContext::device_make_invalid(&*self.context, &self.id, self.data.as_ref())
}
}

impl Drop for Device {
Expand Down
Loading