Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use 4 by 6 instead of 3 by N #7

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified checkpoints/driver.pth
Binary file not shown.
26 changes: 0 additions & 26 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,29 +175,3 @@ def outputs_to_action(output):
action = actions.ALL[position_index]

return action


def action_to_outputs(action):
"""
Converts an action into a target tensor.

This function takes an action (LEFT, RIGHT, or other) and converts it into a target tensor with three elements.
The tensor's elements correspond to the actions LEFT, forward, and RIGHT respectively. The element corresponding
to the given action is set to 1, and the others are set to 0.

Args:
action (str): The action to convert. Should be one of the actions defined in the `actions` class.

Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(7)

try:
action_index = actions.ALL.index(action)
except ValueError:
action_index = 0

target[action_index] = 1

return target
20 changes: 6 additions & 14 deletions mydriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,28 @@

def build_lane_view(world):
"""
Build a 3xN 2D array representation of the world based on the car's lane and x position.
Build a 6x4 2D array representation of the world based on the car's lane and x position.

Args:
world (World): An instance of the World class providing read-only
access to the current game state.

Returns:
list[list[str]]: 3xN array representation of the world view from the car, where N is the specified height.
list[list[str]]: 6x4 array representation of the world view from the car, where 4 is the specified height.
The bottom line is one line above the car's y position, and the top line is the line height lines above that.
The array provides a view of the world from the car's perspective, with the car's y position excluded.

Notes:
The function uses the car's y position to determine the vertical range of the 2D array.
The starting x-coordinate is determined by the car's lane. If the lane is 0, the starting x is 0. If the lane is 1, the starting x is 3.
The function also provides a wrapper around world.get to handle negative y values, returning an empty string for such cases.
"""
height = 4
width = 6
car_y = world.car.y

# Calculate the starting y-coordinate based on the car's y position and the desired height
start_y = car_y - height

# Wrapper around world.get to handle negative y values
def get_value(j, i):
if i < 0:
return ""
return world.get((j, i))

# Generate the 2D array from start_y up to world.car.y
array = [[get_value(j, i) for j in range(6)] for i in range(start_y, car_y)]
array = [
[world.get((j, i)) for j in range(width)] for i in range(car_y - height, car_y)
]

return array

Expand Down
28 changes: 27 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch.nn as nn
import torch.optim as optim

from model import DriverModel, action_to_outputs, actions, obstacles, view_to_inputs
from model import DriverModel, actions, obstacles, view_to_inputs

# Training parameters
num_epochs = 0
Expand Down Expand Up @@ -105,6 +105,32 @@ def driver_simulator(array, car_x):
return action


def action_to_outputs(action):
"""
Converts an action into a target tensor.

This function takes an action (LEFT, RIGHT, or other) and converts it into a target tensor with three elements.
The tensor's elements correspond to the actions LEFT, forward, and RIGHT respectively. The element corresponding
to the given action is set to 1, and the others are set to 0.

Args:
action (str): The action to convert. Should be one of the actions defined in the `actions` class.

Returns:
torch.Tensor: A tensor of shape (3,) where the element corresponding to the given action is 1, and the others are 0.
"""
target = torch.zeros(7)

try:
action_index = actions.ALL.index(action)
except ValueError:
action_index = 0

target[action_index] = 1

return target


def generate_batch(batch_size):
"""
Generates a batch of samples for training.
Expand Down
Loading