Compare commits
	
		
			4 Commits
		
	
	
		
			8b7927a3c5
			...
			fa61e27825
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| fa61e27825 | |||
| b60a0aba4f | |||
| ae02ddefb0 | |||
| f1fadd123f | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -10,3 +10,6 @@ wheels/ | |||||||
| .venv | .venv | ||||||
|  |  | ||||||
| .vscode | .vscode | ||||||
|  |  | ||||||
|  | records | ||||||
|  | *.rec | ||||||
							
								
								
									
										40
									
								
								scripts/example_bot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								scripts/example_bot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | from PyQt6.QtWidgets import QApplication | ||||||
|  |  | ||||||
|  | from src.bot import Bot | ||||||
|  | from src.command import CarControl | ||||||
|  | from src.recorder import RecorderWindow | ||||||
|  | from src.snapshot import Snapshot | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ExampleBot(Bot): | ||||||
|  |     def nn_infer(self, snapshot: Snapshot) -> list[tuple[CarControl, bool]]: | ||||||
|  |         #   Do smart NN inference here | ||||||
|  |         return [(CarControl.FORWARD, True)] | ||||||
|  |  | ||||||
|  |     def on_snapshot_received(self, snapshot: Snapshot): | ||||||
|  |         controls: list[tuple[CarControl, bool]] = self.nn_infer(snapshot) | ||||||
|  |         for control, active in controls: | ||||||
|  |             self.recorder.on_car_controlled(control, active) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     import sys | ||||||
|  |  | ||||||
|  |     def except_hook(cls, exception, traceback): | ||||||
|  |         sys.__excepthook__(cls, exception, traceback) | ||||||
|  |  | ||||||
|  |     sys.excepthook = except_hook | ||||||
|  |  | ||||||
|  |     app: QApplication = QApplication(sys.argv) | ||||||
|  |     recorder: RecorderWindow = RecorderWindow("localhost", 5000) | ||||||
|  |     bot: ExampleBot = ExampleBot(recorder) | ||||||
|  |  | ||||||
|  |     app.aboutToQuit.connect(recorder.shutdown) | ||||||
|  |     recorder.register_bot(bot) | ||||||
|  |     recorder.show() | ||||||
|  |  | ||||||
|  |     app.exec() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
							
								
								
									
										16
									
								
								src/bot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								src/bot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
|  | from src.snapshot import Snapshot | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from src.recorder import RecorderWindow | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Bot: | ||||||
|  |     def __init__(self, recorder: RecorderWindow): | ||||||
|  |         self.recorder: RecorderWindow = recorder | ||||||
|  |  | ||||||
|  |     def on_snapshot_received(self, snapshot: Snapshot): | ||||||
|  |         pass | ||||||
							
								
								
									
										22
									
								
								src/car.py
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								src/car.py
									
									
									
									
									
								
							| @@ -8,7 +8,8 @@ from src.remote_controller import RemoteController | |||||||
| from src.utils import get_segments_intersection, segments_intersect | from src.utils import get_segments_intersection, segments_intersect | ||||||
| from src.vec import Vec | from src.vec import Vec | ||||||
|  |  | ||||||
| sign = lambda x: 0 if x == 0 else (-1 if x < 0 else 1) |  | ||||||
|  | def sign(x): return 0 if x == 0 else (-1 if x < 0 else 1) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Car: | class Car: | ||||||
| @@ -27,6 +28,8 @@ class Car: | |||||||
|     RAYS_MAX_DIST = 100 |     RAYS_MAX_DIST = 100 | ||||||
|  |  | ||||||
|     def __init__(self, pos: Vec, direction: Vec) -> None: |     def __init__(self, pos: Vec, direction: Vec) -> None: | ||||||
|  |         self.initial_pos: Vec = pos.copy() | ||||||
|  |         self.initial_dir: Vec = direction.copy() | ||||||
|         self.pos: Vec = pos |         self.pos: Vec = pos | ||||||
|         self.direction: Vec = direction |         self.direction: Vec = direction | ||||||
|         self.speed: float = 0 |         self.speed: float = 0 | ||||||
| @@ -77,7 +80,8 @@ class Car: | |||||||
|         if show_raycasts: |         if show_raycasts: | ||||||
|             pos: Vec = camera.world2screen(self.pos) |             pos: Vec = camera.world2screen(self.pos) | ||||||
|             for p in self.rays_end: |             for p in self.rays_end: | ||||||
|                 pygame.draw.line(surf, (255, 0, 0), pos, camera.world2screen(p), 2) |                 pygame.draw.line(surf, (255, 0, 0), pos, | ||||||
|  |                                  camera.world2screen(p), 2) | ||||||
|  |  | ||||||
|         pts: list[Vec] = self.get_corners() |         pts: list[Vec] = self.get_corners() | ||||||
|         pts = [camera.world2screen(p) for p in pts] |         pts = [camera.world2screen(p) for p in pts] | ||||||
| @@ -127,14 +131,17 @@ class Car: | |||||||
|                             n *= -1 |                             n *= -1 | ||||||
|                             dist = -dist |                             dist = -dist | ||||||
|                         self.speed = 0 |                         self.speed = 0 | ||||||
|                         self.pos = self.pos + n * (self.COLLISION_MARGIN - dist) |                         self.pos = self.pos + n * \ | ||||||
|  |                             (self.COLLISION_MARGIN - dist) | ||||||
|                         return |                         return | ||||||
|  |  | ||||||
|     def cast_rays(self, polygons: list[list[Vec]]): |     def cast_rays(self, polygons: list[list[Vec]]): | ||||||
|         for i in range(self.N_RAYS): |         for i in range(self.N_RAYS): | ||||||
|             angle: float = radians((i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV) |             angle: float = radians( | ||||||
|  |                 (i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV) | ||||||
|             p: Optional[Vec] = self.cast_ray(angle, polygons) |             p: Optional[Vec] = self.cast_ray(angle, polygons) | ||||||
|             self.rays[i] = self.RAYS_MAX_DIST if p is None else (p - self.pos).mag() |             self.rays[i] = self.RAYS_MAX_DIST if p is None else ( | ||||||
|  |                 p - self.pos).mag() | ||||||
|             self.rays_end[i] = self.pos if p is None else p |             self.rays_end[i] = self.pos if p is None else p | ||||||
|  |  | ||||||
|     def cast_ray(self, angle: float, polygons: list[list[Vec]]) -> Optional[Vec]: |     def cast_ray(self, angle: float, polygons: list[list[Vec]]) -> Optional[Vec]: | ||||||
| @@ -161,3 +168,8 @@ class Car: | |||||||
|                     dist = d |                     dist = d | ||||||
|                     closest = p |                     closest = p | ||||||
|         return closest |         return closest | ||||||
|  |  | ||||||
|  |     def reset(self): | ||||||
|  |         self.pos = self.initial_pos.copy() | ||||||
|  |         self.direction = self.initial_dir.copy() | ||||||
|  |         self.speed = 0 | ||||||
|   | |||||||
| @@ -5,10 +5,14 @@ from enum import IntEnum | |||||||
| import struct | import struct | ||||||
| from typing import Type | from typing import Type | ||||||
|  |  | ||||||
|  | from src.snapshot import Snapshot | ||||||
|  |  | ||||||
|  |  | ||||||
| class CommandType(IntEnum): | class CommandType(IntEnum): | ||||||
|     CAR_CONTROL = 0 |     CAR_CONTROL = 0 | ||||||
|     RECORDING = 1 |     RECORDING = 1 | ||||||
|  |     APPLY_SNAPSHOT = 2 | ||||||
|  |     RESET = 3 | ||||||
|  |  | ||||||
|  |  | ||||||
| class CarControl(IntEnum): | class CarControl(IntEnum): | ||||||
| @@ -30,8 +34,8 @@ class Command(abc.ABC): | |||||||
|             ) |             ) | ||||||
|         Command.REGISTRY[cls.TYPE] = cls |         Command.REGISTRY[cls.TYPE] = cls | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     def get_payload(self) -> bytes: | ||||||
|     def get_payload(self) -> bytes: ... |         return b"" | ||||||
|  |  | ||||||
|     def pack(self) -> bytes: |     def pack(self) -> bytes: | ||||||
|         payload: bytes = self.get_payload() |         payload: bytes = self.get_payload() | ||||||
| @@ -43,8 +47,8 @@ class Command(abc.ABC): | |||||||
|         return Command.REGISTRY[type].from_payload(data[1:]) |         return Command.REGISTRY[type].from_payload(data[1:]) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     @abc.abstractmethod |     def from_payload(cls, payload: bytes) -> Command: | ||||||
|     def from_payload(cls, payload: bytes) -> Command: ... |         return cls() | ||||||
|  |  | ||||||
|  |  | ||||||
| class ControlCommand(Command): | class ControlCommand(Command): | ||||||
| @@ -82,3 +86,24 @@ class RecordingCommand(Command): | |||||||
|     def from_payload(cls, payload: bytes) -> Command: |     def from_payload(cls, payload: bytes) -> Command: | ||||||
|         state: bool = struct.unpack(">B", payload)[0] |         state: bool = struct.unpack(">B", payload)[0] | ||||||
|         return RecordingCommand(state) |         return RecordingCommand(state) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ApplySnapshotCommand(Command): | ||||||
|  |     TYPE = CommandType.APPLY_SNAPSHOT | ||||||
|  |     __match_args__ = ("snapshot",) | ||||||
|  |  | ||||||
|  |     def __init__(self, snapshot: Snapshot) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.snapshot: Snapshot = snapshot | ||||||
|  |  | ||||||
|  |     def get_payload(self) -> bytes: | ||||||
|  |         return self.snapshot.pack() | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_payload(cls, payload: bytes) -> Command: | ||||||
|  |         snapshot: Snapshot = Snapshot.unpack(payload) | ||||||
|  |         return ApplySnapshotCommand(snapshot) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ResetCommand(Command): | ||||||
|  |     TYPE = CommandType.RESET | ||||||
|   | |||||||
| @@ -9,7 +9,8 @@ from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot | |||||||
| from PyQt6.QtGui import QKeyEvent | from PyQt6.QtGui import QKeyEvent | ||||||
| from PyQt6.QtWidgets import QMainWindow | from PyQt6.QtWidgets import QMainWindow | ||||||
|  |  | ||||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | from src.bot import Bot | ||||||
|  | from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand | ||||||
| from src.record_file import RecordFile | from src.record_file import RecordFile | ||||||
| from src.recorder_ui import Ui_Recorder | from src.recorder_ui import Ui_Recorder | ||||||
| from src.snapshot import Snapshot | from src.snapshot import Snapshot | ||||||
| @@ -162,9 +163,11 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | |||||||
|         self.recordDataButton.clicked.connect(self.toggle_record) |         self.recordDataButton.clicked.connect(self.toggle_record) | ||||||
|         self.resetButton.clicked.connect(self.rollback) |         self.resetButton.clicked.connect(self.rollback) | ||||||
|  |  | ||||||
|  |         self.bot: Optional[Bot] = None | ||||||
|         self.autopiloting = False |         self.autopiloting = False | ||||||
|  |  | ||||||
|         self.autopilotButton.clicked.connect(self.toggle_autopilot) |         self.autopilotButton.clicked.connect(self.toggle_autopilot) | ||||||
|  |         self.autopilotButton.setDisabled(True) | ||||||
|  |  | ||||||
|         self.saveRecordButton.clicked.connect(self.save_record) |         self.saveRecordButton.clicked.connect(self.save_record) | ||||||
|  |  | ||||||
| @@ -203,7 +206,19 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | |||||||
|         self.send_command(RecordingCommand(self.recording)) |         self.send_command(RecordingCommand(self.recording)) | ||||||
|  |  | ||||||
|     def rollback(self): |     def rollback(self): | ||||||
|         pass |         rollback_by: int = self.forgetSnapshotNumber.value() | ||||||
|  |         rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1)) | ||||||
|  |  | ||||||
|  |         self.snapshots = self.snapshots[:-rollback_by] | ||||||
|  |         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) | ||||||
|  |  | ||||||
|  |         if len(self.snapshots) == 0: | ||||||
|  |             self.send_command(ResetCommand()) | ||||||
|  |         else: | ||||||
|  |             self.send_command(ApplySnapshotCommand(self.snapshots[-1])) | ||||||
|  |  | ||||||
|  |         if self.recording: | ||||||
|  |             self.toggle_record() | ||||||
|  |  | ||||||
|     def toggle_autopilot(self): |     def toggle_autopilot(self): | ||||||
|         self.autopiloting = not self.autopiloting |         self.autopiloting = not self.autopiloting | ||||||
| @@ -251,8 +266,15 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | |||||||
|         self.snapshots.append(snapshot) |         self.snapshots.append(snapshot) | ||||||
|         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) |         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) | ||||||
|  |  | ||||||
|  |         if self.autopiloting and self.bot is not None: | ||||||
|  |             self.bot.on_snapshot_received(snapshot) | ||||||
|  |  | ||||||
|     def shutdown(self): |     def shutdown(self): | ||||||
|         self.close_signal.emit() |         self.close_signal.emit() | ||||||
|  |  | ||||||
|     def send_command(self, command: Command): |     def send_command(self, command: Command): | ||||||
|         self.send_signal.emit(command) |         self.send_signal.emit(command) | ||||||
|  |  | ||||||
|  |     def register_bot(self, bot: Bot): | ||||||
|  |         self.bot = bot | ||||||
|  |         self.autopilotButton.setDisabled(False) | ||||||
|   | |||||||
| @@ -99,7 +99,7 @@ | |||||||
|       <item> |       <item> | ||||||
|        <widget class="QPushButton" name="resetButton"> |        <widget class="QPushButton" name="resetButton"> | ||||||
|         <property name="text"> |         <property name="text"> | ||||||
|          <string>Reset</string> |          <string>Rollback</string> | ||||||
|         </property> |         </property> | ||||||
|        </widget> |        </widget> | ||||||
|       </item> |       </item> | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| # Form implementation generated from reading ui file 'recorder.ui' | # Form implementation generated from reading ui file 'recorder.ui' | ||||||
| # | # | ||||||
| # Created by: PyQt6 UI code generator 6.8.0 | # Created by: PyQt6 UI code generator 6.8.1 | ||||||
| # | # | ||||||
| # WARNING: Any manual changes made to this file will be lost when pyuic6 is | # WARNING: Any manual changes made to this file will be lost when pyuic6 is | ||||||
| # run again.  Do not edit this file unless you know what you are doing. | # run again.  Do not edit this file unless you know what you are doing. | ||||||
| @@ -102,7 +102,7 @@ class Ui_Recorder(object): | |||||||
|         self.recordDataButton.setText(_translate("Recorder", "Record")) |         self.recordDataButton.setText(_translate("Recorder", "Record")) | ||||||
|         self.saveImgCheckBox.setText(_translate("Recorder", "Imgs")) |         self.saveImgCheckBox.setText(_translate("Recorder", "Imgs")) | ||||||
|         self.saveRecordButton.setText(_translate("Recorder", "Save")) |         self.saveRecordButton.setText(_translate("Recorder", "Save")) | ||||||
|         self.resetButton.setText(_translate("Recorder", "Reset")) |         self.resetButton.setText(_translate("Recorder", "Rollback")) | ||||||
|         self.nbrSnapshotSaved.setText(_translate("Recorder", "0")) |         self.nbrSnapshotSaved.setText(_translate("Recorder", "0")) | ||||||
|         self.autopilotButton.setText(_translate("Recorder", "AutoPilot\n" |         self.autopilotButton.setText(_translate("Recorder", "AutoPilot\n" | ||||||
| "OFF")) | "OFF")) | ||||||
|   | |||||||
| @@ -6,7 +6,7 @@ import struct | |||||||
| import threading | import threading | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand | ||||||
| from src.snapshot import Snapshot | from src.snapshot import Snapshot | ||||||
| from src.utils import RepeatTimer | from src.utils import RepeatTimer | ||||||
|  |  | ||||||
| @@ -119,6 +119,10 @@ class RemoteController: | |||||||
|                 self.set_control(control, active) |                 self.set_control(control, active) | ||||||
|             case RecordingCommand(state): |             case RecordingCommand(state): | ||||||
|                 self.recording = state |                 self.recording = state | ||||||
|  |             case ApplySnapshotCommand(snapshot): | ||||||
|  |                 snapshot.apply(self.car) | ||||||
|  |             case ResetCommand(): | ||||||
|  |                 self.car.reset() | ||||||
|  |  | ||||||
|     def set_control(self, control: CarControl, active: bool): |     def set_control(self, control: CarControl, active: bool): | ||||||
|         setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) |         setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) | ||||||
|   | |||||||
| @@ -94,3 +94,8 @@ class Snapshot: | |||||||
|             raycast_distances=car.rays.copy(), |             raycast_distances=car.rays.copy(), | ||||||
|             image=None |             image=None | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def apply(self, car: Car): | ||||||
|  |         car.pos = self.position.copy() | ||||||
|  |         car.direction = self.direction.copy() | ||||||
|  |         car.speed = 0 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user