feat: add snapshot recording
This commit is contained in:
		| @@ -8,6 +8,7 @@ from typing import Type | |||||||
|  |  | ||||||
| class CommandType(IntEnum): | class CommandType(IntEnum): | ||||||
|     CAR_CONTROL = 0 |     CAR_CONTROL = 0 | ||||||
|  |     RECORDING = 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| class CarControl(IntEnum): | class CarControl(IntEnum): | ||||||
| @@ -64,3 +65,20 @@ class ControlCommand(Command): | |||||||
|         active: bool = (value & 1) == 1 |         active: bool = (value & 1) == 1 | ||||||
|         control: int = value >> 1 |         control: int = value >> 1 | ||||||
|         return ControlCommand(CarControl(control), active) |         return ControlCommand(CarControl(control), active) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RecordingCommand(Command): | ||||||
|  |     TYPE = CommandType.RECORDING | ||||||
|  |     __match_args__ = ("state",) | ||||||
|  |  | ||||||
|  |     def __init__(self, state: bool) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.state: bool = state | ||||||
|  |  | ||||||
|  |     def get_payload(self) -> bytes: | ||||||
|  |         return struct.pack(">B", self.state) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_payload(cls, payload: bytes) -> Command: | ||||||
|  |         state: bool = struct.unpack(">B", payload)[0] | ||||||
|  |         return RecordingCommand(state) | ||||||
|   | |||||||
| @@ -2,10 +2,10 @@ import socket | |||||||
| import struct | import struct | ||||||
|  |  | ||||||
| from PyQt6 import uic | from PyQt6 import uic | ||||||
| from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot | from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot | ||||||
| from PyQt6.QtWidgets import QMainWindow | from PyQt6.QtWidgets import QMainWindow | ||||||
|  |  | ||||||
| from src.command import CarControl, Command, ControlCommand | from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||||
| from src.recorder_ui import Ui_Recorder | from src.recorder_ui import Ui_Recorder | ||||||
| from src.snapshot import Snapshot | from src.snapshot import Snapshot | ||||||
|  |  | ||||||
| @@ -18,7 +18,8 @@ class RecorderClient(QObject): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.host: str = host |         self.host: str = host | ||||||
|         self.port: int = port |         self.port: int = port | ||||||
|         self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |         self.socket: socket.socket = socket.socket( | ||||||
|  |             socket.AF_INET, socket.SOCK_STREAM) | ||||||
|         self.timer: QTimer = QTimer(self) |         self.timer: QTimer = QTimer(self) | ||||||
|         self.timer.timeout.connect(self.poll_socket) |         self.timer.timeout.connect(self.poll_socket) | ||||||
|         self.connected: bool = False |         self.connected: bool = False | ||||||
| @@ -29,7 +30,7 @@ class RecorderClient(QObject): | |||||||
|         self.socket.setblocking(False) |         self.socket.setblocking(False) | ||||||
|         self.connected = True |         self.connected = True | ||||||
|         self.timer.start(50) |         self.timer.start(50) | ||||||
|         print(f"Connected to server") |         print("Connected to server") | ||||||
|  |  | ||||||
|     def poll_socket(self): |     def poll_socket(self): | ||||||
|         buffer: bytes = b"" |         buffer: bytes = b"" | ||||||
| @@ -155,7 +156,10 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | |||||||
|         self.send_command(ControlCommand(control, active)) |         self.send_command(ControlCommand(control, active)) | ||||||
|  |  | ||||||
|     def toggle_record(self): |     def toggle_record(self): | ||||||
|         pass |         self.recording = not self.recording | ||||||
|  |         self.recordDataButton.setText( | ||||||
|  |             "Recording..." if self.recording else "Record") | ||||||
|  |         self.send_command(RecordingCommand(self.recording)) | ||||||
|  |  | ||||||
|     def rollback(self): |     def rollback(self): | ||||||
|         pass |         pass | ||||||
|   | |||||||
| @@ -6,7 +6,9 @@ import struct | |||||||
| import threading | import threading | ||||||
| from typing import TYPE_CHECKING, Optional | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
| from src.command import CarControl, Command, ControlCommand | from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||||
|  | from src.snapshot import Snapshot | ||||||
|  | from src.utils import RepeatTimer | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from src.car import Car |     from src.car import Car | ||||||
| @@ -23,10 +25,13 @@ class RemoteController: | |||||||
|         CarControl.RIGHT: "right", |         CarControl.RIGHT: "right", | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     SNAPSHOT_INTERVAL = 0.1 | ||||||
|  |  | ||||||
|     def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None: |     def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None: | ||||||
|         self.car: Car = car |         self.car: Car = car | ||||||
|         self.port: int = port |         self.port: int = port | ||||||
|         self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |         self.server: socket.socket = socket.socket( | ||||||
|  |             socket.AF_INET, socket.SOCK_STREAM) | ||||||
|         self.server_thread: threading.Thread = threading.Thread( |         self.server_thread: threading.Thread = threading.Thread( | ||||||
|             target=self.wait_for_connections, daemon=True |             target=self.wait_for_connections, daemon=True | ||||||
|         ) |         ) | ||||||
| @@ -34,6 +39,10 @@ class RemoteController: | |||||||
|         self.queue: queue.Queue[Command] = queue.Queue() |         self.queue: queue.Queue[Command] = queue.Queue() | ||||||
|         self.client_thread: Optional[threading.Thread] = None |         self.client_thread: Optional[threading.Thread] = None | ||||||
|         self.client: Optional[socket.socket] = None |         self.client: Optional[socket.socket] = None | ||||||
|  |         self.snapshot_timer: RepeatTimer = RepeatTimer( | ||||||
|  |             interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot) | ||||||
|  |         self.snapshot_timer.start() | ||||||
|  |         self.recording: bool = False | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def is_connected(self) -> bool: |     def is_connected(self) -> bool: | ||||||
| @@ -56,6 +65,7 @@ class RemoteController: | |||||||
|         if self.client: |         if self.client: | ||||||
|             self.client.close() |             self.client.close() | ||||||
|         self.server.close() |         self.server.close() | ||||||
|  |         self.snapshot_timer.cancel() | ||||||
|         self.running = False |         self.running = False | ||||||
|  |  | ||||||
|     def on_client_connected(self, conn: socket.socket): |     def on_client_connected(self, conn: socket.socket): | ||||||
| @@ -107,6 +117,18 @@ class RemoteController: | |||||||
|         match command: |         match command: | ||||||
|             case ControlCommand(control, active): |             case ControlCommand(control, active): | ||||||
|                 self.set_control(control, active) |                 self.set_control(control, active) | ||||||
|  |             case RecordingCommand(state): | ||||||
|  |                 self.recording = state | ||||||
|  |  | ||||||
|     def set_control(self, control: CarControl, active: bool): |     def set_control(self, control: CarControl, active: bool): | ||||||
|         setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) |         setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) | ||||||
|  |  | ||||||
|  |     def take_snapshot(self): | ||||||
|  |         if self.client is None: | ||||||
|  |             return | ||||||
|  |         if not self.recording: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         snapshot: Snapshot = Snapshot.from_car(self.car) | ||||||
|  |         payload: bytes = snapshot.pack() | ||||||
|  |         self.client.sendall(struct.pack(">I", len(payload)) + payload) | ||||||
|   | |||||||
| @@ -2,12 +2,15 @@ from __future__ import annotations | |||||||
|  |  | ||||||
| import struct | import struct | ||||||
| from dataclasses import dataclass, field | from dataclasses import dataclass, field | ||||||
| from typing import Optional | from typing import TYPE_CHECKING, Optional | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from src.vec import Vec | from src.vec import Vec | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from src.car import Car | ||||||
|  |  | ||||||
|  |  | ||||||
| def iter_unpack(format, data): | def iter_unpack(format, data): | ||||||
|     nbr_bytes = struct.calcsize(format) |     nbr_bytes = struct.calcsize(format) | ||||||
| @@ -20,7 +23,8 @@ class Snapshot: | |||||||
|     position: Vec = field(default_factory=Vec) |     position: Vec = field(default_factory=Vec) | ||||||
|     direction: Vec = field(default_factory=Vec) |     direction: Vec = field(default_factory=Vec) | ||||||
|     speed: float = 0 |     speed: float = 0 | ||||||
|     raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list) |     raycast_distances: list[float] | tuple[float, ...] = field( | ||||||
|  |         default_factory=list) | ||||||
|     image: Optional[np.ndarray] = None |     image: Optional[np.ndarray] = None | ||||||
|  |  | ||||||
|     def pack(self): |     def pack(self): | ||||||
| @@ -36,10 +40,12 @@ class Snapshot: | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         nbr_raycasts: int = len(self.raycast_distances) |         nbr_raycasts: int = len(self.raycast_distances) | ||||||
|         data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances) |         data += struct.pack(f">B{nbr_raycasts}f", | ||||||
|  |                             nbr_raycasts, *self.raycast_distances) | ||||||
|  |  | ||||||
|         if self.image is not None: |         if self.image is not None: | ||||||
|             data += struct.pack(">II", self.image.shape[0], self.image.shape[1]) |             data += struct.pack(">II", | ||||||
|  |                                 self.image.shape[0], self.image.shape[1]) | ||||||
|             data += self.image.tobytes() |             data += self.image.tobytes() | ||||||
|         else: |         else: | ||||||
|             data += struct.pack(">II", 0, 0) |             data += struct.pack(">II", 0, 0) | ||||||
| @@ -72,3 +78,19 @@ class Snapshot: | |||||||
|             raycast_distances=raycast_distances, |             raycast_distances=raycast_distances, | ||||||
|             image=image, |             image=image, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def from_car(car: Car) -> Snapshot: | ||||||
|  |         return Snapshot( | ||||||
|  |             controls=( | ||||||
|  |                 car.forward, | ||||||
|  |                 car.backward, | ||||||
|  |                 car.left, | ||||||
|  |                 car.right | ||||||
|  |             ), | ||||||
|  |             position=car.pos.copy(), | ||||||
|  |             direction=car.direction.copy(), | ||||||
|  |             speed=car.speed, | ||||||
|  |             raycast_distances=car.rays.copy(), | ||||||
|  |             image=None | ||||||
|  |         ) | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								src/utils.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								src/utils.py
									
									
									
									
									
								
							| @@ -1,10 +1,12 @@ | |||||||
| import os | import os | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from threading import Timer | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| from src.vec import Vec | from src.vec import Vec | ||||||
|  |  | ||||||
| ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) | ROOT = Path(os.path.abspath(os.path.join( | ||||||
|  |     os.path.dirname(__file__), os.pardir))) | ||||||
|  |  | ||||||
|  |  | ||||||
| def orientation(a: Vec, b: Vec, c: Vec) -> float: | def orientation(a: Vec, b: Vec, c: Vec) -> float: | ||||||
| @@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve | |||||||
|     if intersection.within(a1, a2) and intersection.within(b1, b2): |     if intersection.within(a1, a2) and intersection.within(b1, b2): | ||||||
|         return intersection |         return intersection | ||||||
|     return None |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RepeatTimer(Timer): | ||||||
|  |     def run(self): | ||||||
|  |         while not self.finished.wait(self.interval): | ||||||
|  |             self.function(*self.args, **self.kwargs) | ||||||
|   | |||||||
| @@ -8,6 +8,9 @@ class Vec: | |||||||
|         self.x: float = x |         self.x: float = x | ||||||
|         self.y: float = y |         self.y: float = y | ||||||
|  |  | ||||||
|  |     def copy(self) -> Vec: | ||||||
|  |         return Vec(self.x, self.y) | ||||||
|  |  | ||||||
|     def __add__(self, other: float | Vec) -> Vec: |     def __add__(self, other: float | Vec) -> Vec: | ||||||
|         if isinstance(other, Vec): |         if isinstance(other, Vec): | ||||||
|             return Vec(self.x + other.x, self.y + other.y) |             return Vec(self.x + other.x, self.y + other.y) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user