Skip to content

Commit

Permalink
correctly merge blocks back to original fmap
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 24, 2021
1 parent f34a404 commit b41c540
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions halonet_pytorch/halonet_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def __init__(
self.register_buffer('mask', mask == 0)

def forward(self, x):
shape = x.shape
b, c, h, w, block, halo, heads, device = *shape, self.block_size, self.halo_size, self.heads, x.device
b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device
assert h == w, 'dimensions of fmap must be same on both sides, for now'
assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'

Expand Down Expand Up @@ -175,5 +174,5 @@ def forward(self, x):

# merge blocks back to original feature map

out = rearrange(out, '(b i) j c -> b c j i', i = (h // block) * (w // block))
return out.reshape(shape)
out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b = b, h = (h // block), w = (w // block), p1 = block, p2 = block)
return out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'halonet-pytorch',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'HaloNet - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b41c540

Please sign in to comment.