Skip to content

Commit

Permalink
[METAL] Fix memory leaks in Metal runtime (apache#7714)
Browse files Browse the repository at this point in the history
* [METAL] Fix memory leaks in Metal runtime

1. In case when we build runtime without ARC, we can have problems with
   memory releasing. Due to some of Objective-C methods returns
   autoreleased pointers, we should specify `autoreleasepool` blocks to
   determine life cycle of these pointers.
2. Added workaround for problem with work group size.
   Sometimes auto scheduler generates parameters when work group size
   is more than possible. And in this case we got assert from Metal
   library. Added check for this situation and it helps to avoid
   assert.
3. Fixed memory leak problem when fill tensor by random data.
   DLManagedTensor increases reference counter in NDArray but nobody
   delete this DLManagedTensor in proper way. This is why memory which
   was allocated by NDArray was never released.
4. Removed unnecessary retains. It is not necessary use retain in some
   places where they were used, due to we build metal runtime without
   ARC.

* Use const_cast instead of creation DLManagedTensor
  • Loading branch information
echuraev authored and mehrdadh committed Mar 23, 2021
1 parent 24bba8c commit 917b8b3
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 162 deletions.
5 changes: 3 additions & 2 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ class RandomEngine {
} else {
runtime::NDArray local = runtime::NDArray::Empty(
std::vector<int64_t>{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0});
FillData(&local.ToDLPack()->dl_tensor, size);
runtime::NDArray::CopyFromTo(&local.ToDLPack()->dl_tensor, data);
DLTensor* tensor = const_cast<DLTensor*>(local.operator->());
FillData(tensor, size);
runtime::NDArray::CopyFromTo(tensor, data);
}
}

Expand Down
258 changes: 137 additions & 121 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -30,50 +30,54 @@
namespace metal {

MetalWorkspace* MetalWorkspace::Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
@autoreleasepool {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}
}

void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = int(index < devices.size());
return;
}
ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>([devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break;
@autoreleasepool {
this->Init();
size_t index = static_cast<size_t>(ctx.device_id);
if (kind == kExist) {
*rv = int(index < devices.size());
return;
}
case kWarpSize: {
// Set warp size to be 1 for safty reason.
*rv = 1;
break;
ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
switch (kind) {
case kMaxThreadsPerBlock: {
*rv = static_cast<int>([devices[ctx.device_id] maxThreadsPerThreadgroup].width);
break;
}
case kWarpSize: {
// Set warp size to be 1 for safty reason.
*rv = 1;
break;
}
case kMaxSharedMemoryPerBlock:
return;
case kComputeVersion:
return;
case kDeviceName:
return;
case kMaxClockRate:
return;
case kMultiProcessorCount:
return;
case kMaxThreadDimensions:
return;
case kExist:
return;
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
case kApiVersion:
return;
}
case kMaxSharedMemoryPerBlock:
return;
case kComputeVersion:
return;
case kDeviceName:
return;
case kMaxClockRate:
return;
case kMultiProcessorCount:
return;
case kMaxThreadDimensions:
return;
case kExist:
return;
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
case kApiVersion:
return;
}
}

Expand Down Expand Up @@ -106,7 +110,11 @@ int GetWarpSize(id<MTLDevice> dev) {
ICHECK(f != nil);
id<MTLComputePipelineState> state = [dev newComputePipelineStateWithFunction:f error:&error_msg];
ICHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
return static_cast<int>(state.threadExecutionWidth);
int size = static_cast<int>(state.threadExecutionWidth);
[state release];
[f release];
[lib release];
return size;
}

MetalWorkspace::~MetalWorkspace() {
Expand All @@ -127,14 +135,14 @@ int GetWarpSize(id<MTLDevice> dev) {
#if TARGET_OS_IPHONE
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
devices.push_back(d);
queues.push_back([d newCommandQueue]);
#else
NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back([d retain]);
queues.push_back([[d newCommandQueue] retain]);
devices.push_back(d);
queues.push_back([d newCommandQueue]);
LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
warp_size.push_back(GetWarpSize(d));
}
Expand All @@ -147,102 +155,110 @@ int GetWarpSize(id<MTLDevice> dev) {

void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
DLDataType type_hint) {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
ICHECK(buf != nil);
return (void*)(CFBridgingRetain(buf));
@autoreleasepool {
this->Init();
id<MTLDevice> dev = GetDevice(ctx);
// GPU memory only
MTLResourceOptions storage_mode = MTLResourceStorageModePrivate;
/*
#if TARGET_OS_IPHONE
storage_mode = MTLResourceStorageModeShared;
#else
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
ICHECK(buf != nil);
return (void*)(buf);
}
}

void MetalWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
// MTLBuffer PurgeableState should be set to empty before manual
// release in order to prevent memory leak
[(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
// release the ptr.
CFRelease(ptr);
@autoreleasepool {
// MTLBuffer PurgeableState should be set to empty before manual
// release in order to prevent memory leak
[(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
// release the ptr.
CFRelease(ptr);
}
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, TVMContext ctx_from,
TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) {
this->Init();
ICHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
@autoreleasepool {
this->Init();
ICHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);

if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
ICHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:(__bridge id<MTLBuffer>)(from)
sourceOffset:from_offset
toBuffer:(__bridge id<MTLBuffer>)(to)destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
ICHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:from_buf
[encoder copyFromBuffer:(id<MTLBuffer>)(from)
sourceOffset:from_offset
toBuffer:temp
destinationOffset:0
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp contents]), size);
} else {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([from_buf contents]) + from_offset, size);
}
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
memcpy([temp contents], static_cast<const char*>(from) + from_offset, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:temp
sourceOffset:0
toBuffer:to_buf
destinationOffset:to_offset
toBuffer:(id<MTLBuffer>)(to)destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:from_buf
sourceOffset:from_offset
toBuffer:temp
destinationOffset:0
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>([temp contents]), size);
} else {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<char*>([from_buf contents]) + from_offset, size);
}
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size);
memcpy([temp contents], static_cast<const char*>(from) + from_offset, size);
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
[encoder copyFromBuffer:temp
sourceOffset:0
toBuffer:to_buf
destinationOffset:to_offset
size:size];
[encoder endEncoding];
[cb commit];
[cb waitUntilCompleted];
} else {
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
static_cast<const char*>(from) + from_offset, size);
}
} else {
memcpy(static_cast<char*>([to_buf contents]) + to_offset,
static_cast<const char*>(from) + from_offset, size);
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
}
} else {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
}
}

void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
ICHECK(stream == nullptr);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
[cb commit];
[cb waitUntilCompleted];
@autoreleasepool {
ICHECK(stream == nullptr);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
[cb commit];
[cb waitUntilCompleted];
}
}

void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
Expand All @@ -269,10 +285,10 @@ int GetWarpSize(id<MTLDevice> dev) {
if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length < size) {
id<MTLDevice> dev = MetalWorkspace::Global()->GetDevice(ctx);
if (temp_buffer_[ctx.device_id] != nil) {
[temp_buffer_[ctx.device_id] setPurgeableState:MTLPurgeableStateEmpty];
[temp_buffer_[ctx.device_id] release];
}
temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size
options:MTLStorageModeShared] retain];
temp_buffer_[ctx.device_id] = [dev newBufferWithLength:size options:MTLStorageModeShared];
}
return temp_buffer_[ctx.device_id];
}
Expand Down
Loading

0 comments on commit 917b8b3

Please sign in to comment.