Compare commits
4 Commits
8542ee81e7
...
8b7927a3c5
| Author | SHA1 | Date | |
|---|---|---|---|
|
8b7927a3c5
|
|||
|
62de92e7a2
|
|||
|
8ad97785b8
|
|||
|
db112ada4c
|
@@ -8,6 +8,7 @@ from typing import Type
|
|||||||
|
|
||||||
class CommandType(IntEnum):
|
class CommandType(IntEnum):
|
||||||
CAR_CONTROL = 0
|
CAR_CONTROL = 0
|
||||||
|
RECORDING = 1
|
||||||
|
|
||||||
|
|
||||||
class CarControl(IntEnum):
|
class CarControl(IntEnum):
|
||||||
@@ -64,3 +65,20 @@ class ControlCommand(Command):
|
|||||||
active: bool = (value & 1) == 1
|
active: bool = (value & 1) == 1
|
||||||
control: int = value >> 1
|
control: int = value >> 1
|
||||||
return ControlCommand(CarControl(control), active)
|
return ControlCommand(CarControl(control), active)
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingCommand(Command):
|
||||||
|
TYPE = CommandType.RECORDING
|
||||||
|
__match_args__ = ("state",)
|
||||||
|
|
||||||
|
def __init__(self, state: bool) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.state: bool = state
|
||||||
|
|
||||||
|
def get_payload(self) -> bytes:
|
||||||
|
return struct.pack(">B", self.state)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_payload(cls, payload: bytes) -> Command:
|
||||||
|
state: bool = struct.unpack(">B", payload)[0]
|
||||||
|
return RecordingCommand(state)
|
||||||
|
|||||||
49
src/record_file.py
Normal file
49
src/record_file.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from src.snapshot import Snapshot
|
||||||
|
|
||||||
|
|
||||||
|
class RecordFile:
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
def __init__(self, path: str | Path, mode: Literal["w", "r"]) -> None:
|
||||||
|
self.path: str | Path = path
|
||||||
|
self.mode: Literal["w", "r"] = mode
|
||||||
|
self.file = open(self.path, self.mode + "b")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.file.close()
|
||||||
|
|
||||||
|
def write_header(self, n_snapshots: int):
|
||||||
|
data: bytes = struct.pack(
|
||||||
|
">IId", self.VERSION, n_snapshots, time.time())
|
||||||
|
self.file.write(data)
|
||||||
|
|
||||||
|
def write_snapshots(self, snapshots: list[Snapshot]):
|
||||||
|
self.write_header(len(snapshots))
|
||||||
|
for snapshot in snapshots:
|
||||||
|
data: bytes = snapshot.pack()
|
||||||
|
self.file.write(struct.pack(">I", len(data)) + data)
|
||||||
|
|
||||||
|
def read_snapshots(self) -> list[Snapshot]:
|
||||||
|
version: int = struct.unpack(">I", self.file.read(4))[0]
|
||||||
|
if version != self.VERSION:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot parse record file with format version {version} (current version: {self.VERSION})")
|
||||||
|
|
||||||
|
n_snapshots: int
|
||||||
|
timestamp: float
|
||||||
|
n_snapshots, timestamp = struct.unpack(">Id", self.file.read(12))
|
||||||
|
snapshots: list[Snapshot] = []
|
||||||
|
|
||||||
|
for _ in range(n_snapshots):
|
||||||
|
size: int = struct.unpack(">I", self.file.read(4))[0]
|
||||||
|
snapshots.append(Snapshot.unpack(self.file.read(size)))
|
||||||
|
|
||||||
|
return snapshots
|
||||||
115
src/recorder.py
115
src/recorder.py
@@ -1,11 +1,16 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from PyQt6 import uic
|
from PyQt6 import uic
|
||||||
from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot
|
from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
|
||||||
|
from PyQt6.QtGui import QKeyEvent
|
||||||
from PyQt6.QtWidgets import QMainWindow
|
from PyQt6.QtWidgets import QMainWindow
|
||||||
|
|
||||||
from src.command import CarControl, Command, ControlCommand
|
from src.command import CarControl, Command, ControlCommand, RecordingCommand
|
||||||
|
from src.record_file import RecordFile
|
||||||
from src.recorder_ui import Ui_Recorder
|
from src.recorder_ui import Ui_Recorder
|
||||||
from src.snapshot import Snapshot
|
from src.snapshot import Snapshot
|
||||||
|
|
||||||
@@ -18,9 +23,9 @@ class RecorderClient(QObject):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.host: str = host
|
self.host: str = host
|
||||||
self.port: int = port
|
self.port: int = port
|
||||||
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.socket: socket.socket = socket.socket(
|
||||||
self.timer: QTimer = QTimer(self)
|
socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.timer.timeout.connect(self.poll_socket)
|
self.timer: Optional[QTimer] = None
|
||||||
self.connected: bool = False
|
self.connected: bool = False
|
||||||
|
|
||||||
@pyqtSlot()
|
@pyqtSlot()
|
||||||
@@ -28,8 +33,10 @@ class RecorderClient(QObject):
|
|||||||
self.socket.connect((self.host, self.port))
|
self.socket.connect((self.host, self.port))
|
||||||
self.socket.setblocking(False)
|
self.socket.setblocking(False)
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
self.timer = QTimer(self)
|
||||||
|
self.timer.timeout.connect(self.poll_socket)
|
||||||
self.timer.start(50)
|
self.timer.start(50)
|
||||||
print(f"Connected to server")
|
print("Connected to server")
|
||||||
|
|
||||||
def poll_socket(self):
|
def poll_socket(self):
|
||||||
buffer: bytes = b""
|
buffer: bytes = b""
|
||||||
@@ -78,15 +85,37 @@ class RecorderClient(QObject):
|
|||||||
@pyqtSlot()
|
@pyqtSlot()
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
print("Shutting down client")
|
print("Shutting down client")
|
||||||
self.timer.stop()
|
if self.timer is not None:
|
||||||
|
self.timer.stop()
|
||||||
|
self.timer = None
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedSaver(QThread):
|
||||||
|
def __init__(self, path: str | Path, snapshots: list[Snapshot]):
|
||||||
|
super().__init__()
|
||||||
|
self.path: str | Path = path
|
||||||
|
self.snapshots: list[Snapshot] = snapshots
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
with RecordFile(self.path, "w") as f:
|
||||||
|
f.write_snapshots(self.snapshots)
|
||||||
|
|
||||||
|
|
||||||
class RecorderWindow(Ui_Recorder, QMainWindow):
|
class RecorderWindow(Ui_Recorder, QMainWindow):
|
||||||
close_signal: pyqtSignal = pyqtSignal()
|
close_signal: pyqtSignal = pyqtSignal()
|
||||||
send_signal: pyqtSignal = pyqtSignal(object)
|
send_signal: pyqtSignal = pyqtSignal(object)
|
||||||
|
|
||||||
|
SAVE_DIR: Path = Path(__file__).parent.parent / "records"
|
||||||
|
|
||||||
|
COMMAND_DIRECTIONS: dict[str, CarControl] = {
|
||||||
|
"w": CarControl.FORWARD,
|
||||||
|
"s": CarControl.BACKWARD,
|
||||||
|
"d": CarControl.RIGHT,
|
||||||
|
"a": CarControl.LEFT,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, host: str, port: int) -> None:
|
def __init__(self, host: str, port: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -102,13 +131,6 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
|||||||
|
|
||||||
uic.load_ui.loadUi("src/recorder.ui", self)
|
uic.load_ui.loadUi("src/recorder.ui", self)
|
||||||
|
|
||||||
self.command_directions = {
|
|
||||||
"w": CarControl.FORWARD,
|
|
||||||
"s": CarControl.BACKWARD,
|
|
||||||
"d": CarControl.RIGHT,
|
|
||||||
"a": CarControl.LEFT,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.forwardButton.pressed.connect(
|
self.forwardButton.pressed.connect(
|
||||||
lambda: self.on_car_controlled(CarControl.FORWARD, True)
|
lambda: self.on_car_controlled(CarControl.FORWARD, True)
|
||||||
)
|
)
|
||||||
@@ -146,16 +168,39 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
|||||||
|
|
||||||
self.saveRecordButton.clicked.connect(self.save_record)
|
self.saveRecordButton.clicked.connect(self.save_record)
|
||||||
|
|
||||||
|
self.saving_worker: Optional[ThreadedSaver] = None
|
||||||
self.recording = False
|
self.recording = False
|
||||||
|
|
||||||
self.recorded_data = []
|
self.snapshots: list[Snapshot] = []
|
||||||
self.client_thread.start()
|
self.client_thread.start()
|
||||||
|
|
||||||
def on_car_controlled(self, control: CarControl, active: bool):
|
def on_car_controlled(self, control: CarControl, active: bool):
|
||||||
self.send_command(ControlCommand(control, active))
|
self.send_command(ControlCommand(control, active))
|
||||||
|
|
||||||
|
def keyPressEvent(self, event): # type: ignore
|
||||||
|
if event.isAutoRepeat():
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(event, QKeyEvent):
|
||||||
|
key_text = event.text()
|
||||||
|
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
|
||||||
|
if ctrl is not None:
|
||||||
|
self.on_car_controlled(ctrl, True)
|
||||||
|
|
||||||
|
def keyReleaseEvent(self, event): # type: ignore
|
||||||
|
if event.isAutoRepeat():
|
||||||
|
return
|
||||||
|
if isinstance(event, QKeyEvent):
|
||||||
|
key_text = event.text()
|
||||||
|
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
|
||||||
|
if ctrl is not None:
|
||||||
|
self.on_car_controlled(ctrl, False)
|
||||||
|
|
||||||
def toggle_record(self):
|
def toggle_record(self):
|
||||||
pass
|
self.recording = not self.recording
|
||||||
|
self.recordDataButton.setText(
|
||||||
|
"Recording..." if self.recording else "Record")
|
||||||
|
self.send_command(RecordingCommand(self.recording))
|
||||||
|
|
||||||
def rollback(self):
|
def rollback(self):
|
||||||
pass
|
pass
|
||||||
@@ -167,12 +212,44 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def save_record(self):
|
def save_record(self):
|
||||||
pass
|
if self.saving_worker is not None:
|
||||||
|
print("Already saving !")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.snapshots) == 0:
|
||||||
|
print("No data to save !")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.recording:
|
||||||
|
self.toggle_record()
|
||||||
|
|
||||||
|
self.saveRecordButton.setText("Saving ...")
|
||||||
|
|
||||||
|
self.SAVE_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
record_name: str = "record_%d.rec"
|
||||||
|
fid = 0
|
||||||
|
while os.path.exists(self.SAVE_DIR / (record_name % fid)):
|
||||||
|
fid += 1
|
||||||
|
|
||||||
|
self.saving_worker = ThreadedSaver(
|
||||||
|
self.SAVE_DIR / (record_name % fid), self.snapshots)
|
||||||
|
self.snapshots = []
|
||||||
|
self.nbrSnapshotSaved.setText("0")
|
||||||
|
self.saving_worker.finished.connect(self.on_record_save_done)
|
||||||
|
self.saving_worker.start()
|
||||||
|
|
||||||
|
def on_record_save_done(self):
|
||||||
|
if self.saving_worker is None:
|
||||||
|
return
|
||||||
|
print("Recorded data saved to", self.saving_worker.path)
|
||||||
|
self.saving_worker = None
|
||||||
|
self.saveRecordButton.setText("Save")
|
||||||
|
|
||||||
@pyqtSlot(Snapshot)
|
@pyqtSlot(Snapshot)
|
||||||
def on_snapshot_received(self, snapshot: Snapshot):
|
def on_snapshot_received(self, snapshot: Snapshot):
|
||||||
self.recorded_data.append(snapshot)
|
self.snapshots.append(snapshot)
|
||||||
self.nbrSnapshotSaved.setText(str(len(self.recorded_data)))
|
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.close_signal.emit()
|
self.close_signal.emit()
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import struct
|
|||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from src.command import CarControl, Command, ControlCommand
|
from src.command import CarControl, Command, ControlCommand, RecordingCommand
|
||||||
|
from src.snapshot import Snapshot
|
||||||
|
from src.utils import RepeatTimer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.car import Car
|
from src.car import Car
|
||||||
@@ -23,10 +25,13 @@ class RemoteController:
|
|||||||
CarControl.RIGHT: "right",
|
CarControl.RIGHT: "right",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SNAPSHOT_INTERVAL = 0.1
|
||||||
|
|
||||||
def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None:
|
def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None:
|
||||||
self.car: Car = car
|
self.car: Car = car
|
||||||
self.port: int = port
|
self.port: int = port
|
||||||
self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.server: socket.socket = socket.socket(
|
||||||
|
socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.server_thread: threading.Thread = threading.Thread(
|
self.server_thread: threading.Thread = threading.Thread(
|
||||||
target=self.wait_for_connections, daemon=True
|
target=self.wait_for_connections, daemon=True
|
||||||
)
|
)
|
||||||
@@ -34,6 +39,10 @@ class RemoteController:
|
|||||||
self.queue: queue.Queue[Command] = queue.Queue()
|
self.queue: queue.Queue[Command] = queue.Queue()
|
||||||
self.client_thread: Optional[threading.Thread] = None
|
self.client_thread: Optional[threading.Thread] = None
|
||||||
self.client: Optional[socket.socket] = None
|
self.client: Optional[socket.socket] = None
|
||||||
|
self.snapshot_timer: RepeatTimer = RepeatTimer(
|
||||||
|
interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot)
|
||||||
|
self.snapshot_timer.start()
|
||||||
|
self.recording: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
@@ -56,6 +65,7 @@ class RemoteController:
|
|||||||
if self.client:
|
if self.client:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
self.server.close()
|
self.server.close()
|
||||||
|
self.snapshot_timer.cancel()
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def on_client_connected(self, conn: socket.socket):
|
def on_client_connected(self, conn: socket.socket):
|
||||||
@@ -107,6 +117,18 @@ class RemoteController:
|
|||||||
match command:
|
match command:
|
||||||
case ControlCommand(control, active):
|
case ControlCommand(control, active):
|
||||||
self.set_control(control, active)
|
self.set_control(control, active)
|
||||||
|
case RecordingCommand(state):
|
||||||
|
self.recording = state
|
||||||
|
|
||||||
def set_control(self, control: CarControl, active: bool):
|
def set_control(self, control: CarControl, active: bool):
|
||||||
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)
|
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)
|
||||||
|
|
||||||
|
def take_snapshot(self):
|
||||||
|
if self.client is None:
|
||||||
|
return
|
||||||
|
if not self.recording:
|
||||||
|
return
|
||||||
|
|
||||||
|
snapshot: Snapshot = Snapshot.from_car(self.car)
|
||||||
|
payload: bytes = snapshot.pack()
|
||||||
|
self.client.sendall(struct.pack(">I", len(payload)) + payload)
|
||||||
|
|||||||
@@ -2,12 +2,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import struct
|
import struct
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.vec import Vec
|
from src.vec import Vec
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.car import Car
|
||||||
|
|
||||||
|
|
||||||
def iter_unpack(format, data):
|
def iter_unpack(format, data):
|
||||||
nbr_bytes = struct.calcsize(format)
|
nbr_bytes = struct.calcsize(format)
|
||||||
@@ -20,7 +23,8 @@ class Snapshot:
|
|||||||
position: Vec = field(default_factory=Vec)
|
position: Vec = field(default_factory=Vec)
|
||||||
direction: Vec = field(default_factory=Vec)
|
direction: Vec = field(default_factory=Vec)
|
||||||
speed: float = 0
|
speed: float = 0
|
||||||
raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list)
|
raycast_distances: list[float] | tuple[float, ...] = field(
|
||||||
|
default_factory=list)
|
||||||
image: Optional[np.ndarray] = None
|
image: Optional[np.ndarray] = None
|
||||||
|
|
||||||
def pack(self):
|
def pack(self):
|
||||||
@@ -36,10 +40,12 @@ class Snapshot:
|
|||||||
)
|
)
|
||||||
|
|
||||||
nbr_raycasts: int = len(self.raycast_distances)
|
nbr_raycasts: int = len(self.raycast_distances)
|
||||||
data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances)
|
data += struct.pack(f">B{nbr_raycasts}f",
|
||||||
|
nbr_raycasts, *self.raycast_distances)
|
||||||
|
|
||||||
if self.image is not None:
|
if self.image is not None:
|
||||||
data += struct.pack(">II", self.image.shape[0], self.image.shape[1])
|
data += struct.pack(">II",
|
||||||
|
self.image.shape[0], self.image.shape[1])
|
||||||
data += self.image.tobytes()
|
data += self.image.tobytes()
|
||||||
else:
|
else:
|
||||||
data += struct.pack(">II", 0, 0)
|
data += struct.pack(">II", 0, 0)
|
||||||
@@ -72,3 +78,19 @@ class Snapshot:
|
|||||||
raycast_distances=raycast_distances,
|
raycast_distances=raycast_distances,
|
||||||
image=image,
|
image=image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_car(car: Car) -> Snapshot:
|
||||||
|
return Snapshot(
|
||||||
|
controls=(
|
||||||
|
car.forward,
|
||||||
|
car.backward,
|
||||||
|
car.left,
|
||||||
|
car.right
|
||||||
|
),
|
||||||
|
position=car.pos.copy(),
|
||||||
|
direction=car.direction.copy(),
|
||||||
|
speed=car.speed,
|
||||||
|
raycast_distances=car.rays.copy(),
|
||||||
|
image=None
|
||||||
|
)
|
||||||
|
|||||||
10
src/utils.py
10
src/utils.py
@@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Timer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from src.vec import Vec
|
from src.vec import Vec
|
||||||
|
|
||||||
ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)))
|
ROOT = Path(os.path.abspath(os.path.join(
|
||||||
|
os.path.dirname(__file__), os.pardir)))
|
||||||
|
|
||||||
|
|
||||||
def orientation(a: Vec, b: Vec, c: Vec) -> float:
|
def orientation(a: Vec, b: Vec, c: Vec) -> float:
|
||||||
@@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve
|
|||||||
if intersection.within(a1, a2) and intersection.within(b1, b2):
|
if intersection.within(a1, a2) and intersection.within(b1, b2):
|
||||||
return intersection
|
return intersection
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RepeatTimer(Timer):
|
||||||
|
def run(self):
|
||||||
|
while not self.finished.wait(self.interval):
|
||||||
|
self.function(*self.args, **self.kwargs)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ class Vec:
|
|||||||
self.x: float = x
|
self.x: float = x
|
||||||
self.y: float = y
|
self.y: float = y
|
||||||
|
|
||||||
|
def copy(self) -> Vec:
|
||||||
|
return Vec(self.x, self.y)
|
||||||
|
|
||||||
def __add__(self, other: float | Vec) -> Vec:
|
def __add__(self, other: float | Vec) -> Vec:
|
||||||
if isinstance(other, Vec):
|
if isinstance(other, Vec):
|
||||||
return Vec(self.x + other.x, self.y + other.y)
|
return Vec(self.x + other.x, self.y + other.y)
|
||||||
|
|||||||
Reference in New Issue
Block a user