Skip to content

Commit

Permalink
Avoid overlapping cells
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 16, 2025
1 parent 3278e52 commit 9f4bbee
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
if colspan == 1:
# Skip single column cells
continue
if PolygonBox(polygon=polygon).height < row.height * .9:
# Spanning cell must cover most of the row
continue

spanning_cells.append(
TableCell(
Expand All @@ -281,17 +284,24 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si

# Add cells - either add spanning cells (multiple cols), or generate a cell based on row/col
used_spanning_cells = set()
skip_columns = 0
for l, col in enumerate(columns):
if skip_columns:
skip_columns -= 1
continue
cell_polygon = row.intersection_polygon(col)
cell_added = False
for zz, spanning_cell in enumerate(spanning_cells):
intersection_pct = PolygonBox(polygon=cell_polygon).intersection_pct(spanning_cell)
if intersection_pct > .5:
cell_polygonbox = PolygonBox(polygon=cell_polygon)
intersection_pct = cell_polygonbox.intersection_pct(spanning_cell)
# Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns)
if intersection_pct > .9 and spanning_cell.width > cell_polygonbox.width:
cell_added = True
if zz not in used_spanning_cells:
used_spanning_cells.add(zz)
spanning_cell.col_id = l
cells.append(spanning_cell)
skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell

if not cell_added:
cells.append(
Expand Down

0 comments on commit 9f4bbee

Please sign in to comment.