Skip to content

Commit

Permalink
Add model support for headers
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 9, 2025
1 parent a1f500c commit d2a0352
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"

# Table Rec
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/surya_tablerec"
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/table_rec_3"
TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
TABLE_REC_MAX_BOXES: int = 150
TABLE_REC_BATCH_SIZE: Optional[int] = None
Expand Down
13 changes: 9 additions & 4 deletions surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def batch_table_recognition(
"category": CATEGORY_TO_ID["Table"],
"colspan": 0,
"merges": 0,
"is_header": 0
})

output_order = []
Expand Down Expand Up @@ -184,6 +185,7 @@ def batch_table_recognition(
"category": row_prediction["category"],
"colspan": 0,
"merges": 0,
"is_header": row_prediction["is_header"]
})
row_encoder_hidden_states.append(encoder_hidden_states[j])
idx_map.append(j)
Expand All @@ -193,6 +195,7 @@ def batch_table_recognition(
"category": row_prediction["category"],
"colspan": 0,
"merges": 0,
"is_header": row_prediction["is_header"]
})

# Re-inference to predict cells
Expand Down Expand Up @@ -234,7 +237,8 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
columns.append(
TableCol(
polygon=polygon,
col_id=z
col_id=z,
is_header=col_prediction["is_header"]
)
)

Expand All @@ -244,7 +248,8 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
polygon = self.processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
row = TableRow(
polygon=polygon,
row_id=z
row_id=z,
is_header=row_prediction["is_header"]
)
rows.append(row)

Expand All @@ -264,7 +269,7 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"],
MERGE_KEYS["merge_both"]],
is_header=z == 0
is_header=row.is_header or z == 0
)
)
cell_id += 1
Expand Down Expand Up @@ -295,7 +300,7 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
merge_up=False,
merge_down=False,
col_id=l,
is_header=z == 0
is_header=row.is_header or col.is_header or z == 0
)
)
cell_id += 1
Expand Down
11 changes: 9 additions & 2 deletions surya/table_rec/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@
}
CATEGORY_TO_ID = {v: k for k, v in ID_TO_CATEGORY.items()}

ID_TO_HEADER = {
0: "None",
1: "Header"
}
HEADER_TO_ID = {v: k for k, v in ID_TO_HEADER.items()}

BOX_PROPERTIES = [
("bbox", 6, "regression"),
("category", len(ID_TO_CATEGORY), "classification"),
("merges", len(MERGE_KEYS), "classification"),
("colspan", 1, "regression")
("colspan", 1, "regression"),
("is_header", len(ID_TO_HEADER), "classification")
]


Expand Down Expand Up @@ -132,7 +139,7 @@ class SuryaTableRecDecoderConfig(PretrainedConfig):

def __init__(
self,
num_hidden_layers=10,
num_hidden_layers=6,
vocab_size=BOX_DIM + 1,
bbox_size=BOX_DIM,
hidden_size=512,
Expand Down
2 changes: 2 additions & 0 deletions surya/table_rec/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def label(self):

class TableRow(PolygonBox):
row_id: int
is_header: bool

@property
def label(self):
Expand All @@ -32,6 +33,7 @@ def label(self):

class TableCol(PolygonBox):
col_id: int
is_header: bool

@property
def label(self):
Expand Down

0 comments on commit d2a0352

Please sign in to comment.