Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Object Detection 3D #3979

Merged
merged 41 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ee6e29e
added MonoDetr model as is
kprokofi Sep 11, 2024
f70c9ca
model are initilized, dataset worked
kprokofi Sep 12, 2024
f0550bf
added recipe, data module
kprokofi Sep 12, 2024
2882444
continue debugging
kprokofi Sep 13, 2024
9f5e5fc
added metric. not working. cuda only
kprokofi Sep 13, 2024
08fe9ad
metrics works. 1 epoch works. Metric is not alined, result is different
kprokofi Sep 18, 2024
b99bc06
train e2e works, but metric is still not the same
kprokofi Sep 18, 2024
721f234
experiment wo early stopping
kprokofi Sep 18, 2024
e35a584
a bit refactoring
kprokofi Sep 19, 2024
4507fe6
refactored model outputs a bit
kprokofi Sep 19, 2024
3dd043b
add export
kprokofi Sep 19, 2024
969479c
metric works properly
kprokofi Sep 20, 2024
4f261ad
update metric
kprokofi Sep 23, 2024
2ba3132
start cleaning the code
kprokofi Sep 23, 2024
1086a7a
some fixes
kprokofi Sep 23, 2024
584f6e6
few fixes
kprokofi Sep 24, 2024
3a7c3c2
refactoring
kprokofi Sep 24, 2024
957528a
continue refactoring
kprokofi Sep 25, 2024
0a6c3a8
merge develop
kprokofi Sep 25, 2024
7c44761
added datumaro dataset
kprokofi Sep 26, 2024
3aa999f
fix linter
kprokofi Sep 26, 2024
e454860
added two version of 3d ap calculation.
kprokofi Sep 26, 2024
83b4dd8
rename backbone
kprokofi Sep 26, 2024
e241501
minor
kprokofi Sep 26, 2024
4e31082
revert recipe back
kprokofi Sep 26, 2024
73ae0a0
revert recipe back
kprokofi Sep 26, 2024
a45c746
moved transformer layers to common. Added integration tests.
kprokofi Sep 27, 2024
4cfcd0f
fix unit tests
kprokofi Sep 27, 2024
44ae515
ready_to_integrate
kprokofi Sep 27, 2024
819a1b5
merge datumaro integration branch
kprokofi Sep 27, 2024
b39d7e5
added datumaro support. 1 class not working. Validation works
kprokofi Sep 27, 2024
fd811c0
fixed issues
kprokofi Sep 28, 2024
50faed9
fix pre-commit
kprokofi Sep 28, 2024
ffa2970
change inputs name
kprokofi Sep 29, 2024
dc31f2b
integration tests OK
kprokofi Sep 29, 2024
d992730
restore missing init
kprokofi Sep 29, 2024
b30423e
Add intg test into CI & tox
harimkang Sep 30, 2024
648e3e4
Revert load_checkpoint_to_model
harimkang Sep 30, 2024
af96e58
Merge branch 'develop' into kp/add_3d_det_task
sovrasov Sep 30, 2024
686653c
upgrade datumaro 1.10.0rc0
wonjuleee Sep 30, 2024
3e8506d
unpin py3.11.8 for unittesting
yunchu Sep 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/pre_merge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ jobs:
include:
- python-version: "3.10"
tox-env: "py310"
# TODO(vinnamki): Revisit after fixing in the upstream: https://github.com/omni-us/jsonargparse/issues/484
# Ticket no. 138075
- python-version: "3.11.8"
- python-version: "3.11"
tox-env: "py311"
name: Unit-Test-with-Python${{ matrix.python-version }}
steps:
Expand Down Expand Up @@ -112,6 +110,7 @@ jobs:
- task: "anomaly_detection"
- task: "anomaly_segmentation"
- task: "keypoint_detection"
- task: "object_detection_3d"
name: Integration-Test-${{ matrix.task }}-py310
steps:
- name: Checkout repository
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies = [
"datumaro==1.7.0",
"datumaro==1.10.0rc0",
"omegaconf==2.3.0",
"rich==13.8.0",
"jsonargparse==4.30.0",
Expand All @@ -39,6 +39,7 @@ dependencies = [
"einops==0.8.0",
"decord==0.6.0",
"typeguard==4.3.*",
"numba==0.60.0",
# TODO(ashwinvaidya17): https://github.com/openvinotoolkit/anomalib/issues/2126
"setuptools<70",
]
Expand Down
122 changes: 122 additions & 0 deletions src/otx/algo/common/layers/transformer_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Implementation of common transformer layers."""

from __future__ import annotations

import copy
from typing import Callable

import torch
from torch import nn


class TransformerEncoderLayer(nn.Module):
"""TransformerEncoderLayer."""

def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Callable[..., nn.Module] = nn.GELU,
normalize_before: bool = False,
batch_first: bool = True,
key_mask: bool = False,
) -> None:
super().__init__()
self.normalize_before = normalize_before
self.key_mask = key_mask

self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=batch_first)

self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)

self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation()

@staticmethod
def with_pos_embed(tensor: torch.Tensor, pos_embed: torch.Tensor | None) -> torch.Tensor:
"""Attach position embeddings to the tensor."""
return tensor if pos_embed is None else tensor + pos_embed

def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward the transformer encoder layer.

Args:
src (torch.Tensor): The input tensor.
src_mask (torch.Tensor | None, optional): The mask tensor. Defaults to None.
pos_embed (torch.Tensor | None, optional): The position embedding tensor. Defaults to None.
"""
residual = src
if self.normalize_before:
src = self.norm1(src)

Check warning on line 65 in src/otx/algo/common/layers/transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/layers/transformer_layers.py#L65

Added line #L65 was not covered by tests
q = k = self.with_pos_embed(src, pos_embed)
if self.key_mask:
src = self.self_attn(q, k, value=src, key_padding_mask=src_mask)[0]

Check warning on line 68 in src/otx/algo/common/layers/transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/layers/transformer_layers.py#L68

Added line #L68 was not covered by tests
else:
src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask)

src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)

residual = src
if self.normalize_before:
src = self.norm2(src)

Check warning on line 78 in src/otx/algo/common/layers/transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/layers/transformer_layers.py#L78

Added line #L78 was not covered by tests
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return src


class TransformerEncoder(nn.Module):
"""TransformerEncoder."""

def __init__(self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module | None = None) -> None:
"""Initialize the TransformerEncoder.

Args:
encoder_layer (nn.Module): The encoder layer module.
num_layers (int): The number of layers.
norm (nn.Module | None, optional): The normalization module. Defaults to None.
"""
super().__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = norm

def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
pos_embed: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward the transformer encoder.

Args:
src (torch.Tensor): The input tensor.
src_mask (torch.Tensor | None, optional): The mask tensor. Defaults to None.
pos_embed (torch.Tensor | None, optional): The position embedding tensor. Defaults to None.
"""
output = src
for layer in self.layers:
output = layer(output, src_mask=src_mask, pos_embed=pos_embed)

if self.norm is not None:
output = self.norm(output)

Check warning on line 120 in src/otx/algo/common/layers/transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/layers/transformer_layers.py#L120

Added line #L120 was not covered by tests

return output
182 changes: 180 additions & 2 deletions src/otx/algo/common/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import torch
import torch.nn.functional
from otx.algo.common.losses.utils import weight_reduce_loss
from torch import nn

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -50,7 +51,7 @@
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
# Thus it's pt.pow(gamma) rather than (1 - pt).pow(gamma)
focal_weight = (alpha * target + (1 - alpha) * (1 - target)) * pt.pow(gamma)
loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight
loss = nn.functional.binary_cross_entropy_with_logits(pred, target, reduction="none") * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
Expand All @@ -70,3 +71,180 @@
msg = "The number of dimensions in weight should be equal to the number of dimensions in loss."
raise ValueError(msg)
return weight_reduce_loss(loss, weight, reduction, avg_factor)


def one_hot(
labels: torch.Tensor,
num_classes: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
eps: float = 1e-6,
) -> torch.Tensor:
r"""Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.

Args:
labels: tensor with labels of shape :math:`(N, *)`, where N is batch size.
Each value is an integer representing correct classification.
num_classes: number of classes in labels.
device: the desired device of returned tensor.
dtype: the desired data type of returned tensor.

Returns:
the labels in one hot tensor of shape :math:`(N, C, *)`,

Examples:
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> one_hot(labels, num_classes=3)
tensor([[[[1.0000e+00, 1.0000e-06],
[1.0000e-06, 1.0000e+00]],
<BLANKLINE>
[[1.0000e-06, 1.0000e+00],
[1.0000e-06, 1.0000e-06]],
<BLANKLINE>
[[1.0000e-06, 1.0000e-06],
[1.0000e+00, 1.0000e-06]]]])
"""
if not isinstance(labels, torch.Tensor):
msg = f"Input labels type is not a torch.Tensor. Got {type(labels)}"
raise TypeError(msg)

Check warning on line 109 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L107-L109

Added lines #L107 - L109 were not covered by tests

if labels.dtype != torch.int64:
msg = f"labels must be of the same dtype torch.int64. Got: {labels.dtype}"
raise ValueError(msg)

Check warning on line 113 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L111-L113

Added lines #L111 - L113 were not covered by tests

if num_classes < 1:
msg = f"The number of classes must be bigger than one. Got: {num_classes}"
raise ValueError(msg)
shape = labels.shape
one_hot = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps

Check warning on line 120 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L115-L120

Added lines #L115 - L120 were not covered by tests


def focal_loss(
inputs: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = "none",
eps: float | None = None,
) -> torch.Tensor:
r"""Criterion that computes Focal loss.

According to :cite:`lin2018focal`, the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
Where:
- :math:`p_t` is the model's estimated probability for each class.

Args:
inputs: logits tensor with shape :math:`(N, C, *)` where C = number of classes.
target: labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C-1`.
alpha: Weighting factor :math:`\alpha \in [0, 1]`.
gamma: Focusing parameter :math:`\gamma >= 0`.
reduction: Specifies the reduction to apply to the
output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the sum of the output will be divided by
the number of elements in the output, ``'sum'``: the output will be
summed.
eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.

Return:
the computed loss.

Example:
>>> N = 5 # num_classes
>>> inputs = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = focal_loss(inputs, target, alpha=0.5, gamma=2.0, reduction='mean')
>>> output.backward()
"""
if eps is not None and not torch.jit.is_scripting():
warnings.warn(

Check warning on line 162 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L161-L162

Added lines #L161 - L162 were not covered by tests
"`focal_loss` has been reworked for improved numerical stability "
"and the `eps` argument is no longer necessary",
DeprecationWarning,
stacklevel=2,
)

if not isinstance(inputs, torch.Tensor):
msg = f"inputs type is not a torch.Tensor. Got {type(inputs)}"
raise TypeError(msg)

Check warning on line 171 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L169-L171

Added lines #L169 - L171 were not covered by tests

if not len(inputs.shape) >= 2:
msg = f"Invalid inputs shape, we expect BxCx*. Got: {inputs.shape}"
raise ValueError(msg)

Check warning on line 175 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L173-L175

Added lines #L173 - L175 were not covered by tests

if inputs.size(0) != target.size(0):
msg = f"Expected inputs batch_size ({inputs.size(0)}) to match target batch_size ({target.size(0)})."
raise ValueError(msg)

Check warning on line 179 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L177-L179

Added lines #L177 - L179 were not covered by tests

n = inputs.size(0)
out_size = (n,) + inputs.size()[2:]
if target.size()[1:] != inputs.size()[2:]:
msg = f"Expected target size {out_size}, got {target.size()}"
raise ValueError(msg)

Check warning on line 185 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L181-L185

Added lines #L181 - L185 were not covered by tests

if inputs.device != target.device:
msg = f"inputs and target must be in the same device. Got: {inputs.device} and {target.device}"
raise ValueError(msg)

Check warning on line 189 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L187-L189

Added lines #L187 - L189 were not covered by tests

# compute softmax over the classes axis
input_soft: torch.Tensor = nn.functional.softmax(inputs, dim=1)
log_input_soft: torch.Tensor = nn.functional.log_softmax(inputs, dim=1)

Check warning on line 193 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L192-L193

Added lines #L192 - L193 were not covered by tests
# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(

Check warning on line 195 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L195

Added line #L195 was not covered by tests
target,
num_classes=inputs.shape[1],
device=inputs.device,
dtype=inputs.dtype,
)

# compute the actual focal loss
weight = torch.pow(-input_soft + 1.0, gamma)

Check warning on line 203 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L203

Added line #L203 was not covered by tests

focal = -alpha * weight * log_input_soft
loss_tmp = torch.einsum("bc...,bc...->b...", (target_one_hot, focal))
return weight_reduce_loss(loss_tmp, reduction=reduction, avg_factor=None)

Check warning on line 207 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L205-L207

Added lines #L205 - L207 were not covered by tests


class FocalLoss(nn.Module):
"""Criterion that computes Focal loss."""

def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none", eps: float | None = None) -> None:
r"""Criterion that computes Focal loss.

According to :cite:`lin2018focal`, the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\\gamma} \\, \text{log}(p_t)
Where:
- :math:`p_t` is the model's estimated probability for each class.

Args:
alpha: Weighting factor :math:`\alpha \\in [0, 1]`.
gamma: Focusing parameter :math:`\\gamma >= 0`.
reduction: Specifies the reduction to apply to the
output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
will be applied, ``'mean'``: the sum of the output will be divided by
the number of elements in the output, ``'sum'``: the output will be
summed.
eps: Deprecated: scalar to enforce numerical stability. This is no longer
used.

Example:
>>> N = 5 # num_classes
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
>>> criterion = FocalLoss(**kwargs)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = criterion(input, target)
>>> output.backward()
"""
super().__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: float | None = eps

def forward(self, inputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward."""
return focal_loss(inputs, target, self.alpha, self.gamma, self.reduction, self.eps)

Check warning on line 250 in src/otx/algo/common/losses/focal_loss.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/common/losses/focal_loss.py#L250

Added line #L250 was not covered by tests
Loading
Loading