Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New architecture for detection/layout #148

Merged
merged 10 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ Install with:
pip install surya-ocr
```

Model weights will automatically download the first time you run surya. Note that this does not work with the latest version of transformers `4.37+` [yet](https://github.com/huggingface/transformers/issues/28846#issuecomment-1926109135), so you will need to keep `4.36.2`, which is installed with surya.
Model weights will automatically download the first time you run surya.

# Usage

- Inspect the settings in `surya/settings.py`. You can override any settings with environment variables.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. For text detection, the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.

## Interactive App

Expand All @@ -79,8 +79,6 @@ pip install streamlit
surya_gui
```

Pass the `--math` command line argument to use the math text detection model instead of the default model. This will detect math better, but will be worse at everything else.

## OCR (text recognition)

This command will write out a json file with the detected text and bboxes:
Expand Down Expand Up @@ -151,7 +149,6 @@ surya_detect DATA_PATH --images
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--math` uses a specialized math detection model instead of the default model. This will be better at math, but worse at everything else.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

Expand All @@ -166,14 +163,14 @@ The `results.json` file will contain a json dictionary where the keys are the in

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `2`.
Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `440MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.

### From python

```python
from PIL import Image
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor

image = Image.open(IMAGE_PATH)
model, processor = load_model(), load_processor()
Expand Down Expand Up @@ -207,15 +204,15 @@ The `results.json` file will contain a json dictionary where the keys are the in

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `2`.
Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `400MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.

### From python

```python
from PIL import Image
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.settings import settings

image = Image.open(IMAGE_PATH)
Expand Down Expand Up @@ -334,16 +331,16 @@ For Google Cloud, I aligned the output from Google Cloud with the ground truth.

![Benchmark chart](static/images/benchmark_chart_small.png)

| Model | Time (s) | Time per page (s) | precision | recall |
| Model | Time (s) | Time per page (s) | precision | recall |
|-----------|------------|---------------------|-------------|----------|
| surya | 52.6892 | 0.205817 | 0.844426 | 0.937818 |
| tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 |
| surya | 50.2099 | 0.196133 | 0.821061 | 0.956556 |
| tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 |


Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A6000 GPU, and a 32 core CPU. This was the resource usage:
Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU. This was the resource usage:

- tesseract - 32 CPU cores, or 8 workers using 4 cores each
- surya - 32 batch size, for 9GB VRAM usage
- surya - 36 batch size, for 16GB VRAM usage

**Methodology**

Expand All @@ -362,14 +359,14 @@ Then we calculate precision and recall for the whole dataset.

![Benchmark chart](static/images/benchmark_layout_chart.png)

| Layout Type | precision | recall |
|---------------|-------------|----------|
| Image | 0.95 | 0.99 |
| Table | 0.95 | 0.96 |
| Text | 0.89 | 0.95 |
| Title | 0.92 | 0.89 |
| Layout Type | precision | recall |
| ----------- | --------- | ------ |
| Image | 0.97 | 0.96 |
| Table | 0.99 | 0.99 |
| Text | 0.9 | 0.97 |
| Title | 0.94 | 0.88 |

Time per image - .79 seconds on GPU (A6000).
Time per image - .4 seconds on GPU (A10).

**Methodology**

Expand Down Expand Up @@ -446,7 +443,7 @@ python benchmark/ordering.py

# Training

Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.

Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).

Expand All @@ -455,6 +452,8 @@ Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a m
This work would not have been possible without amazing open source AI work:

- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
- [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT
- [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman
- [Donut](https://github.com/clovaai/donut) from Naver
- [transformers](https://github.com/huggingface/transformers) from huggingface
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model
Expand Down
2 changes: 1 addition & 1 deletion benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from surya.benchmark.bbox import get_pdf_lines
from surya.benchmark.metrics import precision_recall
from surya.benchmark.tesseract import tesseract_parallel
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.detection import batch_text_detection
from surya.postprocessing.heatmap import draw_polys_on_image
Expand Down
2 changes: 1 addition & 1 deletion benchmark/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from surya.benchmark.metrics import precision_recall
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.layout import batch_layout_detection
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
Expand Down
2 changes: 1 addition & 1 deletion detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
Expand Down
10 changes: 7 additions & 3 deletions detect_text.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse
import copy
import json
import time
from collections import defaultdict

from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.detection import batch_text_detection
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_polys_on_image
Expand All @@ -20,10 +21,9 @@ def main():
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)
args = parser.parse_args()

checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
model = load_model(checkpoint=checkpoint)
processor = load_processor(checkpoint=checkpoint)

Expand All @@ -34,9 +34,13 @@ def main():
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
predictions = batch_text_detection(images, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
end = time.time()
if args.debug:
print(f"Detection took {end - start} seconds")

if args.images:
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
Expand Down
15 changes: 2 additions & 13 deletions ocr_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import argparse
import io
from typing import List

import pypdfium2
import streamlit as st
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
Expand All @@ -22,18 +20,9 @@
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
from surya.settings import settings

parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.")
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)

try:
args = parser.parse_args()
except SystemExit as e:
print(f"Error parsing arguments: {e}")
os._exit(e.code)

@st.cache_resource()
def load_det_cached():
checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)


Expand Down
2 changes: 1 addition & 1 deletion ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.load import load_from_folder, load_from_file, load_lang_file
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.4.14"
version = "0.4.15"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion reading_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.ordering import batch_ordering
Expand Down
7 changes: 0 additions & 7 deletions run_ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@


def run_app():
parser = argparse.ArgumentParser(description="Run the streamlit OCR app")
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)
args = parser.parse_args()

cur_dir = os.path.dirname(os.path.abspath(__file__))
ocr_app_path = os.path.join(cur_dir, "ocr_app.py")
cmd = ["streamlit", "run", ocr_app_path]
if args.math:
cmd.append("--")
cmd.append("--math")
subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})

if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL import Image

from surya.model.detection.segformer import SegformerForRegressionMask
from surya.model.detection.model import EfficientViTForSemanticSegmentation
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
Expand All @@ -19,12 +19,14 @@ def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
if batch_size is None:
batch_size = 6
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 24
batch_size = 36
return batch_size


def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
batch_size = get_batch_size()
Expand Down
51 changes: 51 additions & 0 deletions surya/model/detection/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from transformers import PretrainedConfig


class EfficientViTConfig(PretrainedConfig):
r"""
```"""

model_type = "efficientvit"

def __init__(
self,
num_classes=2,
num_channels=3,
widths=(32, 64, 128, 256, 512),
head_dim=32,
num_stages=4,
depths=(1, 1, 1, 6, 6),
strides=(2, 2, 2, 2, 2),
hidden_sizes=(32, 64, 160, 256),
patch_size=(7, 7),
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
classifier_dropout_prob=0.0,
layer_norm_eps=1e-6,
decoder_layer_hidden_size=128,
decoder_hidden_size=512,
semantic_loss_ignore_index=255,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)

self.num_classes = num_classes
self.widths = widths
self.head_dim = head_dim

self.num_channels = num_channels
self.num_stages = num_stages
self.depths = depths
self.strides = strides
self.hidden_sizes = hidden_sizes
self.patch_size = patch_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.decoder_hidden_size = decoder_hidden_size
self.decoder_layer_hidden_size = decoder_layer_hidden_size
self.semantic_loss_ignore_index = semantic_loss_ignore_index

self.initializer_range = initializer_range
Loading
Loading