forked from njacquelin/sports_field_registration
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmake_dataset.py
230 lines (189 loc) · 8.03 KB
/
make_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
'''
input video
output is basketball homography dataset
'''
import cv2
import numpy as np
import copy
import os
import matplotlib.pyplot as plt
WHITE = (255, 255, 255)
IMG_WIDTH = 1280
IMG_HEIGHT = 720
# Reading input video, setting up save directories
# overarching is dataset/ncaa_bball
# file structure is dataset/ncaa_bball/[annotations or images]/game name/[.npy (annotations) or .png (images)]
input_video_path = "C:/Users/simon/OneDrive/Desktop/senior ds capstone/video_capstone/20240106_duke_notredame_h1.mp4"
video_name = os.path.basename(input_video_path)
video_name_without_extension = os.path.splitext(video_name)[0]
# images will be output as jpg here
img_directory = os.path.join('dataset', 'ncaa_bball', 'images', video_name_without_extension)
# corresponding homography matrices will be output as npy here
homography_directory = os.path.join('dataset', 'ncaa_bball', 'annotations', video_name_without_extension)
print(homography_directory)
# update the train.txt file, which contains all the games that will be used in training
# in our case we're not gonna test the model so we put every game in the train.txt
train_file_path = os.path.join('dataset', 'ncaa_bball', 'train.txt')
if os.path.exists(train_file_path):
print(f'File {train_file_path} already exists')
else:
with open(train_file_path, 'w') as file:
pass
print(f"File {train_file_path} created.")
# now we have to add the game to the txt file
# Read the file into a list of lines
with open(train_file_path, 'r') as file:
lines = [line.strip() for line in file.readlines()]
# Check if the line exists and add it if not
if video_name_without_extension not in lines:
lines.append(video_name_without_extension)
# Write the updated list back to the file
with open(train_file_path, 'w') as file:
for line in lines:
file.write(f"{line}\n")
print(f'Game {video_name_without_extension} added to {train_file_path}.')
else:
print(f'Game {video_name_without_extension} already in {train_file_path}.')
# check whether img and homography folders for this specific video exist
if not os.path.exists(img_directory):
os.makedirs(img_directory)
print(f"Directory '{img_directory}' created.")
else:
print(f"Directory '{img_directory}' already exists.")
if not os.path.exists(homography_directory):
os.makedirs(homography_directory)
print(f"Directory '{homography_directory}' created.")
else:
print(f"Directory '{homography_directory}' already exists.")
cap = cv2.VideoCapture(input_video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
court_width = 94 # length of bball court
court_height = 50 # width of bball court
skip_length = 3 # we read 1 frame every <skip_length> seconds
# if you want to sample frames more frequently, you can reduce this value
frame_skip = int(fps * skip_length)
# Define a callback function to get mouse click coordinates
def click_event(event, x, y, flags, param):
# TODO give option to redo last click
# TODO give option to skip go to next frame, (might be wrong camera angle)
if event == cv2.EVENT_LBUTTONDOWN:
print(f"Clicked at: ({x}, {y})")
src_points.append((x, y))
cv2.circle(show_frame, (x, y), 5, WHITE, -1)
court_point = input("Enter corresponding court diagram point: ")
while True:
try:
dst_points.append(court_diagram_points[court_point])
break
except KeyError as e:
court_point = input("Not valid court point. Enter corresponding court diagram point again: ")
if len(src_points) >= 4:
should_continue = input("Continue annotating? (y/n): ")
if should_continue.lower() != 'y':
param[0] = False # Stop annotating
# Predefined court diagram points
# court is 94 by 50, (0,0) is top left
court_diagram_points = {
# whole court
'top left': (0, 0),
'top middle': (47, 0),
'top right': (94, 0),
'bottom right': (94, 50),
'bottom middle': (47, 50),
'bottom left': (0, 50),
# left key
'lk top left': (0, 19),
'lk top right': (19, 19),
'lk bottom right': (19, 31),
'lk bottom left': (0, 31),
# right key
'rk top left': (75, 19),
'rk top right': (94, 19),
'rk bottom right': (94, 31),
'rk bottom left': (75, 31),
# center circle
'cc left': (41, 25),
'cc top': (47, 19),
'cc right': (53, 25),
'cc bottom': (47, 31),
# other
# 3pt line
# NOTE: as of ~2022 (?) NCAA mens and womens have same 3 pt line
'3 top left': (0, 40.125 / 12),
# where the 3 pt line intersects the baseline on the upper half of left side of court
'3 top right': (94, 40.125 / 12),
# where the 3 pt line intersects the baseline on the upper half of right side of court
'3 bottom right': (94, 50 - 40.125 / 12),
# where the 3 pt line intersects the baseline on the bottom half of right side of court
'3 bottom left': (0, 50 - 40.125 / 12),
# where the 3 pt line intersects the baseline on the bottom half of left side of court
'3 top key left': (22 + 64.75 / 12, 25), # top of the key 3 on left side of court
'3 top key right': (94 - (22 + 64.75 / 12), 25) # top of the key 3 on right side of court
}
### Display the coordinates and their names on diagram for reference during annotation
# Separate the x and y coordinates
x_coords = [coord[0] for coord in court_diagram_points.values()]
y_coords = [coord[1] for coord in court_diagram_points.values()]
# Plot the points
plt.scatter(x_coords, y_coords)
# Optionally, add labels for each point
for label, (x, y) in court_diagram_points.items():
plt.annotate(label, (x, y))
# Invert the y-axis
plt.gca().invert_yaxis()
# Set the same spacing for x and y axes
plt.axis('equal')
# Show the plot
plt.show()
# iterate through the input video frame by frame and annotate each frame
ret_val, frame = cap.read()
frame = cv2.resize(frame, (IMG_WIDTH, IMG_HEIGHT))
cv2.namedWindow('Frame', cv2.WINDOW_NORMAL)
cv2.moveWindow('Frame', 0, 0)
while ret_val:
current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
src_points = []
dst_points = []
annotating = [True] # Use a mutable object to allow modification inside the callback
show_frame = copy.copy(frame) # draw circles on copy of frame, but save original frame
cv2.imshow('Frame', show_frame)
cv2.waitKey(1)
annotate = True
if input('do you want to annotate this frame? (y/n): ').lower() == 'y':
annotate = True
else:
print(f'skipping frame {current_frame}')
annotate = False
if annotate:
while annotating[0]:
cv2.imshow('Frame', show_frame)
cv2.setMouseCallback('Frame', click_event, annotating)
cv2.waitKey(1) # Wait for a key press for 1 ms
# now we need to scale dst pts to the image size
x_scale = IMG_WIDTH / court_width
y_scale = IMG_HEIGHT / court_height
dst_points = [(x_scale * x, y_scale * y) for (x, y) in dst_points]
# Now we have a list of at least 4 (src, dst) pairs
print("Source points:", src_points)
print("Destination points:", dst_points)
# Compute homography
homography_matrix, _ = cv2.findHomography(np.array(src_points), np.array(dst_points), cv2.RANSAC, 10)
print(homography_matrix)
img_save_file = f'{img_directory}/frame_{current_frame}.jpg' # jpg over png to save space
homo_save_file = f'{homography_directory}/frame_{current_frame}.npy'
print(f'saving image to: {img_save_file}')
print(f'saving homography to: {homo_save_file}')
cv2.imwrite(img_save_file, frame)
np.save(homo_save_file, homography_matrix)
if input('continue to next frame? (y/n): ').lower() == 'y':
pass
else:
print('ending annotation script. you annotated x frames in this session!')
break
# Read next frame
cap.set(cv2.CAP_PROP_POS_FRAMES, int(current_frame + frame_skip))
ret_val, frame = cap.read()
frame = cv2.resize(frame, (IMG_WIDTH, IMG_HEIGHT))
cv2.destroyAllWindows()
cap.release()