Compare commits
	
		
			4 Commits
		
	
	
		
			8b7927a3c5
			...
			fa61e27825
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| fa61e27825 | |||
| b60a0aba4f | |||
| ae02ddefb0 | |||
| f1fadd123f | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -10,3 +10,6 @@ wheels/ | ||||
| .venv | ||||
|  | ||||
| .vscode | ||||
|  | ||||
| records | ||||
| *.rec | ||||
							
								
								
									
										40
									
								
								scripts/example_bot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								scripts/example_bot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| from PyQt6.QtWidgets import QApplication | ||||
|  | ||||
| from src.bot import Bot | ||||
| from src.command import CarControl | ||||
| from src.recorder import RecorderWindow | ||||
| from src.snapshot import Snapshot | ||||
|  | ||||
|  | ||||
| class ExampleBot(Bot): | ||||
|     def nn_infer(self, snapshot: Snapshot) -> list[tuple[CarControl, bool]]: | ||||
|         #   Do smart NN inference here | ||||
|         return [(CarControl.FORWARD, True)] | ||||
|  | ||||
|     def on_snapshot_received(self, snapshot: Snapshot): | ||||
|         controls: list[tuple[CarControl, bool]] = self.nn_infer(snapshot) | ||||
|         for control, active in controls: | ||||
|             self.recorder.on_car_controlled(control, active) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     import sys | ||||
|  | ||||
|     def except_hook(cls, exception, traceback): | ||||
|         sys.__excepthook__(cls, exception, traceback) | ||||
|  | ||||
|     sys.excepthook = except_hook | ||||
|  | ||||
|     app: QApplication = QApplication(sys.argv) | ||||
|     recorder: RecorderWindow = RecorderWindow("localhost", 5000) | ||||
|     bot: ExampleBot = ExampleBot(recorder) | ||||
|  | ||||
|     app.aboutToQuit.connect(recorder.shutdown) | ||||
|     recorder.register_bot(bot) | ||||
|     recorder.show() | ||||
|  | ||||
|     app.exec() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										16
									
								
								src/bot.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								src/bot.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from src.snapshot import Snapshot | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from src.recorder import RecorderWindow | ||||
|  | ||||
|  | ||||
| class Bot: | ||||
|     def __init__(self, recorder: RecorderWindow): | ||||
|         self.recorder: RecorderWindow = recorder | ||||
|  | ||||
|     def on_snapshot_received(self, snapshot: Snapshot): | ||||
|         pass | ||||
							
								
								
									
										22
									
								
								src/car.py
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								src/car.py
									
									
									
									
									
								
							| @@ -8,7 +8,8 @@ from src.remote_controller import RemoteController | ||||
| from src.utils import get_segments_intersection, segments_intersect | ||||
| from src.vec import Vec | ||||
|  | ||||
| sign = lambda x: 0 if x == 0 else (-1 if x < 0 else 1) | ||||
|  | ||||
| def sign(x): return 0 if x == 0 else (-1 if x < 0 else 1) | ||||
|  | ||||
|  | ||||
| class Car: | ||||
| @@ -27,6 +28,8 @@ class Car: | ||||
|     RAYS_MAX_DIST = 100 | ||||
|  | ||||
|     def __init__(self, pos: Vec, direction: Vec) -> None: | ||||
|         self.initial_pos: Vec = pos.copy() | ||||
|         self.initial_dir: Vec = direction.copy() | ||||
|         self.pos: Vec = pos | ||||
|         self.direction: Vec = direction | ||||
|         self.speed: float = 0 | ||||
| @@ -77,7 +80,8 @@ class Car: | ||||
|         if show_raycasts: | ||||
|             pos: Vec = camera.world2screen(self.pos) | ||||
|             for p in self.rays_end: | ||||
|                 pygame.draw.line(surf, (255, 0, 0), pos, camera.world2screen(p), 2) | ||||
|                 pygame.draw.line(surf, (255, 0, 0), pos, | ||||
|                                  camera.world2screen(p), 2) | ||||
|  | ||||
|         pts: list[Vec] = self.get_corners() | ||||
|         pts = [camera.world2screen(p) for p in pts] | ||||
| @@ -127,14 +131,17 @@ class Car: | ||||
|                             n *= -1 | ||||
|                             dist = -dist | ||||
|                         self.speed = 0 | ||||
|                         self.pos = self.pos + n * (self.COLLISION_MARGIN - dist) | ||||
|                         self.pos = self.pos + n * \ | ||||
|                             (self.COLLISION_MARGIN - dist) | ||||
|                         return | ||||
|  | ||||
|     def cast_rays(self, polygons: list[list[Vec]]): | ||||
|         for i in range(self.N_RAYS): | ||||
|             angle: float = radians((i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV) | ||||
|             angle: float = radians( | ||||
|                 (i / (self.N_RAYS - 1) - 0.5) * self.RAYS_FOV) | ||||
|             p: Optional[Vec] = self.cast_ray(angle, polygons) | ||||
|             self.rays[i] = self.RAYS_MAX_DIST if p is None else (p - self.pos).mag() | ||||
|             self.rays[i] = self.RAYS_MAX_DIST if p is None else ( | ||||
|                 p - self.pos).mag() | ||||
|             self.rays_end[i] = self.pos if p is None else p | ||||
|  | ||||
|     def cast_ray(self, angle: float, polygons: list[list[Vec]]) -> Optional[Vec]: | ||||
| @@ -161,3 +168,8 @@ class Car: | ||||
|                     dist = d | ||||
|                     closest = p | ||||
|         return closest | ||||
|  | ||||
|     def reset(self): | ||||
|         self.pos = self.initial_pos.copy() | ||||
|         self.direction = self.initial_dir.copy() | ||||
|         self.speed = 0 | ||||
|   | ||||
| @@ -5,10 +5,14 @@ from enum import IntEnum | ||||
| import struct | ||||
| from typing import Type | ||||
|  | ||||
| from src.snapshot import Snapshot | ||||
|  | ||||
|  | ||||
| class CommandType(IntEnum): | ||||
|     CAR_CONTROL = 0 | ||||
|     RECORDING = 1 | ||||
|     APPLY_SNAPSHOT = 2 | ||||
|     RESET = 3 | ||||
|  | ||||
|  | ||||
| class CarControl(IntEnum): | ||||
| @@ -30,8 +34,8 @@ class Command(abc.ABC): | ||||
|             ) | ||||
|         Command.REGISTRY[cls.TYPE] = cls | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def get_payload(self) -> bytes: ... | ||||
|     def get_payload(self) -> bytes: | ||||
|         return b"" | ||||
|  | ||||
|     def pack(self) -> bytes: | ||||
|         payload: bytes = self.get_payload() | ||||
| @@ -43,8 +47,8 @@ class Command(abc.ABC): | ||||
|         return Command.REGISTRY[type].from_payload(data[1:]) | ||||
|  | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     def from_payload(cls, payload: bytes) -> Command: ... | ||||
|     def from_payload(cls, payload: bytes) -> Command: | ||||
|         return cls() | ||||
|  | ||||
|  | ||||
| class ControlCommand(Command): | ||||
| @@ -82,3 +86,24 @@ class RecordingCommand(Command): | ||||
|     def from_payload(cls, payload: bytes) -> Command: | ||||
|         state: bool = struct.unpack(">B", payload)[0] | ||||
|         return RecordingCommand(state) | ||||
|  | ||||
|  | ||||
| class ApplySnapshotCommand(Command): | ||||
|     TYPE = CommandType.APPLY_SNAPSHOT | ||||
|     __match_args__ = ("snapshot",) | ||||
|  | ||||
|     def __init__(self, snapshot: Snapshot) -> None: | ||||
|         super().__init__() | ||||
|         self.snapshot: Snapshot = snapshot | ||||
|  | ||||
|     def get_payload(self) -> bytes: | ||||
|         return self.snapshot.pack() | ||||
|  | ||||
|     @classmethod | ||||
|     def from_payload(cls, payload: bytes) -> Command: | ||||
|         snapshot: Snapshot = Snapshot.unpack(payload) | ||||
|         return ApplySnapshotCommand(snapshot) | ||||
|  | ||||
|  | ||||
| class ResetCommand(Command): | ||||
|     TYPE = CommandType.RESET | ||||
|   | ||||
| @@ -9,7 +9,8 @@ from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot | ||||
| from PyQt6.QtGui import QKeyEvent | ||||
| from PyQt6.QtWidgets import QMainWindow | ||||
|  | ||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||
| from src.bot import Bot | ||||
| from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand | ||||
| from src.record_file import RecordFile | ||||
| from src.recorder_ui import Ui_Recorder | ||||
| from src.snapshot import Snapshot | ||||
| @@ -162,9 +163,11 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|         self.recordDataButton.clicked.connect(self.toggle_record) | ||||
|         self.resetButton.clicked.connect(self.rollback) | ||||
|  | ||||
|         self.bot: Optional[Bot] = None | ||||
|         self.autopiloting = False | ||||
|  | ||||
|         self.autopilotButton.clicked.connect(self.toggle_autopilot) | ||||
|         self.autopilotButton.setDisabled(True) | ||||
|  | ||||
|         self.saveRecordButton.clicked.connect(self.save_record) | ||||
|  | ||||
| @@ -203,7 +206,19 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|         self.send_command(RecordingCommand(self.recording)) | ||||
|  | ||||
|     def rollback(self): | ||||
|         pass | ||||
|         rollback_by: int = self.forgetSnapshotNumber.value() | ||||
|         rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1)) | ||||
|  | ||||
|         self.snapshots = self.snapshots[:-rollback_by] | ||||
|         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) | ||||
|  | ||||
|         if len(self.snapshots) == 0: | ||||
|             self.send_command(ResetCommand()) | ||||
|         else: | ||||
|             self.send_command(ApplySnapshotCommand(self.snapshots[-1])) | ||||
|  | ||||
|         if self.recording: | ||||
|             self.toggle_record() | ||||
|  | ||||
|     def toggle_autopilot(self): | ||||
|         self.autopiloting = not self.autopiloting | ||||
| @@ -251,8 +266,15 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|         self.snapshots.append(snapshot) | ||||
|         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) | ||||
|  | ||||
|         if self.autopiloting and self.bot is not None: | ||||
|             self.bot.on_snapshot_received(snapshot) | ||||
|  | ||||
|     def shutdown(self): | ||||
|         self.close_signal.emit() | ||||
|  | ||||
|     def send_command(self, command: Command): | ||||
|         self.send_signal.emit(command) | ||||
|  | ||||
|     def register_bot(self, bot: Bot): | ||||
|         self.bot = bot | ||||
|         self.autopilotButton.setDisabled(False) | ||||
|   | ||||
| @@ -99,7 +99,7 @@ | ||||
|       <item> | ||||
|        <widget class="QPushButton" name="resetButton"> | ||||
|         <property name="text"> | ||||
|          <string>Reset</string> | ||||
|          <string>Rollback</string> | ||||
|         </property> | ||||
|        </widget> | ||||
|       </item> | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| # Form implementation generated from reading ui file 'recorder.ui' | ||||
| # | ||||
| # Created by: PyQt6 UI code generator 6.8.0 | ||||
| # Created by: PyQt6 UI code generator 6.8.1 | ||||
| # | ||||
| # WARNING: Any manual changes made to this file will be lost when pyuic6 is | ||||
| # run again.  Do not edit this file unless you know what you are doing. | ||||
| @@ -102,7 +102,7 @@ class Ui_Recorder(object): | ||||
|         self.recordDataButton.setText(_translate("Recorder", "Record")) | ||||
|         self.saveImgCheckBox.setText(_translate("Recorder", "Imgs")) | ||||
|         self.saveRecordButton.setText(_translate("Recorder", "Save")) | ||||
|         self.resetButton.setText(_translate("Recorder", "Reset")) | ||||
|         self.resetButton.setText(_translate("Recorder", "Rollback")) | ||||
|         self.nbrSnapshotSaved.setText(_translate("Recorder", "0")) | ||||
|         self.autopilotButton.setText(_translate("Recorder", "AutoPilot\n" | ||||
| "OFF")) | ||||
|   | ||||
| @@ -6,7 +6,7 @@ import struct | ||||
| import threading | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||
| from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand | ||||
| from src.snapshot import Snapshot | ||||
| from src.utils import RepeatTimer | ||||
|  | ||||
| @@ -119,6 +119,10 @@ class RemoteController: | ||||
|                 self.set_control(control, active) | ||||
|             case RecordingCommand(state): | ||||
|                 self.recording = state | ||||
|             case ApplySnapshotCommand(snapshot): | ||||
|                 snapshot.apply(self.car) | ||||
|             case ResetCommand(): | ||||
|                 self.car.reset() | ||||
|  | ||||
|     def set_control(self, control: CarControl, active: bool): | ||||
|         setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) | ||||
|   | ||||
| @@ -94,3 +94,8 @@ class Snapshot: | ||||
|             raycast_distances=car.rays.copy(), | ||||
|             image=None | ||||
|         ) | ||||
|  | ||||
|     def apply(self, car: Car): | ||||
|         car.pos = self.position.copy() | ||||
|         car.direction = self.direction.copy() | ||||
|         car.speed = 0 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user