Skip to content

Commit

Permalink
Merge pull request #104 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Bump transformers version
  • Loading branch information
VikParuchuri committed May 17, 2024
2 parents c89e015 + 79831d2 commit 7a65c45
Show file tree
Hide file tree
Showing 17 changed files with 394 additions and 386 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ The `results.json` file will contain a json dictionary where the keys are the in
- `confidence` - the confidence of the model in the detected text (0-1)
- `vertical_lines` - vertical lines detected in the document
- `bbox` - the axis-aligned line coordinates.
- `horizontal_lines` - horizontal lines detected in the document
- `bbox` - the axis-aligned line coordinates.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

Expand Down
669 changes: 329 additions & 340 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.4.5"
version = "0.4.6"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand All @@ -21,8 +21,8 @@ include = [

[tool.poetry.dependencies]
python = ">=3.9,<3.13,!=3.9.7"
transformers = "4.36.2"
torch = "^2.2.2"
transformers = "^4.41.0"
torch = "^2.3.0"
pydantic = "^2.5.3"
pydantic-settings = "^2.1.0"
python-dotenv = "^1.0.0"
Expand Down
5 changes: 4 additions & 1 deletion run_ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ def run_app():
if args.math:
cmd.append("--")
cmd.append("--math")
subprocess.run(cmd)
subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})

if __name__ == "__main__":
run_app()
6 changes: 2 additions & 4 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from surya.model.detection.segformer import SegformerForRegressionMask
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines, get_horizontal_lines
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image, split_image, get_total_splits
from surya.schema import TextDetectionResult
from surya.settings import settings
Expand Down Expand Up @@ -109,12 +109,10 @@ def parallel_get_lines(preds, orig_sizes):
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes)
horizontal_lines = get_horizontal_lines(affinity_map, affinity_size, orig_sizes)

result = TextDetectionResult(
bboxes=bboxes,
vertical_lines=vertical_lines,
horizontal_lines=horizontal_lines,
heatmap=heat_img,
affinity_map=aff_img,
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
Expand All @@ -125,7 +123,7 @@ def parallel_get_lines(preds, orig_sizes):
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
results = []
if len(images) == 1: # Ensures we don't parallelize with streamlit
if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit
for i in range(len(images)):
result = parallel_get_lines(preds[i], orig_sizes[i])
results.append(result)
Expand Down
25 changes: 16 additions & 9 deletions surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,38 @@ def slice_bboxes_from_image(image: Image.Image, bboxes):


def slice_polys_from_image(image: Image.Image, polys):
image_array = np.array(image)
lines = []
for idx, poly in enumerate(polys):
lines.append(slice_and_pad_poly(image, poly, idx))
lines.append(slice_and_pad_poly(image, image_array, poly, idx))
return lines


def slice_and_pad_poly(image: Image.Image, coordinates, idx):
def slice_and_pad_poly(image: Image.Image, image_array: np.array, coordinates, idx):
# Create a mask for the polygon
mask = Image.new('L', image.size, 0)

# coordinates must be in tuple form for PIL
# Draw polygon onto mask
coordinates = [(corner[0], corner[1]) for corner in coordinates]
ImageDraw.Draw(mask).polygon(coordinates, outline=1, fill=1)
bbox = mask.getbbox()

if bbox is None:
return None

mask = np.array(mask)

# Extract the polygonal area from the image
polygon_image = np.array(image)
# We mask out anything not in the polygon
polygon_image = image_array.copy()
polygon_image[mask == 0] = settings.RECOGNITION_PAD_VALUE
polygon_image = Image.fromarray(polygon_image)

rectangle = Image.new('RGB', (bbox[2] - bbox[0], bbox[3] - bbox[1]), 'white')
# Crop out the bbox, and ensure we pad the area outside the polygon with the pad value
cropped_polygon = polygon_image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
rectangle = np.full((bbox[3] - bbox[1], bbox[2] - bbox[0], 3), settings.RECOGNITION_PAD_VALUE, dtype=np.uint8)
rectangle[:, :] = cropped_polygon

# Paste the polygon into the rectangle
rectangle.paste(polygon_image.crop(bbox), (0, 0))
rectangle_image = Image.fromarray(rectangle)

return rectangle
return rectangle_image

2 changes: 1 addition & 1 deletion surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
id2label = model.config.id2label

results = []
if len(images) == 1: # Ensures we don't parallelize with streamlit
if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit
for i in range(len(images)):
result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
results.append(result)
Expand Down
2 changes: 1 addition & 1 deletion surya/model/detection/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TO
print("Warning: MPS may have poor results. This is a bug with MPS, see here - https://github.com/pytorch/pytorch/issues/84936")
model = model.to(device)
model = model.eval()
print(f"Loading detection model {checkpoint} on device {device} with dtype {dtype}")
print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
return model


Expand Down
2 changes: 1 addition & 1 deletion surya/model/ordering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH

model = model.to(device)
model = model.eval()
print(f"Loading reading order model {checkpoint} on device {device} with dtype {dtype}")
print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}")
return model
2 changes: 1 addition & 1 deletion surya/model/recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings

model = model.to(device)
model = model.eval()
print(f"Loading recognition model {checkpoint} on device {device} with dtype {dtype}")
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
return model


Expand Down
34 changes: 27 additions & 7 deletions surya/ocr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import List
from tqdm import tqdm

Expand All @@ -10,6 +11,7 @@
from surya.postprocessing.text import truncate_repetitions, sort_text_lines
from surya.recognition import batch_recognition
from surya.schema import TextLine, OCRResult
from surya.settings import settings


def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]:
Expand Down Expand Up @@ -60,20 +62,38 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model
return predictions_by_image


def parallel_slice_polys(det_pred, image):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
return slices


def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]:
det_predictions = batch_text_detection(images, det_model, det_processor)
if det_model.device.type == "cuda":
torch.cuda.empty_cache() # Empty cache from first model run

slice_map = []
all_slices = []

if settings.IN_STREAMLIT:
all_slices = [parallel_slice_polys(det_pred, image) for det_pred, image in zip(det_predictions, images)]
else:
futures = []
with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor:
for image_idx in range(len(images)):
future = executor.submit(parallel_slice_polys, det_predictions[image_idx], images[image_idx])
futures.append(future)

for future in futures:
all_slices.append(future.result())

slice_map = []
all_langs = []
for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([lang] * len(slices))
for idx, (slice, lang) in enumerate(zip(all_slices, langs)):
slice_map.append(len(slice))
all_langs.extend([lang] * len(slice))

all_slices = [slice for sublist in all_slices for slice in sublist]

rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)

Expand Down
2 changes: 1 addition & 1 deletion surya/ordering.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import List, Optional
from typing import List
import torch
from PIL import Image

Expand Down
9 changes: 1 addition & 8 deletions surya/postprocessing/affinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,4 @@ def get_vertical_lines(image, processor_size, image_size, divisor=20, x_toleranc
# Always start with top left of page
vertical_lines[0].bbox[1] = 0

return vertical_lines


def get_horizontal_lines(affinity_map, processor_size, image_size) -> List[ColumnLine]:
horizontal_lines = get_detected_lines(affinity_map, horizontal=True)
for line in horizontal_lines:
line.rescale_bbox(processor_size, image_size)
return horizontal_lines
return vertical_lines
9 changes: 4 additions & 5 deletions surya/postprocessing/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def detect_boxes(linemap, text_threshold, low_text):

ret, text_score = cv2.threshold(linemap, low_text, 1, cv2.THRESH_BINARY)

text_score_comb = np.clip(text_score, 0, 1)
label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
text_score_comb = np.clip(text_score, 0, 1).astype(np.uint8)
label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb, connectivity=4)

det = []
confidences = []
Expand Down Expand Up @@ -140,9 +140,8 @@ def detect_boxes(linemap, text_threshold, low_text):
box = np.roll(box, 4-startidx, 0)
box = np.array(box)

mask = np.zeros_like(linemap).astype(np.uint8)
cv2.fillPoly(mask, [np.int32(box)], 255)
mask = mask.astype(np.float32) / 255
mask = np.zeros_like(linemap, dtype=np.uint8)
cv2.fillPoly(mask, [np.int32(box)], 1)

roi = np.where(mask == 1, linemap, 0)
confidence = np.mean(roi[roi != 0])
Expand Down
1 change: 1 addition & 0 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor

output_text = []
confidences = []

for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
batch_langs = languages[i:i+batch_size]
has_math = ["_math" in lang for lang in batch_langs]
Expand Down
1 change: 0 additions & 1 deletion surya/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ class OCRResult(BaseModel):
class TextDetectionResult(BaseModel):
bboxes: List[PolygonBox]
vertical_lines: List[ColumnLine]
horizontal_lines: List[ColumnLine]
heatmap: Any
affinity_map: Any
image_bbox: List[float]
Expand Down
3 changes: 2 additions & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Settings(BaseSettings):
# General
TORCH_DEVICE: Optional[str] = None
IMAGE_DPI: int = 96
IN_STREAMLIT: bool = False # Whether we're running in streamlit

# Paths
DATA_DIR: str = "data"
Expand Down Expand Up @@ -69,7 +70,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
}
RECOGNITION_FONT_DL_BASE: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench"
RECOGNITION_PAD_VALUE: int = 0 # Should be 0 or 255
RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255

# Layout
LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout2"
Expand Down

0 comments on commit 7a65c45

Please sign in to comment.