Compare commits

...

4 Commits

10 changed files with 143 additions and 16 deletions

5
.gitignore vendored
View File

@@ -9,4 +9,7 @@ wheels/
# Virtual environments
.venv
.vscode
.vscode
records
*.rec

40
scripts/example_bot.py Normal file
View File

@@ -0,0 +1,40 @@
from PyQt6.QtWidgets import QApplication
from src.bot import Bot
from src.command import CarControl
from src.recorder import RecorderWindow
from src.snapshot import Snapshot
class ExampleBot(Bot):
def nn_infer(self, snapshot: Snapshot) -> list[tuple[CarControl, bool]]:
# Do smart NN inference here
return [(CarControl.FORWARD, True)]
def on_snapshot_received(self, snapshot: Snapshot):
controls: list[tuple[CarControl, bool]] = self.nn_infer(snapshot)
for control, active in controls:
self.recorder.on_car_controlled(control, active)
def main():
import sys
def except_hook(cls, exception, traceback):
sys.__excepthook__(cls, exception, traceback)
sys.excepthook = except_hook
app: QApplication = QApplication(sys.argv)
recorder: RecorderWindow = RecorderWindow("localhost", 5000)
bot: ExampleBot = ExampleBot(recorder)
app.aboutToQuit.connect(recorder.shutdown)
recorder.register_bot(bot)
recorder.show()
app.exec()
if __name__ == "__main__":
main()

16
src/bot.py Normal file
View File

@@ -0,0 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from src.snapshot import Snapshot
if TYPE_CHECKING:
from src.recorder import RecorderWindow
class Bot:
def __init__(self, recorder: RecorderWindow):
self.recorder: RecorderWindow = recorder
def on_snapshot_received(self, snapshot: Snapshot):
pass

View File

@@ -8,7 +8,8 @@ from src.remote_controller import RemoteController
from src.utils import get_segments_intersection, segments_intersect
from src.vec import Vec
sign = lambda x: 0 if x == 0 else (-1 if x < 0 else 1)
def sign(x): return 0 if x == 0 else (-1 if x < 0 else 1)
class Car:
@@ -27,6 +28,8 @@ class Car:
RAYS_MAX_DIST = 100
def __init__(self, pos: Vec, direction: Vec) -> None:
self.initial_pos: Vec = pos.copy()
self.initial_dir: Vec = direction.copy()
self.pos: Vec = pos
self.direction: Vec = direction
self.speed: float = 0
@@ -77,7 +80,8 @@ class Car:
if show_raycasts:
pos: Vec = camera.world2screen(self.pos)
for p in self.rays_end:
pygame.draw.line(surf, (255, 0, 0), pos, camera.world2screen(p), 2)
pygame.draw.line(surf, (255, 0, 0), pos,
camera.world2screen(p), 2)
pts: list[Vec] = self.get_corners()
pts = [camera.world2screen(p) for p in pts]
@@ -127,14 +131,17 @@ class Car:
n *= -1
dist = -dist
self.speed = 0
self.pos = self.pos + n * (self.COLLISION_MARGIN - dist)
self.pos = self.pos + n * \
(self.COLLISION_MARGIN - dist)
return
def cast_rays(self, polygons: list[list[Vec]]):
for i in range(self.N_RAYS):
angle: float = radians((i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV)
angle: float = radians(
(i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV)
p: Optional[Vec] = self.cast_ray(angle, polygons)
self.rays[i] = self.RAYS_MAX_DIST if p is None else (p - self.pos).mag()
self.rays[i] = self.RAYS_MAX_DIST if p is None else (
p - self.pos).mag()
self.rays_end[i] = self.pos if p is None else p
def cast_ray(self, angle: float, polygons: list[list[Vec]]) -> Optional[Vec]:
@@ -161,3 +168,8 @@ class Car:
dist = d
closest = p
return closest
def reset(self):
self.pos = self.initial_pos.copy()
self.direction = self.initial_dir.copy()
self.speed = 0

View File

@@ -5,10 +5,14 @@ from enum import IntEnum
import struct
from typing import Type
from src.snapshot import Snapshot
class CommandType(IntEnum):
CAR_CONTROL = 0
RECORDING = 1
APPLY_SNAPSHOT = 2
RESET = 3
class CarControl(IntEnum):
@@ -30,8 +34,8 @@ class Command(abc.ABC):
)
Command.REGISTRY[cls.TYPE] = cls
@abc.abstractmethod
def get_payload(self) -> bytes: ...
def get_payload(self) -> bytes:
return b""
def pack(self) -> bytes:
payload: bytes = self.get_payload()
@@ -43,8 +47,8 @@ class Command(abc.ABC):
return Command.REGISTRY[type].from_payload(data[1:])
@classmethod
@abc.abstractmethod
def from_payload(cls, payload: bytes) -> Command: ...
def from_payload(cls, payload: bytes) -> Command:
return cls()
class ControlCommand(Command):
@@ -82,3 +86,24 @@ class RecordingCommand(Command):
def from_payload(cls, payload: bytes) -> Command:
state: bool = struct.unpack(">B", payload)[0]
return RecordingCommand(state)
class ApplySnapshotCommand(Command):
TYPE = CommandType.APPLY_SNAPSHOT
__match_args__ = ("snapshot",)
def __init__(self, snapshot: Snapshot) -> None:
super().__init__()
self.snapshot: Snapshot = snapshot
def get_payload(self) -> bytes:
return self.snapshot.pack()
@classmethod
def from_payload(cls, payload: bytes) -> Command:
snapshot: Snapshot = Snapshot.unpack(payload)
return ApplySnapshotCommand(snapshot)
class ResetCommand(Command):
TYPE = CommandType.RESET

View File

@@ -9,7 +9,8 @@ from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtWidgets import QMainWindow
from src.command import CarControl, Command, ControlCommand, RecordingCommand
from src.bot import Bot
from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
from src.record_file import RecordFile
from src.recorder_ui import Ui_Recorder
from src.snapshot import Snapshot
@@ -162,9 +163,11 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.recordDataButton.clicked.connect(self.toggle_record)
self.resetButton.clicked.connect(self.rollback)
self.bot: Optional[Bot] = None
self.autopiloting = False
self.autopilotButton.clicked.connect(self.toggle_autopilot)
self.autopilotButton.setDisabled(True)
self.saveRecordButton.clicked.connect(self.save_record)
@@ -203,7 +206,19 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.send_command(RecordingCommand(self.recording))
def rollback(self):
pass
rollback_by: int = self.forgetSnapshotNumber.value()
rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1))
self.snapshots = self.snapshots[:-rollback_by]
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
if len(self.snapshots) == 0:
self.send_command(ResetCommand())
else:
self.send_command(ApplySnapshotCommand(self.snapshots[-1]))
if self.recording:
self.toggle_record()
def toggle_autopilot(self):
self.autopiloting = not self.autopiloting
@@ -251,8 +266,15 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.snapshots.append(snapshot)
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
if self.autopiloting and self.bot is not None:
self.bot.on_snapshot_received(snapshot)
def shutdown(self):
self.close_signal.emit()
def send_command(self, command: Command):
self.send_signal.emit(command)
def register_bot(self, bot: Bot):
self.bot = bot
self.autopilotButton.setDisabled(False)

View File

@@ -99,7 +99,7 @@
<item>
<widget class="QPushButton" name="resetButton">
<property name="text">
<string>Reset</string>
<string>Rollback</string>
</property>
</widget>
</item>

View File

@@ -1,6 +1,6 @@
# Form implementation generated from reading ui file 'recorder.ui'
#
# Created by: PyQt6 UI code generator 6.8.0
# Created by: PyQt6 UI code generator 6.8.1
#
# WARNING: Any manual changes made to this file will be lost when pyuic6 is
# run again. Do not edit this file unless you know what you are doing.
@@ -102,7 +102,7 @@ class Ui_Recorder(object):
self.recordDataButton.setText(_translate("Recorder", "Record"))
self.saveImgCheckBox.setText(_translate("Recorder", "Imgs"))
self.saveRecordButton.setText(_translate("Recorder", "Save"))
self.resetButton.setText(_translate("Recorder", "Reset"))
self.resetButton.setText(_translate("Recorder", "Rollback"))
self.nbrSnapshotSaved.setText(_translate("Recorder", "0"))
self.autopilotButton.setText(_translate("Recorder", "AutoPilot\n"
"OFF"))

View File

@@ -6,7 +6,7 @@ import struct
import threading
from typing import TYPE_CHECKING, Optional
from src.command import CarControl, Command, ControlCommand, RecordingCommand
from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
from src.snapshot import Snapshot
from src.utils import RepeatTimer
@@ -119,6 +119,10 @@ class RemoteController:
self.set_control(control, active)
case RecordingCommand(state):
self.recording = state
case ApplySnapshotCommand(snapshot):
snapshot.apply(self.car)
case ResetCommand():
self.car.reset()
def set_control(self, control: CarControl, active: bool):
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)

View File

@@ -94,3 +94,8 @@ class Snapshot:
raycast_distances=car.rays.copy(),
image=None
)
def apply(self, car: Car):
car.pos = self.position.copy()
car.direction = self.direction.copy()
car.speed = 0