Skip to content

Commit

Permalink
Fix minor issues
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 8, 2025
1 parent 29bd086 commit a1f500c
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Model weights will automatically download the first time you run surya.
I've included a streamlit app that lets you interactively try Surya on images or PDF files. Run it with:

```shell
pip install streamlit
pip install streamlit pdftext
surya_gui
```

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ opencv-python = "^4.9.0.80"
tabulate = "^0.9.0"
filetype = "^1.2.0"
ftfy = "^6.1.3"
pdftext = "~0.4.1"
click = "^8.1.8"

[tool.poetry.group.dev.dependencies]
Expand All @@ -47,6 +46,7 @@ arabic-reshaper = "^3.0.0"
streamlit = "^1.31.0"
playwright = "^1.41.2"
pytest = "^8.3.4"
pdftext = "^0.4.1"

[tool.poetry.scripts]
surya_detect = "detect_text:main"
Expand Down
5 changes: 5 additions & 0 deletions surya/detection/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def model(
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype | str] = None
) -> EfficientViTForSemanticSegmentation:
if device is None:
device = settings.TORCH_DEVICE_MODEL
if dtype is None:
dtype = settings.MODEL_DTYPE

config = EfficientViTConfig.from_pretrained(self.checkpoint)
model = EfficientViTForSemanticSegmentation.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config,
ignore_mismatched_sizes=True)
Expand Down
5 changes: 5 additions & 0 deletions surya/layout/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def model(
device=settings.TORCH_DEVICE_MODEL,
dtype=settings.MODEL_DTYPE
) -> SuryaLayoutModel:
if device is None:
device = settings.TORCH_DEVICE_MODEL
if dtype is None:
dtype = settings.MODEL_DTYPE

config = SuryaLayoutConfig.from_pretrained(self.checkpoint)
decoder_config = config.decoder
decoder = SuryaLayoutDecoderConfig(**decoder_config)
Expand Down
5 changes: 5 additions & 0 deletions surya/ocr_error/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def model(
device=settings.TORCH_DEVICE_MODEL,
dtype=settings.MODEL_DTYPE
) -> DistilBertForSequenceClassification:
if device is None:
device = settings.TORCH_DEVICE_MODEL
if dtype is None:
dtype = settings.MODEL_DTYPE

config = DistilBertConfig.from_pretrained(self.checkpoint)
model = DistilBertForSequenceClassification.from_pretrained(self.checkpoint, torch_dtype=dtype, config=config).to(
device).eval()
Expand Down
4 changes: 4 additions & 0 deletions surya/recognition/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def model(
device=settings.TORCH_DEVICE_MODEL,
dtype=settings.MODEL_DTYPE
) -> OCREncoderDecoderModel:
if device is None:
device = settings.TORCH_DEVICE_MODEL
if dtype is None:
dtype = settings.MODEL_DTYPE
config = SuryaOCRConfig.from_pretrained(self.checkpoint)
decoder_config = config.decoder
decoder = SuryaOCRDecoderConfig(**decoder_config)
Expand Down
4 changes: 3 additions & 1 deletion surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"],
MERGE_KEYS["merge_both"]],
is_header=z == 0
)
)
cell_id += 1
Expand Down Expand Up @@ -293,7 +294,8 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
colspan=1,
merge_up=False,
merge_down=False,
col_id=l
col_id=l,
is_header=z == 0
)
)
cell_id += 1
Expand Down
4 changes: 4 additions & 0 deletions surya/table_rec/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def model(
device=settings.TORCH_DEVICE_MODEL,
dtype=settings.MODEL_DTYPE
) -> TableRecEncoderDecoderModel:
if device is None:
device = settings.TORCH_DEVICE_MODEL
if dtype is None:
dtype = settings.MODEL_DTYPE
config = SuryaTableRecConfig.from_pretrained(self.checkpoint)
decoder_config = config.decoder
decoder = SuryaTableRecDecoderConfig(**decoder_config)
Expand Down
2 changes: 2 additions & 0 deletions surya/table_rec/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ class TableCell(PolygonBox):
colspan: int
within_row_id: int
cell_id: int
is_header: bool
rowspan: int | None = None
merge_up: bool = False
merge_down: bool = False
col_id: int | None = None
text: str | None = None

@property
def label(self):
Expand Down

0 comments on commit a1f500c

Please sign in to comment.