Skip to content

Commit

Permalink
Simplify model
Browse files Browse the repository at this point in the history
Signed-off-by: Yaacov Zamir <[email protected]>
  • Loading branch information
yaacov committed Jan 24, 2024
1 parent 210730b commit 95c10e4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 35 deletions.
Binary file modified checkpoints/driver.pth
Binary file not shown.
10 changes: 5 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class DriverModel(nn.Module):
Architecture:
- Input layer: Size determined by
3 (width) x 4 (height) x 7 (possible obstacles) + 3 (car current lane)= 87 neurons.
6 (width) x 4 (height) x 7 (possible obstacles) + 6 (car current lane)= 174 neurons.
- Hidden Layer 1: 512 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 2: 256 neurons, followed by batch normalization and 50% dropout.
- Hidden Layer 3: 128 neurons, followed by batch normalization and 50% dropout.
Expand All @@ -71,13 +71,13 @@ class DriverModel(nn.Module):
- Dropout with a rate of 50% is applied after each hidden layer to prevent overfitting.
Note:
- The model expects a flattened version of the 3x4x7 input tensor, which should be reshaped to (batch_size, 84) before being passed to the model.
- The model expects a flattened version of the 6x4x7+6 input tensor, which should be reshaped to (batch_size, 174) before being passed to the model.
"""

def __init__(self):
super(DriverModel, self).__init__()

self.fc1 = nn.Linear(3 * 4 * 7 + 3, 512)
self.fc1 = nn.Linear(6 * 4 * 7 + 6, 512)
self.bn1 = nn.BatchNorm1d(512)
self.dropout1 = nn.Dropout(0.5)

Expand Down Expand Up @@ -128,7 +128,7 @@ def view_to_inputs(array, car_lane):
Args:
array (list[list[str]]): 2D array representation of the world with obstacles as strings.
car_lane (int): current lane of the car, can be 0, 1 or 2.
car_lane (int): current lane of the car, can be 0..5 inclusive.
Returns:
torch.Tensor: A tensor of shape (1, height * width * num_obstacle_types) suitable for model input.
Expand All @@ -150,7 +150,7 @@ def view_to_inputs(array, car_lane):
tensor[i, j, OBSTACLE_TO_INDEX[obstacle]] = 1

world_tensor = tensor.view(-1)
car_lane_tensor = torch.tensor([0, 0, 0])
car_lane_tensor = torch.zeros(6)
car_lane_tensor[car_lane] = 1

return torch.cat((world_tensor, car_lane_tensor))
Expand Down
29 changes: 5 additions & 24 deletions mydriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,13 @@
# ----------------------------------------------------------------------------------


def build_lane_view(world, height, lane):
def build_lane_view(world):
"""
Build a 3xN 2D array representation of the world based on the car's lane and x position.
Args:
world (World): An instance of the World class providing read-only
access to the current game state.
height (int): The height of the returned 2D array.
lane (int) : The car's current lane, 0 or 1. This determines the starting x-coordinate for the 2D array.
Returns:
list[list[str]]: 3xN array representation of the world view from the car, where N is the specified height.
Expand All @@ -68,11 +66,9 @@ def build_lane_view(world, height, lane):
The starting x-coordinate is determined by the car's lane. If the lane is 0, the starting x is 0. If the lane is 1, the starting x is 3.
The function also provides a wrapper around world.get to handle negative y values, returning an empty string for such cases.
"""
height = 4
car_y = world.car.y

# Determine the starting x-coordinate for the 2D array based on the car's x position
start_x = 0 if lane == 0 else 3

# Calculate the starting y-coordinate based on the car's y position and the desired height
start_y = car_y - height

Expand All @@ -84,7 +80,7 @@ def get_value(j, i):

# Generate the 2D array from start_y up to world.car.y
array = [
[get_value(j + start_x, i) for j in [0, 1, 2]] for i in range(start_y, car_y)
[get_value(j, i) for j in range(6)] for i in range(start_y, car_y)
]

return array
Expand All @@ -94,30 +90,15 @@ def drive(world):
"""
Determine the appropriate driving action based on the current state of the world.
The function first constructs a 3xN 2D view of the world based on the car's position.
This view is then converted to an input tensor format suitable for the model.
Depending on the car's x position within its lane, the function uses one of two models (`model_x1` or `model_x0`)
to predict the best action. If the world view was flipped (because the car is on the rightmost side of its lane),
the action might be flipped back (e.g., from LEFT to RIGHT or vice versa).
Args:
world (World): An instance of the World class providing read-only access to the current game state.
Returns:
str: The determined action for the car to take. Possible actions include those defined in the `actions` class.
Notes:
The function uses two models (`model_x1` and `model_x0`) to predict actions based on the car's x position within its lane.
The `flip_world` flag determines if the world view was flipped horizontally, which affects the final action decision.
"""
view_height = 4

lane = 0 if world.car.x < 3 else 1
x_in_lane = world.car.x % 3

# Convert real world input, into a tensor
view = build_lane_view(world, view_height, lane)
input_tensor = view_to_inputs(view, x_in_lane).unsqueeze(0)
view = build_lane_view(world)
input_tensor = view_to_inputs(view, world.car.x).unsqueeze(0)

# Use neural network model to get the outputs tensor
output = model(input_tensor)
Expand Down
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@

def generate_obstacle_array():
"""
Generates a 4x3 2D array with random obstacles.
Generates a 4x6 2D array with random obstacles.
Returns:
list[list[str]]: 4x3 2D array with random obstacles.
list[list[str]]: 4x6 2D array with random obstacles.
"""
array = [["" for _ in range(3)] for _ in range(4)]
array = [["" for _ in range(6)] for _ in range(4)]

for i in range(4):
obstacle = random.choice(list(OBSTACLE_TO_INDEX.keys()))
position = random.randint(0, 2)
position = random.randint(0, 5)
array[i][position] = obstacle

return array
Expand Down Expand Up @@ -115,7 +115,7 @@ def generate_batch(batch_size):
inputs = []
targets = []
for _ in range(batch_size):
car_x = random.choice([0, 1, 2])
car_x = random.choice([0, 1, 2, 3, 4, 5])
array = generate_obstacle_array()
correct_output = driver_simulator(array, car_x)

Expand Down Expand Up @@ -172,7 +172,7 @@ def main():
"--checkpoint-out", default="", help="Path to the output checkpoint file."
)
parser.add_argument(
"--num-epochs", type=int, default=25, help="Number of epochs for training."
"--num-epochs", type=int, default=80, help="Number of epochs for training."
)
parser.add_argument(
"--batch-size", type=int, default=250, help="Batch size for training."
Expand Down

0 comments on commit 95c10e4

Please sign in to comment.