Skip to content

Commit

Permalink
add device
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin San Roman committed Aug 11, 2023
1 parent b948971 commit ac311af
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions audiocraft/models/multibanddiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_mbd_musicgen(device=None):
models, processors, cfgs = load_diffusion_models(path, device=device)
DPs = []
for i in range(len(models)):
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i])
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

Expand Down Expand Up @@ -106,7 +106,7 @@ def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
models, processors, cfgs = load_diffusion_models(path, device=device)
DPs = []
for i in range(len(models)):
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i])
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

Expand Down

0 comments on commit ac311af

Please sign in to comment.