From bdc244e573495f88e6d33a13e6c7784130d51be8 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Tue, 7 Jan 2025 19:56:39 -0500 Subject: [PATCH] Refactor CLI scripts --- README.md | 17 +++++---- detect_layout.py | 42 ++++++++-------------- detect_text.py | 44 +++++++++-------------- ocr_text.py | 67 +++++++++++++++--------------------- poetry.lock | 10 +++--- pyproject.toml | 1 + surya/common/cli/__init__.py | 0 surya/common/cli/config.py | 63 +++++++++++++++++++++++++++++++++ surya/common/polygon.py | 5 +++ surya/detection/affinity.py | 4 +-- surya/input/load.py | 46 ++++++++----------------- table_recognition.py | 50 ++++++++++----------------- 12 files changed, 176 insertions(+), 173 deletions(-) create mode 100644 surya/common/cli/__init__.py create mode 100644 surya/common/cli/config.py diff --git a/README.md b/README.md index 5c0d4994..37de982c 100644 --- a/README.md +++ b/README.md @@ -100,9 +100,8 @@ surya_ocr DATA_PATH - `--langs` is an optional (but recommended) argument that specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`. - `--lang_file` if you want to use a different language for different PDFs/images, you can optionally specify languages in a file. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`. - `--images` will save images of the pages and detected text lines (optional) -- `--results_dir` specifies the directory to save results to instead of the default -- `--max` specifies the maximum number of pages to process if you don't want to process everything -- `--start_page` specifies the page number to start processing from +- `--output_dir` specifies the directory to save results to instead of the default +- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: @@ -165,8 +164,8 @@ surya_detect DATA_PATH - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected text lines (optional) -- `--max` specifies the maximum number of pages to process if you don't want to process everything -- `--results_dir` specifies the directory to save results to instead of the default +- `--output_dir` specifies the directory to save results to instead of the default +- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: @@ -206,8 +205,8 @@ surya_layout DATA_PATH - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected text lines (optional) -- `--max` specifies the maximum number of pages to process if you don't want to process everything -- `--results_dir` specifies the directory to save results to instead of the default +- `--output_dir` specifies the directory to save results to instead of the default +- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains: @@ -247,8 +246,8 @@ surya_table DATA_PATH - `DATA_PATH` can be an image, pdf, or folder of images/pdfs - `--images` will save images of the pages and detected table cells + rows and columns (optional) -- `--max` specifies the maximum number of pages to process if you don't want to process everything -- `--results_dir` specifies the directory to save results to instead of the default +- `--output_dir` specifies the directory to save results to instead of the default +- `--page_range` specifies the page range to process in the PDF, specified as a single number, a comma separated list, a range, or comma separated ranges - example: `0,5-10,20`. - `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible. - `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table. diff --git a/detect_layout.py b/detect_layout.py index 8c21c18b..5ed728e9 100644 --- a/detect_layout.py +++ b/detect_layout.py @@ -1,5 +1,5 @@ import time -import argparse +import click import copy import json from collections import defaultdict @@ -8,51 +8,39 @@ from surya.layout import LayoutPredictor from surya.postprocessing.heatmap import draw_polys_on_image from surya.settings import settings +from surya.common.cli.config import CLILoader import os - -def main(): - parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).") - parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.") - parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) - parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) - parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) - parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) - args = parser.parse_args() +@click.command(help="Detect layout of an input file or folder (PDFs or image).") +@CLILoader.common_options +def main(input_path: str, **kwargs): + loader = CLILoader(input_path, kwargs) layout_predictor = LayoutPredictor() - if os.path.isdir(args.input_path): - images, names, _ = load_from_folder(args.input_path, args.max) - folder_name = os.path.basename(args.input_path) - else: - images, names, _ = load_from_file(args.input_path, args.max) - folder_name = os.path.basename(args.input_path).split(".")[0] - start = time.time() - layout_predictions = layout_predictor(images) - result_path = os.path.join(args.results_dir, folder_name) - os.makedirs(result_path, exist_ok=True) - if args.debug: + layout_predictions = layout_predictor(loader.images) + + if loader.debug: print(f"Layout took {time.time() - start} seconds") - if args.images: - for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)): + if loader.images: + for idx, (image, layout_pred, name) in enumerate(zip(loader.images, layout_predictions, loader.names)): polygons = [p.polygon for p in layout_pred.bboxes] labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes] bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels) - bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png")) + bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_layout.png")) predictions_by_page = defaultdict(list) - for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)): + for idx, (pred, name, image) in enumerate(zip(layout_predictions, loader.names, loader.images)): out_pred = pred.model_dump() out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) - with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(predictions_by_page, f, ensure_ascii=False) - print(f"Wrote results to {result_path}") + print(f"Wrote results to {loader.result_path}") if __name__ == "__main__": diff --git a/detect_text.py b/detect_text.py index 7ce3ea64..5ff78383 100644 --- a/detect_text.py +++ b/detect_text.py @@ -1,4 +1,4 @@ -import argparse +import click import copy import json import time @@ -7,57 +7,45 @@ from surya.input.load import load_from_folder, load_from_file from surya.detection import DetectionPredictor from surya.postprocessing.heatmap import draw_polys_on_image +from surya.common.cli.config import CLILoader from surya.settings import settings import os from tqdm import tqdm -def main(): - parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).") - parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.") - parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya")) - parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) - parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False) - parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) - args = parser.parse_args() +@click.command(help="Detect bboxes in an input file or folder (PDFs or image).") +@CLILoader.common_options +def main(input_path: str, **kwargs): + loader = CLILoader(input_path, kwargs) det_predictor = DetectionPredictor() - if os.path.isdir(args.input_path): - images, names, _ = load_from_folder(args.input_path, args.max) - folder_name = os.path.basename(args.input_path) - else: - images, names, _ = load_from_file(args.input_path, args.max) - folder_name = os.path.basename(args.input_path).split(".")[0] - start = time.time() - predictions = det_predictor(images, include_maps=args.debug) - result_path = os.path.join(args.results_dir, folder_name) - os.makedirs(result_path, exist_ok=True) + predictions = det_predictor(loader.images, include_maps=loader.debug) end = time.time() - if args.debug: + if loader.debug: print(f"Detection took {end - start} seconds") - if args.images: - for idx, (image, pred, name) in enumerate(zip(images, predictions, names)): + if loader.images: + for idx, (image, pred, name) in enumerate(zip(loader.images, predictions, loader.names)): polygons = [p.polygon for p in pred.bboxes] bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image)) - bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png")) + bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png")) - if args.debug: + if loader.debug: heatmap = pred.heatmap - heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png")) + heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png")) predictions_by_page = defaultdict(list) - for idx, (pred, name, image) in enumerate(zip(predictions, names, images)): + for idx, (pred, name, image) in enumerate(zip(predictions, loader.names, loader.images)): out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"]) out_pred["page"] = len(predictions_by_page[name]) + 1 predictions_by_page[name].append(out_pred) - with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(predictions_by_page, f, ensure_ascii=False) - print(f"Wrote results to {result_path}") + print(f"Wrote results to {loader.result_path}") if __name__ == "__main__": diff --git a/ocr_text.py b/ocr_text.py index 3fa09d48..85161f02 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -1,5 +1,5 @@ import os -import argparse +import click import json import time from collections import defaultdict @@ -9,74 +9,63 @@ from surya.input.load import load_from_folder, load_from_file, load_lang_file from surya.postprocessing.text import draw_text_on_image from surya.recognition import RecognitionPredictor +from surya.common.cli.config import CLILoader from surya.settings import settings +@click.command(help="Detect bboxes in an input file or folder (PDFs or image).") +@CLILoader.common_options +@click.option("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None) +@click.option("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None) +def main(input_path: str, langs: str, lang_file: str, **kwargs): + loader = CLILoader(input_path, kwargs, highres=True) -def main(): - parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).") - parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.") - parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya")) - parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) - parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0) - parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False) - parser.add_argument("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None) - parser.add_argument("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None) - parser.add_argument("--debug", action="store_true", help="Enable debug logging.", default=False) - args = parser.parse_args() - - if os.path.isdir(args.input_path): - images, names, _ = load_from_folder(args.input_path, args.max, args.start_page) - highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES) - folder_name = os.path.basename(args.input_path) - else: - images, names, _ = load_from_file(args.input_path, args.max, args.start_page) - highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES) - folder_name = os.path.basename(args.input_path).split(".")[0] - - if args.lang_file: + if lang_file: # We got all of our language settings from a file - langs = load_lang_file(args.lang_file, names) + langs = load_lang_file(lang_file, loader.names) for lang in langs: replace_lang_with_code(lang) image_langs = langs - elif args.langs: + elif langs: # We got our language settings from the input - langs = args.langs.split(",") + langs = langs.split(",") replace_lang_with_code(langs) - image_langs = [langs] * len(images) + image_langs = [langs] * len(loader.images) else: - image_langs = [None] * len(images) + image_langs = [None] * len(loader.images) det_predictor = DetectionPredictor() rec_predictor = RecognitionPredictor() - result_path = os.path.join(args.results_dir, folder_name) - os.makedirs(result_path, exist_ok=True) - start = time.time() - predictions_by_image = rec_predictor(images, image_langs, det_predictor=det_predictor, highres_images=highres_images) - if args.debug: + predictions_by_image = rec_predictor( + loader.images, + image_langs, + det_predictor=det_predictor, + highres_images=loader.highres_images + ) + + if loader.debug: print(f"OCR took {time.time() - start:.2f} seconds") max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines]) print(f"Max chars: {max_chars}") - if args.images: - for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)): + if loader.images: + for idx, (name, image, pred, langs) in enumerate(zip(loader.names, loader.images, predictions_by_image, image_langs)): bboxes = [l.bbox for l in pred.text_lines] pred_text = [l.text for l in pred.text_lines] page_image = draw_text_on_image(bboxes, pred_text, image.size, langs) - page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png")) + page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png")) out_preds = defaultdict(list) - for name, pred, image in zip(names, predictions_by_image, images): + for name, pred, image in zip(loader.names, predictions_by_image, loader.images): out_pred = pred.model_dump() out_pred["page"] = len(out_preds[name]) + 1 out_preds[name].append(out_pred) - with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(out_preds, f, ensure_ascii=False) - print(f"Wrote results to {result_path}") + print(f"Wrote results to {loader.result_path}") if __name__ == "__main__": diff --git a/poetry.lock b/poetry.lock index 2ae4cb93..3aba7eb7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2391,10 +2391,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -2472,9 +2472,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4962,4 +4962,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "2a70ff4a7f9e53313d60cf3e78c8e37b34cb24449ac130c1dfeb6a54ddeeec6c" +content-hash = "19780a5ddafa234794316f2eebfd605aa875138ac81d019f167116812b0a2920" diff --git a/pyproject.toml b/pyproject.toml index 39b0bf84..fc590b2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ tabulate = "^0.9.0" filetype = "^1.2.0" ftfy = "^6.1.3" pdftext = "~0.4.1" +click = "^8.1.8" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" diff --git a/surya/common/cli/__init__.py b/surya/common/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/surya/common/cli/config.py b/surya/common/cli/config.py new file mode 100644 index 00000000..4a40d054 --- /dev/null +++ b/surya/common/cli/config.py @@ -0,0 +1,63 @@ +from typing import List + +import click +import os +from surya.input.load import load_from_folder, load_from_file +from surya.settings import settings + + +class CLILoader: + def __init__(self, filepath: str, cli_options: dict, highres: bool = False): + self.page_range = cli_options.get("page_range") + if self.page_range: + self.page_range = self.parse_range_str(self.page_range) + self.filepath = filepath + self.config = cli_options + self.save_images = cli_options.get("images", False) + self.debug = cli_options.get("debug", False) + self.output_dir = cli_options.get("output_dir") + + self.load(highres) + + @staticmethod + def common_options(fn): + fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn) + fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn) + fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20")(fn) + fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn) + fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn) + return fn + + def load(self, highres: bool = False): + highres_images = None + if os.path.isdir(self.filepath): + images, names = load_from_folder(self.filepath, self.page_range) + folder_name = os.path.basename(self.filepath) + if highres: + highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) + else: + images, names = load_from_file(self.filepath, self.page_range) + folder_name = os.path.basename(self.filepath).split(".")[0] + if highres: + highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) + + + self.images = images + self.highres_images = highres_images + self.names = names + + self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name)) + os.makedirs(self.result_path, exist_ok=True) + + @staticmethod + def parse_range_str(range_str: str) -> List[int]: + range_lst = range_str.split(",") + page_lst = [] + for i in range_lst: + if "-" in i: + start, end = i.split("-") + page_lst += list(range(int(start), int(end) + 1)) + else: + page_lst.append(int(i)) + page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order + return page_lst \ No newline at end of file diff --git a/surya/common/polygon.py b/surya/common/polygon.py index 443c42bd..0f171684 100644 --- a/surya/common/polygon.py +++ b/surya/common/polygon.py @@ -65,6 +65,11 @@ def rescale(self, processor_size, image_size): corner[1] = int(corner[1] * height_scaler) self.polygon = new_corners + def round(self, divisor): + for corner in self.polygon: + corner[0] = int(corner[0] / divisor) * divisor + corner[1] = int(corner[1] / divisor) * divisor + def fit_to_bounds(self, bounds): new_corners = copy.deepcopy(self.polygon) for corner in new_corners: diff --git a/surya/detection/affinity.py b/surya/detection/affinity.py index fdadf7c9..2e264ae3 100644 --- a/surya/detection/affinity.py +++ b/surya/detection/affinity.py @@ -87,10 +87,10 @@ def get_detected_lines(image, slope_tol_deg=2, vertical=False, horizontal=False) def get_vertical_lines(image, processor_size, image_size, divisor=20, x_tolerance=40, y_tolerance=20) -> List[ColumnLine]: vertical_lines = get_detected_lines(image, vertical=True) for line in vertical_lines: - line.rescale_bbox(processor_size, image_size) + line.rescale(processor_size, image_size) vertical_lines = sorted(vertical_lines, key=lambda x: x.bbox[0]) for line in vertical_lines: - line.round_bbox(divisor) + line.round(divisor) # Merge adjacent line segments together to_remove = [] diff --git a/surya/input/load.py b/surya/input/load.py index 215a1512..2aae3b9b 100644 --- a/surya/input/load.py +++ b/surya/input/load.py @@ -1,3 +1,4 @@ +from typing import List import PIL from surya.input.processing import open_pdf, get_page_images @@ -12,73 +13,56 @@ def get_name_from_path(path): return os.path.basename(path).split(".")[0] -def load_pdf(pdf_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False, flatten_pdf=settings.FLATTEN_PDF): +def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): doc = open_pdf(pdf_path) last_page = len(doc) - if start_page: - assert start_page < last_page and start_page >= 0, f"Start page must be between 0 and {last_page}" + if page_range: + assert all([0 <= page < last_page for page in page_range]), f"Invalid page range: {page_range}" else: - start_page = 0 - - if max_pages: - assert max_pages >= 0, f"Max pages must be greater than 0" - last_page = min(start_page + max_pages, last_page) - - page_indices = list(range(start_page, last_page)) - images = get_page_images(doc, page_indices, dpi=dpi) - text_lines = [None] * len(page_indices) - if load_text_lines: - from surya.input.pdflines import get_page_text_lines # Putting import here because pypdfium2 causes warnings if its not the top import - text_lines = get_page_text_lines( - pdf_path, - page_indices, - [i.size for i in images], - flatten_pdf=flatten_pdf - ) + page_range = list(range(last_page)) + + images = get_page_images(doc, page_range, dpi=dpi) doc.close() - names = [get_name_from_path(pdf_path) for _ in page_indices] - return images, names, text_lines + names = [get_name_from_path(pdf_path) for _ in page_range] + return images, names def load_image(image_path): image = Image.open(image_path).convert("RGB") name = get_name_from_path(image_path) - return [image], [name], [None] + return [image], [name] -def load_from_file(input_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False, flatten_pdf=settings.FLATTEN_PDF): +def load_from_file(input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): input_type = filetype.guess(input_path) if input_type and input_type.extension == "pdf": - return load_pdf(input_path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines, flatten_pdf=flatten_pdf) + return load_pdf(input_path, page_range, dpi=dpi) else: return load_image(input_path) -def load_from_folder(folder_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False, flatten_pdf=settings.FLATTEN_PDF): +def load_from_folder(folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")] image_paths = [ip for ip in image_paths if not os.path.isdir(ip)] images = [] names = [] - text_lines = [] for path in image_paths: extension = filetype.guess(path) if extension and extension.extension == "pdf": - image, name, text_line = load_pdf(path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines, flatten_pdf=flatten_pdf) + image, name, text_line = load_pdf(path, page_range, dpi=dpi) images.extend(image) names.extend(name) - text_lines.extend(text_line) else: try: image, name, text_line = load_image(path) images.extend(image) names.extend(name) - text_lines.extend(text_line) except PIL.UnidentifiedImageError: print(f"Could not load image {path}") continue - return images, names, text_lines + return images, names def load_lang_file(lang_path, names): diff --git a/table_recognition.py b/table_recognition.py index be53c96a..8d79b506 100644 --- a/table_recognition.py +++ b/table_recognition.py @@ -1,9 +1,10 @@ import os -import argparse +import click import copy import json from collections import defaultdict +from surya.common.cli.config import CLILoader from surya.input.load import load_from_folder, load_from_file from surya.layout import LayoutPredictor from surya.table_rec import TableRecPredictor @@ -11,32 +12,19 @@ from surya.settings import settings from surya.postprocessing.util import rescale_bbox - -def main(): - parser = argparse.ArgumentParser(description="Detect tables in an input file or folder (PDFs or image).") - parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder.") - parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) - parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) - parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) - parser.add_argument("--detect_boxes", action="store_true", help="Detect table boxes.", default=False) - parser.add_argument("--skip_table_detection", action="store_true", help="Tables are already cropped, so don't re-detect tables.", default=False) - args = parser.parse_args() +@click.command(help="Detect layout of an input file or folder (PDFs or image).") +@CLILoader.common_options +@click.option("--detect_boxes", is_flag=True, help="Detect table boxes.", default=False) +@click.option("--skip_table_detection", is_flag=True, help="Tables are already cropped, so don't re-detect tables.", default=False) +def main(input_path: str, detect_boxes: bool, skip_table_detection: bool, **kwargs): + loader = CLILoader(input_path, kwargs, highres=True) table_rec_predictor = TableRecPredictor() layout_predictor = LayoutPredictor() - if os.path.isdir(args.input_path): - images, _, _ = load_from_folder(args.input_path, args.max) - highres_images, names, _ = load_from_folder(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES) - folder_name = os.path.basename(args.input_path) - else: - images, _, _ = load_from_file(args.input_path, args.max) - highres_images, names, _ = load_from_file(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES) - folder_name = os.path.basename(args.input_path).split(".")[0] - pnums = [] prev_name = None - for i, name in enumerate(names): + for i, name in enumerate(loader.names): if prev_name is None or prev_name != name: pnums.append(0) else: @@ -44,14 +32,14 @@ def main(): prev_name = name - layout_predictions = layout_predictor(images) + layout_predictions = layout_predictor(loader.images) table_imgs = [] table_counts = [] - for layout_pred, img, highres_img in zip(layout_predictions, images, highres_images): + for layout_pred, img, highres_img in zip(layout_predictions, loader.images, loader.highres_images): # The table may already be cropped - if args.skip_table_detection: + if skip_table_detection: table_imgs.append(highres_img) table_counts.append(1) else: @@ -73,8 +61,6 @@ def main(): table_imgs.extend(page_table_imgs) table_preds = table_rec_predictor(table_imgs) - result_path = os.path.join(args.results_dir, folder_name) - os.makedirs(result_path, exist_ok=True) img_idx = 0 prev_count = 0 @@ -85,7 +71,7 @@ def main(): img_idx += 1 pred = table_preds[i] - orig_name = names[img_idx] + orig_name = loader.names[img_idx] pnum = pnums[img_idx] table_img = table_imgs[i] @@ -95,7 +81,7 @@ def main(): out_pred["table_idx"] = table_idx table_predictions[orig_name].append(out_pred) - if args.images: + if loader.images: rows = [l.bbox for l in pred.rows] cols = [l.bbox for l in pred.cols] row_labels = [f"Row {l.row_id}" for l in pred.rows] @@ -105,16 +91,16 @@ def main(): rc_image = copy.deepcopy(table_img) rc_image = draw_bboxes_on_image(rows, rc_image, labels=row_labels, label_font_size=20, color="blue") rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red") - rc_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png")) + rc_image.save(os.path.join(loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png")) cell_image = copy.deepcopy(table_img) cell_image = draw_bboxes_on_image(cells, cell_image, color="green") - cell_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png")) + cell_image.save(os.path.join(loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png")) - with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f: json.dump(table_predictions, f, ensure_ascii=False) - print(f"Wrote results to {result_path}") + print(f"Wrote results to {loader.result_path}") if __name__ == "__main__":