Skip to content

Commit

Permalink
[feature] change val loss arg name
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-tf committed May 22, 2023
1 parent 344ac21 commit 9671ed8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
7 changes: 5 additions & 2 deletions configs/resnet/resnet50_4xb128_bp_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
val_evaluator = [
dict(type='MAE', gt_key='gt_label',
pred_key='pred_label', loss_key='pred_loss'),
pred_key='pred_label', compute_loss=True),
# dict(type='BlandAltmanPlot')
]
test_dataloader = dict(
Expand All @@ -51,7 +51,10 @@
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = val_evaluator
test_evaluator = [
dict(type='MAE', gt_key='gt_label',
pred_key='pred_label')
]

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam',
Expand Down
14 changes: 3 additions & 11 deletions deep_vital/evaluation/metrics/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
import torch
import torch.nn.functional as F
from mmengine.evaluator import BaseMetric
from mmengine import MessageHub
from deep_vital.registry import METRICS

message_hub = MessageHub.get_current_instance()


@METRICS.register_module()
class MAE(BaseMetric):
Expand All @@ -16,32 +13,27 @@ class MAE(BaseMetric):
def __init__(self,
gt_key='gt_label',
pred_key='pred_label',
loss_key=None,
compute_loss=False,
collect_device: str = 'cpu',
prefix: Optional[str] = None):
super().__init__(collect_device=collect_device, prefix=prefix)
self.gt_key = gt_key
self.pred_key = pred_key
self.loss_key = loss_key
self.compute_loss = compute_loss

def process(self, data_batch, data_samples):
for data_sample in data_samples:
result = {
'pred_label': data_sample[self.pred_key],
'gt_label': data_sample[self.gt_key]
}
# if self.loss_key:
# result['pred_loss'] = data_sample[self.loss_key]
self.results.append(result)

def compute_metrics(self, results):
metrics = {}
target = torch.stack([res['gt_label'] for res in results])
pred = torch.stack([res['pred_label'] for res in results])
if self.loss_key:
# loss = torch.stack([res['pred_loss'] for res in results])
# val_loss = loss.mean()
# metrics['pred_loss'] = val_loss
if self.compute_loss:
sbp_pred = pred[:, 0]
dbp_pred = pred[:, 1]
sbp_target = target[:, 0]
Expand Down
2 changes: 1 addition & 1 deletion deep_vital/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.0'
__version__ = '1.0.0'
short_version = __version__


Expand Down

0 comments on commit 9671ed8

Please sign in to comment.