Compare commits
4 Commits
8542ee81e7
...
8b7927a3c5
| Author | SHA1 | Date | |
|---|---|---|---|
|
8b7927a3c5
|
|||
|
62de92e7a2
|
|||
|
8ad97785b8
|
|||
|
db112ada4c
|
@@ -8,6 +8,7 @@ from typing import Type
|
||||
|
||||
class CommandType(IntEnum):
|
||||
CAR_CONTROL = 0
|
||||
RECORDING = 1
|
||||
|
||||
|
||||
class CarControl(IntEnum):
|
||||
@@ -64,3 +65,20 @@ class ControlCommand(Command):
|
||||
active: bool = (value & 1) == 1
|
||||
control: int = value >> 1
|
||||
return ControlCommand(CarControl(control), active)
|
||||
|
||||
|
||||
class RecordingCommand(Command):
|
||||
TYPE = CommandType.RECORDING
|
||||
__match_args__ = ("state",)
|
||||
|
||||
def __init__(self, state: bool) -> None:
|
||||
super().__init__()
|
||||
self.state: bool = state
|
||||
|
||||
def get_payload(self) -> bytes:
|
||||
return struct.pack(">B", self.state)
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: bytes) -> Command:
|
||||
state: bool = struct.unpack(">B", payload)[0]
|
||||
return RecordingCommand(state)
|
||||
|
||||
49
src/record_file.py
Normal file
49
src/record_file.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
import struct
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from src.snapshot import Snapshot
|
||||
|
||||
|
||||
class RecordFile:
|
||||
VERSION = 1
|
||||
|
||||
def __init__(self, path: str | Path, mode: Literal["w", "r"]) -> None:
|
||||
self.path: str | Path = path
|
||||
self.mode: Literal["w", "r"] = mode
|
||||
self.file = open(self.path, self.mode + "b")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.file.close()
|
||||
|
||||
def write_header(self, n_snapshots: int):
|
||||
data: bytes = struct.pack(
|
||||
">IId", self.VERSION, n_snapshots, time.time())
|
||||
self.file.write(data)
|
||||
|
||||
def write_snapshots(self, snapshots: list[Snapshot]):
|
||||
self.write_header(len(snapshots))
|
||||
for snapshot in snapshots:
|
||||
data: bytes = snapshot.pack()
|
||||
self.file.write(struct.pack(">I", len(data)) + data)
|
||||
|
||||
def read_snapshots(self) -> list[Snapshot]:
|
||||
version: int = struct.unpack(">I", self.file.read(4))[0]
|
||||
if version != self.VERSION:
|
||||
raise ValueError(
|
||||
f"Cannot parse record file with format version {version} (current version: {self.VERSION})")
|
||||
|
||||
n_snapshots: int
|
||||
timestamp: float
|
||||
n_snapshots, timestamp = struct.unpack(">Id", self.file.read(12))
|
||||
snapshots: list[Snapshot] = []
|
||||
|
||||
for _ in range(n_snapshots):
|
||||
size: int = struct.unpack(">I", self.file.read(4))[0]
|
||||
snapshots.append(Snapshot.unpack(self.file.read(size)))
|
||||
|
||||
return snapshots
|
||||
113
src/recorder.py
113
src/recorder.py
@@ -1,11 +1,16 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import socket
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
from PyQt6 import uic
|
||||
from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot
|
||||
from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
|
||||
from PyQt6.QtGui import QKeyEvent
|
||||
from PyQt6.QtWidgets import QMainWindow
|
||||
|
||||
from src.command import CarControl, Command, ControlCommand
|
||||
from src.command import CarControl, Command, ControlCommand, RecordingCommand
|
||||
from src.record_file import RecordFile
|
||||
from src.recorder_ui import Ui_Recorder
|
||||
from src.snapshot import Snapshot
|
||||
|
||||
@@ -18,9 +23,9 @@ class RecorderClient(QObject):
|
||||
super().__init__()
|
||||
self.host: str = host
|
||||
self.port: int = port
|
||||
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.timer: QTimer = QTimer(self)
|
||||
self.timer.timeout.connect(self.poll_socket)
|
||||
self.socket: socket.socket = socket.socket(
|
||||
socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.timer: Optional[QTimer] = None
|
||||
self.connected: bool = False
|
||||
|
||||
@pyqtSlot()
|
||||
@@ -28,8 +33,10 @@ class RecorderClient(QObject):
|
||||
self.socket.connect((self.host, self.port))
|
||||
self.socket.setblocking(False)
|
||||
self.connected = True
|
||||
self.timer = QTimer(self)
|
||||
self.timer.timeout.connect(self.poll_socket)
|
||||
self.timer.start(50)
|
||||
print(f"Connected to server")
|
||||
print("Connected to server")
|
||||
|
||||
def poll_socket(self):
|
||||
buffer: bytes = b""
|
||||
@@ -78,15 +85,37 @@ class RecorderClient(QObject):
|
||||
@pyqtSlot()
|
||||
def shutdown(self):
|
||||
print("Shutting down client")
|
||||
if self.timer is not None:
|
||||
self.timer.stop()
|
||||
self.timer = None
|
||||
self.connected = False
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class ThreadedSaver(QThread):
|
||||
def __init__(self, path: str | Path, snapshots: list[Snapshot]):
|
||||
super().__init__()
|
||||
self.path: str | Path = path
|
||||
self.snapshots: list[Snapshot] = snapshots
|
||||
|
||||
def run(self):
|
||||
with RecordFile(self.path, "w") as f:
|
||||
f.write_snapshots(self.snapshots)
|
||||
|
||||
|
||||
class RecorderWindow(Ui_Recorder, QMainWindow):
|
||||
close_signal: pyqtSignal = pyqtSignal()
|
||||
send_signal: pyqtSignal = pyqtSignal(object)
|
||||
|
||||
SAVE_DIR: Path = Path(__file__).parent.parent / "records"
|
||||
|
||||
COMMAND_DIRECTIONS: dict[str, CarControl] = {
|
||||
"w": CarControl.FORWARD,
|
||||
"s": CarControl.BACKWARD,
|
||||
"d": CarControl.RIGHT,
|
||||
"a": CarControl.LEFT,
|
||||
}
|
||||
|
||||
def __init__(self, host: str, port: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -102,13 +131,6 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
||||
|
||||
uic.load_ui.loadUi("src/recorder.ui", self)
|
||||
|
||||
self.command_directions = {
|
||||
"w": CarControl.FORWARD,
|
||||
"s": CarControl.BACKWARD,
|
||||
"d": CarControl.RIGHT,
|
||||
"a": CarControl.LEFT,
|
||||
}
|
||||
|
||||
self.forwardButton.pressed.connect(
|
||||
lambda: self.on_car_controlled(CarControl.FORWARD, True)
|
||||
)
|
||||
@@ -146,16 +168,39 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
||||
|
||||
self.saveRecordButton.clicked.connect(self.save_record)
|
||||
|
||||
self.saving_worker: Optional[ThreadedSaver] = None
|
||||
self.recording = False
|
||||
|
||||
self.recorded_data = []
|
||||
self.snapshots: list[Snapshot] = []
|
||||
self.client_thread.start()
|
||||
|
||||
def on_car_controlled(self, control: CarControl, active: bool):
|
||||
self.send_command(ControlCommand(control, active))
|
||||
|
||||
def keyPressEvent(self, event): # type: ignore
|
||||
if event.isAutoRepeat():
|
||||
return
|
||||
|
||||
if isinstance(event, QKeyEvent):
|
||||
key_text = event.text()
|
||||
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
|
||||
if ctrl is not None:
|
||||
self.on_car_controlled(ctrl, True)
|
||||
|
||||
def keyReleaseEvent(self, event): # type: ignore
|
||||
if event.isAutoRepeat():
|
||||
return
|
||||
if isinstance(event, QKeyEvent):
|
||||
key_text = event.text()
|
||||
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
|
||||
if ctrl is not None:
|
||||
self.on_car_controlled(ctrl, False)
|
||||
|
||||
def toggle_record(self):
|
||||
pass
|
||||
self.recording = not self.recording
|
||||
self.recordDataButton.setText(
|
||||
"Recording..." if self.recording else "Record")
|
||||
self.send_command(RecordingCommand(self.recording))
|
||||
|
||||
def rollback(self):
|
||||
pass
|
||||
@@ -167,12 +212,44 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
||||
)
|
||||
|
||||
def save_record(self):
|
||||
pass
|
||||
if self.saving_worker is not None:
|
||||
print("Already saving !")
|
||||
return
|
||||
|
||||
if len(self.snapshots) == 0:
|
||||
print("No data to save !")
|
||||
return
|
||||
|
||||
if self.recording:
|
||||
self.toggle_record()
|
||||
|
||||
self.saveRecordButton.setText("Saving ...")
|
||||
|
||||
self.SAVE_DIR.mkdir(exist_ok=True)
|
||||
|
||||
record_name: str = "record_%d.rec"
|
||||
fid = 0
|
||||
while os.path.exists(self.SAVE_DIR / (record_name % fid)):
|
||||
fid += 1
|
||||
|
||||
self.saving_worker = ThreadedSaver(
|
||||
self.SAVE_DIR / (record_name % fid), self.snapshots)
|
||||
self.snapshots = []
|
||||
self.nbrSnapshotSaved.setText("0")
|
||||
self.saving_worker.finished.connect(self.on_record_save_done)
|
||||
self.saving_worker.start()
|
||||
|
||||
def on_record_save_done(self):
|
||||
if self.saving_worker is None:
|
||||
return
|
||||
print("Recorded data saved to", self.saving_worker.path)
|
||||
self.saving_worker = None
|
||||
self.saveRecordButton.setText("Save")
|
||||
|
||||
@pyqtSlot(Snapshot)
|
||||
def on_snapshot_received(self, snapshot: Snapshot):
|
||||
self.recorded_data.append(snapshot)
|
||||
self.nbrSnapshotSaved.setText(str(len(self.recorded_data)))
|
||||
self.snapshots.append(snapshot)
|
||||
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
|
||||
|
||||
def shutdown(self):
|
||||
self.close_signal.emit()
|
||||
|
||||
@@ -6,7 +6,9 @@ import struct
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.command import CarControl, Command, ControlCommand
|
||||
from src.command import CarControl, Command, ControlCommand, RecordingCommand
|
||||
from src.snapshot import Snapshot
|
||||
from src.utils import RepeatTimer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.car import Car
|
||||
@@ -23,10 +25,13 @@ class RemoteController:
|
||||
CarControl.RIGHT: "right",
|
||||
}
|
||||
|
||||
SNAPSHOT_INTERVAL = 0.1
|
||||
|
||||
def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None:
|
||||
self.car: Car = car
|
||||
self.port: int = port
|
||||
self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.server: socket.socket = socket.socket(
|
||||
socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.server_thread: threading.Thread = threading.Thread(
|
||||
target=self.wait_for_connections, daemon=True
|
||||
)
|
||||
@@ -34,6 +39,10 @@ class RemoteController:
|
||||
self.queue: queue.Queue[Command] = queue.Queue()
|
||||
self.client_thread: Optional[threading.Thread] = None
|
||||
self.client: Optional[socket.socket] = None
|
||||
self.snapshot_timer: RepeatTimer = RepeatTimer(
|
||||
interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot)
|
||||
self.snapshot_timer.start()
|
||||
self.recording: bool = False
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -56,6 +65,7 @@ class RemoteController:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.server.close()
|
||||
self.snapshot_timer.cancel()
|
||||
self.running = False
|
||||
|
||||
def on_client_connected(self, conn: socket.socket):
|
||||
@@ -107,6 +117,18 @@ class RemoteController:
|
||||
match command:
|
||||
case ControlCommand(control, active):
|
||||
self.set_control(control, active)
|
||||
case RecordingCommand(state):
|
||||
self.recording = state
|
||||
|
||||
def set_control(self, control: CarControl, active: bool):
|
||||
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)
|
||||
|
||||
def take_snapshot(self):
|
||||
if self.client is None:
|
||||
return
|
||||
if not self.recording:
|
||||
return
|
||||
|
||||
snapshot: Snapshot = Snapshot.from_car(self.car)
|
||||
payload: bytes = snapshot.pack()
|
||||
self.client.sendall(struct.pack(">I", len(payload)) + payload)
|
||||
|
||||
@@ -2,12 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.vec import Vec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.car import Car
|
||||
|
||||
|
||||
def iter_unpack(format, data):
|
||||
nbr_bytes = struct.calcsize(format)
|
||||
@@ -20,7 +23,8 @@ class Snapshot:
|
||||
position: Vec = field(default_factory=Vec)
|
||||
direction: Vec = field(default_factory=Vec)
|
||||
speed: float = 0
|
||||
raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list)
|
||||
raycast_distances: list[float] | tuple[float, ...] = field(
|
||||
default_factory=list)
|
||||
image: Optional[np.ndarray] = None
|
||||
|
||||
def pack(self):
|
||||
@@ -36,10 +40,12 @@ class Snapshot:
|
||||
)
|
||||
|
||||
nbr_raycasts: int = len(self.raycast_distances)
|
||||
data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances)
|
||||
data += struct.pack(f">B{nbr_raycasts}f",
|
||||
nbr_raycasts, *self.raycast_distances)
|
||||
|
||||
if self.image is not None:
|
||||
data += struct.pack(">II", self.image.shape[0], self.image.shape[1])
|
||||
data += struct.pack(">II",
|
||||
self.image.shape[0], self.image.shape[1])
|
||||
data += self.image.tobytes()
|
||||
else:
|
||||
data += struct.pack(">II", 0, 0)
|
||||
@@ -72,3 +78,19 @@ class Snapshot:
|
||||
raycast_distances=raycast_distances,
|
||||
image=image,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_car(car: Car) -> Snapshot:
|
||||
return Snapshot(
|
||||
controls=(
|
||||
car.forward,
|
||||
car.backward,
|
||||
car.left,
|
||||
car.right
|
||||
),
|
||||
position=car.pos.copy(),
|
||||
direction=car.direction.copy(),
|
||||
speed=car.speed,
|
||||
raycast_distances=car.rays.copy(),
|
||||
image=None
|
||||
)
|
||||
|
||||
10
src/utils.py
10
src/utils.py
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from threading import Timer
|
||||
from typing import Optional
|
||||
|
||||
from src.vec import Vec
|
||||
|
||||
ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)))
|
||||
ROOT = Path(os.path.abspath(os.path.join(
|
||||
os.path.dirname(__file__), os.pardir)))
|
||||
|
||||
|
||||
def orientation(a: Vec, b: Vec, c: Vec) -> float:
|
||||
@@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve
|
||||
if intersection.within(a1, a2) and intersection.within(b1, b2):
|
||||
return intersection
|
||||
return None
|
||||
|
||||
|
||||
class RepeatTimer(Timer):
|
||||
def run(self):
|
||||
while not self.finished.wait(self.interval):
|
||||
self.function(*self.args, **self.kwargs)
|
||||
|
||||
@@ -8,6 +8,9 @@ class Vec:
|
||||
self.x: float = x
|
||||
self.y: float = y
|
||||
|
||||
def copy(self) -> Vec:
|
||||
return Vec(self.x, self.y)
|
||||
|
||||
def __add__(self, other: float | Vec) -> Vec:
|
||||
if isinstance(other, Vec):
|
||||
return Vec(self.x + other.x, self.y + other.y)
|
||||
|
||||
Reference in New Issue
Block a user