"""Minimal RCSP client example.

This script is a small, end-to-end walkthrough of the Rokoko Command Server
Protocol described in the RCSP spec (/reference/rcsp/ on the docs site).
The goal is for the code to read in roughly the same order the spec does:

    1. Open a TCP socket to the RCSP server (default localhost:45451).
    2. Send the `ListDevices` top-level command and read the response.
    3. Send a `GetDeviceInfo` device command for the first device returned.

Run it:

    python rcsp_cmd_example.py --host localhost --port 45451
"""

from __future__ import annotations

import argparse
import enum
import itertools
import json
import socket
import sys
from typing import Any


# ── Protocol constants ────────────────────────────────────────────────────────
# These mirror the `struct header` definition in the "Message structure"
# section of the RCSP spec. Every message on the wire is an 8-byte header
# followed by a JSON payload.

DEFAULT_PORT = 45451     # The server listens here.
HEADER_MARKER = 0xDC     # Sentinel byte ("DeviceCommands"); every header starts with it.
HEADER_VERSION = 1       # Binary header format version.
HEADER_SIZE = 8          # Total bytes in the header.


class PayloadType(enum.IntEnum):
    """The four payload kinds the protocol carries (the header's `type` field)."""
    COMMAND        = 1   # client → server  (always elicits exactly one response)
    RESPONSE_OK    = 2   # server → client
    RESPONSE_ERROR = 3   # server → client
    EVENT          = 4   # server → client  (pub/sub; may interleave with responses)


# ── Binary header ─────────────────────────────────────────────────────────────
# Layout (all little-endian):
#     marker         : uint8   (always 0xDC)
#     header_version : uint8   (currently 1)
#     header_size    : uint8   (== 8)
#     type           : uint8   (PayloadType)
#     payload_size   : uint32  (length of the JSON body that follows)

def build_header(payload_type: PayloadType, payload_size: int) -> bytes:
    return bytes([HEADER_MARKER, HEADER_VERSION, HEADER_SIZE, int(payload_type)]) \
         + payload_size.to_bytes(4, byteorder="little")


def parse_header(header: bytes) -> tuple[PayloadType, int]:
    if len(header) != HEADER_SIZE:
        raise RuntimeError(f"Short header: got {len(header)} bytes, expected {HEADER_SIZE}")
    if header[0] != HEADER_MARKER:
        raise RuntimeError(f"Bad marker: 0x{header[0]:02x} (expected 0x{HEADER_MARKER:02x})")
    return PayloadType(header[3]), int.from_bytes(header[4:8], byteorder="little")


# ── Wire framing ──────────────────────────────────────────────────────────────
# TCP is a raw byte stream — it has no message boundaries of its own. We use
# the header's `payload_size` field to know exactly how many bytes of JSON to
# read for each message.

def _recv_exact(sock: socket.socket, nbytes: int) -> bytes:
    """Read exactly `nbytes` from the socket, looping until the buffer is full."""
    buf = bytearray()
    while len(buf) < nbytes:
        chunk = sock.recv(nbytes - len(buf))
        if not chunk:
            raise ConnectionError("Server closed the connection mid-message")
        buf.extend(chunk)
    return bytes(buf)


def send_message(sock: socket.socket, payload_type: PayloadType, payload: dict[str, Any]) -> None:
    body = json.dumps(payload).encode()
    sock.sendall(build_header(payload_type, len(body)) + body)


def recv_message(sock: socket.socket) -> tuple[PayloadType, dict[str, Any]]:
    payload_type, payload_size = parse_header(_recv_exact(sock, HEADER_SIZE))
    payload = json.loads(_recv_exact(sock, payload_size).decode())
    return payload_type, payload


# ── Command envelope ──────────────────────────────────────────────────────────
# Every command is wrapped in the `rcsp_cmd` object from the spec:
#
#     { "Command": <str>, "TrackId": <str>, "Version": <num>, "Arguments": <opt obj> }
#
# The server echoes `TrackId` in the matching response so a client can pair
# requests with replies. We just use a monotonically increasing counter.

_track_ids = itertools.count(1)


def make_command(name: str, arguments: dict[str, Any] | None = None) -> dict[str, Any]:
    cmd: dict[str, Any] = {
        "Command": name,
        "TrackId": str(next(_track_ids)),
        "Version": 1,
    }
    if arguments is not None:
        cmd["Arguments"] = arguments
    return cmd


# ── Round-trip helper ─────────────────────────────────────────────────────────
# Per the spec, every command MUST get exactly one response (ok or error), and
# responses arrive in the same order commands were sent. Events MAY arrive
# between a command and its response, so we skip them while waiting.

def send_command(sock: socket.socket, name: str,
                 arguments: dict[str, Any] | None = None) -> dict[str, Any]:
    """Send a command and return the server's `Response` object.

    Raises RuntimeError on a `response_error` payload.
    """
    send_message(sock, PayloadType.COMMAND, make_command(name, arguments))

    while True:
        ptype, payload = recv_message(sock)
        if ptype == PayloadType.EVENT:
            continue
        if ptype == PayloadType.RESPONSE_ERROR:
            err = payload.get("Error", {})
            raise RuntimeError(f"{name} failed: {err.get('Code')}: {err.get('Message')}")
        if ptype == PayloadType.RESPONSE_OK:
            # An "Ok with no data" response may omit "Response" entirely.
            return payload.get("Response", {})
        raise RuntimeError(f"Unexpected payload type from server: {ptype!r}")


# ── Demo ──────────────────────────────────────────────────────────────────────

def main(host: str, port: int) -> int:
    with socket.create_connection((host, port), timeout=5.0) as sock:
        print(f"Connected to RCSP server at {host}:{port}\n")

        # 1. ListDevices — a top-level command, no Arguments needed.
        devices_resp = send_command(sock, "ListDevices")
        print("ListDevices response:")
        print(json.dumps(devices_resp, indent=2))

        devices = devices_resp.get("Devices", [])
        if not devices:
            print("\nNo devices connected — nothing more to demo.")
            return 0

        # 2. GetDeviceInfo — a device command, so it needs DeviceId in Arguments.
        device_id = devices[0]["DeviceId"]
        print(f"\nQuerying GetDeviceInfo for DeviceId={device_id}...")
        info = send_command(sock, "GetDeviceInfo", {"DeviceId": device_id})
        print("GetDeviceInfo response:")
        print(json.dumps(info, indent=2))

    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Minimal RCSP client example.")
    parser.add_argument("--host", default="localhost",
                        help="RCSP server host (default: localhost)")
    parser.add_argument("--port", type=int, default=DEFAULT_PORT,
                        help=f"RCSP server port (default: {DEFAULT_PORT})")
    args = parser.parse_args()

    sys.exit(main(args.host, args.port))
