-
Notifications
You must be signed in to change notification settings - Fork 864
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #51 from VikParuchuri/dev
Update line detection, beta layout detection
- Loading branch information
Showing
26 changed files
with
1,154 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import argparse | ||
import json | ||
from collections import defaultdict | ||
|
||
import datasets | ||
from surya.settings import settings | ||
from google.cloud import vision | ||
import hashlib | ||
import os | ||
from tqdm import tqdm | ||
import io | ||
|
||
DATA_DIR = os.path.join(settings.BASE_DIR, settings.DATA_DIR) | ||
RESULT_DIR = os.path.join(settings.BASE_DIR, settings.RESULT_DIR) | ||
|
||
rtl_langs = ["ar", "fa", "he", "ur", "ps", "sd", "yi", "ug"] | ||
|
||
def polygon_to_bbox(polygon): | ||
x = [vertex["x"] for vertex in polygon["vertices"]] | ||
y = [vertex["y"] for vertex in polygon["vertices"]] | ||
return (min(x), min(y), max(x), max(y)) | ||
|
||
|
||
def text_with_break(text, property, is_rtl=False): | ||
break_type = None | ||
prefix = False | ||
if property: | ||
if "detectedBreak" in property: | ||
if "type" in property["detectedBreak"]: | ||
break_type = property["detectedBreak"]["type"] | ||
if "isPrefix" in property["detectedBreak"]: | ||
prefix = property["detectedBreak"]["isPrefix"] | ||
break_char = "" | ||
if break_type == 1: | ||
break_char = " " | ||
if break_type == 5: | ||
break_char = "\n" | ||
|
||
if is_rtl: | ||
prefix = not prefix | ||
|
||
if prefix: | ||
text = break_char + text | ||
else: | ||
text = text + break_char | ||
return text | ||
|
||
|
||
def bbox_overlap_pct(box1, box2): | ||
x1, y1, x2, y2 = box1 | ||
x3, y3, x4, y4 = box2 | ||
dx = min(x2, x4) - max(x1, x3) | ||
dy = min(y2, y4) - max(y1, y3) | ||
if (dx >= 0) and (dy >= 0): | ||
return dx * dy / ((x2 - x1) * (y2 - y1)) | ||
return 0 | ||
|
||
|
||
def annotate_image(img, client, language, cache_dir): | ||
img_byte_arr = io.BytesIO() | ||
img.save(img_byte_arr, format=img.format) | ||
img_byte_arr = img_byte_arr.getvalue() | ||
|
||
img_hash = hashlib.sha256(img_byte_arr).hexdigest() | ||
cache_path = os.path.join(cache_dir, f"{img_hash}.json") | ||
if os.path.exists(cache_path): | ||
with open(cache_path, "r") as f: | ||
response = json.load(f) | ||
return response | ||
|
||
gc_image = vision.Image(content=img_byte_arr) | ||
context = vision.ImageContext(language_hints=[language]) | ||
response = client.document_text_detection(image=gc_image, image_context=context) | ||
response_json = vision.AnnotateImageResponse.to_json(response) | ||
loaded_response = json.loads(response_json) | ||
with open(cache_path, "w+") as f: | ||
json.dump(loaded_response, f) | ||
return loaded_response | ||
|
||
|
||
def get_line_text(response, lines, is_rtl=False): | ||
document = response["fullTextAnnotation"] | ||
|
||
bounds = [] | ||
for page in document["pages"]: | ||
for block in page["blocks"]: | ||
for paragraph in block["paragraphs"]: | ||
for word in paragraph["words"]: | ||
for symbol in word["symbols"]: | ||
bounds.append((symbol["boundingBox"], symbol["text"], symbol.get("property"))) | ||
|
||
bboxes = [(polygon_to_bbox(b[0]), text_with_break(b[1], b[2], is_rtl)) for b in bounds] | ||
line_boxes = defaultdict(list) | ||
for i, bbox in enumerate(bboxes): | ||
max_overlap_pct = 0 | ||
max_overlap_idx = None | ||
for j, line in enumerate(lines): | ||
overlap = bbox_overlap_pct(bbox[0], line) | ||
if overlap > max_overlap_pct: | ||
max_overlap_pct = overlap | ||
max_overlap_idx = j | ||
if max_overlap_idx is not None: | ||
line_boxes[max_overlap_idx].append(bbox) | ||
|
||
ocr_lines = [] | ||
for j, line in enumerate(lines): | ||
ocr_bboxes = sorted(line_boxes[j], key=lambda x: x[0][0]) | ||
if is_rtl: | ||
ocr_bboxes = list(reversed(ocr_bboxes)) | ||
ocr_text = "".join([b[1] for b in ocr_bboxes]) | ||
ocr_lines.append(ocr_text) | ||
|
||
assert len(ocr_lines) == len(lines) | ||
return ocr_lines | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Label text in dataset with google cloud vision.") | ||
parser.add_argument("--project_id", type=str, help="Google cloud project id.", required=True) | ||
parser.add_argument("--service_account", type=str, help="Path to service account json.", required=True) | ||
parser.add_argument("--max", type=int, help="Maximum number of pages to label.", default=None) | ||
args = parser.parse_args() | ||
|
||
cache_dir = os.path.join(DATA_DIR, "gcloud_cache") | ||
os.makedirs(cache_dir, exist_ok=True) | ||
|
||
dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split="train") | ||
client = vision.ImageAnnotatorClient.from_service_account_json(args.service_account) | ||
|
||
all_gc_lines = [] | ||
for i in tqdm(range(len(dataset))): | ||
img = dataset[i]["image"] | ||
lines = dataset[i]["bboxes"] | ||
language = dataset[i]["language"] | ||
|
||
response = annotate_image(img, client, language, cache_dir) | ||
ocr_lines = get_line_text(response, lines, is_rtl=language in rtl_langs) | ||
|
||
all_gc_lines.append(ocr_lines) | ||
|
||
if args.max is not None and i >= args.max: | ||
break | ||
|
||
with open(os.path.join(RESULT_DIR, "gcloud_ocr.json"), "w+") as f: | ||
json.dump(all_gc_lines, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import argparse | ||
import copy | ||
import json | ||
from collections import defaultdict | ||
|
||
from surya.detection import batch_text_detection | ||
from surya.input.load import load_from_folder, load_from_file | ||
from surya.layout import batch_layout_detection | ||
from surya.model.detection.segformer import load_model, load_processor | ||
from surya.postprocessing.heatmap import draw_polys_on_image | ||
from surya.settings import settings | ||
import os | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).") | ||
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.") | ||
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) | ||
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) | ||
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) | ||
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) | ||
args = parser.parse_args() | ||
|
||
print("Layout detection is currently in beta! There may be issues with the output.") | ||
|
||
model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) | ||
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) | ||
det_model = load_model() | ||
det_processor = load_processor() | ||
|
||
if os.path.isdir(args.input_path): | ||
images, names = load_from_folder(args.input_path, args.max) | ||
folder_name = os.path.basename(args.input_path) | ||
else: | ||
images, names = load_from_file(args.input_path, args.max) | ||
folder_name = os.path.basename(args.input_path).split(".")[0] | ||
|
||
line_predictions = batch_text_detection(images, det_model, det_processor) | ||
|
||
layout_predictions = batch_layout_detection(images, model, processor, line_predictions) | ||
result_path = os.path.join(args.results_dir, folder_name) | ||
os.makedirs(result_path, exist_ok=True) | ||
|
||
for idx, (layout_pred, line_pred, name) in enumerate(zip(layout_predictions, line_predictions, names)): | ||
blocks = layout_pred.bboxes | ||
for line in line_pred.vertical_lines: | ||
new_blocks = [] | ||
for block in blocks: | ||
block_modified = False | ||
|
||
if line.bbox[0] > block.bbox[0] and line.bbox[2] < block.bbox[2]: | ||
overlap_pct = (min(line.bbox[3], block.bbox[3]) - max(line.bbox[1], block.bbox[1])) / ( | ||
block.bbox[3] - block.bbox[1]) | ||
if overlap_pct > 0.5: | ||
block1 = copy.deepcopy(block) | ||
block2 = copy.deepcopy(block) | ||
block1.bbox[2] = line.bbox[0] | ||
block2.bbox[0] = line.bbox[2] | ||
new_blocks.append(block1) | ||
new_blocks.append(block2) | ||
block_modified = True | ||
if not block_modified: | ||
new_blocks.append(block) | ||
blocks = new_blocks | ||
layout_pred.bboxes = blocks | ||
|
||
if args.images: | ||
for idx, (image, layout_pred, line_pred, name) in enumerate(zip(images, layout_predictions, line_predictions, names)): | ||
polygons = [p.polygon for p in layout_pred.bboxes] | ||
labels = [p.label for p in layout_pred.bboxes] | ||
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels) | ||
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png")) | ||
|
||
if args.debug: | ||
heatmap = layout_pred.segmentation_map | ||
heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png")) | ||
|
||
predictions_by_page = defaultdict(list) | ||
for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)): | ||
out_pred = pred.model_dump(exclude=["segmentation_map"]) | ||
out_pred["page"] = len(predictions_by_page[name]) + 1 | ||
predictions_by_page[name].append(out_pred) | ||
|
||
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: | ||
json.dump(predictions_by_page, f, ensure_ascii=False) | ||
|
||
print(f"Wrote results to {result_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.