Skip to content

Commit

Permalink
feat(api): initial support for textual inversion embeddings from civi…
Browse files Browse the repository at this point in the history
…tai and others (#179)
  • Loading branch information
ssube committed Mar 2, 2023
1 parent 1f3a5f6 commit 46aac26
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
7 changes: 6 additions & 1 deletion api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,16 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for inversion in model.get("inversions", []):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", "huggingface")
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
ctx, inversion_name, model["source"], inversion_source
ctx,
inversion_name,
model["source"],
inversion_source,
inversion_format,
)

except Exception as e:
Expand Down
36 changes: 25 additions & 11 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@torch.no_grad()
def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str
context: ConversionContext, name: str, base_model: str, inversion: str, format: str
):
dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info(
Expand All @@ -26,11 +26,31 @@ def convert_diffusion_textual_inversion(

makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)

embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
if format == "huggingface":
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")

with open(token_file, "r") as f:
token = f.read()
with open(token_file, "r") as f:
token = f.read()

loaded_embeds = torch.load(embeds_file, map_location=context.map_location)

# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]
elif format == "embeddings":
loaded_embeds = torch.load(inversion, map_location=context.map_location)

string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]

token = name

# separate token and embeds
trained_token = list(string_to_token.keys())[0]
embeds = string_to_param[trained_token]

logger.info("found embedding for token %s: %s", trained_token, embeds.shape)

tokenizer = CLIPTokenizer.from_pretrained(
base_model,
Expand All @@ -41,12 +61,6 @@ def convert_diffusion_textual_inversion(
subfolder="text_encoder",
)

loaded_embeds = torch.load(embeds_file, map_location=context.map_location)

# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]

# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
Expand Down

0 comments on commit 46aac26

Please sign in to comment.