feat: add record saving
This commit is contained in:
		
							
								
								
									
										49
									
								
								src/record_file.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								src/record_file.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | |||||||
|  | from pathlib import Path | ||||||
|  | import struct | ||||||
|  | import time | ||||||
|  | from typing import Literal | ||||||
|  |  | ||||||
|  | from src.snapshot import Snapshot | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RecordFile: | ||||||
|  |     VERSION = 1 | ||||||
|  |  | ||||||
|  |     def __init__(self, path: str | Path, mode: Literal["w", "r"]) -> None: | ||||||
|  |         self.path: str | Path = path | ||||||
|  |         self.mode: Literal["w", "r"] = mode | ||||||
|  |         self.file = open(self.path, self.mode + "b") | ||||||
|  |  | ||||||
|  |     def __enter__(self): | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |     def __exit__(self, type, value, traceback): | ||||||
|  |         self.file.close() | ||||||
|  |  | ||||||
|  |     def write_header(self, n_snapshots: int): | ||||||
|  |         data: bytes = struct.pack( | ||||||
|  |             ">IId", self.VERSION, n_snapshots, time.time()) | ||||||
|  |         self.file.write(data) | ||||||
|  |  | ||||||
|  |     def write_snapshots(self, snapshots: list[Snapshot]): | ||||||
|  |         self.write_header(len(snapshots)) | ||||||
|  |         for snapshot in snapshots: | ||||||
|  |             data: bytes = snapshot.pack() | ||||||
|  |             self.file.write(struct.pack(">I", len(data)) + data) | ||||||
|  |  | ||||||
|  |     def read_snapshots(self) -> list[Snapshot]: | ||||||
|  |         version: int = struct.unpack(">I", self.file.read(4))[0] | ||||||
|  |         if version != self.VERSION: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"Cannot parse record file with format version {version} (current version: {self.VERSION})") | ||||||
|  |  | ||||||
|  |         n_snapshots: int | ||||||
|  |         timestamp: float | ||||||
|  |         n_snapshots, timestamp = struct.unpack(">Id", self.file.read(12)) | ||||||
|  |         snapshots: list[Snapshot] = [] | ||||||
|  |  | ||||||
|  |         for _ in range(n_snapshots): | ||||||
|  |             size: int = struct.unpack(">I", self.file.read(4))[0] | ||||||
|  |             snapshots.append(Snapshot.unpack(self.file.read(size))) | ||||||
|  |  | ||||||
|  |         return snapshots | ||||||
| @@ -1,11 +1,15 @@ | |||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
| import socket | import socket | ||||||
| import struct | import struct | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
| from PyQt6 import uic | from PyQt6 import uic | ||||||
| from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot | from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot | ||||||
| from PyQt6.QtWidgets import QMainWindow | from PyQt6.QtWidgets import QMainWindow | ||||||
|  |  | ||||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||||
|  | from src.record_file import RecordFile | ||||||
| from src.recorder_ui import Ui_Recorder | from src.recorder_ui import Ui_Recorder | ||||||
| from src.snapshot import Snapshot | from src.snapshot import Snapshot | ||||||
|  |  | ||||||
| @@ -84,10 +88,23 @@ class RecorderClient(QObject): | |||||||
|         self.socket.close() |         self.socket.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ThreadedSaver(QThread): | ||||||
|  |     def __init__(self, path: str | Path, snapshots: list[Snapshot]): | ||||||
|  |         super().__init__() | ||||||
|  |         self.path: str | Path = path | ||||||
|  |         self.snapshots: list[Snapshot] = snapshots | ||||||
|  |  | ||||||
|  |     def run(self): | ||||||
|  |         with RecordFile(self.path, "w") as f: | ||||||
|  |             f.write_snapshots(self.snapshots) | ||||||
|  |  | ||||||
|  |  | ||||||
| class RecorderWindow(Ui_Recorder, QMainWindow): | class RecorderWindow(Ui_Recorder, QMainWindow): | ||||||
|     close_signal: pyqtSignal = pyqtSignal() |     close_signal: pyqtSignal = pyqtSignal() | ||||||
|     send_signal: pyqtSignal = pyqtSignal(object) |     send_signal: pyqtSignal = pyqtSignal(object) | ||||||
|  |  | ||||||
|  |     SAVE_DIR: Path = Path(__file__).parent.parent / "records" | ||||||
|  |  | ||||||
|     def __init__(self, host: str, port: int) -> None: |     def __init__(self, host: str, port: int) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
| @@ -147,9 +164,10 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | |||||||
|  |  | ||||||
|         self.saveRecordButton.clicked.connect(self.save_record) |         self.saveRecordButton.clicked.connect(self.save_record) | ||||||
|  |  | ||||||
|  |         self.saving_worker: Optional[ThreadedSaver] = None | ||||||
|         self.recording = False |         self.recording = False | ||||||
|  |  | ||||||
|         self.recorded_data = [] |         self.snapshots: list[Snapshot] = [] | ||||||
|         self.client_thread.start() |         self.client_thread.start() | ||||||
|  |  | ||||||
|     def on_car_controlled(self, control: CarControl, active: bool): |     def on_car_controlled(self, control: CarControl, active: bool): | ||||||
| @@ -171,12 +189,44 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def save_record(self): |     def save_record(self): | ||||||
|         pass |         if self.saving_worker is not None: | ||||||
|  |             print("Already saving !") | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if len(self.snapshots) == 0: | ||||||
|  |             print("No data to save !") | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if self.recording: | ||||||
|  |             self.toggle_record() | ||||||
|  |  | ||||||
|  |         self.saveRecordButton.setText("Saving ...") | ||||||
|  |  | ||||||
|  |         self.SAVE_DIR.mkdir(exist_ok=True) | ||||||
|  |  | ||||||
|  |         record_name: str = "record_%d.rec" | ||||||
|  |         fid = 0 | ||||||
|  |         while os.path.exists(self.SAVE_DIR / (record_name % fid)): | ||||||
|  |             fid += 1 | ||||||
|  |  | ||||||
|  |         self.saving_worker = ThreadedSaver( | ||||||
|  |             self.SAVE_DIR / (record_name % fid), self.snapshots) | ||||||
|  |         self.snapshots = [] | ||||||
|  |         self.nbrSnapshotSaved.setText("0") | ||||||
|  |         self.saving_worker.finished.connect(self.on_record_save_done) | ||||||
|  |         self.saving_worker.start() | ||||||
|  |  | ||||||
|  |     def on_record_save_done(self): | ||||||
|  |         if self.saving_worker is None: | ||||||
|  |             return | ||||||
|  |         print("Recorded data saved to", self.saving_worker.path) | ||||||
|  |         self.saving_worker = None | ||||||
|  |         self.saveRecordButton.setText("Save") | ||||||
|  |  | ||||||
|     @pyqtSlot(Snapshot) |     @pyqtSlot(Snapshot) | ||||||
|     def on_snapshot_received(self, snapshot: Snapshot): |     def on_snapshot_received(self, snapshot: Snapshot): | ||||||
|         self.recorded_data.append(snapshot) |         self.snapshots.append(snapshot) | ||||||
|         self.nbrSnapshotSaved.setText(str(len(self.recorded_data))) |         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) | ||||||
|  |  | ||||||
|     def shutdown(self): |     def shutdown(self): | ||||||
|         self.close_signal.emit() |         self.close_signal.emit() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user