Compare commits

...

4 Commits

10 changed files with 143 additions and 16 deletions

3
.gitignore vendored
View File

@@ -10,3 +10,6 @@ wheels/
.venv .venv
.vscode .vscode
records
*.rec

40
scripts/example_bot.py Normal file
View File

@@ -0,0 +1,40 @@
from PyQt6.QtWidgets import QApplication
from src.bot import Bot
from src.command import CarControl
from src.recorder import RecorderWindow
from src.snapshot import Snapshot
class ExampleBot(Bot):
def nn_infer(self, snapshot: Snapshot) -> list[tuple[CarControl, bool]]:
# Do smart NN inference here
return [(CarControl.FORWARD, True)]
def on_snapshot_received(self, snapshot: Snapshot):
controls: list[tuple[CarControl, bool]] = self.nn_infer(snapshot)
for control, active in controls:
self.recorder.on_car_controlled(control, active)
def main():
import sys
def except_hook(cls, exception, traceback):
sys.__excepthook__(cls, exception, traceback)
sys.excepthook = except_hook
app: QApplication = QApplication(sys.argv)
recorder: RecorderWindow = RecorderWindow("localhost", 5000)
bot: ExampleBot = ExampleBot(recorder)
app.aboutToQuit.connect(recorder.shutdown)
recorder.register_bot(bot)
recorder.show()
app.exec()
if __name__ == "__main__":
main()

16
src/bot.py Normal file
View File

@@ -0,0 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from src.snapshot import Snapshot
if TYPE_CHECKING:
from src.recorder import RecorderWindow
class Bot:
def __init__(self, recorder: RecorderWindow):
self.recorder: RecorderWindow = recorder
def on_snapshot_received(self, snapshot: Snapshot):
pass

View File

@@ -8,7 +8,8 @@ from src.remote_controller import RemoteController
from src.utils import get_segments_intersection, segments_intersect from src.utils import get_segments_intersection, segments_intersect
from src.vec import Vec from src.vec import Vec
sign = lambda x: 0 if x == 0 else (-1 if x < 0 else 1)
def sign(x): return 0 if x == 0 else (-1 if x < 0 else 1)
class Car: class Car:
@@ -27,6 +28,8 @@ class Car:
RAYS_MAX_DIST = 100 RAYS_MAX_DIST = 100
def __init__(self, pos: Vec, direction: Vec) -> None: def __init__(self, pos: Vec, direction: Vec) -> None:
self.initial_pos: Vec = pos.copy()
self.initial_dir: Vec = direction.copy()
self.pos: Vec = pos self.pos: Vec = pos
self.direction: Vec = direction self.direction: Vec = direction
self.speed: float = 0 self.speed: float = 0
@@ -77,7 +80,8 @@ class Car:
if show_raycasts: if show_raycasts:
pos: Vec = camera.world2screen(self.pos) pos: Vec = camera.world2screen(self.pos)
for p in self.rays_end: for p in self.rays_end:
pygame.draw.line(surf, (255, 0, 0), pos, camera.world2screen(p), 2) pygame.draw.line(surf, (255, 0, 0), pos,
camera.world2screen(p), 2)
pts: list[Vec] = self.get_corners() pts: list[Vec] = self.get_corners()
pts = [camera.world2screen(p) for p in pts] pts = [camera.world2screen(p) for p in pts]
@@ -127,14 +131,17 @@ class Car:
n *= -1 n *= -1
dist = -dist dist = -dist
self.speed = 0 self.speed = 0
self.pos = self.pos + n * (self.COLLISION_MARGIN - dist) self.pos = self.pos + n * \
(self.COLLISION_MARGIN - dist)
return return
def cast_rays(self, polygons: list[list[Vec]]): def cast_rays(self, polygons: list[list[Vec]]):
for i in range(self.N_RAYS): for i in range(self.N_RAYS):
angle: float = radians((i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV) angle: float = radians(
(i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV)
p: Optional[Vec] = self.cast_ray(angle, polygons) p: Optional[Vec] = self.cast_ray(angle, polygons)
self.rays[i] = self.RAYS_MAX_DIST if p is None else (p - self.pos).mag() self.rays[i] = self.RAYS_MAX_DIST if p is None else (
p - self.pos).mag()
self.rays_end[i] = self.pos if p is None else p self.rays_end[i] = self.pos if p is None else p
def cast_ray(self, angle: float, polygons: list[list[Vec]]) -> Optional[Vec]: def cast_ray(self, angle: float, polygons: list[list[Vec]]) -> Optional[Vec]:
@@ -161,3 +168,8 @@ class Car:
dist = d dist = d
closest = p closest = p
return closest return closest
def reset(self):
self.pos = self.initial_pos.copy()
self.direction = self.initial_dir.copy()
self.speed = 0

View File

@@ -5,10 +5,14 @@ from enum import IntEnum
import struct import struct
from typing import Type from typing import Type
from src.snapshot import Snapshot
class CommandType(IntEnum): class CommandType(IntEnum):
CAR_CONTROL = 0 CAR_CONTROL = 0
RECORDING = 1 RECORDING = 1
APPLY_SNAPSHOT = 2
RESET = 3
class CarControl(IntEnum): class CarControl(IntEnum):
@@ -30,8 +34,8 @@ class Command(abc.ABC):
) )
Command.REGISTRY[cls.TYPE] = cls Command.REGISTRY[cls.TYPE] = cls
@abc.abstractmethod def get_payload(self) -> bytes:
def get_payload(self) -> bytes: ... return b""
def pack(self) -> bytes: def pack(self) -> bytes:
payload: bytes = self.get_payload() payload: bytes = self.get_payload()
@@ -43,8 +47,8 @@ class Command(abc.ABC):
return Command.REGISTRY[type].from_payload(data[1:]) return Command.REGISTRY[type].from_payload(data[1:])
@classmethod @classmethod
@abc.abstractmethod def from_payload(cls, payload: bytes) -> Command:
def from_payload(cls, payload: bytes) -> Command: ... return cls()
class ControlCommand(Command): class ControlCommand(Command):
@@ -82,3 +86,24 @@ class RecordingCommand(Command):
def from_payload(cls, payload: bytes) -> Command: def from_payload(cls, payload: bytes) -> Command:
state: bool = struct.unpack(">B", payload)[0] state: bool = struct.unpack(">B", payload)[0]
return RecordingCommand(state) return RecordingCommand(state)
class ApplySnapshotCommand(Command):
TYPE = CommandType.APPLY_SNAPSHOT
__match_args__ = ("snapshot",)
def __init__(self, snapshot: Snapshot) -> None:
super().__init__()
self.snapshot: Snapshot = snapshot
def get_payload(self) -> bytes:
return self.snapshot.pack()
@classmethod
def from_payload(cls, payload: bytes) -> Command:
snapshot: Snapshot = Snapshot.unpack(payload)
return ApplySnapshotCommand(snapshot)
class ResetCommand(Command):
TYPE = CommandType.RESET

View File

@@ -9,7 +9,8 @@ from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
from PyQt6.QtGui import QKeyEvent from PyQt6.QtGui import QKeyEvent
from PyQt6.QtWidgets import QMainWindow from PyQt6.QtWidgets import QMainWindow
from src.command import CarControl, Command, ControlCommand, RecordingCommand from src.bot import Bot
from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
from src.record_file import RecordFile from src.record_file import RecordFile
from src.recorder_ui import Ui_Recorder from src.recorder_ui import Ui_Recorder
from src.snapshot import Snapshot from src.snapshot import Snapshot
@@ -162,9 +163,11 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.recordDataButton.clicked.connect(self.toggle_record) self.recordDataButton.clicked.connect(self.toggle_record)
self.resetButton.clicked.connect(self.rollback) self.resetButton.clicked.connect(self.rollback)
self.bot: Optional[Bot] = None
self.autopiloting = False self.autopiloting = False
self.autopilotButton.clicked.connect(self.toggle_autopilot) self.autopilotButton.clicked.connect(self.toggle_autopilot)
self.autopilotButton.setDisabled(True)
self.saveRecordButton.clicked.connect(self.save_record) self.saveRecordButton.clicked.connect(self.save_record)
@@ -203,7 +206,19 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.send_command(RecordingCommand(self.recording)) self.send_command(RecordingCommand(self.recording))
def rollback(self): def rollback(self):
pass rollback_by: int = self.forgetSnapshotNumber.value()
rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1))
self.snapshots = self.snapshots[:-rollback_by]
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
if len(self.snapshots) == 0:
self.send_command(ResetCommand())
else:
self.send_command(ApplySnapshotCommand(self.snapshots[-1]))
if self.recording:
self.toggle_record()
def toggle_autopilot(self): def toggle_autopilot(self):
self.autopiloting = not self.autopiloting self.autopiloting = not self.autopiloting
@@ -251,8 +266,15 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.snapshots.append(snapshot) self.snapshots.append(snapshot)
self.nbrSnapshotSaved.setText(str(len(self.snapshots))) self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
if self.autopiloting and self.bot is not None:
self.bot.on_snapshot_received(snapshot)
def shutdown(self): def shutdown(self):
self.close_signal.emit() self.close_signal.emit()
def send_command(self, command: Command): def send_command(self, command: Command):
self.send_signal.emit(command) self.send_signal.emit(command)
def register_bot(self, bot: Bot):
self.bot = bot
self.autopilotButton.setDisabled(False)

View File

@@ -99,7 +99,7 @@
<item> <item>
<widget class="QPushButton" name="resetButton"> <widget class="QPushButton" name="resetButton">
<property name="text"> <property name="text">
<string>Reset</string> <string>Rollback</string>
</property> </property>
</widget> </widget>
</item> </item>

View File

@@ -1,6 +1,6 @@
# Form implementation generated from reading ui file 'recorder.ui' # Form implementation generated from reading ui file 'recorder.ui'
# #
# Created by: PyQt6 UI code generator 6.8.0 # Created by: PyQt6 UI code generator 6.8.1
# #
# WARNING: Any manual changes made to this file will be lost when pyuic6 is # WARNING: Any manual changes made to this file will be lost when pyuic6 is
# run again. Do not edit this file unless you know what you are doing. # run again. Do not edit this file unless you know what you are doing.
@@ -102,7 +102,7 @@ class Ui_Recorder(object):
self.recordDataButton.setText(_translate("Recorder", "Record")) self.recordDataButton.setText(_translate("Recorder", "Record"))
self.saveImgCheckBox.setText(_translate("Recorder", "Imgs")) self.saveImgCheckBox.setText(_translate("Recorder", "Imgs"))
self.saveRecordButton.setText(_translate("Recorder", "Save")) self.saveRecordButton.setText(_translate("Recorder", "Save"))
self.resetButton.setText(_translate("Recorder", "Reset")) self.resetButton.setText(_translate("Recorder", "Rollback"))
self.nbrSnapshotSaved.setText(_translate("Recorder", "0")) self.nbrSnapshotSaved.setText(_translate("Recorder", "0"))
self.autopilotButton.setText(_translate("Recorder", "AutoPilot\n" self.autopilotButton.setText(_translate("Recorder", "AutoPilot\n"
"OFF")) "OFF"))

View File

@@ -6,7 +6,7 @@ import struct
import threading import threading
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from src.command import CarControl, Command, ControlCommand, RecordingCommand from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
from src.snapshot import Snapshot from src.snapshot import Snapshot
from src.utils import RepeatTimer from src.utils import RepeatTimer
@@ -119,6 +119,10 @@ class RemoteController:
self.set_control(control, active) self.set_control(control, active)
case RecordingCommand(state): case RecordingCommand(state):
self.recording = state self.recording = state
case ApplySnapshotCommand(snapshot):
snapshot.apply(self.car)
case ResetCommand():
self.car.reset()
def set_control(self, control: CarControl, active: bool): def set_control(self, control: CarControl, active: bool):
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)

View File

@@ -94,3 +94,8 @@ class Snapshot:
raycast_distances=car.rays.copy(), raycast_distances=car.rays.copy(),
image=None image=None
) )
def apply(self, car: Car):
car.pos = self.position.copy()
car.direction = self.direction.copy()
car.speed = 0