Skip to content

Commit

Permalink
[webgpu] Check if runtime support WebGPU before initial a WebGPU back…
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyunfeix authored Jun 22, 2021
1 parent 8a1fd30 commit 2d16dc9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 36 deletions.
18 changes: 12 additions & 6 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ export interface WebGPUTimingInfo extends TimingInfo {
const CPU_HANDOFF_SIZE_THRESHOLD =
env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');

const DEFAULT_GPUBUFFER_USAGE =
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;

export class WebGPUBackend extends KernelBackend {
device: GPUDevice;
queue: GPUQueue;
Expand Down Expand Up @@ -109,6 +106,9 @@ export class WebGPUBackend extends KernelBackend {

constructor(device: GPUDevice, glslang: Glslang, supportTimeQuery = false) {
super();
if (!webgpu_util.isWebGPUSupported()) {
throw new Error('WebGPU is not supported on this device');
}
this.layoutCache = {};
this.pipelineCache = {};
this.device = device;
Expand All @@ -133,6 +133,11 @@ export class WebGPUBackend extends KernelBackend {
return 32;
}

defaultGpuBufferUsage(): number {
return GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC |
GPUBufferUsage.COPY_DST;
}

flushDisposalQueue() {
this.tensorDisposalQueue.forEach(d => {
this.maybeReleaseBuffer(d);
Expand Down Expand Up @@ -191,7 +196,8 @@ export class WebGPUBackend extends KernelBackend {
}

acquireBuffer(
byteSize: number, usage: GPUBufferUsageFlags = DEFAULT_GPUBUFFER_USAGE) {
byteSize: number,
usage: GPUBufferUsageFlags = this.defaultGpuBufferUsage()) {
return this.bufferManager.acquireBuffer(byteSize, usage);
}

Expand Down Expand Up @@ -248,7 +254,7 @@ export class WebGPUBackend extends KernelBackend {
this.tensorMap.set(dataId, {
dtype,
values,
bufferInfo: {byteSize, usage: DEFAULT_GPUBUFFER_USAGE},
bufferInfo: {byteSize, usage: this.defaultGpuBufferUsage()},
refCount: 1
});
return dataId;
Expand All @@ -268,7 +274,7 @@ export class WebGPUBackend extends KernelBackend {
this.tensorMap.set(dataId, {
dtype,
values,
bufferInfo: {byteSize, usage: DEFAULT_GPUBUFFER_USAGE},
bufferInfo: {byteSize, usage: this.defaultGpuBufferUsage()},
refCount
});
}
Expand Down
62 changes: 32 additions & 30 deletions tfjs-backend-webgpu/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,42 @@
import './flags_webgpu';
import './register_all_kernels';

import {env, registerBackend} from '@tensorflow/tfjs-core';
import {device_util, env, registerBackend} from '@tensorflow/tfjs-core';
import glslangInit from '@webgpu/glslang/dist/web-devel/glslang.onefile';

import {WebGPUBackend} from './backend_webgpu';
import * as webgpu from './webgpu';

registerBackend('webgpu', async () => {
// Remove it once we figure out how to correctly read the tensor data before
// the tensor is disposed in profiling mode.
env().set('CHECK_COMPUTATION_FOR_ERRORS', false);

const glslang = await glslangInit();
const gpuDescriptor: GPURequestAdapterOptions = {
powerPreference: env().get('WEBGPU_USE_LOW_POWER_GPU') ? 'low-power' :
'high-performance'
};

const adapter = await navigator.gpu.requestAdapter(gpuDescriptor);
let deviceDescriptor: GPUDeviceDescriptor = {};
const supportTimeQuery = adapter.features.has('timestamp-query');

if (supportTimeQuery) {
deviceDescriptor = {
nonGuaranteedFeatures: ['timestamp-query' as const]
import {isWebGPUSupported} from './webgpu_util';

if (device_util.isBrowser() && isWebGPUSupported()) {
registerBackend('webgpu', async () => {
// Remove it once we figure out how to correctly read the tensor data
// before the tensor is disposed in profiling mode.
env().set('CHECK_COMPUTATION_FOR_ERRORS', false);

const glslang = await glslangInit();
const gpuDescriptor: GPURequestAdapterOptions = {
powerPreference: env().get('WEBGPU_USE_LOW_POWER_GPU') ?
'low-power' :
'high-performance'
};
} else {
console.warn(
`This device doesn't support timestamp-query extension. ` +
`Zero will shown for the kernel time when profiling mode is enabled. ` +
`Using performance.now is not workable for webgpu since it doesn't ` +
`support synchronously to read data from GPU.`);
}
const device: GPUDevice = await adapter.requestDevice(deviceDescriptor);
return new WebGPUBackend(device, glslang, supportTimeQuery);
}, 3 /*priority*/);

const adapter = await navigator.gpu.requestAdapter(gpuDescriptor);
let deviceDescriptor: GPUDeviceDescriptor = {};
const supportTimeQuery = adapter.features.has('timestamp-query');

if (supportTimeQuery) {
deviceDescriptor = {nonGuaranteedFeatures: ['timestamp-query' as const ]};
} else {
console.warn(
`This device doesn't support timestamp-query extension. ` +
`Zero will shown for the kernel time when profiling mode is` +
`enabled. Using performance.now is not workable for webgpu since` +
`it doesn't support synchronously to read data from GPU.`);
}
const device: GPUDevice = await adapter.requestDevice(deviceDescriptor);
return new WebGPUBackend(device, glslang, supportTimeQuery);
}, 3 /*priority*/);
}

export {webgpu};
7 changes: 7 additions & 0 deletions tfjs-backend-webgpu/src/webgpu_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,10 @@ export function ArrayBufferToTypedArray(data: ArrayBuffer, dtype: DataType) {
throw new Error(`Unknown dtype ${dtype}`);
}
}

export function isWebGPUSupported(): boolean {
if (!navigator.gpu) {
return false;
}
return true;
}

0 comments on commit 2d16dc9

Please sign in to comment.