Skip to content

Commit

Permalink
Merge pull request #47 from cameron-a-johnson/dev/add-offline-validat…
Browse files Browse the repository at this point in the history
…ion-saving

write out prediction json at test step during training
  • Loading branch information
Purg authored Nov 21, 2024
2 parents 9705534 + f0c8372 commit 1c847d5
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tcn_hpl/callbacks/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_lightning.utilities.types import STEP_OUTPUT
from sklearn.metrics import confusion_matrix
import torch
import kwcoco

try:
from aim import Image
Expand Down Expand Up @@ -345,6 +346,7 @@ def on_test_batch_end(
self._val_all_targets.append(outputs["targets"].cpu())
self._val_all_source_vids.append(outputs["source_vid"].cpu())
self._val_all_source_frames.append(outputs["source_frame"].cpu())
self._preds_dset_output_fpath = self.output_dir / "tcn_activity_predictions.kwcoco.json"

def on_test_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
Expand All @@ -362,6 +364,36 @@ def on_test_epoch_end(
test_acc = pl_module.test_metrics.acc.compute()
test_f1 = pl_module.test_metrics.f1.compute()

# Create activity predictions KWCOCO JSON
truth_dset_fpath = trainer.datamodule.hparams["coco_test_activities"]
truth_dset = kwcoco.CocoDataset(truth_dset_fpath)
acts_dset = kwcoco.CocoDataset()
acts_dset.fpath = self._preds_dset_output_fpath
acts_dset.dataset['videos'] = truth_dset.dataset['videos']
acts_dset.dataset['images'] = truth_dset.dataset['images']
acts_dset.dataset['categories'] = truth_dset.dataset['categories']
acts_dset.index.build(acts_dset)
# Create numpy lookup tables
for i in range(len(all_preds)):
frame_index = all_source_frames[i].item()
video_id = all_source_vids[i].item()
# Now get the image_id that matches the frame_index and video_id.
sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)]
image_id = sorted_img_ids_for_one_video[frame_index]
# Sanity check: this image_id corresponds to the frame_index and video_id
assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index
assert acts_dset.index.imgs[image_id]['video_id'] == video_id

acts_dset.add_annotation(
image_id=image_id,
category_id=all_preds[i].item(),
score=all_probs[i][all_preds[i]].item(),
prob=all_probs[i].numpy().tolist(),
)
print(f"Dumping activities file to {acts_dset.fpath}")
acts_dset.dump(acts_dset.fpath, newlines=True)


#
# Plot per-video class predictions vs. GT across progressive frames in
# that video.
Expand Down

0 comments on commit 1c847d5

Please sign in to comment.