Skip to content

Commit

Permalink
Merge pull request #139 from invoke-ai/allow-passing-dataset-to-autoc…
Browse files Browse the repository at this point in the history
…aption-script

Update autocaption main function to accept either a dataset or directory
  • Loading branch information
brandonrising committed Jun 4, 2024
2 parents 0b44077 + 78d4522 commit 36e0ac7
Showing 1 changed file with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ def process_images(images: list[Image.Image], prompt: str, moondream, tokenizer)
return answers


def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_path: str):
def main(
prompt: str,
use_cpu: bool,
batch_size: int,
output_path: str,
dataset: torch.utils.data.Dataset,
):
device, dtype = select_device_and_dtype(use_cpu)
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")
Expand All @@ -52,9 +58,6 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat
).to(device=device, dtype=dtype)
moondream_model.eval()

# Prepare the dataloader.
dataset = ImageDirDataset(image_dir)
print(f"Found {len(dataset)} images in '{image_dir}'.")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False
)
Expand Down Expand Up @@ -107,4 +110,8 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat
)
args = parser.parse_args()

main(args.dir, args.prompt, args.cpu, args.batch_size, args.output)
# Prepare the dataset.
dataset = ImageDirDataset(args.dir)
print(f"Found {len(dataset)} images in '{args.dir}'.")

main(args.prompt, args.cpu, args.batch_size, args.output, dataset)

0 comments on commit 36e0ac7

Please sign in to comment.