Skip to content

Commit

Permalink
refactor: rename temp->temperature
Browse files Browse the repository at this point in the history
  • Loading branch information
0xlws committed Sep 13, 2023
1 parent a2b9675 commit 0335d54
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion audiocraft/models/audiogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
Expand Down
16 changes: 8 additions & 8 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _sample_next_token(self,
cfg_conditions: CFGConditions,
unconditional_state: State,
use_sampling: bool = False,
temp: float = 1.0,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
Expand All @@ -325,7 +325,7 @@ def _sample_next_token(self,
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
should be twice the batch size, being the concatenation of the conditions + null conditions.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
temperature (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): classifier free guidance coefficient
Expand Down Expand Up @@ -363,9 +363,9 @@ def _sample_next_token(self,
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
logits = logits[..., -1] # [B x K x card]

# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
if use_sampling and temp > 0.0:
probs = torch.softmax(logits / temp, dim=-1)
# Apply softmax for sampling if temperature > 0. Else, do greedy sampling to avoid zero division error.
if use_sampling and temperature > 0.0:
probs = torch.softmax(logits / temperature, dim=-1)
if top_p > 0.0:
next_token = utils.sample_top_p(probs, p=top_p)
elif top_k > 0:
Expand All @@ -384,7 +384,7 @@ def generate(self,
num_samples: tp.Optional[int] = None,
max_gen_len: int = 256,
use_sampling: bool = True,
temp: float = 1.0,
temperature: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
Expand All @@ -401,7 +401,7 @@ def generate(self,
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
max_gen_len (int): Maximum generation length.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
temperature (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coeff (float, optional): Classifier-free guidance coefficient.
Expand Down Expand Up @@ -492,7 +492,7 @@ def generate(self,
assert not (curr_sequence == unknown_token).any()
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temperature, top_k, top_p,
cfg_coef=cfg_coef)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/solvers/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, cfg: omegaconf.DictConfig):
# easier access to sampling parameters
self.generation_params = {
'use_sampling': self.cfg.generate.lm.use_sampling,
'temp': self.cfg.generate.lm.temp,
'temperature': self.cfg.generate.lm.temperature,
'top_k': self.cfg.generate.lm.top_k,
'top_p': self.cfg.generate.lm.top_p,
}
Expand Down
2 changes: 1 addition & 1 deletion config/solver/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ generate:
sample_rate: null
lm:
use_sampling: false
temp: 1.0
temperature: 1.0
top_k: 0
top_p: 0.0
evaluate:
Expand Down
2 changes: 1 addition & 1 deletion config/solver/musicgen/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ generate:
remove_prompts: false
# generation params
use_sampling: false
temp: 1.0
temperature: 1.0
top_k: 0
top_p: 0.0
evaluate:
Expand Down

0 comments on commit 0335d54

Please sign in to comment.