Skip to content

Commit

Permalink
Merge pull request #51 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Update line detection, beta layout detection
  • Loading branch information
VikParuchuri authored Mar 6, 2024
2 parents 5919af1 + aefb78d commit f68379a
Show file tree
Hide file tree
Showing 26 changed files with 1,154 additions and 119 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ pip install streamlit
surya_gui
```

Pass the `--math` command line argument to use the math detection model instead of the default model. This will detect math better, but will be worse at everything else.

## OCR (text recognition)

You can OCR text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
Expand All @@ -81,6 +83,7 @@ The `results.json` file will contain a json dictionary where the keys are the in

- `text_lines` - the detected text and bounding boxes for each line
- `text` - the text in the line
- `confidence` - the confidence of the model in the detected text (0-1)
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `languages` - the languages specified for the page
Expand Down Expand Up @@ -120,12 +123,14 @@ surya_detect DATA_PATH --images
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--math` uses a specialized math detection model instead of the default model. This will be better at math, but worse at everything else.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

- `bboxes` - detected bounding boxes for text
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `confidence` - the confidence of the model in the detected text (0-1)
- `vertical_lines` - vertical lines detected in the document
- `bbox` - the axis-aligned line coordinates.
- `horizontal_lines` - horizontal lines detected in the document
Expand Down
4 changes: 2 additions & 2 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from surya.benchmark.tesseract import tesseract_parallel
from surya.model.detection.segformer import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images
from surya.detection import batch_detection
from surya.detection import batch_text_detection
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.util import rescale_bbox
from surya.settings import settings
Expand Down Expand Up @@ -54,7 +54,7 @@ def main():
correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])

start = time.time()
predictions = batch_detection(images, model, processor)
predictions = batch_text_detection(images, model, processor)
surya_time = time.time() - start

start = time.time()
Expand Down
149 changes: 149 additions & 0 deletions benchmark/gcloud_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import argparse
import json
from collections import defaultdict

import datasets
from surya.settings import settings
from google.cloud import vision
import hashlib
import os
from tqdm import tqdm
import io

DATA_DIR = os.path.join(settings.BASE_DIR, settings.DATA_DIR)
RESULT_DIR = os.path.join(settings.BASE_DIR, settings.RESULT_DIR)

rtl_langs = ["ar", "fa", "he", "ur", "ps", "sd", "yi", "ug"]

def polygon_to_bbox(polygon):
x = [vertex["x"] for vertex in polygon["vertices"]]
y = [vertex["y"] for vertex in polygon["vertices"]]
return (min(x), min(y), max(x), max(y))


def text_with_break(text, property, is_rtl=False):
break_type = None
prefix = False
if property:
if "detectedBreak" in property:
if "type" in property["detectedBreak"]:
break_type = property["detectedBreak"]["type"]
if "isPrefix" in property["detectedBreak"]:
prefix = property["detectedBreak"]["isPrefix"]
break_char = ""
if break_type == 1:
break_char = " "
if break_type == 5:
break_char = "\n"

if is_rtl:
prefix = not prefix

if prefix:
text = break_char + text
else:
text = text + break_char
return text


def bbox_overlap_pct(box1, box2):
x1, y1, x2, y2 = box1
x3, y3, x4, y4 = box2
dx = min(x2, x4) - max(x1, x3)
dy = min(y2, y4) - max(y1, y3)
if (dx >= 0) and (dy >= 0):
return dx * dy / ((x2 - x1) * (y2 - y1))
return 0


def annotate_image(img, client, language, cache_dir):
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format=img.format)
img_byte_arr = img_byte_arr.getvalue()

img_hash = hashlib.sha256(img_byte_arr).hexdigest()
cache_path = os.path.join(cache_dir, f"{img_hash}.json")
if os.path.exists(cache_path):
with open(cache_path, "r") as f:
response = json.load(f)
return response

gc_image = vision.Image(content=img_byte_arr)
context = vision.ImageContext(language_hints=[language])
response = client.document_text_detection(image=gc_image, image_context=context)
response_json = vision.AnnotateImageResponse.to_json(response)
loaded_response = json.loads(response_json)
with open(cache_path, "w+") as f:
json.dump(loaded_response, f)
return loaded_response


def get_line_text(response, lines, is_rtl=False):
document = response["fullTextAnnotation"]

bounds = []
for page in document["pages"]:
for block in page["blocks"]:
for paragraph in block["paragraphs"]:
for word in paragraph["words"]:
for symbol in word["symbols"]:
bounds.append((symbol["boundingBox"], symbol["text"], symbol.get("property")))

bboxes = [(polygon_to_bbox(b[0]), text_with_break(b[1], b[2], is_rtl)) for b in bounds]
line_boxes = defaultdict(list)
for i, bbox in enumerate(bboxes):
max_overlap_pct = 0
max_overlap_idx = None
for j, line in enumerate(lines):
overlap = bbox_overlap_pct(bbox[0], line)
if overlap > max_overlap_pct:
max_overlap_pct = overlap
max_overlap_idx = j
if max_overlap_idx is not None:
line_boxes[max_overlap_idx].append(bbox)

ocr_lines = []
for j, line in enumerate(lines):
ocr_bboxes = sorted(line_boxes[j], key=lambda x: x[0][0])
if is_rtl:
ocr_bboxes = list(reversed(ocr_bboxes))
ocr_text = "".join([b[1] for b in ocr_bboxes])
ocr_lines.append(ocr_text)

assert len(ocr_lines) == len(lines)
return ocr_lines


def main():
parser = argparse.ArgumentParser(description="Label text in dataset with google cloud vision.")
parser.add_argument("--project_id", type=str, help="Google cloud project id.", required=True)
parser.add_argument("--service_account", type=str, help="Path to service account json.", required=True)
parser.add_argument("--max", type=int, help="Maximum number of pages to label.", default=None)
args = parser.parse_args()

cache_dir = os.path.join(DATA_DIR, "gcloud_cache")
os.makedirs(cache_dir, exist_ok=True)

dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split="train")
client = vision.ImageAnnotatorClient.from_service_account_json(args.service_account)

all_gc_lines = []
for i in tqdm(range(len(dataset))):
img = dataset[i]["image"]
lines = dataset[i]["bboxes"]
language = dataset[i]["language"]

response = annotate_image(img, client, language, cache_dir)
ocr_lines = get_line_text(response, lines, is_rtl=language in rtl_langs)

all_gc_lines.append(ocr_lines)

if args.max is not None and i >= args.max:
break

with open(os.path.join(RESULT_DIR, "gcloud_ocr.json"), "w+") as f:
json.dump(all_gc_lines, f)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def main():
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
pred_text = [l.text for l in pred.text_lines]
pred_image = draw_text_on_image(bbox, pred_text, image.size)
pred_image = draw_text_on_image(bbox, pred_text, image.size, lang)
pred_image.save(os.path.join(result_path, pred_image_name))
ref_image = draw_text_on_image(bbox, ref_text, image.size)
ref_image = draw_text_on_image(bbox, ref_text, image.size, lang)
ref_image.save(os.path.join(result_path, ref_image_name))
image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))

Expand Down
98 changes: 98 additions & 0 deletions detect_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import copy
import json
from collections import defaultdict

from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os


def main():
parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

print("Layout detection is currently in beta! There may be issues with the output.")

model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
det_model = load_model()
det_processor = load_processor()

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

line_predictions = batch_text_detection(images, det_model, det_processor)

layout_predictions = batch_layout_detection(images, model, processor, line_predictions)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

for idx, (layout_pred, line_pred, name) in enumerate(zip(layout_predictions, line_predictions, names)):
blocks = layout_pred.bboxes
for line in line_pred.vertical_lines:
new_blocks = []
for block in blocks:
block_modified = False

if line.bbox[0] > block.bbox[0] and line.bbox[2] < block.bbox[2]:
overlap_pct = (min(line.bbox[3], block.bbox[3]) - max(line.bbox[1], block.bbox[1])) / (
block.bbox[3] - block.bbox[1])
if overlap_pct > 0.5:
block1 = copy.deepcopy(block)
block2 = copy.deepcopy(block)
block1.bbox[2] = line.bbox[0]
block2.bbox[0] = line.bbox[2]
new_blocks.append(block1)
new_blocks.append(block2)
block_modified = True
if not block_modified:
new_blocks.append(block)
blocks = new_blocks
layout_pred.bboxes = blocks

if args.images:
for idx, (image, layout_pred, line_pred, name) in enumerate(zip(images, layout_predictions, line_predictions, names)):
polygons = [p.polygon for p in layout_pred.bboxes]
labels = [p.label for p in layout_pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png"))

if args.debug:
heatmap = layout_pred.segmentation_map
heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png"))

predictions_by_page = defaultdict(list)
for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)):
out_pred = pred.model_dump(exclude=["segmentation_map"])
out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(predictions_by_page, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")


if __name__ == "__main__":
main()







12 changes: 7 additions & 5 deletions detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.segformer import load_model, load_processor
from surya.detection import batch_detection
from surya.detection import batch_text_detection
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
Expand All @@ -20,10 +20,12 @@ def main():
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)
args = parser.parse_args()

model = load_model()
processor = load_processor()
checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT
model = load_model(checkpoint=checkpoint)
processor = load_processor(checkpoint=checkpoint)

if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
Expand All @@ -32,7 +34,7 @@ def main():
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

predictions = batch_detection(images, model, processor)
predictions = batch_text_detection(images, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

Expand All @@ -58,7 +60,7 @@ def main():
out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+") as f:
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(predictions_by_page, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")
Expand Down
Loading

0 comments on commit f68379a

Please sign in to comment.