Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ server = [
"python-json-logger>=3.1.0",
"prometheus-client",
"grpcio>=1.50",
"protobuf>=6.33.5",
"smg-grpc-proto>=0.4.7",
]
aws = [
"boto3>=1.38.13",
Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/cli/commands/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def _command(self, args: argparse.Namespace):
os.environ["DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT"] = "1"
if args.token:
os.environ["DSTACK_SERVER_ADMIN_TOKEN"] = args.token
# Hide noisy "Other threads are currently calling into gRPC, skipping fork() handlers"
# messages in server logs. Users can still change this with GRPC_VERBOSITY.
os.environ.setdefault("GRPC_VERBOSITY", "ERROR")
uvicorn_log_level = os.getenv("DSTACK_SERVER_UVICORN_LOG_LEVEL", "ERROR").lower()
reload_disabled = os.getenv("DSTACK_SERVER_RELOAD_DISABLED") is not None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""SSH-tunneled gRPC channel target to a job's service port (UDS)."""

from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

import grpc

from dstack._internal.core.services.ssh.tunnel import (
SSH_DEFAULT_OPTIONS,
IPSocket,
SocketPair,
UnixSocket,
)
from dstack._internal.server.models import JobModel
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.ssh import container_ssh_tunnel
from dstack._internal.utils.common import get_or_error

SSH_CONNECT_TIMEOUT = timedelta(seconds=10)
# Match router_worker_sync HTTP server_info cap (_MAX_SERVER_INFO_RESPONSE_BYTES).
_MAX_GRPC_MESSAGE_BYTES = 256 * 1024
_GRPC_CHANNEL_OPTIONS = (
("grpc.max_receive_message_length", _MAX_GRPC_MESSAGE_BYTES),
("grpc.max_send_message_length", _MAX_GRPC_MESSAGE_BYTES),
)


@asynccontextmanager
async def get_service_replica_grpc_client(job: JobModel) -> AsyncGenerator[Any, None]:
options = {
**SSH_DEFAULT_OPTIONS,
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
}
job_spec = get_job_spec(job)
with TemporaryDirectory() as temp_dir:
# Keep the same socket file name as the HTTP helper for consistency.
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
async with container_ssh_tunnel(
job=job,
forwarded_sockets=[
SocketPair(
remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
local=UnixSocket(app_socket_path),
),
],
options=options,
):
target = f"unix://{app_socket_path}"
channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS)
try:
yield channel
finally:
await channel.close()
Loading
Loading