Skip to content

Commit

Permalink
lazy import_var error handling for saves
Browse files Browse the repository at this point in the history
  • Loading branch information
YellowRoseCx committed Jul 21, 2023
1 parent 9553e52 commit 521ad6b
Showing 1 changed file with 102 additions and 67 deletions.
169 changes: 102 additions & 67 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,77 +1070,112 @@ def import_vars(dict):
smartcontext.set(1 if dict["smartcontext"] else 0)
unbantokens.set(1 if dict["unbantokens"] else 0)
runopts_var.set(runopts[0])
if dict["useclblast"]:
runopts_var.set(runopts[1])
gpu_choice_var.set(str(["0 0", "1 0", "0 1"].index(str(dict["useclblast"][0]) + " " + str(dict["useclblast"][1])) + 1))
elif dict["usecublas"]:
runopts_var.set(runopts[2])
if len(dict["usecublas"])==1:
lowvram_var.set(1 if dict["usecublas"][0]=="lowvram" else 0)
else:
lowvram_var.set(1 if "lowvram" in dict["usecublas"] else 0)
gpu_choice_var.set("1")
for g in range(3):
if str(g) in dict["usecublas"]:
gpu_choice_var.set(str(g+1))
break
if dict["gpulayers"]:
gpulayers_var.set(dict["gpulayers"])

if dict["noblas"] and dict["noavx2"]:
runopts_var.set(runopts[5])
elif dict["noavx2"]:
runopts_var.set(runopts[5])
elif dict["noblas"]:
runopts_var.set(runopts[3])
if dict["blasthreads"]:
blas_threads_var.set(str(dict["blasthreads"]))
else:
blas_threads_var.set("")

if dict["contextsize"]:
context_var.set(contextsize_text.index(str(dict["contextsize"])))
try:
if dict["useclblast"]:
runopts_var.set(runopts[1])
gpu_choice_var.set(str(["0 0", "1 0", "0 1"].index(str(dict["useclblast"][0]) + " " + str(dict["useclblast"][1])) + 1))
elif dict["usecublas"]:
runopts_var.set(runopts[2])
if len(dict["usecublas"])==1:
lowvram_var.set(1 if dict["usecublas"][0]=="lowvram" else 0)
else:
lowvram_var.set(1 if "lowvram" in dict["usecublas"] else 0)
gpu_choice_var.set("1")
for g in range(3):
if str(g) in dict["usecublas"]:
gpu_choice_var.set(str(g+1))
break
except (KeyError, IndexError):
pass

if dict["ropeconfig"] and len(dict["ropeconfig"])>1:
if dict["ropeconfig"][0]>0:
customrope_var.set(1)
customrope_scale.set(str(dict["ropeconfig"][0]))
customrope_base.set(str(dict["ropeconfig"][1]))
else:
customrope_var.set(0)

if dict["blasbatchsize"]:
blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"])))
if dict["forceversion"]:
version_var.set(str(dict["forceversion"]))

if dict["mirostat"] and len(dict["mirostat"])>1:
usemirostat.set(0 if str(dict["mirostat"][0])=="0" else 1)
mirostat_var.set(str(dict["mirostat"][0]))
mirostat_tau.set(str(dict["mirostat"][1]))
mirostat_eta.set(str(dict["mirostat"][2]))

if dict["model_param"]:
model_var.set(dict["model_param"])

if dict["lora"]:
if len(dict["lora"]) > 1:
lora_var.set(dict["lora"][0])
lora_base_var.set(dict["lora"][1])
try:
if dict["gpulayers"]:
gpulayers_var.set(dict["gpulayers"])
except (KeyError, IndexError):
pass
try:
if dict["noblas"] and dict["noavx2"]:
runopts_var.set(runopts[5])
elif dict["noavx2"]:
runopts_var.set(runopts[5])
elif dict["noblas"]:
runopts_var.set(runopts[3])
except (KeyError, IndexError):
pass
try:
if dict["blasthreads"]:
blas_threads_var.set(str(dict["blasthreads"]))
else:
lora_var.set(dict["lora"][0])

if dict["port_param"]:
port_var.set(dict["port_param"])

if dict["host"]:
host_var.set(dict["host"])
blas_threads_var.set("")
except (KeyError, IndexError):
pass
try:
if dict["contextsize"]:
context_var.set(contextsize_text.index(str(dict["contextsize"])))
except (KeyError, IndexError):
pass
try:
if dict["ropeconfig"] and len(dict["ropeconfig"])>1:
if dict["ropeconfig"][0]>0:
customrope_var.set(1)
customrope_scale.set(str(dict["ropeconfig"][0]))
customrope_base.set(str(dict["ropeconfig"][1]))
else:
customrope_var.set(0)
except (KeyError, IndexError):
pass
try:
if dict["blasbatchsize"]:
blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"])))
except (KeyError, IndexError):
pass
try:
if dict["forceversion"]:
version_var.set(str(dict["forceversion"]))
except (KeyError, IndexError):
pass
try:
if dict["mirostat"] and len(dict["mirostat"])>1:
usemirostat.set(0 if str(dict["mirostat"][0])=="0" else 1)
mirostat_var.set(str(dict["mirostat"][0]))
mirostat_tau.set(str(dict["mirostat"][1]))
mirostat_eta.set(str(dict["mirostat"][2]))
except (KeyError, IndexError):
pass
try:
if dict["model_param"]:
model_var.set(dict["model_param"])
except (KeyError, IndexError):
pass
try:
if dict["lora"]:
if len(dict["lora"]) > 1:
lora_var.set(dict["lora"][0])
lora_base_var.set(dict["lora"][1])
else:
lora_var.set(dict["lora"][0])
except (KeyError, IndexError):
pass
try:
if dict["port_param"]:
port_var.set(dict["port_param"])
except (KeyError, IndexError):
pass
try:
if dict["host"]:
host_var.set(dict["host"])
except (KeyError, IndexError):
pass
try:
if dict["hordeconfig"] and len(dict["hordeconfig"]) > 1:
horde_name_var.set(dict["hordeconfig"][0])
horde_gen_var.set(dict["hordeconfig"][1])
horde_context_var.set(dict["hordeconfig"][2])
except (KeyError, IndexError):
pass

if dict["hordeconfig"] and len(dict["hordeconfig"]) > 1:
horde_name_var.set(dict["hordeconfig"][0])
horde_gen_var.set(dict["hordeconfig"][1])
horde_context_var.set(dict["hordeconfig"][2])


def save_config():
file_type = [("KoboldCpp Settings", "*.kcpps")]
filename = asksaveasfile(filetypes=file_type, defaultextension=file_type)
Expand Down

0 comments on commit 521ad6b

Please sign in to comment.