Skip to content

Commit

Permalink
Refactor CLI scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 8, 2025
1 parent 3cf4d29 commit bdc244e
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 173 deletions.
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ surya_ocr DATA_PATH
- `--langs` is an optional (but recommended) argument that specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`.
- `--lang_file` if you want to use a different language for different PDFs/images, you can optionally specify languages in a file. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`.
- `--images` will save images of the pages and detected text lines (optional)
- `--results_dir` specifies the directory to save results to instead of the default
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--start_page` specifies the page number to start processing from
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

Expand Down Expand Up @@ -165,8 +164,8 @@ surya_detect DATA_PATH

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

Expand Down Expand Up @@ -206,8 +205,8 @@ surya_layout DATA_PATH

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

Expand Down Expand Up @@ -247,8 +246,8 @@ surya_table DATA_PATH

- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected table cells + rows and columns (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--output_dir` specifies the directory to save results to instead of the default
- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`.
- `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible.
- `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table.

Expand Down
42 changes: 15 additions & 27 deletions detect_layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
import argparse
import click
import copy
import json
from collections import defaultdict
Expand All @@ -8,51 +8,39 @@
from surya.layout import LayoutPredictor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
from surya.common.cli.config import CLILoader
import os


def main():
parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()
@click.command(help="Detect layout of an input file or folder (PDFs or image).")
@CLILoader.common_options
def main(input_path: str, **kwargs):
loader = CLILoader(input_path, kwargs)

layout_predictor = LayoutPredictor()

if os.path.isdir(args.input_path):
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
layout_predictions = layout_predictor(images)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
if args.debug:
layout_predictions = layout_predictor(loader.images)

if loader.debug:
print(f"Layout took {time.time() - start} seconds")

if args.images:
for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)):
if loader.images:
for idx, (image, layout_pred, name) in enumerate(zip(loader.images, layout_predictions, loader.names)):
polygons = [p.polygon for p in layout_pred.bboxes]
labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels)
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png"))
bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_layout.png"))

predictions_by_page = defaultdict(list)
for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)):
for idx, (pred, name, image) in enumerate(zip(layout_predictions, loader.names, loader.images)):
out_pred = pred.model_dump()
out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(predictions_by_page, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")
print(f"Wrote results to {loader.result_path}")


if __name__ == "__main__":
Expand Down
44 changes: 16 additions & 28 deletions detect_text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse
import click
import copy
import json
import time
Expand All @@ -7,57 +7,45 @@
from surya.input.load import load_from_folder, load_from_file
from surya.detection import DetectionPredictor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.common.cli.config import CLILoader
from surya.settings import settings
import os
from tqdm import tqdm


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()
@click.command(help="Detect bboxes in an input file or folder (PDFs or image).")
@CLILoader.common_options
def main(input_path: str, **kwargs):
loader = CLILoader(input_path, kwargs)

det_predictor = DetectionPredictor()

if os.path.isdir(args.input_path):
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
predictions = det_predictor(images, include_maps=args.debug)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
predictions = det_predictor(loader.images, include_maps=loader.debug)
end = time.time()
if args.debug:
if loader.debug:
print(f"Detection took {end - start} seconds")

if args.images:
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
if loader.images:
for idx, (image, pred, name) in enumerate(zip(loader.images, predictions, loader.names)):
polygons = [p.polygon for p in pred.bboxes]
bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image))
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))
bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png"))

if args.debug:
if loader.debug:
heatmap = pred.heatmap
heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png"))
heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png"))

predictions_by_page = defaultdict(list)
for idx, (pred, name, image) in enumerate(zip(predictions, names, images)):
for idx, (pred, name, image) in enumerate(zip(predictions, loader.names, loader.images)):
out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"])
out_pred["page"] = len(predictions_by_page[name]) + 1
predictions_by_page[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(predictions_by_page, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")
print(f"Wrote results to {loader.result_path}")


if __name__ == "__main__":
Expand Down
67 changes: 28 additions & 39 deletions ocr_text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import argparse
import click
import json
import time
from collections import defaultdict
Expand All @@ -9,74 +9,63 @@
from surya.input.load import load_from_folder, load_from_file, load_lang_file
from surya.postprocessing.text import draw_text_on_image
from surya.recognition import RecognitionPredictor
from surya.common.cli.config import CLILoader
from surya.settings import settings

@click.command(help="Detect bboxes in an input file or folder (PDFs or image).")
@CLILoader.common_options
@click.option("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None)
@click.option("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None)
def main(input_path: str, langs: str, lang_file: str, **kwargs):
loader = CLILoader(input_path, kwargs, highres=True)

def main():
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None)
parser.add_argument("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None)
parser.add_argument("--debug", action="store_true", help="Enable debug logging.", default=False)
args = parser.parse_args()

if os.path.isdir(args.input_path):
images, names, _ = load_from_folder(args.input_path, args.max, args.start_page)
highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
folder_name = os.path.basename(args.input_path)
else:
images, names, _ = load_from_file(args.input_path, args.max, args.start_page)
highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
folder_name = os.path.basename(args.input_path).split(".")[0]

if args.lang_file:
if lang_file:
# We got all of our language settings from a file
langs = load_lang_file(args.lang_file, names)
langs = load_lang_file(lang_file, loader.names)
for lang in langs:
replace_lang_with_code(lang)
image_langs = langs
elif args.langs:
elif langs:
# We got our language settings from the input
langs = args.langs.split(",")
langs = langs.split(",")
replace_lang_with_code(langs)
image_langs = [langs] * len(images)
image_langs = [langs] * len(loader.images)
else:
image_langs = [None] * len(images)
image_langs = [None] * len(loader.images)

det_predictor = DetectionPredictor()
rec_predictor = RecognitionPredictor()

result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

start = time.time()
predictions_by_image = rec_predictor(images, image_langs, det_predictor=det_predictor, highres_images=highres_images)
if args.debug:
predictions_by_image = rec_predictor(
loader.images,
image_langs,
det_predictor=det_predictor,
highres_images=loader.highres_images
)

if loader.debug:
print(f"OCR took {time.time() - start:.2f} seconds")
max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines])
print(f"Max chars: {max_chars}")

if args.images:
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
if loader.images:
for idx, (name, image, pred, langs) in enumerate(zip(loader.names, loader.images, predictions_by_image, image_langs)):
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs)
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png"))

out_preds = defaultdict(list)
for name, pred, image in zip(names, predictions_by_image, images):
for name, pred, image in zip(loader.names, predictions_by_image, loader.images):
out_pred = pred.model_dump()
out_pred["page"] = len(out_preds[name]) + 1
out_preds[name].append(out_pred)

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_preds, f, ensure_ascii=False)

print(f"Wrote results to {result_path}")
print(f"Wrote results to {loader.result_path}")


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ tabulate = "^0.9.0"
filetype = "^1.2.0"
ftfy = "^6.1.3"
pdftext = "~0.4.1"
click = "^8.1.8"

[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
Empty file added surya/common/cli/__init__.py
Empty file.
Loading

0 comments on commit bdc244e

Please sign in to comment.