Skip to content

Commit

Permalink
add flag to control lora
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 15, 2024
1 parent 9c4b262 commit af71748
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/ai_models_aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ class AuroraModel(Model):
# Output

expver = "auro"
lora = None

def run(self):

# TODO: control location of cache

LOG.info(f"Model is {self.__class__.__name__}, use_lora={self.use_lora}")
use_lora = self.lora if self.lora is not None else self.use_lora

model = self.klass(use_lora=self.use_lora)
LOG.info(f"Model is {self.__class__.__name__}, use_lora={use_lora}")

model = self.klass(use_lora=use_lora)
model = model.to(self.device)

path = os.path.join(self.assets, os.path.basename(self.checkpoint))
Expand Down Expand Up @@ -158,6 +161,16 @@ def nan_extend(self, data):
axis=0,
)

def add_model_args(self, parser):
parser.add_argument(
"--lora",
type=lambda x: (str(x).lower() in ["true", "1", "yes"]),
nargs="?",
const=True,
default=None,
help="Use LoRA model (true/false). Default depends on the model.",
)


class Aurora2p5(AuroraModel):
klass = Aurora
Expand Down

0 comments on commit af71748

Please sign in to comment.