diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 850702bf8..1a4206e31 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -10,6 +10,7 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT from sklearn.metrics import confusion_matrix import torch +import kwcoco try: from aim import Image @@ -345,6 +346,7 @@ def on_test_batch_end( self._val_all_targets.append(outputs["targets"].cpu()) self._val_all_source_vids.append(outputs["source_vid"].cpu()) self._val_all_source_frames.append(outputs["source_frame"].cpu()) + self._preds_dset_output_fpath = self.output_dir / "tcn_activity_predictions.kwcoco.json" def on_test_epoch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" @@ -362,6 +364,36 @@ def on_test_epoch_end( test_acc = pl_module.test_metrics.acc.compute() test_f1 = pl_module.test_metrics.f1.compute() + # Create activity predictions KWCOCO JSON + truth_dset_fpath = trainer.datamodule.hparams["coco_test_activities"] + truth_dset = kwcoco.CocoDataset(truth_dset_fpath) + acts_dset = kwcoco.CocoDataset() + acts_dset.fpath = self._preds_dset_output_fpath + acts_dset.dataset['videos'] = truth_dset.dataset['videos'] + acts_dset.dataset['images'] = truth_dset.dataset['images'] + acts_dset.dataset['categories'] = truth_dset.dataset['categories'] + acts_dset.index.build(acts_dset) + # Create numpy lookup tables + for i in range(len(all_preds)): + frame_index = all_source_frames[i].item() + video_id = all_source_vids[i].item() + # Now get the image_id that matches the frame_index and video_id. + sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)] + image_id = sorted_img_ids_for_one_video[frame_index] + # Sanity check: this image_id corresponds to the frame_index and video_id + assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index + assert acts_dset.index.imgs[image_id]['video_id'] == video_id + + acts_dset.add_annotation( + image_id=image_id, + category_id=all_preds[i].item(), + score=all_probs[i][all_preds[i]].item(), + prob=all_probs[i].numpy().tolist(), + ) + print(f"Dumping activities file to {acts_dset.fpath}") + acts_dset.dump(acts_dset.fpath, newlines=True) + + # # Plot per-video class predictions vs. GT across progressive frames in # that video.