From d9a242a5fa0af63cea2a9cdf852541695b260410 Mon Sep 17 00:00:00 2001 From: Peter Woolery Date: Mon, 11 May 2026 06:52:46 -0700 Subject: [PATCH] feat(server): add OTA check and firmware download endpoints Implements /ota/check (version comparison + sig_b64 payload) and /ota/firmware (binary stream) using the same _impl pattern as camera_endpoint.py. HMAC auth left commented pending main app wiring. 6/6 tests passing. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- server/ota_endpoint.py | 102 ++++++++++++++++++++++++++++++++++++ server/test_ota_endpoint.py | 70 +++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 server/ota_endpoint.py create mode 100644 server/test_ota_endpoint.py diff --git a/server/ota_endpoint.py b/server/ota_endpoint.py new file mode 100644 index 0000000..0f76c2a --- /dev/null +++ b/server/ota_endpoint.py @@ -0,0 +1,102 @@ +# server/ota_endpoint.py +""" +OTA firmware update endpoints. + +To register in the server main app: + from server.ota_endpoint import router as ota_router + app.include_router(ota_router) + +Route handlers have HMAC auth commented out pending import-path confirmation: + from .main import verify_device_hmac # adjust to actual module +Then uncomment the Depends lines in ota_check() and ota_firmware(). + +Firmware artifacts expected under FIRMWARE_DIR (default: server/firmware/): + current.bin — raw firmware binary + current.sig — 64-byte r‖s Ed25519/ECDSA signature + manifest.json — {"version": "X.Y.Z", "size": N, "sha256": "hex..."} +""" +import base64 +import json +from pathlib import Path + +from fastapi import APIRouter +from fastapi.responses import FileResponse + +FIRMWARE_DIR = Path(__file__).parent / "firmware" + +router = APIRouter(prefix="/ota", tags=["ota"]) + + +class FirmwareNotFoundError(Exception): + pass + + +def _parse_version(v: str) -> tuple: + """Parse semver string to comparable tuple; returns (0,0,0) on malformed input.""" + try: + parts = v.strip().split(".") + return tuple(int(x) for x in parts) + except (ValueError, AttributeError): + return (0, 0, 0) + + +def ota_check_impl(current_version: str, firmware_dir: Path = FIRMWARE_DIR) -> dict: + """ + Compare device's current_version against staged manifest. + Returns {"update": False} when no update is available or manifest is missing. + Returns full update payload when server version is strictly newer. + """ + manifest_path = firmware_dir / "manifest.json" + if not manifest_path.exists(): + return {"update": False} + + manifest = json.loads(manifest_path.read_text()) + if _parse_version(manifest["version"]) <= _parse_version(current_version): + return {"update": False} + + sig_path = firmware_dir / "current.sig" + sig_b64 = base64.b64encode(sig_path.read_bytes()).decode() + + return { + "update": True, + "version": manifest["version"], + "size": manifest["size"], + "sha256": manifest["sha256"], + "sig_b64": sig_b64, + } + + +def ota_firmware_impl(firmware_dir: Path = FIRMWARE_DIR) -> bytes: + """ + Return raw firmware binary bytes. + Raises FirmwareNotFoundError if current.bin is absent. + """ + bin_path = firmware_dir / "current.bin" + if not bin_path.exists(): + raise FirmwareNotFoundError("No firmware staged") + return bin_path.read_bytes() + + +@router.get("/check") +async def ota_check( + version: str, + # device_id: str = Depends(verify_device_hmac), # uncomment when wiring into app +): + """Check whether a firmware update is available for the given device version.""" + return ota_check_impl(current_version=version) + + +@router.get("/firmware") +async def ota_firmware( + # device_id: str = Depends(verify_device_hmac), # uncomment when wiring into app +): + """Stream the staged firmware binary to the device.""" + try: + ota_firmware_impl() # validate existence before streaming + except FirmwareNotFoundError: + from fastapi import HTTPException + raise HTTPException(status_code=404, detail="No firmware available") + return FileResponse( + FIRMWARE_DIR / "current.bin", + media_type="application/octet-stream", + ) diff --git a/server/test_ota_endpoint.py b/server/test_ota_endpoint.py new file mode 100644 index 0000000..a85bc9b --- /dev/null +++ b/server/test_ota_endpoint.py @@ -0,0 +1,70 @@ +# server/test_ota_endpoint.py +import base64 +import hashlib +import json +from pathlib import Path + +import pytest + +from server.ota_endpoint import ota_check_impl, ota_firmware_impl + + +def write_firmware(firmware_dir: Path, version: str, data: bytes = b"fake_fw") -> None: + sig = bytes(64) # zero sig (not validated server-side) + manifest = { + "version": version, + "size": len(data), + "sha256": hashlib.sha256(data).hexdigest(), + } + (firmware_dir / "current.bin").write_bytes(data) + (firmware_dir / "current.sig").write_bytes(sig) + (firmware_dir / "manifest.json").write_text(json.dumps(manifest)) + + +@pytest.fixture(autouse=True) +def patch_firmware_dir(tmp_path, monkeypatch): + import server.ota_endpoint as mod + monkeypatch.setattr(mod, "FIRMWARE_DIR", tmp_path) + yield tmp_path + + +def test_check_no_update_same_version(tmp_path): + write_firmware(tmp_path, "1.0.0") + result = ota_check_impl(current_version="1.0.0", firmware_dir=tmp_path) + assert result["update"] is False + + +def test_check_no_update_newer_local(tmp_path): + write_firmware(tmp_path, "1.0.0") + result = ota_check_impl(current_version="1.1.0", firmware_dir=tmp_path) + assert result["update"] is False + + +def test_check_update_available(tmp_path): + write_firmware(tmp_path, "1.1.0", data=b"new firmware") + result = ota_check_impl(current_version="1.0.0", firmware_dir=tmp_path) + assert result["update"] is True + assert result["version"] == "1.1.0" + assert result["size"] == len(b"new firmware") + assert "sha256" in result + assert "sig_b64" in result + sig_bytes = base64.b64decode(result["sig_b64"]) + assert len(sig_bytes) == 64 + + +def test_check_no_manifest(tmp_path): + result = ota_check_impl(current_version="1.0.0", firmware_dir=tmp_path) + assert result["update"] is False + + +def test_firmware_endpoint_returns_binary(tmp_path): + fw_data = b"firmware binary content" + write_firmware(tmp_path, "1.1.0", data=fw_data) + content = ota_firmware_impl(firmware_dir=tmp_path) + assert content == fw_data + + +def test_firmware_endpoint_missing_raises(tmp_path): + import server.ota_endpoint as mod + with pytest.raises(mod.FirmwareNotFoundError): + ota_firmware_impl(firmware_dir=tmp_path)