From baf92a4e8ef40adbfc57a001473a473b2e03afa3 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 5 Jun 2026 15:40:56 +0545 Subject: [PATCH 1/2] Support grpc communication with smg workers Minor Change Support grpc communication with smg router --- pyproject.toml | 2 + src/dstack/_internal/cli/commands/server.py | 3 + .../services/jobs/job_replica_grpc_client.py | 59 +++ .../services/runs/router_worker_sync.py | 353 ++++++++++++++++-- .../services/runs/test_router_worker_sync.py | 233 ++++++++++++ 5 files changed, 629 insertions(+), 21 deletions(-) create mode 100644 src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py create mode 100644 src/tests/_internal/server/services/runs/test_router_worker_sync.py diff --git a/pyproject.toml b/pyproject.toml index 4f09349ced..6067afb4bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,6 +200,8 @@ server = [ "python-json-logger>=3.1.0", "prometheus-client", "grpcio>=1.50", + "protobuf>=6.33.5", + "smg-grpc-proto>=0.4.7", ] aws = [ "boto3>=1.38.13", diff --git a/src/dstack/_internal/cli/commands/server.py b/src/dstack/_internal/cli/commands/server.py index a9040274dd..255f92e9a5 100644 --- a/src/dstack/_internal/cli/commands/server.py +++ b/src/dstack/_internal/cli/commands/server.py @@ -80,6 +80,9 @@ def _command(self, args: argparse.Namespace): os.environ["DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT"] = "1" if args.token: os.environ["DSTACK_SERVER_ADMIN_TOKEN"] = args.token + # Hide noisy "Other threads are currently calling into gRPC, skipping fork() handlers" + # messages in server logs. Users can still change this with GRPC_VERBOSITY. + os.environ.setdefault("GRPC_VERBOSITY", "ERROR") uvicorn_log_level = os.getenv("DSTACK_SERVER_UVICORN_LOG_LEVEL", "ERROR").lower() reload_disabled = os.getenv("DSTACK_SERVER_RELOAD_DISABLED") is not None diff --git a/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py new file mode 100644 index 0000000000..61825a9af0 --- /dev/null +++ b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py @@ -0,0 +1,59 @@ +"""SSH-tunneled gRPC channel target to a job's service port (UDS).""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import timedelta +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any + +import grpc + +from dstack._internal.core.services.ssh.tunnel import ( + SSH_DEFAULT_OPTIONS, + IPSocket, + SocketPair, + UnixSocket, +) +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.jobs import get_job_spec +from dstack._internal.server.services.ssh import container_ssh_tunnel +from dstack._internal.utils.common import get_or_error + +SSH_CONNECT_TIMEOUT = timedelta(seconds=10) +# Match router_worker_sync HTTP server_info cap (_MAX_SERVER_INFO_RESPONSE_BYTES). +_MAX_GRPC_MESSAGE_BYTES = 256 * 1024 +_GRPC_CHANNEL_OPTIONS = ( + ("grpc.max_receive_message_length", _MAX_GRPC_MESSAGE_BYTES), + ("grpc.max_send_message_length", _MAX_GRPC_MESSAGE_BYTES), +) + + +@asynccontextmanager +async def get_service_replica_grpc_client(job: JobModel) -> AsyncGenerator[Any, None]: + options = { + **SSH_DEFAULT_OPTIONS, + "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), + } + job_spec = get_job_spec(job) + with TemporaryDirectory() as temp_dir: + # Keep the same socket file name as the HTTP helper for consistency. + app_socket_path = (Path(temp_dir) / "replica.sock").absolute() + async with container_ssh_tunnel( + job=job, + forwarded_sockets=[ + SocketPair( + remote=IPSocket("localhost", get_or_error(job_spec.service_port)), + local=UnixSocket(app_socket_path), + ), + ], + options=options, + ): + target = f"unix://{app_socket_path}" + channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS) + try: + yield channel + finally: + await channel.close() diff --git a/src/dstack/_internal/server/services/runs/router_worker_sync.py b/src/dstack/_internal/server/services/runs/router_worker_sync.py index 2fc9add74b..06bebfd577 100644 --- a/src/dstack/_internal/server/services/runs/router_worker_sync.py +++ b/src/dstack/_internal/server/services/runs/router_worker_sync.py @@ -4,7 +4,22 @@ from typing import Any, Dict, List, Literal, Optional, TypedDict from urllib.parse import urlsplit, urlunsplit -from httpx import AsyncClient, Response +import grpc +from google.protobuf.json_format import MessageToDict +from httpx import ( + AsyncClient, + ConnectError, + ConnectTimeout, + ReadTimeout, + RemoteProtocolError, + Response, +) +from smg_grpc_proto import ( + sglang_scheduler_pb2, + sglang_scheduler_pb2_grpc, + vllm_engine_pb2, + vllm_engine_pb2_grpc, +) from typing_extensions import NotRequired from dstack._internal.core.errors import SSHError @@ -12,6 +27,9 @@ from dstack._internal.core.models.runs import JobStatus, RunSpec, get_service_port from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_spec +from dstack._internal.server.services.jobs.job_replica_grpc_client import ( + get_service_replica_grpc_client, +) from dstack._internal.server.services.jobs.job_replica_http_client import ( get_service_replica_client, ) @@ -29,6 +47,7 @@ _MAX_WORKERS_RESPONSE_BYTES = 2 * 1024 * 1024 _MAX_WORKERS_COMMAND_ACK_BYTES = 64 * 1024 _MAX_WORKERS_LIST_ITEMS = 8192 +_GRPC_DISCOVERY_TIMEOUT = 30.0 class _ResponseTooLargeError(Exception): @@ -91,6 +110,15 @@ class _TargetWorker(TypedDict): url: str worker_type: str bootstrap_port: NotRequired[Optional[int]] + connection_mode: NotRequired[str] + runtime_type: NotRequired[str] + kv_connector: NotRequired[str] + kv_role: NotRequired[str] + + +_ConnectionMode = Literal["grpc", "http"] +_RuntimeType = Literal["sglang", "vllm"] +_GRPC_RUNTIME_TYPES: tuple[_RuntimeType, ...] = ("sglang", "vllm") def run_model_has_sglang_router_replica_group(run_model: RunModel) -> bool: @@ -121,6 +149,70 @@ def _normalize_worker_url(url: str) -> str: return urlunsplit((parts.scheme, parts.netloc, path, parts.query, parts.fragment)) +def _get_connection_mode_from_workers( + current_workers: List[dict], +) -> Optional[_ConnectionMode]: + # PD services register multiple workers (e.g. prefill and decode). We expect + # every listed worker to use the same connection_mode (all grpc or all http), + # not a mix of protocols on one router. + modes: set[str] = set() + for worker in current_workers: + mode = worker.get("connection_mode") + if isinstance(mode, str) and mode in ("http", "grpc"): + modes.add(mode) + if modes == {"grpc"}: + return "grpc" + if modes == {"http"}: + return "http" + return None + + +def _get_runtime_type_from_workers( + current_workers: List[dict], +) -> Optional[_RuntimeType]: + # We expect every listed gRPC worker to share the same runtime_type + # (all sglang or all vllm), not a mix of runtimes on one router. + runtimes: set[str] = set() + for worker in current_workers: + # For HTTP workers,there is no “pick vLLM vs SGLang gRPC stub” step, + # so runtime_type is irrelevant for HTTP workers. + if worker.get("connection_mode") != "grpc": + continue + runtime_type = worker.get("runtime_type") + if isinstance(runtime_type, str) and runtime_type in _GRPC_RUNTIME_TYPES: + runtimes.add(runtime_type) + if runtimes == {"sglang"}: + return "sglang" + if runtimes == {"vllm"}: + return "vllm" + return None + + +def _is_expected_router_workers_fetch_error(error: Exception) -> bool: + """SMG router may not accept HTTP yet during startup.""" + if isinstance( + error, + ( + RemoteProtocolError, + ConnectError, + ConnectTimeout, + ReadTimeout, + TimeoutError, + ), + ): + return True + if isinstance(error, OSError) and error.errno in {61, 111}: + return True + return False + + +def _log_router_workers_fetch_failure(error: Exception) -> None: + if _is_expected_router_workers_fetch_error(error): + logger.debug("Router /workers not ready yet: %r", error) + return + logger.exception("Error getting router /workers") + + async def _get_router_workers(client: AsyncClient) -> List[dict]: try: data = await _request_json_limited( @@ -144,8 +236,8 @@ async def _get_router_workers(client: AsyncClient) -> List[dict]: return [w for w in workers if isinstance(w, dict)] except _ResponseTooLargeError: logger.warning("Router /workers response exceeded size limit") - except Exception: - logger.exception("Error getting router /workers") + except Exception as e: + _log_router_workers_fetch_failure(e) return [] @@ -154,11 +246,24 @@ async def _add_worker_to_router( url: str, worker_type: str = "regular", bootstrap_port: Optional[int] = None, + *, + connection_mode: Optional[str] = None, + runtime_type: Optional[str] = None, + kv_connector: Optional[str] = None, + kv_role: Optional[str] = None, ) -> bool: try: payload: dict = {"url": url, "worker_type": worker_type} if bootstrap_port is not None: payload["bootstrap_port"] = bootstrap_port + if connection_mode is not None: + payload["connection_mode"] = connection_mode + if runtime_type is not None: + payload["runtime_type"] = runtime_type + if kv_connector is not None: + payload["kv_connector"] = kv_connector + if kv_role is not None: + payload["kv_role"] = kv_role body = await _request_json_limited( client, "POST", @@ -199,11 +304,12 @@ async def _remove_worker_from_router_by_id( async def _update_workers_in_router_replica( client: AsyncClient, target_workers: List[_TargetWorker], + *, + current_workers: List[dict], ) -> None: - current = await _get_router_workers(client) current_urls: set[str] = set() current_ids_by_norm_url: dict[str, str] = {} - for w in current: + for w in current_workers: u = w.get("url") if not isinstance(u, str) or not u: continue @@ -223,6 +329,10 @@ async def _update_workers_in_router_replica( tw["url"], tw["worker_type"], tw.get("bootstrap_port"), + connection_mode=tw.get("connection_mode"), + runtime_type=tw.get("runtime_type"), + kv_connector=tw.get("kv_connector"), + kv_role=tw.get("kv_role"), ) if not ok: logger.warning("Failed to add worker %s, continuing with others", tw["url"]) @@ -237,7 +347,46 @@ async def _update_workers_in_router_replica( logger.warning("Failed to remove worker %s, continuing with others", url) -async def _get_worker_payload(job_model: JobModel, worker_url: str) -> _WorkerPayloadResult: +def _payload_to_target_worker(payload: Dict[str, Any]) -> _TargetWorker: + entry: _TargetWorker = { + "url": payload["url"], + "worker_type": payload.get("worker_type", "regular"), + } + if payload.get("bootstrap_port") is not None: + entry["bootstrap_port"] = payload["bootstrap_port"] + if payload.get("connection_mode") is not None: + entry["connection_mode"] = payload["connection_mode"] + if payload.get("runtime_type") is not None: + entry["runtime_type"] = payload["runtime_type"] + if payload.get("kv_connector") is not None: + entry["kv_connector"] = payload["kv_connector"] + if payload.get("kv_role") is not None: + entry["kv_role"] = payload["kv_role"] + return entry + + +def _vllm_kv_role_to_worker_type(kv_role: str) -> str: + if kv_role == "kv_producer": + return "prefill" + if kv_role == "kv_consumer": + return "decode" + return "regular" + + +def _is_expected_grpc_discovery_error(error: Exception) -> bool: + """Expected while a gRPC worker is still starting or the wrong stub is probed.""" + if isinstance(error, grpc.aio.AioRpcError): + return error.code() in ( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.UNIMPLEMENTED, + ) + return False + + +async def _get_http_worker_payload( + job_model: JobModel, *, worker_url: str +) -> _WorkerPayloadResult: try: async with get_service_replica_client(job_model) as client: data = await _request_json_limited( @@ -258,29 +407,172 @@ async def _get_worker_payload(job_model: JobModel, worker_url: str) -> _WorkerPa "payload": { "url": worker_url, "worker_type": "prefill", + "connection_mode": "http", + "runtime_type": "sglang", "bootstrap_port": bootstrap_port, }, } if mode == "decode": return { "status": "ready", - "payload": {"url": worker_url, "worker_type": "decode"}, + "payload": { + "url": worker_url, + "worker_type": "decode", + "connection_mode": "http", + "runtime_type": "sglang", + }, } return { "status": "ready", - "payload": {"url": worker_url, "worker_type": "regular"}, + "payload": { + "url": worker_url, + "worker_type": "regular", + "connection_mode": "http", + "runtime_type": "sglang", + }, } except _ResponseTooLargeError: logger.warning("server_info response too large for worker %s", worker_url) + except RemoteProtocolError as e: + logger.debug("HTTP server_info not available for worker %s: %r", worker_url, e) except Exception as e: logger.exception("Could not fetch server_info for worker %s: %r", worker_url, e) return {"status": "not_ready", "payload": None} +async def _get_grpc_server_info( + channel: grpc.aio.Channel, + runtime_type: _RuntimeType, +) -> Any: + if runtime_type == "sglang": + stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel) + request = sglang_scheduler_pb2.GetServerInfoRequest() + else: + stub = vllm_engine_pb2_grpc.VllmEngineStub(channel) + request = vllm_engine_pb2.GetServerInfoRequest() + return await stub.GetServerInfo(request, timeout=_GRPC_DISCOVERY_TIMEOUT) + + +async def _discover_grpc_server_info( + channel: grpc.aio.Channel, +) -> tuple[Optional[_RuntimeType], Optional[Any]]: + # Bootstrap only: router workers list has no runtime_type yet. + for runtime_type in _GRPC_RUNTIME_TYPES: + try: + response = await _get_grpc_server_info(channel, runtime_type) + except Exception as e: + if _is_expected_grpc_discovery_error(e): + continue + raise + return runtime_type, response + return None, None + + +def _grpc_server_info_to_worker_payload( + worker_url: str, + runtime_type: _RuntimeType, + response: Any, +) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "url": worker_url, + "connection_mode": "grpc", + "runtime_type": runtime_type, + } + if runtime_type == "vllm": + kv_role = response.kv_role or "" + kv_connector = response.kv_connector or "" + payload["worker_type"] = _vllm_kv_role_to_worker_type(kv_role) + if kv_connector: + payload["kv_connector"] = kv_connector + if kv_role: + payload["kv_role"] = kv_role + return payload + + server_args = ( + MessageToDict(response.server_args, preserving_proto_field_name=True) + if response.server_args is not None + else {} + ) + mode = server_args.get("disaggregation_mode") + payload["worker_type"] = mode if mode in ("prefill", "decode") else "regular" + if payload["worker_type"] == "prefill": + bootstrap_port = server_args.get("disaggregation_bootstrap_port") + if bootstrap_port is not None: + payload["bootstrap_port"] = int(bootstrap_port) + return payload + + +async def _get_grpc_worker_payload( + job_model: JobModel, + *, + worker_url: str, + runtime_type: Optional[_RuntimeType] = None, +) -> _WorkerPayloadResult: + try: + async with get_service_replica_grpc_client(job_model) as channel: + if runtime_type is not None: + try: + response = await _get_grpc_server_info(channel, runtime_type) + except Exception as e: + if _is_expected_grpc_discovery_error(e): + logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) + return {"status": "not_ready", "payload": None} + raise + else: + runtime_type, response = await _discover_grpc_server_info(channel) + if runtime_type is None or response is None: + logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) + return {"status": "not_ready", "payload": None} + except Exception as e: + logger.exception( + "Could not fetch gRPC GetServerInfo for worker %s: %r", + worker_url, + e, + ) + return {"status": "not_ready", "payload": None} + + payload = _grpc_server_info_to_worker_payload(worker_url, runtime_type, response) + return {"status": "ready", "payload": payload} + + +async def _get_worker_payload( + job_model: JobModel, + *, + http_worker_url: str, + grpc_worker_url: str, + connection_mode: Optional[_ConnectionMode] = None, + runtime_type: Optional[_RuntimeType] = None, +) -> _WorkerPayloadResult: + if connection_mode == "grpc": + return await _get_grpc_worker_payload( + job_model, worker_url=grpc_worker_url, runtime_type=runtime_type + ) + if connection_mode == "http": + return await _get_http_worker_payload(job_model, worker_url=http_worker_url) + # Router workers list is empty and no connection_mode discovered. + try: + result = await _get_http_worker_payload(job_model, worker_url=http_worker_url) + except RemoteProtocolError as e: + logger.debug( + "HTTP server_info probe failed for %s (trying gRPC): %r", + http_worker_url, + e, + ) + result: _WorkerPayloadResult = {"status": "not_ready", "payload": None} + if result["status"] == "ready": + return result + return await _get_grpc_worker_payload( + job_model, worker_url=grpc_worker_url, runtime_type=runtime_type + ) + + async def _build_target_workers( run_model: RunModel, run_spec: RunSpec, replica_groups: list[ReplicaGroup], + *, + connection_mode: Optional[_ConnectionMode] = None, + runtime_type: Optional[_RuntimeType] = None, ) -> List[_TargetWorker]: payloads: List[_TargetWorker] = [] config = run_spec.configuration @@ -305,19 +597,23 @@ async def _build_target_workers( continue job_spec = get_job_spec(job) port = get_service_port(job_spec, config) - worker_url = f"http://{ip}:{port}" - result = await _get_worker_payload(job, worker_url) + http_worker_url = f"http://{ip}:{port}" + grpc_worker_url = f"grpc://{ip}:{port}" + result = await _get_worker_payload( + job, + http_worker_url=http_worker_url, + grpc_worker_url=grpc_worker_url, + connection_mode=connection_mode, + runtime_type=runtime_type, + ) if result["status"] == "ready" and result["payload"]: - p = result["payload"] - entry: _TargetWorker = { - "url": p["url"], - "worker_type": p.get("worker_type", "regular"), - } - if p.get("bootstrap_port") is not None: - entry["bootstrap_port"] = p["bootstrap_port"] - payloads.append(entry) + payloads.append(_payload_to_target_worker(result["payload"])) elif result["status"] == "not_ready": - logger.debug("Worker %s not ready", worker_url) + logger.debug( + "Worker not ready http=%s grpc=%s", + http_worker_url, + grpc_worker_url, + ) return payloads @@ -331,13 +627,28 @@ async def sync_router_workers_for_run_model(run_model: RunModel) -> None: if router_group is None: return - target_workers = await _build_target_workers(run_model, run_spec, replica_groups) router_job = _get_router_job(run_model, router_group) if router_job is None: return try: async with get_service_replica_client(router_job) as client: - await _update_workers_in_router_replica(client, target_workers) + current_workers = await _get_router_workers(client) + # connection_mode can be grpc or http, runtime_type can be sglang or vllm. + connection_mode = _get_connection_mode_from_workers(current_workers) + runtime_type = _get_runtime_type_from_workers(current_workers) + # Empty current_workers on first sync is expected. First syncprobes both connection_mode and + # runtime_type. Subsequent syncs don't need to probe again because connection_mode and runtime_type + # is already set in current_workers. + target_workers = await _build_target_workers( + run_model, + run_spec, + replica_groups, + connection_mode=connection_mode, + runtime_type=runtime_type, + ) + await _update_workers_in_router_replica( + client, target_workers, current_workers=current_workers + ) except SSHError as e: logger.warning( "%s: failed to sync workers with router: %r", diff --git a/src/tests/_internal/server/services/runs/test_router_worker_sync.py b/src/tests/_internal/server/services/runs/test_router_worker_sync.py new file mode 100644 index 0000000000..87e8fb4f37 --- /dev/null +++ b/src/tests/_internal/server/services/runs/test_router_worker_sync.py @@ -0,0 +1,233 @@ +from contextlib import asynccontextmanager, contextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from dstack._internal.server.services.runs.router_worker_sync import ( + _get_connection_mode_from_workers, + _get_grpc_worker_payload, + _get_runtime_type_from_workers, + _get_worker_payload, + _grpc_server_info_to_worker_payload, +) + + +class TestGetConnectionModeFromWorkers: + def test_grpc(self): + current = [{"connection_mode": "grpc"}] + assert _get_connection_mode_from_workers(current) == "grpc" + + def test_http(self): + current = [{"connection_mode": "http"}] + assert _get_connection_mode_from_workers(current) == "http" + + def test_mixed(self): + current = [{"connection_mode": "grpc"}, {"connection_mode": "http"}] + assert _get_connection_mode_from_workers(current) is None + + +class TestRuntimeTypeFromRouterWorkers: + def test_vllm_grpc_workers(self): + current = [{"connection_mode": "grpc", "runtime_type": "vllm"}] + assert _get_runtime_type_from_workers(current) == "vllm" + + def test_sglang_grpc_workers(self): + current = [{"connection_mode": "grpc", "runtime_type": "sglang"}] + assert _get_runtime_type_from_workers(current) == "sglang" + + def test_ignores_http_workers(self): + current = [{"connection_mode": "http", "runtime_type": "sglang"}] + assert _get_runtime_type_from_workers(current) is None + + def test_mixed_runtimes(self): + current = [ + {"connection_mode": "grpc", "runtime_type": "vllm"}, + {"connection_mode": "grpc", "runtime_type": "sglang"}, + ] + assert _get_runtime_type_from_workers(current) is None + + +class TestGrpcServerInfoToWorkerPayload: + def test_vllm_prefill(self): + response = MagicMock(kv_role="kv_producer", kv_connector="NixlConnector") + payload = _grpc_server_info_to_worker_payload("grpc://10.0.0.1:50051", "vllm", response) + assert payload["worker_type"] == "prefill" + assert payload["runtime_type"] == "vllm" + assert payload["kv_role"] == "kv_producer" + + def test_sglang_prefill(self): + server_args = MagicMock() + response = MagicMock(server_args=server_args) + with patch( + "dstack._internal.server.services.runs.router_worker_sync.MessageToDict", + return_value={ + "disaggregation_mode": "prefill", + "disaggregation_bootstrap_port": 8998, + }, + ): + payload = _grpc_server_info_to_worker_payload( + "grpc://10.0.0.1:8000", "sglang", response + ) + assert payload == { + "url": "grpc://10.0.0.1:8000", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "sglang", + "bootstrap_port": 8998, + } + + +@contextmanager +def _fake_vllm_grpc_proto(*, server_info: MagicMock): + stub = MagicMock() + stub.GetServerInfo = AsyncMock(return_value=server_info) + pb2 = MagicMock(GetServerInfoRequest=MagicMock(return_value="req")) + pb2_grpc = MagicMock(VllmEngineStub=MagicMock(return_value=stub)) + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync.vllm_engine_pb2", + pb2, + ), + patch( + "dstack._internal.server.services.runs.router_worker_sync.vllm_engine_pb2_grpc", + pb2_grpc, + ), + ): + yield + + +@contextmanager +def _fake_sglang_grpc_proto(*, server_info: MagicMock): + stub = MagicMock() + stub.GetServerInfo = AsyncMock(return_value=server_info) + pb2 = MagicMock(GetServerInfoRequest=MagicMock(return_value="req")) + pb2_grpc = MagicMock(SglangSchedulerStub=MagicMock(return_value=stub)) + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync.sglang_scheduler_pb2", + pb2, + ), + patch( + "dstack._internal.server.services.runs.router_worker_sync.sglang_scheduler_pb2_grpc", + pb2_grpc, + ), + ): + yield + + +@pytest.mark.asyncio +async def test_get_grpc_worker_payload_ready(): + job = MagicMock() + channel = MagicMock() + + @asynccontextmanager + async def _fake_grpc_client(_job): + yield channel + + server_info = MagicMock(kv_role="kv_producer", kv_connector="NixlConnector") + + with ( + _fake_vllm_grpc_proto(server_info=server_info), + patch( + "dstack._internal.server.services.runs.router_worker_sync.get_service_replica_grpc_client", + _fake_grpc_client, + ), + ): + result = await _get_grpc_worker_payload( + job, + worker_url="grpc://10.0.0.1:50051", + runtime_type="vllm", + ) + + assert result["status"] == "ready" + assert result["payload"] == { + "url": "grpc://10.0.0.1:50051", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "vllm", + "kv_connector": "NixlConnector", + "kv_role": "kv_producer", + } + + +@pytest.mark.asyncio +async def test_get_grpc_worker_payload_not_ready_on_error(): + job = MagicMock() + + @asynccontextmanager + async def _failing_client(_job): + raise OSError("ssh failed") + yield # pragma: no cover + + with patch( + "dstack._internal.server.services.runs.router_worker_sync.get_service_replica_grpc_client", + _failing_client, + ): + result = await _get_grpc_worker_payload(job, worker_url="grpc://10.0.0.1:50051") + + assert result == {"status": "not_ready", "payload": None} + + +@pytest.mark.asyncio +async def test_get_grpc_worker_payload_sglang_bootstrap(): + job = MagicMock() + channel = MagicMock() + sglang_server_info = MagicMock(server_args=MagicMock()) + + @asynccontextmanager + async def _fake_grpc_client(_job): + yield channel + + with ( + _fake_sglang_grpc_proto(server_info=sglang_server_info), + patch( + "dstack._internal.server.services.runs.router_worker_sync.MessageToDict", + return_value={ + "disaggregation_mode": "prefill", + "disaggregation_bootstrap_port": 8998, + }, + ), + patch( + "dstack._internal.server.services.runs.router_worker_sync" + ".get_service_replica_grpc_client", + _fake_grpc_client, + ), + ): + result = await _get_grpc_worker_payload(job, worker_url="grpc://10.0.0.1:8000") + + assert result["status"] == "ready" + assert result["payload"] == { + "url": "grpc://10.0.0.1:8000", + "worker_type": "prefill", + "connection_mode": "grpc", + "runtime_type": "sglang", + "bootstrap_port": 8998, + } + + +@pytest.mark.asyncio +async def test_get_worker_payload_grpc_preference_skips_http(): + job = MagicMock() + grpc_not_ready = {"status": "not_ready", "payload": None} + + with ( + patch( + "dstack._internal.server.services.runs.router_worker_sync._get_grpc_worker_payload", + new_callable=AsyncMock, + return_value=grpc_not_ready, + ) as grpc_mock, + patch( + "dstack._internal.server.services.runs.router_worker_sync._get_http_worker_payload", + new_callable=AsyncMock, + ) as http_mock, + ): + result = await _get_worker_payload( + job, + http_worker_url="http://10.0.0.1:8000", + grpc_worker_url="grpc://10.0.0.1:8000", + connection_mode="grpc", + ) + + assert result == grpc_not_ready + grpc_mock.assert_awaited_once() + http_mock.assert_not_awaited() From fce861b2ea9761395e050cb1f8bc00b6706e9228 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Thu, 11 Jun 2026 14:43:34 +0545 Subject: [PATCH 2/2] Resolve Review Comments --- .../services/jobs/job_replica_grpc_client.py | 2 - .../services/runs/router_worker_sync.py | 133 ++++++++---------- .../services/runs/test_router_worker_sync.py | 50 ++++--- 3 files changed, 82 insertions(+), 103 deletions(-) diff --git a/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py index 61825a9af0..bc6f6cffe9 100644 --- a/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py +++ b/src/dstack/_internal/server/services/jobs/job_replica_grpc_client.py @@ -1,7 +1,5 @@ """SSH-tunneled gRPC channel target to a job's service port (UDS).""" -from __future__ import annotations - from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from datetime import timedelta diff --git a/src/dstack/_internal/server/services/runs/router_worker_sync.py b/src/dstack/_internal/server/services/runs/router_worker_sync.py index 06bebfd577..910dc8d576 100644 --- a/src/dstack/_internal/server/services/runs/router_worker_sync.py +++ b/src/dstack/_internal/server/services/runs/router_worker_sync.py @@ -1,7 +1,7 @@ """Reconcile SGLang router /workers with dstack's registered worker replicas (async, SSH-tunneled).""" import json -from typing import Any, Dict, List, Literal, Optional, TypedDict +from typing import Any, List, Literal, Optional, TypedDict from urllib.parse import urlsplit, urlunsplit import grpc @@ -101,11 +101,6 @@ async def _request_json_limited( return None -class _WorkerPayloadResult(TypedDict): - status: Literal["ready", "not_ready"] - payload: Optional[Dict[str, Any]] - - class _TargetWorker(TypedDict): url: str worker_type: str @@ -116,6 +111,11 @@ class _TargetWorker(TypedDict): kv_role: NotRequired[str] +class _WorkerPayloadResult(TypedDict): + status: Literal["ready", "not_ready"] + worker: Optional[_TargetWorker] + + _ConnectionMode = Literal["grpc", "http"] _RuntimeType = Literal["sglang", "vllm"] _GRPC_RUNTIME_TYPES: tuple[_RuntimeType, ...] = ("sglang", "vllm") @@ -347,24 +347,6 @@ async def _update_workers_in_router_replica( logger.warning("Failed to remove worker %s, continuing with others", url) -def _payload_to_target_worker(payload: Dict[str, Any]) -> _TargetWorker: - entry: _TargetWorker = { - "url": payload["url"], - "worker_type": payload.get("worker_type", "regular"), - } - if payload.get("bootstrap_port") is not None: - entry["bootstrap_port"] = payload["bootstrap_port"] - if payload.get("connection_mode") is not None: - entry["connection_mode"] = payload["connection_mode"] - if payload.get("runtime_type") is not None: - entry["runtime_type"] = payload["runtime_type"] - if payload.get("kv_connector") is not None: - entry["kv_connector"] = payload["kv_connector"] - if payload.get("kv_role") is not None: - entry["kv_role"] = payload["kv_role"] - return entry - - def _vllm_kv_role_to_worker_type(kv_role: str) -> str: if kv_role == "kv_producer": return "prefill" @@ -384,9 +366,7 @@ def _is_expected_grpc_discovery_error(error: Exception) -> bool: return False -async def _get_http_worker_payload( - job_model: JobModel, *, worker_url: str -) -> _WorkerPayloadResult: +async def _get_http_worker(job_model: JobModel, *, worker_url: str) -> _WorkerPayloadResult: try: async with get_service_replica_client(job_model) as client: data = await _request_json_limited( @@ -398,24 +378,23 @@ async def _get_http_worker_payload( ) if isinstance(data, dict): if data.get("status") != "ready": - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} mode = data.get("disaggregation_mode", "") if mode == "prefill": bootstrap_port = data.get("disaggregation_bootstrap_port") - return { - "status": "ready", - "payload": { - "url": worker_url, - "worker_type": "prefill", - "connection_mode": "http", - "runtime_type": "sglang", - "bootstrap_port": bootstrap_port, - }, + worker: _TargetWorker = { + "url": worker_url, + "worker_type": "prefill", + "connection_mode": "http", + "runtime_type": "sglang", } + if bootstrap_port is not None: + worker["bootstrap_port"] = bootstrap_port + return {"status": "ready", "worker": worker} if mode == "decode": return { "status": "ready", - "payload": { + "worker": { "url": worker_url, "worker_type": "decode", "connection_mode": "http", @@ -424,7 +403,7 @@ async def _get_http_worker_payload( } return { "status": "ready", - "payload": { + "worker": { "url": worker_url, "worker_type": "regular", "connection_mode": "http", @@ -437,7 +416,7 @@ async def _get_http_worker_payload( logger.debug("HTTP server_info not available for worker %s: %r", worker_url, e) except Exception as e: logger.exception("Could not fetch server_info for worker %s: %r", worker_url, e) - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} async def _get_grpc_server_info( @@ -468,25 +447,25 @@ async def _discover_grpc_server_info( return None, None -def _grpc_server_info_to_worker_payload( +def _grpc_server_info_to_worker( worker_url: str, runtime_type: _RuntimeType, response: Any, -) -> Dict[str, Any]: - payload: Dict[str, Any] = { - "url": worker_url, - "connection_mode": "grpc", - "runtime_type": runtime_type, - } +) -> _TargetWorker: if runtime_type == "vllm": kv_role = response.kv_role or "" kv_connector = response.kv_connector or "" - payload["worker_type"] = _vllm_kv_role_to_worker_type(kv_role) + worker: _TargetWorker = { + "url": worker_url, + "connection_mode": "grpc", + "runtime_type": runtime_type, + "worker_type": _vllm_kv_role_to_worker_type(kv_role), + } if kv_connector: - payload["kv_connector"] = kv_connector + worker["kv_connector"] = kv_connector if kv_role: - payload["kv_role"] = kv_role - return payload + worker["kv_role"] = kv_role + return worker server_args = ( MessageToDict(response.server_args, preserving_proto_field_name=True) @@ -494,15 +473,21 @@ def _grpc_server_info_to_worker_payload( else {} ) mode = server_args.get("disaggregation_mode") - payload["worker_type"] = mode if mode in ("prefill", "decode") else "regular" - if payload["worker_type"] == "prefill": + worker_type = mode if mode in ("prefill", "decode") else "regular" + worker = { + "url": worker_url, + "connection_mode": "grpc", + "runtime_type": runtime_type, + "worker_type": worker_type, + } + if worker_type == "prefill": bootstrap_port = server_args.get("disaggregation_bootstrap_port") if bootstrap_port is not None: - payload["bootstrap_port"] = int(bootstrap_port) - return payload + worker["bootstrap_port"] = int(bootstrap_port) + return worker -async def _get_grpc_worker_payload( +async def _get_grpc_worker( job_model: JobModel, *, worker_url: str, @@ -516,26 +501,26 @@ async def _get_grpc_worker_payload( except Exception as e: if _is_expected_grpc_discovery_error(e): logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} raise else: runtime_type, response = await _discover_grpc_server_info(channel) if runtime_type is None or response is None: logger.debug("gRPC worker %s not ready (GetServerInfo)", worker_url) - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} except Exception as e: logger.exception( "Could not fetch gRPC GetServerInfo for worker %s: %r", worker_url, e, ) - return {"status": "not_ready", "payload": None} + return {"status": "not_ready", "worker": None} - payload = _grpc_server_info_to_worker_payload(worker_url, runtime_type, response) - return {"status": "ready", "payload": payload} + worker = _grpc_server_info_to_worker(worker_url, runtime_type, response) + return {"status": "ready", "worker": worker} -async def _get_worker_payload( +async def _get_worker( job_model: JobModel, *, http_worker_url: str, @@ -544,26 +529,24 @@ async def _get_worker_payload( runtime_type: Optional[_RuntimeType] = None, ) -> _WorkerPayloadResult: if connection_mode == "grpc": - return await _get_grpc_worker_payload( + return await _get_grpc_worker( job_model, worker_url=grpc_worker_url, runtime_type=runtime_type ) if connection_mode == "http": - return await _get_http_worker_payload(job_model, worker_url=http_worker_url) + return await _get_http_worker(job_model, worker_url=http_worker_url) # Router workers list is empty and no connection_mode discovered. try: - result = await _get_http_worker_payload(job_model, worker_url=http_worker_url) + result = await _get_http_worker(job_model, worker_url=http_worker_url) except RemoteProtocolError as e: logger.debug( "HTTP server_info probe failed for %s (trying gRPC): %r", http_worker_url, e, ) - result: _WorkerPayloadResult = {"status": "not_ready", "payload": None} + result: _WorkerPayloadResult = {"status": "not_ready", "worker": None} if result["status"] == "ready": return result - return await _get_grpc_worker_payload( - job_model, worker_url=grpc_worker_url, runtime_type=runtime_type - ) + return await _get_grpc_worker(job_model, worker_url=grpc_worker_url, runtime_type=runtime_type) async def _build_target_workers( @@ -574,10 +557,10 @@ async def _build_target_workers( connection_mode: Optional[_ConnectionMode] = None, runtime_type: Optional[_RuntimeType] = None, ) -> List[_TargetWorker]: - payloads: List[_TargetWorker] = [] + workers: List[_TargetWorker] = [] config = run_spec.configuration if not isinstance(config, ServiceConfiguration): - return payloads + return workers for group in replica_groups: if group.router is not None: @@ -599,22 +582,22 @@ async def _build_target_workers( port = get_service_port(job_spec, config) http_worker_url = f"http://{ip}:{port}" grpc_worker_url = f"grpc://{ip}:{port}" - result = await _get_worker_payload( + result = await _get_worker( job, http_worker_url=http_worker_url, grpc_worker_url=grpc_worker_url, connection_mode=connection_mode, runtime_type=runtime_type, ) - if result["status"] == "ready" and result["payload"]: - payloads.append(_payload_to_target_worker(result["payload"])) + if result["status"] == "ready" and result["worker"]: + workers.append(result["worker"]) elif result["status"] == "not_ready": logger.debug( "Worker not ready http=%s grpc=%s", http_worker_url, grpc_worker_url, ) - return payloads + return workers async def sync_router_workers_for_run_model(run_model: RunModel) -> None: diff --git a/src/tests/_internal/server/services/runs/test_router_worker_sync.py b/src/tests/_internal/server/services/runs/test_router_worker_sync.py index 87e8fb4f37..2cf0275632 100644 --- a/src/tests/_internal/server/services/runs/test_router_worker_sync.py +++ b/src/tests/_internal/server/services/runs/test_router_worker_sync.py @@ -5,10 +5,10 @@ from dstack._internal.server.services.runs.router_worker_sync import ( _get_connection_mode_from_workers, - _get_grpc_worker_payload, + _get_grpc_worker, _get_runtime_type_from_workers, - _get_worker_payload, - _grpc_server_info_to_worker_payload, + _get_worker, + _grpc_server_info_to_worker, ) @@ -47,13 +47,13 @@ def test_mixed_runtimes(self): assert _get_runtime_type_from_workers(current) is None -class TestGrpcServerInfoToWorkerPayload: +class TestGrpcServerInfoToWorker: def test_vllm_prefill(self): response = MagicMock(kv_role="kv_producer", kv_connector="NixlConnector") - payload = _grpc_server_info_to_worker_payload("grpc://10.0.0.1:50051", "vllm", response) - assert payload["worker_type"] == "prefill" - assert payload["runtime_type"] == "vllm" - assert payload["kv_role"] == "kv_producer" + worker = _grpc_server_info_to_worker("grpc://10.0.0.1:50051", "vllm", response) + assert worker["worker_type"] == "prefill" + assert worker.get("runtime_type") == "vllm" + assert worker.get("kv_role") == "kv_producer" def test_sglang_prefill(self): server_args = MagicMock() @@ -65,10 +65,8 @@ def test_sglang_prefill(self): "disaggregation_bootstrap_port": 8998, }, ): - payload = _grpc_server_info_to_worker_payload( - "grpc://10.0.0.1:8000", "sglang", response - ) - assert payload == { + worker = _grpc_server_info_to_worker("grpc://10.0.0.1:8000", "sglang", response) + assert worker == { "url": "grpc://10.0.0.1:8000", "worker_type": "prefill", "connection_mode": "grpc", @@ -116,7 +114,7 @@ def _fake_sglang_grpc_proto(*, server_info: MagicMock): @pytest.mark.asyncio -async def test_get_grpc_worker_payload_ready(): +async def test_get_grpc_worker_ready(): job = MagicMock() channel = MagicMock() @@ -133,14 +131,14 @@ async def _fake_grpc_client(_job): _fake_grpc_client, ), ): - result = await _get_grpc_worker_payload( + result = await _get_grpc_worker( job, worker_url="grpc://10.0.0.1:50051", runtime_type="vllm", ) assert result["status"] == "ready" - assert result["payload"] == { + assert result["worker"] == { "url": "grpc://10.0.0.1:50051", "worker_type": "prefill", "connection_mode": "grpc", @@ -151,7 +149,7 @@ async def _fake_grpc_client(_job): @pytest.mark.asyncio -async def test_get_grpc_worker_payload_not_ready_on_error(): +async def test_get_grpc_worker_not_ready_on_error(): job = MagicMock() @asynccontextmanager @@ -163,13 +161,13 @@ async def _failing_client(_job): "dstack._internal.server.services.runs.router_worker_sync.get_service_replica_grpc_client", _failing_client, ): - result = await _get_grpc_worker_payload(job, worker_url="grpc://10.0.0.1:50051") + result = await _get_grpc_worker(job, worker_url="grpc://10.0.0.1:50051") - assert result == {"status": "not_ready", "payload": None} + assert result == {"status": "not_ready", "worker": None} @pytest.mark.asyncio -async def test_get_grpc_worker_payload_sglang_bootstrap(): +async def test_get_grpc_worker_sglang_bootstrap(): job = MagicMock() channel = MagicMock() sglang_server_info = MagicMock(server_args=MagicMock()) @@ -193,10 +191,10 @@ async def _fake_grpc_client(_job): _fake_grpc_client, ), ): - result = await _get_grpc_worker_payload(job, worker_url="grpc://10.0.0.1:8000") + result = await _get_grpc_worker(job, worker_url="grpc://10.0.0.1:8000") assert result["status"] == "ready" - assert result["payload"] == { + assert result["worker"] == { "url": "grpc://10.0.0.1:8000", "worker_type": "prefill", "connection_mode": "grpc", @@ -206,22 +204,22 @@ async def _fake_grpc_client(_job): @pytest.mark.asyncio -async def test_get_worker_payload_grpc_preference_skips_http(): +async def test_get_worker_grpc_preference_skips_http(): job = MagicMock() - grpc_not_ready = {"status": "not_ready", "payload": None} + grpc_not_ready = {"status": "not_ready", "worker": None} with ( patch( - "dstack._internal.server.services.runs.router_worker_sync._get_grpc_worker_payload", + "dstack._internal.server.services.runs.router_worker_sync._get_grpc_worker", new_callable=AsyncMock, return_value=grpc_not_ready, ) as grpc_mock, patch( - "dstack._internal.server.services.runs.router_worker_sync._get_http_worker_payload", + "dstack._internal.server.services.runs.router_worker_sync._get_http_worker", new_callable=AsyncMock, ) as http_mock, ): - result = await _get_worker_payload( + result = await _get_worker( job, http_worker_url="http://10.0.0.1:8000", grpc_worker_url="grpc://10.0.0.1:8000",