Skip to content

Commit

Permalink
Revert "[METAL] Fix memory leaks in Metal runtime (apache#7714)"
Browse files Browse the repository at this point in the history
This reverts commit 917b8b3.
  • Loading branch information
mehrdadh committed Mar 23, 2021
1 parent dd47bc4 commit 5124ed4
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 189 deletions.
5 changes: 2 additions & 3 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ class RandomEngine {
} else {
runtime::NDArray local = runtime::NDArray::Empty(
std::vector<int64_t>{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0});
DLTensor* tensor = const_cast<DLTensor*>(local.operator->());
FillData(tensor, size);
runtime::NDArray::CopyFromTo(tensor, data);
FillData(&local.ToDLPack()->dl_tensor, size);
runtime::NDArray::CopyFromTo(&local.ToDLPack()->dl_tensor, data);
}
}

Expand Down
258 changes: 121 additions & 137 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -30,54 +30,50 @@
namespace metal {

MetalWorkspace* MetalWorkspace::Global() {
@autoreleasepool {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}

void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
@autoreleasepool {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = int(index < devices.size());
return;
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = int(index < devices.size());
return;
}
ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>([devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break;
}
ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>([devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break;
}
case kWarpSize: {
// Set warp size to be 1 for safty reason.
*rv = 1;
break;
}
case kMaxSharedMemoryPerBlock:
return;
case kComputeVersion:
return;
case kDeviceName:
return;
case kMaxClockRate:
return;
case kMultiProcessorCount:
return;
case kMaxThreadDimensions:
return;
case kExist:
return;
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
case kApiVersion:
return;
case kWarpSize: {
// Set warp size to be 1 for safty reason.
*rv = 1;
break;
}
case kMaxSharedMemoryPerBlock:
return;
case kComputeVersion:
return;
case kDeviceName:
return;
case kMaxClockRate:
return;
case kMultiProcessorCount:
return;
case kMaxThreadDimensions:
return;
case kExist:
return;
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
case kApiVersion:
return;
}
}

Expand Down Expand Up @@ -110,11 +106,7 @@ int GetWarpSize(id<MTLDevice> dev) {
ICHECK(f != nil);
id<MTLComputePipelineState> state = [dev newComputePipelineStateWithFunction:f error:&error_msg];
ICHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
int size = static_cast<int>(state.threadExecutionWidth);
[state release];
[f release];
[lib release];
return size;
return static_cast<int>(state.threadExecutionWidth);
}

MetalWorkspace::~MetalWorkspace() {
Expand All @@ -135,14 +127,14 @@ int GetWarpSize(id<MTLDevice> dev) {
#if TARGET_OS_IPHONE
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back(d);
queues.push_back([d newCommandQueue]);
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
#else
NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
queues.push_back([d newCommandQueue]);
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
warp_size.push_back(GetWarpSize(d));
}
Expand All @@ -155,110 +147,102 @@ int GetWarpSize(id<MTLDevice> dev) {

void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) {
@autoreleasepool {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
ICHECK(buf != nil);
return (void*)(buf);
}
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
ICHECK(buf != nil);
return (void*)(CFBridgingRetain(buf));
}

void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
@autoreleasepool {
// MTLBuffer PurgeableState should be set to empty before manual
// release in order to prevent memory leak
[(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
// release the ptr.
CFRelease(ptr);
}
// MTLBuffer PurgeableState should be set to empty before manual
// release in order to prevent memory leak
[(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
// release the ptr.
CFRelease(ptr);
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, TVMContext ctx_from,
TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) {
@autoreleasepool {
this->Init();
ICHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
this->Init();
ICHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);

if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
ICHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy.";
if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
ICHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
sourceOffset:from_offset
toBuffer:(__bridge id<MTLBuffer>)(to)destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:(id<MTLBuffer>)(from)
[encoder copyFromBuffer:from_buf
sourceOffset:from_offset
toBuffer:(id<MTLBuffer>)(to)destinationOffset:to_offset
toBuffer:temp
destinationOffset:0
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp contents]), size);
} else {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([from_buf contents]) + from_offset, size);
}
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
memcpy([temp contents], static_cast<const char*>(from) + from_offset, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:temp
sourceOffset:0
toBuffer:to_buf
destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:from_buf
sourceOffset:from_offset
toBuffer:temp
destinationOffset:0
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp contents]), size);
} else {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([from_buf contents]) + from_offset, size);
}
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
memcpy([temp contents], static_cast<const char*>(from) + from_offset, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:temp
sourceOffset:0
toBuffer:to_buf
destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
} else {
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
static_cast<const char*>(from) + from_offset, size);
}
[cb waitUntilCompleted];
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
static_cast<const char*>(from) + from_offset, size);
}
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
}
}

void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
@autoreleasepool {
ICHECK(stream == nullptr);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
[cb commit];
[cb waitUntilCompleted];
}
ICHECK(stream == nullptr);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
[cb commit];
[cb waitUntilCompleted];
}

void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
Expand All @@ -285,10 +269,10 @@ int GetWarpSize(id<MTLDevice> dev) {
if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length < size) {
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
if (temp_buffer_[ctx.device_id] != nil) {
[temp_buffer_[ctx.device_id] setPurgeableState:MTLPurgeableStateEmpty];
[temp_buffer_[ctx.device_id] release];
}
temp_buffer_[ctx.device_id] = [dev newBufferWithLength:size options:MTLStorageModeShared];
temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size
options:MTLStorageModeShared] retain];
}
return temp_buffer_[ctx.device_id];
}
Expand Down
Loading

0 comments on commit 5124ed4

Please sign in to comment.