Skip to content

Commit

Permalink
Black reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Purg committed Oct 9, 2024
1 parent 19adc05 commit 5d3b526
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 41 deletions.
65 changes: 33 additions & 32 deletions tcn_hpl/data/add_gt_to_kwcoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import tcn_hpl.utils.utils as utils
import ubelt as ub


def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict):

# set background to everything first
activity_gt_list = [0 for x in range(num_frames)]
f = open(text_file, "r")
Expand All @@ -17,50 +17,47 @@ def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict):
text_list = text.split("\t")
if text_list[-1] == "":
text_list = text_list[:-1]

# this check handles inconsistencies in the GT we get from BBN
if task == "r18" or task=="m3":
if task == "r18" or task == "m3":
jump = 4
elif task=="m2" or task=="m5":
elif task == "m2" or task == "m5":
jump = 3

for index in range(0, len(text_list), jump):
triplet = text_list[index : index + jump]
start_frame = int(triplet[0])
end_frame = int(triplet[1])
desc = triplet[jump-1]
desc = triplet[jump - 1]

gt_label = mapping[desc]

if end_frame - 1 > num_frames:
### address issue with GT activity labels
print(
"Max frame in GT is larger than number of frames in the video"
)
print("Max frame in GT is larger than number of frames in the video")

for label_index in range(
start_frame, min(end_frame - 1, num_frames)
):
for label_index in range(start_frame, min(end_frame - 1, num_frames)):
# print(f"label_index: {label_index}")
activity_gt_list[label_index] = gt_label

return activity_gt_list


def main(config_path: str):

with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

task_name = config["task"]
raw_data_root = f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/"

raw_data_root = (
f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/"
)

dset = kwcoco.CocoDataset(config["data_gen"]["dataset_kwcoco"])


with open(config['data_gen']['activity_config_fn'], "r") as stream:
with open(config["data_gen"]["activity_config_fn"], "r") as stream:
activity_config = yaml.safe_load(stream)
activity_labels = activity_config["labels"]

activity_labels_desc_mapping = {}
activity_labels_label_mapping = {}
for label in activity_labels:
Expand All @@ -75,15 +72,15 @@ def main(config_path: str):
activity_labels_label_mapping[label_str] = label["id"]
if label_str == "done":
continue

gt_paths_to_names_dict = {}
gt_paths = utils.dictionary_contents(raw_data_root, types=['*.txt'])
gt_paths = utils.dictionary_contents(raw_data_root, types=["*.txt"])
for gt_path in gt_paths:
name = gt_path.split('/')[-1].split('.')[0]
name = gt_path.split("/")[-1].split(".")[0]
gt_paths_to_names_dict[name] = gt_path

print(gt_paths_to_names_dict)

if not "activity_gt" in list(dset.imgs.values())[0].keys():
print("adding activity ground truth to the dataset")
for video_id in ub.ProgIter(dset.index.videos.keys()):
Expand All @@ -94,18 +91,20 @@ def main(config_path: str):
else:
print(f"GT file does not exist for {video_name}. Continue...")
continue

image_ids = dset.index.vidid_to_gids[video_id]
num_frames = len(image_ids)

activity_gt_list = text_to_labels(gt_text, num_frames, task_name, activity_labels_desc_mapping)


activity_gt_list = text_to_labels(
gt_text, num_frames, task_name, activity_labels_desc_mapping
)

for index, img_id in enumerate(image_ids):
im = dset.index.imgs[img_id]
frame_index = int(im["frame_index"])
dset.index.imgs[img_id]["activity_gt"] = activity_gt_list[frame_index]
# print(f"video: {video}")

dset.dump("test.mscoco.json", newlines=True)
# print(raw_data_root)
# print(activity_labels)
Expand All @@ -117,9 +116,11 @@ def main(config_path: str):
parser = argparse.ArgumentParser()

parser.add_argument(
"--config", default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml", help=""
"--config",
default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml",
help="",
)

args = parser.parse_args()

main(args.config)
main(args.config)
20 changes: 11 additions & 9 deletions tcn_hpl/data/tcn_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def define_tcn_vector(inputs):


class TCNDataset(Dataset):

def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int):
"""
Initializes the dataset.
Expand All @@ -74,15 +73,20 @@ def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int):
# for offline training, pre-cut videos into clips according to window size for easy batching
self.frames = []

logger.info(f"Generating dataset with {len(list(self.dset.index.videos.keys()))} videos")
pber = tqdm(self.dset.index.videos.keys(), total=len(list(self.dset.index.videos.keys())))
logger.info(
f"Generating dataset with {len(list(self.dset.index.videos.keys()))} videos"
)
pber = tqdm(
self.dset.index.videos.keys(),
total=len(list(self.dset.index.videos.keys())),
)
for vid in pber:
video_dict = self.dset.index.videos[vid]

vid_frames = self.dset.index.vidid_to_gids[vid]

for index in range(0, len(vid_frames)-window_size-1, sample_rate):
video_slice = vid_frames[index: index+window_size]
for index in range(0, len(vid_frames) - window_size - 1, sample_rate):
video_slice = vid_frames[index : index + window_size]
window_frame_dicts = [self.dset.index.imgs[gid] for gid in video_slice]

# start_frame = window_frame_dicts[0]['frame_index']
Expand Down Expand Up @@ -122,12 +126,10 @@ def __len__(self):


if __name__ == "__main__":
# Example usage:
# Example usage:
kwcoco_path = "/data/PTG/medical/training/yolo_object_detector/detect/r18_all/r18_all_all_obj_results_with_dets_and_pose.mscoco.json"

dataset = TCNDataset(kwcoco_path=kwcoco_path,
sample_rate=1,
window_size=25)
dataset = TCNDataset(kwcoco_path=kwcoco_path, sample_rate=1, window_size=25)

print(f"dataset: {len(dataset)}")
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
Expand Down

0 comments on commit 5d3b526

Please sign in to comment.