Skip to content

Commit

Permalink
Merge a234d8e into 5b5f895
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 authored Aug 7, 2023
2 parents 5b5f895 + a234d8e commit 858a650
Show file tree
Hide file tree
Showing 13 changed files with 1,785 additions and 5 deletions.
4 changes: 3 additions & 1 deletion mmagic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .comp1k_dataset import AdobeComp1kDataset
from .controlnet_dataset import ControlNetDataset
from .dreambooth_dataset import DreamBoothDataset
from .dummy_dataset import DummyDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
Expand All @@ -19,5 +20,6 @@
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset',
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset'
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset',
'DummyDataset'
]
30 changes: 30 additions & 0 deletions mmagic/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

from torch.utils.data import Dataset

from mmagic.registry import DATASETS


@DATASETS.register_module()
class DummyDataset(Dataset):

def __init__(self, max_length=100, batch_size=None, sample_kwargs=None):
super().__init__()
self.max_length = max_length
self.sample_kwargs = sample_kwargs
self.batch_size = batch_size

Check warning on line 16 in mmagic/datasets/dummy_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmagic/datasets/dummy_dataset.py#L13-L16

Added lines #L13 - L16 were not covered by tests

def __len__(self):
return self.max_length

Check warning on line 19 in mmagic/datasets/dummy_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmagic/datasets/dummy_dataset.py#L19

Added line #L19 was not covered by tests

def __getitem__(self, index):
data_dict = dict()
input_dict = dict()

Check warning on line 23 in mmagic/datasets/dummy_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmagic/datasets/dummy_dataset.py#L22-L23

Added lines #L22 - L23 were not covered by tests
if self.batch_size is not None:
input_dict['num_batches'] = self.batch_size

Check warning on line 25 in mmagic/datasets/dummy_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmagic/datasets/dummy_dataset.py#L25

Added line #L25 was not covered by tests
if self.sample_kwargs is not None:
input_dict['sample_kwargs'] = deepcopy(self.sample_kwargs)

Check warning on line 27 in mmagic/datasets/dummy_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmagic/datasets/dummy_dataset.py#L27

Added line #L27 was not covered by tests

data_dict['inputs'] = input_dict
return data_dict

Check warning on line 30 in mmagic/datasets/dummy_dataset.py

View check run for this annotation

Codecov / codecov/patch

mmagic/datasets/dummy_dataset.py#L29-L30

Added lines #L29 - L30 were not covered by tests
3 changes: 2 additions & 1 deletion mmagic/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dreamfusion_hook import DreamFusionTrainingHook
from .ema import ExponentialMovingAverageHook
from .iter_time_hook import IterTimerHook
from .pggan_fetch_data_hook import PGGANFetchDataHook
Expand All @@ -9,5 +10,5 @@
__all__ = [
'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'VisualizationHook',
'ExponentialMovingAverageHook', 'IterTimerHook', 'PGGANFetchDataHook',
'PickleDataHook'
'PickleDataHook', 'DreamFusionTrainingHook'
]
55 changes: 55 additions & 0 deletions mmagic/engine/hooks/dreamfusion_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmagic.registry import HOOKS


@HOOKS.register_module()
class DreamFusionTrainingHook(Hook):

def __init__(self, albedo_iters: int):
super().__init__()
self.albedo_iters = albedo_iters

Check warning on line 15 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L14-L15

Added lines #L14 - L15 were not covered by tests

self.shading_test = 'albedo'
self.ambident_ratio_test = 1.0

Check warning on line 18 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L17-L18

Added lines #L17 - L18 were not covered by tests

def set_shading_and_ambient(self, runner, shading: str,
ambient_ratio: str) -> None:
model = runner.model

Check warning on line 22 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L22

Added line #L22 was not covered by tests
if is_model_wrapper(model):
model = model.module
renderer = model.renderer

Check warning on line 25 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L24-L25

Added lines #L24 - L25 were not covered by tests
if is_model_wrapper(renderer):
renderer = renderer.module
renderer.set_shading(shading)
renderer.set_ambient_ratio(ambient_ratio)

Check warning on line 29 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L27-L29

Added lines #L27 - L29 were not covered by tests

def after_train_iter(self, runner, batch_idx: int, *args,
**kwargs) -> None:
if batch_idx < self.albedo_iters or self.albedo_iters == -1:
shading = 'albedo'
ambient_ratio = 1.0

Check warning on line 35 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L34-L35

Added lines #L34 - L35 were not covered by tests
else:
rand = random.random()

Check warning on line 37 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L37

Added line #L37 was not covered by tests
if rand > 0.8: # NOTE: this should be 0.75 in paper
shading = 'albedo'
ambient_ratio = 1.0

Check warning on line 40 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L39-L40

Added lines #L39 - L40 were not covered by tests
elif rand > 0.4: # NOTE: this should be 0.75 * 0.5 = 0.325
shading = 'textureless'
ambient_ratio = 0.1

Check warning on line 43 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L42-L43

Added lines #L42 - L43 were not covered by tests
else:
shading = 'lambertian'
ambient_ratio = 0.1
self.set_shading_and_ambient(runner, shading, ambient_ratio)

Check warning on line 47 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L45-L47

Added lines #L45 - L47 were not covered by tests

def before_test(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,

Check warning on line 50 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L50

Added line #L50 was not covered by tests
self.ambident_ratio_test)

def before_val(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,

Check warning on line 54 in mmagic/engine/hooks/dreamfusion_hook.py

View check run for this annotation

Codecov / codecov/patch

mmagic/engine/hooks/dreamfusion_hook.py#L54

Added line #L54 was not covered by tests
self.ambident_ratio_test)
3 changes: 2 additions & 1 deletion mmagic/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .dim import DIM
from .disco_diffusion import ClipWrapper, DiscoDiffusion
from .dreambooth import DreamBooth
from .dreamfusion import DreamFusion
from .edsr import EDSRNet
from .edvr import EDVR, EDVRNet
from .eg3d import EG3D
Expand Down Expand Up @@ -89,5 +90,5 @@
'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet',
'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion',
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion'
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DreamFusion'
]
10 changes: 10 additions & 0 deletions mmagic/models/editors/dreamfusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .camera import DreamFusionCamera
from .dreamfusion import DreamFusion
from .renderer import DreamFusionRenderer
from .stable_diffusion_wrapper import StableDiffusionWrapper

__all__ = [
'DreamFusion', 'DreamFusionRenderer', 'DreamFusionCamera',
'StableDiffusionWrapper'
]
22 changes: 22 additions & 0 deletions mmagic/models/editors/dreamfusion/activate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd


class _trunc_exp(Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)

Check warning on line 13 in mmagic/models/editors/dreamfusion/activate.py

View check run for this annotation

Codecov / codecov/patch

mmagic/models/editors/dreamfusion/activate.py#L12-L13

Added lines #L12 - L13 were not covered by tests

@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))

Check warning on line 19 in mmagic/models/editors/dreamfusion/activate.py

View check run for this annotation

Codecov / codecov/patch

mmagic/models/editors/dreamfusion/activate.py#L18-L19

Added lines #L18 - L19 were not covered by tests


trunc_exp = _trunc_exp.apply
Loading

0 comments on commit 858a650

Please sign in to comment.