282 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			282 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| from pathlib import Path
 | |
| import socket
 | |
| import struct
 | |
| from typing import Optional
 | |
| 
 | |
| from PyQt6 import uic
 | |
| from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
 | |
| from PyQt6.QtGui import QKeyEvent
 | |
| from PyQt6.QtWidgets import QMainWindow
 | |
| 
 | |
| from src.bot import Bot
 | |
| from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
 | |
| from src.record_file import RecordFile
 | |
| from src.recorder_ui import Ui_Recorder
 | |
| from src.snapshot import Snapshot
 | |
| 
 | |
| 
 | |
| class RecorderClient(QObject):
 | |
|     DATA_CHUNK_SIZE = 4096
 | |
|     data_received: pyqtSignal = pyqtSignal(Snapshot)
 | |
| 
 | |
|     def __init__(self, host: str, port: int) -> None:
 | |
|         super().__init__()
 | |
|         self.host: str = host
 | |
|         self.port: int = port
 | |
|         self.socket: socket.socket = socket.socket(
 | |
|             socket.AF_INET, socket.SOCK_STREAM)
 | |
|         self.timer: Optional[QTimer] = None
 | |
|         self.connected: bool = False
 | |
|         self.buffer: bytes = b""
 | |
| 
 | |
|     @pyqtSlot()
 | |
|     def start(self):
 | |
|         self.socket.connect((self.host, self.port))
 | |
|         self.socket.setblocking(False)
 | |
|         self.connected = True
 | |
|         self.timer = QTimer(self)
 | |
|         self.timer.timeout.connect(self.poll_socket)
 | |
|         self.timer.start(50)
 | |
|         print("Connected to server")
 | |
| 
 | |
|     def poll_socket(self):
 | |
|         if not self.connected:
 | |
|             return
 | |
| 
 | |
|         try:
 | |
|             while True:
 | |
|                 chunk: bytes = self.socket.recv(self.DATA_CHUNK_SIZE)
 | |
|                 if not chunk:
 | |
|                     return
 | |
|                 self.buffer += chunk
 | |
| 
 | |
|                 while True:
 | |
|                     if len(self.buffer) < 4:
 | |
|                         break
 | |
|                     msg_len: int = struct.unpack(">I", self.buffer[:4])[0]
 | |
|                     msg_end: int = 4 + msg_len
 | |
|                     if len(self.buffer) < msg_end:
 | |
|                         break
 | |
| 
 | |
|                     message: bytes = self.buffer[4:msg_end]
 | |
|                     self.buffer = self.buffer[msg_end:]
 | |
|                     self.on_message(message)
 | |
|         except BlockingIOError:
 | |
|             pass
 | |
|         except Exception as e:
 | |
|             print(f"Socket error: {e}")
 | |
|             self.shutdown()
 | |
| 
 | |
|     def on_message(self, message: bytes):
 | |
|         snapshot: Snapshot = Snapshot.unpack(message)
 | |
|         self.data_received.emit(snapshot)
 | |
| 
 | |
|     @pyqtSlot(object)
 | |
|     def send_command(self, command):
 | |
|         if self.connected:
 | |
|             try:
 | |
|                 payload: bytes = command.pack()
 | |
|                 self.socket.sendall(struct.pack(">I", len(payload)) + payload)
 | |
|             except Exception as e:
 | |
|                 print(f"An exception occured: {e}")
 | |
|                 self.shutdown()
 | |
|         else:
 | |
|             print("Not connected")
 | |
| 
 | |
|     @pyqtSlot()
 | |
|     def shutdown(self):
 | |
|         print("Shutting down client")
 | |
|         if self.timer is not None:
 | |
|             self.timer.stop()
 | |
|             self.timer = None
 | |
|         self.connected = False
 | |
|         self.socket.close()
 | |
| 
 | |
| 
 | |
| class ThreadedSaver(QThread):
 | |
|     def __init__(self, path: str | Path, snapshots: list[Snapshot]):
 | |
|         super().__init__()
 | |
|         self.path: str | Path = path
 | |
|         self.snapshots: list[Snapshot] = snapshots
 | |
| 
 | |
|     def run(self):
 | |
|         with RecordFile(self.path, "w") as f:
 | |
|             f.write_snapshots(self.snapshots)
 | |
| 
 | |
| 
 | |
| class RecorderWindow(Ui_Recorder, QMainWindow):
 | |
|     close_signal: pyqtSignal = pyqtSignal()
 | |
|     send_signal: pyqtSignal = pyqtSignal(object)
 | |
| 
 | |
|     SAVE_DIR: Path = Path(__file__).parent.parent / "records"
 | |
| 
 | |
|     COMMAND_DIRECTIONS: dict[str, CarControl] = {
 | |
|         "w": CarControl.FORWARD,
 | |
|         "s": CarControl.BACKWARD,
 | |
|         "d": CarControl.RIGHT,
 | |
|         "a": CarControl.LEFT,
 | |
|     }
 | |
| 
 | |
|     def __init__(self, host: str, port: int) -> None:
 | |
|         super().__init__()
 | |
| 
 | |
|         self.host: str = host
 | |
|         self.port: int = port
 | |
|         self.client_thread: QThread = QThread()
 | |
|         self.client: RecorderClient = RecorderClient(self.host, self.port)
 | |
|         self.client.data_received.connect(self.on_snapshot_received)
 | |
|         self.client.moveToThread(self.client_thread)
 | |
|         self.client_thread.started.connect(self.client.start)
 | |
|         self.close_signal.connect(self.client.shutdown)
 | |
|         self.send_signal.connect(self.client.send_command)
 | |
| 
 | |
|         uic.load_ui.loadUi("src/recorder.ui", self)
 | |
| 
 | |
|         self.forwardButton.pressed.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.FORWARD, True)
 | |
|         )
 | |
|         self.forwardButton.released.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.FORWARD, False)
 | |
|         )
 | |
| 
 | |
|         self.backwardButton.pressed.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.BACKWARD, True)
 | |
|         )
 | |
|         self.backwardButton.released.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.BACKWARD, False)
 | |
|         )
 | |
| 
 | |
|         self.rightButton.pressed.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.RIGHT, True)
 | |
|         )
 | |
|         self.rightButton.released.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.RIGHT, False)
 | |
|         )
 | |
| 
 | |
|         self.leftButton.pressed.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.LEFT, True)
 | |
|         )
 | |
|         self.leftButton.released.connect(
 | |
|             lambda: self.on_car_controlled(CarControl.LEFT, False)
 | |
|         )
 | |
| 
 | |
|         self.recordDataButton.clicked.connect(self.toggle_record)
 | |
|         self.resetButton.clicked.connect(self.rollback)
 | |
| 
 | |
|         self.bot: Optional[Bot] = None
 | |
|         self.autopiloting = False
 | |
| 
 | |
|         self.autopilotButton.clicked.connect(self.toggle_autopilot)
 | |
|         self.autopilotButton.setDisabled(True)
 | |
| 
 | |
|         self.saveRecordButton.clicked.connect(self.save_record)
 | |
| 
 | |
|         self.saving_worker: Optional[ThreadedSaver] = None
 | |
|         self.recording = False
 | |
| 
 | |
|         self.snapshots: list[Snapshot] = []
 | |
|         self.client_thread.start()
 | |
| 
 | |
|     def on_car_controlled(self, control: CarControl, active: bool):
 | |
|         self.send_command(ControlCommand(control, active))
 | |
| 
 | |
|     def keyPressEvent(self, event):  # type: ignore
 | |
|         if event.isAutoRepeat():
 | |
|             return
 | |
| 
 | |
|         if isinstance(event, QKeyEvent):
 | |
|             key_text = event.text()
 | |
|             ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
 | |
|             if ctrl is not None:
 | |
|                 self.on_car_controlled(ctrl, True)
 | |
| 
 | |
|     def keyReleaseEvent(self, event):  # type: ignore
 | |
|         if event.isAutoRepeat():
 | |
|             return
 | |
|         if isinstance(event, QKeyEvent):
 | |
|             key_text = event.text()
 | |
|             ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
 | |
|             if ctrl is not None:
 | |
|                 self.on_car_controlled(ctrl, False)
 | |
| 
 | |
|     def toggle_record(self):
 | |
|         self.recording = not self.recording
 | |
|         self.recordDataButton.setText(
 | |
|             "Recording..." if self.recording else "Record")
 | |
|         self.send_command(RecordingCommand(self.recording))
 | |
| 
 | |
|     def rollback(self):
 | |
|         rollback_by: int = self.forgetSnapshotNumber.value()
 | |
|         rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1))
 | |
| 
 | |
|         self.snapshots = self.snapshots[:-rollback_by]
 | |
|         self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
 | |
| 
 | |
|         if len(self.snapshots) == 0:
 | |
|             self.send_command(ResetCommand())
 | |
|         else:
 | |
|             self.send_command(ApplySnapshotCommand(self.snapshots[-1]))
 | |
| 
 | |
|         if self.recording:
 | |
|             self.toggle_record()
 | |
| 
 | |
|     def toggle_autopilot(self):
 | |
|         self.autopiloting = not self.autopiloting
 | |
|         self.autopilotButton.setText(
 | |
|             "AutoPilot:\n" + ("ON" if self.autopiloting else "OFF")
 | |
|         )
 | |
| 
 | |
|     def save_record(self):
 | |
|         if self.saving_worker is not None:
 | |
|             print("Already saving !")
 | |
|             return
 | |
| 
 | |
|         if len(self.snapshots) == 0:
 | |
|             print("No data to save !")
 | |
|             return
 | |
| 
 | |
|         if self.recording:
 | |
|             self.toggle_record()
 | |
| 
 | |
|         self.saveRecordButton.setText("Saving ...")
 | |
| 
 | |
|         self.SAVE_DIR.mkdir(exist_ok=True)
 | |
| 
 | |
|         record_name: str = "record_%d.rec"
 | |
|         fid = 0
 | |
|         while os.path.exists(self.SAVE_DIR / (record_name % fid)):
 | |
|             fid += 1
 | |
| 
 | |
|         self.saving_worker = ThreadedSaver(
 | |
|             self.SAVE_DIR / (record_name % fid), self.snapshots)
 | |
|         self.snapshots = []
 | |
|         self.nbrSnapshotSaved.setText("0")
 | |
|         self.saving_worker.finished.connect(self.on_record_save_done)
 | |
|         self.saving_worker.start()
 | |
| 
 | |
|     def on_record_save_done(self):
 | |
|         if self.saving_worker is None:
 | |
|             return
 | |
|         print("Recorded data saved to", self.saving_worker.path)
 | |
|         self.saving_worker = None
 | |
|         self.saveRecordButton.setText("Save")
 | |
| 
 | |
|     @pyqtSlot(Snapshot)
 | |
|     def on_snapshot_received(self, snapshot: Snapshot):
 | |
|         self.snapshots.append(snapshot)
 | |
|         self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
 | |
| 
 | |
|         if self.autopiloting and self.bot is not None:
 | |
|             self.bot.on_snapshot_received(snapshot)
 | |
| 
 | |
|     def shutdown(self):
 | |
|         self.close_signal.emit()
 | |
| 
 | |
|     def send_command(self, command: Command):
 | |
|         self.send_signal.emit(command)
 | |
| 
 | |
|     def register_bot(self, bot: Bot):
 | |
|         self.bot = bot
 | |
|         self.autopilotButton.setDisabled(False)
 |