Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the difference between the ReinMask2FormerHead and original Mask2FormerHead? #12

Closed
hongge831 opened this issue Mar 16, 2024 · 3 comments

Comments

@hongge831
Copy link

Nice work. I noticed that you reweite the Mask2FormerHead as ReinMask2FormerHead, I wonder that what is the difference between them?
By when you use the VPT to comprare your Rein, What are the detailed settings of VPT? For example, how many tokens are used, and are there any differences in hyperparameter settings?

looking for your reply

@w1oves
Copy link
Owner

w1oves commented Mar 16, 2024 via email

@w1oves w1oves closed this as completed Mar 17, 2024
@hongge831
Copy link
Author

Thanks for your reply, would mind uploading this part of code for us as a reference? I would be grateful if you could do this

@w1oves
Copy link
Owner

w1oves commented Mar 18, 2024

ReinMask2Former

You can find details of ReinMask2Former at rein/models/heads/rein_mask2former.py.

Implemented VPT Version:

class SimpleVPT(nn.Module):
    def __init__(self, embed_dim, depth, num_tokens) -> None:
        super().__init__()
        self.num_tokens = num_tokens
        self.embed_dim = embed_dim
        # Initialize prompt embeddings
        self.prompt_embeddings = nn.Parameter(torch.zeros(depth, num_tokens, embed_dim))
        # Uniform initialization of embeddings
        val = math.sqrt(6.0 / float(3 * reduce(mul, (16, 16), 1) + embed_dim))
        nn.init.uniform_(self.prompt_embeddings.data, -val, val)

    def forward(self, idx, layer, x: Tensor, batch_first=True, cls_token=True):
        if not batch_first:
            x = x.permute(1, 0, 2)  # Adjust for batch first convention
        B, _, _ = x.shape
        # Concatenate prompt embeddings
        x = torch.cat([self.prompt_embeddings[idx].expand(B, -1, -1), x], dim=1)
        x = layer(x)
        # Split and permute if necessary
        _, x = torch.tensor_split(x, [self.num_tokens], dim=1)
        if not batch_first:
            x = x.permute(1, 0, 2)
        return x

DINOv2 + VPT Integration:

class VPTDinoVisionTransformer(DinoVisionTransformer):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        ffn_bias=True,
        proj_bias=True,
        drop_path_rate=0.0,
        drop_path_uniform=False,
        init_values=None,  # for layerscale: None or 0 => no layerscale
        embed_layer=PatchEmbed,
        act_layer=nn.GELU,
        block_fn=partial(Block, attn_class=MemEffAttention),
        ffn_layer="mlp",
        block_chunks=1,
        out_indices=[7, 11, 15, 23],
        upscale_feats=False,
        init_cfg=None,
    ):
        super().__init__(
            img_size,
            patch_size,
            in_chans,
            embed_dim,
            depth,
            num_heads,
            mlp_ratio,
            qkv_bias,
            ffn_bias,
            proj_bias,
            drop_path_rate,
            drop_path_uniform,
            init_values,
            embed_layer,
            act_layer,
            block_fn,
            ffn_layer,
            block_chunks,
            out_indices,
            upscale_feats,
            init_cfg,
        )
        # Initialize VPT
        self.vpt = SimpleVPT(embed_dim=embed_dim, depth=depth, num_tokens=150)

    def forward_features(self, x, masks=None):
        B, _, h, w = x.shape
        H, W = h // self.patch_size, w // self.patch_size
        
        if isinstance(x, list):
            return self.forward_features_list(x, masks)

        x = self.prepare_tokens_with_masks(x, masks)
        outs = []
        for idx, blk in enumerate(self.blocks):
            x = self.vpt.forward(idx, blk, x, batch_first=True, cls_token=True)
            if idx in self.out_indices:
                outs.append(x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, H, W).contiguous())
        return outs

    def train(self, mode: bool = True):
        if not mode:
            return super().train(mode)
        set_requires_grad(self, ["vpt"])
        set_train(self, ["vpt"])

@w1oves w1oves pinned this issue Mar 18, 2024
@w1oves w1oves unpinned this issue Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants