From a08f85398c569b874003dd122165dbf4c30d6d26 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 4 May 2026 21:56:16 +0200 Subject: [PATCH 1/5] refactor(gpu): centralize driver request validation Signed-off-by: Evan Lezar --- .../openshell-driver-kubernetes/src/driver.rs | 4 ++++ crates/openshell-driver-vm/src/driver.rs | 24 ++++++++++++------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 5a43eb980..a5fadabcb 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -204,6 +204,10 @@ impl KubernetesComputeDriver { pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); + self.validate_gpu_request(gpu_requested).await + } + + async fn validate_gpu_request(&self, gpu_requested: bool) -> Result<(), tonic::Status> { if gpu_requested && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 445905a1e..ed6b57a01 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -2577,15 +2577,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - if spec.gpu && !gpu_enabled { - return Err(Status::failed_precondition( - "GPU support is not enabled on this driver; start with --gpu", - )); - } - - if !spec.gpu && !spec.gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); - } + validate_gpu_request(spec.gpu, &spec.gpu_device, gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -2628,6 +2620,20 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +#[allow(clippy::result_large_err)] +fn validate_gpu_request(gpu: bool, gpu_device: &str, gpu_enabled: bool) -> Result<(), Status> { + if gpu && !gpu_enabled { + return Err(Status::failed_precondition( + "GPU support is not enabled on this driver; start with --gpu", + )); + } + + if !gpu && !gpu_device.is_empty() { + return Err(Status::invalid_argument("gpu_device requires gpu=true")); + } + Ok(()) +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { From 787ae75c5b44a291fce05798975b9c3ba900f693 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 4 May 2026 22:12:28 +0200 Subject: [PATCH 2/5] refactor(vm): derive GPU device request once Signed-off-by: Evan Lezar --- crates/openshell-driver-vm/src/driver.rs | 31 +++++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index ed6b57a01..56e431f38 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -615,10 +615,11 @@ impl VmDriver { ))); } - let spec = sandbox.spec.as_ref(); - let is_gpu = spec.is_some_and(|s| s.gpu); - let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); - let gpu_bdf = if is_gpu { + let gpu_device = sandbox + .spec + .as_ref() + .and_then(|spec| requested_gpu_device(spec.gpu, &spec.gpu_device)); + let gpu_bdf = if let Some(gpu_device) = gpu_device { Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) } else { None @@ -2620,6 +2621,10 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +fn requested_gpu_device(gpu: bool, gpu_device: &str) -> Option<&str> { + gpu.then_some(gpu_device) +} + #[allow(clippy::result_large_err)] fn validate_gpu_request(gpu: bool, gpu_device: &str, gpu_enabled: bool) -> Result<(), Status> { if gpu && !gpu_enabled { @@ -4538,6 +4543,24 @@ mod tests { assert!(err.message().contains("gpu_device requires gpu=true")); } + #[test] + fn requested_gpu_device_returns_none_without_gpu_request() { + assert_eq!(requested_gpu_device(false, ""), None); + } + + #[test] + fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { + assert_eq!(requested_gpu_device(true, ""), Some("")); + } + + #[test] + fn requested_gpu_device_returns_explicit_device_id() { + assert_eq!( + requested_gpu_device(true, "0000:2d:00.0"), + Some("0000:2d:00.0") + ); + } + #[test] fn validate_vm_sandbox_rejects_platform_config() { let sandbox = Sandbox { From 23bdbb13f896457862ed4405f0589b9bdd9b0978 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 4 May 2026 22:14:48 +0200 Subject: [PATCH 3/5] feat(gpu): introduce GPU request spec Signed-off-by: Evan Lezar --- architecture/compute-runtimes.md | 4 +- crates/openshell-cli/src/main.rs | 54 ++++++- crates/openshell-cli/src/run.rs | 92 ++++++++--- .../sandbox_create_lifecycle_integration.rs | 56 +++++++ crates/openshell-core/src/gpu.rs | 59 +++++-- crates/openshell-driver-docker/README.md | 2 +- crates/openshell-driver-docker/src/lib.rs | 30 ++-- crates/openshell-driver-docker/src/tests.rs | 48 ++++-- crates/openshell-driver-kubernetes/README.md | 6 +- .../openshell-driver-kubernetes/src/driver.rs | 146 +++++++++++++----- crates/openshell-driver-podman/README.md | 2 +- .../openshell-driver-podman/src/container.rs | 19 ++- crates/openshell-driver-podman/src/driver.rs | 29 +++- crates/openshell-driver-vm/src/driver.rs | 115 ++++++++++---- crates/openshell-server/src/compute/mod.rs | 73 +++++++-- crates/openshell-server/src/grpc/sandbox.rs | 2 +- .../openshell-server/src/grpc/validation.rs | 68 +++++++- docs/sandboxes/manage-sandboxes.mdx | 10 ++ proto/compute_driver.proto | 19 ++- proto/openshell.proto | 27 ++-- 20 files changed, 689 insertions(+), 172 deletions(-) diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 02891c03e..58f46eeb3 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -77,7 +77,9 @@ users. Custom sandbox images must include the agent runtime and any system dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The -runtime still owns GPU device injection. +runtime still owns GPU device injection. GPU requests can include explicit +driver-native device IDs or a requested count; the gateway validates the public +request shape and each runtime enforces the GPU allocation modes it supports. ## Deployment Shape diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 2254f0c89..686103971 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1216,9 +1216,13 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] + #[arg(long, requires = "gpu", conflicts_with = "gpu_count")] gpu_device: Option, + /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. + #[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")] + gpu_count: Option, + /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] cpu: Option, @@ -2539,6 +2543,7 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + gpu_count, cpu, memory, providers, @@ -2608,6 +2613,7 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + gpu_count, cpu.as_deref(), memory.as_deref(), editor, @@ -4287,6 +4293,52 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_count_parses_without_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"]) + .expect("sandbox create --gpu-count should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, gpu_count, .. }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_count, Some(2)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]); + + assert!( + result.is_err(), + "sandbox create --gpu-count 0 should be rejected" + ); + } + + #[test] + fn sandbox_create_gpu_count_conflicts_with_gpu_device() { + let result = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "--gpu-device", + "nvidia.com/gpu=0", + "--gpu-count", + "2", + ]); + + assert!( + result.is_err(), + "sandbox create should reject --gpu-count with --gpu-device" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 9988d46db..0f7a84d49 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -39,17 +39,18 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest, - ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, - ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, - ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, - ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, - RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, - SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, - SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, - UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, - setting_value, tcp_forward_init, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuRequestSpec, + HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, + ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, + ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, + PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, + RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, + SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, + SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, + UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, + tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -1679,6 +1680,7 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, cpu: Option<&str>, memory: Option<&str>, editor: Option, @@ -1732,7 +1734,8 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); + let requested_gpu = + gpu || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { @@ -1763,8 +1766,7 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, - gpu_device: gpu_device.unwrap_or_default().to_string(), + gpu: gpu_request_from_cli(requested_gpu, gpu_device, gpu_count), policy, providers: configured_providers, template, @@ -2189,6 +2191,20 @@ pub async fn sandbox_create( } } +fn gpu_request_from_cli( + requested_gpu: bool, + gpu_device: Option<&str>, + gpu_count: Option, +) -> Option { + requested_gpu.then(|| GpuRequestSpec { + device_id: gpu_device + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default(), + count: gpu_count, + }) +} + /// Resolved source for the `--from` flag on `sandbox create`. #[derive(Debug)] enum ResolvedSource { @@ -7438,10 +7454,10 @@ mod tests { dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header, format_gateway_select_items, format_provider_attachment_table, gateway_add, gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, - git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, - inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, - parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs, - plaintext_gateway_is_remote, progress_step_from_metadata, + git_sync_files, gpu_request_from_cli, http_health_check, image_requests_gpu, + import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, + parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs, + parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata, provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, sandbox_should_persist, sandbox_upload_plan, service_expose_status_error, @@ -7924,6 +7940,46 @@ mod tests { } } + #[test] + fn gpu_request_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() { + let request = + gpu_request_from_cli(true, None, None).expect("gpu request should be present"); + + assert!(request.device_id.is_empty()); + assert_eq!(request.count, None); + } + + #[test] + fn gpu_request_from_cli_maps_gpu_device_to_one_device_id() { + let request = gpu_request_from_cli(true, Some("0000:2d:00.0"), None) + .expect("gpu request should be present"); + + assert_eq!(request.device_id, vec!["0000:2d:00.0"]); + assert_eq!(request.count, None); + } + + #[test] + fn gpu_request_from_cli_maps_gpu_count() { + let request = gpu_request_from_cli(true, None, Some(2)).expect("gpu request should exist"); + + assert!(request.device_id.is_empty()); + assert_eq!(request.count, Some(2)); + } + + #[test] + fn gpu_request_from_cli_preserves_device_and_gpu_count_for_gateway_validation() { + let request = gpu_request_from_cli(true, Some("nvidia.com/gpu=0"), Some(2)) + .expect("gpu request should exist"); + + assert_eq!(request.device_id, vec!["nvidia.com/gpu=0"]); + assert_eq!(request.count, Some(2)); + } + + #[test] + fn gpu_request_from_cli_omits_gpu_request_when_not_requested() { + assert!(gpu_request_from_cli(false, Some("0"), None).is_none()); + } + #[test] fn resolve_from_classifies_existing_dockerfile_path() { let temp = tempfile::tempdir().expect("failed to create tempdir"); diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index aee91de56..2372100e3 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -787,6 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -826,6 +827,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { true, false, None, + None, Some("500m"), Some("2Gi"), None, @@ -884,6 +886,52 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { assert!(!resources.fields.contains_key("requests")); } +#[tokio::test] +async fn sandbox_create_sends_gpu_count_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-count"), + None, + "openshell", + None, + true, + false, + None, + Some(2), + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.gpu.as_ref()) + .expect("GPU request should be sent"); + + assert!(gpu.device_id.is_empty()); + assert_eq!(gpu.count, Some(2)); +} + #[tokio::test] async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { let server = run_server().await; @@ -906,6 +954,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, + None, &[], None, None, @@ -963,6 +1012,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, + None, &[], None, None, @@ -1016,6 +1066,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, + None, &[], None, None, @@ -1061,6 +1112,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, + None, &[], None, None, @@ -1102,6 +1154,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1147,6 +1200,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1192,6 +1246,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -1237,6 +1292,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 5df8702ed..b79f3f59d 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,21 +4,19 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::compute::v1::GpuRequestSpec; -/// Resolve the existing GPU request fields into CDI device identifiers. +/// Resolve a driver GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. A GPU request with no explicit device -/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes -/// through unchanged. +/// `None` means no GPU was requested. Presence with no explicit device IDs +/// uses the CDI all-GPU request; otherwise the driver-native IDs pass through. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { - gpu.then(|| { - if gpu_device.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - vec![gpu_device.to_string()] - } - }) +pub fn cdi_gpu_device_ids(gpu: Option<&GpuRequestSpec>) -> Option> { + match gpu { + Some(gpu) if gpu.device_id.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(gpu) => Some(gpu.device_id.clone()), + None => None, + } } #[cfg(test)] @@ -27,22 +25,51 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, ""), None); + assert_eq!(cdi_gpu_device_ids(None), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let request = GpuRequestSpec { + device_id: vec![], + count: None, + }; + assert_eq!( - cdi_gpu_device_ids(true, ""), + cdi_gpu_device_ids(Some(&request)), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] - fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + fn cdi_gpu_device_ids_passes_single_device_id_through() { + let request = GpuRequestSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }; + assert_eq!( - cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), + cdi_gpu_device_ids(Some(&request)), Some(vec!["nvidia.com/gpu=0".to_string()]) ); } + + #[test] + fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { + let request = GpuRequestSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + count: None, + }; + + assert_eq!( + cdi_gpu_device_ids(Some(&request)), + Some(vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ]) + ); + } } diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index ea57f44e4..df4069059 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses explicit GPU request device IDs when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index e30ee7754..e18671067 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -34,11 +34,11 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + GetSandboxResponse, GpuRequestSpec, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::{Config, Error, Result as CoreResult}; use std::collections::HashMap; @@ -375,7 +375,7 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; + Self::validate_gpu_request(spec.gpu.as_ref(), config.supports_gpu)?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -409,8 +409,16 @@ impl DockerComputeDriver { )) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { - if gpu && !supports_gpu { + fn validate_gpu_request( + gpu: Option<&GpuRequestSpec>, + supports_gpu: bool, + ) -> Result<(), Status> { + if gpu.is_some_and(|gpu| gpu.count.is_some()) { + return Err(Status::invalid_argument( + "docker compute driver does not support GPU count requests", + )); + } + if gpu.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); @@ -1713,8 +1721,8 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { - cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { +fn docker_gpu_device_requests(gpu: Option<&GpuRequestSpec>) -> Option> { + cdi_gpu_device_ids(gpu).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), device_ids: Some(device_ids), @@ -1765,7 +1773,7 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, pids_limit: docker_pids_limit(config.sandbox_pids_limit)?, - device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), + device_requests: docker_gpu_device_requests(spec.gpu.as_ref()), binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index c9b34ff8f..07a68177e 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -13,7 +13,7 @@ use openshell_core::progress::{ PROGRESS_STEP_STARTING_SANDBOX, }; use openshell_core::proto::compute::v1::{ - DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, + DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, GpuRequestSpec, }; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -42,8 +42,7 @@ fn test_sandbox() -> DriverSandbox { resources: None, platform_config: None, }), - gpu: false, - gpu_device: String::new(), + gpu: None, sandbox_token: String::new(), }), status: None, @@ -605,7 +604,10 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { + device_id: vec![], + count: None, + }); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -613,6 +615,22 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { + device_id: vec![], + count: Some(2), + }); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("does not support GPU count")); +} + #[test] fn validate_sandbox_auth_requires_gateway_token() { let mut sandbox = test_sandbox(); @@ -640,7 +658,10 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { + device_id: vec![], + count: None, + }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -658,13 +679,17 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { } #[test] -fn build_container_create_body_passes_explicit_cdi_device_id_through() { +fn build_container_create_body_passes_explicit_cdi_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; - spec.gpu_device = "nvidia.com/gpu=0".to_string(); + sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + count: None, + }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -677,7 +702,10 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { assert_eq!(request.driver.as_deref(), Some("cdi")); assert_eq!( request.device_ids.as_ref().unwrap(), - &vec!["nvidia.com/gpu=0".to_string()] + &vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ] ); } diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 1d45a1d83..329cde120 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -50,6 +50,6 @@ pods do not need direct external ingress for SSH. ## GPU Support When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The -sandbox image must provide the user-space libraries needed by the agent -workload. +for `nvidia.com/gpu` and sets the workload's `nvidia.com/gpu` resource limit. +Requests without an explicit count use one GPU. The sandbox image must provide +the user-space libraries needed by the agent workload. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index a5fadabcb..f4eb8fd96 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -25,7 +25,7 @@ use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + GetCapabilitiesResponse, GpuRequestSpec, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; use std::collections::BTreeMap; @@ -77,7 +77,11 @@ const SANDBOX_VERSION: &str = "v1alpha1"; pub const SANDBOX_KIND: &str = "Sandbox"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; -const GPU_RESOURCE_QUANTITY: &str = "1"; +const DEFAULT_GPU_COUNT: u32 = 1; + +fn gpu_has_explicit_device_ids(gpu: Option<&GpuRequestSpec>) -> bool { + gpu.is_some_and(|gpu| !gpu.device_id.is_empty()) +} // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) @@ -203,12 +207,20 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); - self.validate_gpu_request(gpu_requested).await + let gpu = sandbox.spec.as_ref().and_then(|spec| spec.gpu.as_ref()); + self.validate_gpu_request(gpu).await } - async fn validate_gpu_request(&self, gpu_requested: bool) -> Result<(), tonic::Status> { - if gpu_requested + async fn validate_gpu_request( + &self, + gpu: Option<&GpuRequestSpec>, + ) -> Result<(), tonic::Status> { + if gpu_has_explicit_device_ids(gpu) { + return Err(tonic::Status::invalid_argument( + "kubernetes compute driver does not support explicit GPU device IDs", + )); + } + if gpu.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -300,6 +312,14 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { + if let Some(gpu) = sandbox.spec.as_ref().and_then(|spec| spec.gpu.as_ref()) + && gpu_has_explicit_device_ids(Some(gpu)) + { + return Err(KubernetesDriverError::Precondition( + "kubernetes compute driver does not support explicit GPU device IDs".to_string(), + )); + } + let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -1106,7 +1126,13 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s( + template, + spec.gpu.as_ref(), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1138,7 +1164,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + spec.and_then(|s| s.gpu.as_ref()), &pod_env, inject_workspace, params, @@ -1153,7 +1179,7 @@ fn sandbox_to_k8s_spec( fn sandbox_template_to_k8s( template: &SandboxTemplate, - gpu: bool, + gpu: Option<&GpuRequestSpec>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1207,7 +1233,7 @@ fn sandbox_template_to_k8s( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1388,7 +1414,10 @@ fn image_pull_secret_refs(secrets: &[String]) -> Vec { .collect() } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: Option<&GpuRequestSpec>, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1421,8 +1450,8 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_request(count: Option) -> GpuRequestSpec { + GpuRequestSpec { + device_id: vec![], + count, + } + } + #[test] fn kube_pulling_event_adds_image_progress_metadata() { let mut metadata = std::collections::HashMap::new(); @@ -1998,7 +2034,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - true, + Some(&gpu_request(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -2011,10 +2047,44 @@ mod tests { ); assert_eq!( pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) + ); + } + + #[test] + fn gpu_sandbox_uses_requested_gpu_count() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + Some(&gpu_request(Some(2))), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") ); } + #[test] + fn gpu_has_explicit_device_ids_only_when_ids_are_present() { + use openshell_core::proto::compute::v1::GpuRequestSpec; + + assert!(!gpu_has_explicit_device_ids(None)); + assert!(!gpu_has_explicit_device_ids(Some(&GpuRequestSpec { + device_id: vec![], + count: None, + }))); + assert!(gpu_has_explicit_device_ids(Some(&GpuRequestSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }))); + } + #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { @@ -2034,7 +2104,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_request(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -2066,7 +2136,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2094,7 +2164,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_request(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -2105,7 +2175,7 @@ mod tests { assert_eq!(limits["cpu"], serde_json::json!("2")); assert_eq!( limits[GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) ); } @@ -2125,7 +2195,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2148,7 +2218,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2173,7 +2243,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2196,7 +2266,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2335,7 +2405,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), false, // user provided custom VCTs ¶ms, @@ -2373,7 +2443,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2438,7 +2508,7 @@ mod tests { let params = SandboxPodParams::default(); // cluster default is off let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2476,7 +2546,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2502,7 +2572,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2524,7 +2594,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2546,7 +2616,7 @@ mod tests { fn sandbox_template_omits_empty_image_pull_secrets() { let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, &SandboxPodParams::default(), @@ -2571,7 +2641,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2600,7 +2670,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2728,7 +2798,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, @@ -2789,7 +2859,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 77b42ba37..e4183f75b 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | Sandbox `gpu_device` value when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | +| CDI GPU devices | Explicit GPU request device IDs when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-based GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 13f053e93..c31f72022 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -379,8 +379,8 @@ fn podman_pids_limit(value: i64) -> Option { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let spec = sandbox.spec.as_ref()?; - cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { + let gpu = sandbox.spec.as_ref().and_then(|spec| spec.gpu.as_ref()); + cdi_gpu_device_ids(gpu).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -808,11 +808,14 @@ mod tests { #[test] fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { use openshell_core::config::CDI_GPU_DEVICE_ALL; - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{DriverSandboxSpec, GpuRequestSpec}; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: None, + }), ..Default::default() }); let config = test_config(); @@ -826,12 +829,14 @@ mod tests { #[test] fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{DriverSandboxSpec, GpuRequestSpec}; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - gpu_device: "nvidia.com/gpu=0".to_string(), + gpu: Some(GpuRequestSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), ..Default::default() }); let config = test_config(); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index e2deb1c63..f3f65747c 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,7 +10,7 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse, GpuRequestSpec}; use std::path::PathBuf; use std::time::Duration; use tracing::{info, warn}; @@ -280,12 +280,17 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); - Self::validate_gpu_request(gpu_requested) + let gpu = sandbox.spec.as_ref().and_then(|s| s.gpu.as_ref()); + Self::validate_gpu_request(gpu) } - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { - if gpu_requested && !Self::has_gpu_capacity() { + fn validate_gpu_request(gpu: Option<&GpuRequestSpec>) -> Result<(), ComputeDriverError> { + if gpu.is_some_and(|gpu| gpu.count.is_some()) { + return Err(ComputeDriverError::Precondition( + "podman compute driver does not support GPU count requests".to_string(), + )); + } + if gpu.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); @@ -305,6 +310,7 @@ impl PodmanComputeDriver { "sandbox id is required".into(), )); } + self.validate_sandbox_create(sandbox)?; // Validate the composed container name early, before creating any // resources (volume), so we don't leave orphans when the name is @@ -667,6 +673,19 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_count() { + let err = PodmanComputeDriver::validate_gpu_request(Some(&GpuRequestSpec { + device_id: vec![], + count: Some(2), + })) + .expect_err("GPU count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("does not support GPU count")) + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 56e431f38..35c5fef50 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -32,11 +32,11 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, - ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, - ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, - WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, - compute_driver_server::ComputeDriver, watch_sandboxes_event, + GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, GpuRequestSpec, + ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, + ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, + WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, + WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_vfio::SysfsRoot; use prost::Message; @@ -618,7 +618,7 @@ impl VmDriver { let gpu_device = sandbox .spec .as_ref() - .and_then(|spec| requested_gpu_device(spec.gpu, &spec.gpu_device)); + .and_then(|spec| requested_gpu_device(spec.gpu.as_ref())); let gpu_bdf = if let Some(gpu_device) = gpu_device { Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) } else { @@ -2578,7 +2578,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - validate_gpu_request(spec.gpu, &spec.gpu_device, gpu_enabled)?; + validate_gpu_request(spec.gpu.as_ref(), gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -2621,20 +2621,29 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } -fn requested_gpu_device(gpu: bool, gpu_device: &str) -> Option<&str> { - gpu.then_some(gpu_device) +fn requested_gpu_device(gpu: Option<&GpuRequestSpec>) -> Option<&str> { + let gpu = gpu?; + Some(gpu.device_id.first().map_or("", String::as_str)) } #[allow(clippy::result_large_err)] -fn validate_gpu_request(gpu: bool, gpu_device: &str, gpu_enabled: bool) -> Result<(), Status> { - if gpu && !gpu_enabled { +fn validate_gpu_request(gpu: Option<&GpuRequestSpec>, gpu_enabled: bool) -> Result<(), Status> { + if gpu.is_some() && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", )); } - if !gpu && !gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); + if gpu.is_some_and(|gpu| gpu.count.is_some_and(|count| count > 1)) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU", + )); + } + + if gpu.is_some_and(|gpu| gpu.device_id.len() > 1) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU device ID", + )); } Ok(()) } @@ -4423,7 +4432,7 @@ mod tests { PROGRESS_COMPLETE_STEP_KEY, }; use openshell_core::proto::compute::v1::{ - DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, + DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, GpuRequestSpec, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; @@ -4502,7 +4511,10 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: None, + }), ..Default::default() }), ..Default::default() @@ -4518,7 +4530,10 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: None, + }), ..Default::default() }), ..Default::default() @@ -4527,38 +4542,82 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_accepts_gpu_count_one_when_enabled() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: Some(1), + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("gpu count one should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_greater_than_one() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, - gpu_device: "0000:2d:00.0".to_string(), + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: Some(2), + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("gpu count > 1 should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], + count: None, + }), ..Default::default() }), ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device without gpu should be rejected"); + .expect_err("multiple GPU device IDs should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device requires gpu=true")); + assert!(err.message().contains("at most one GPU device ID")); } #[test] fn requested_gpu_device_returns_none_without_gpu_request() { - assert_eq!(requested_gpu_device(false, ""), None); + assert_eq!(requested_gpu_device(None), None); } #[test] fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { - assert_eq!(requested_gpu_device(true, ""), Some("")); + let gpu = GpuRequestSpec { + device_id: vec![], + count: None, + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("")); } #[test] - fn requested_gpu_device_returns_explicit_device_id() { - assert_eq!( - requested_gpu_device(true, "0000:2d:00.0"), - Some("0000:2d:00.0") - ); + fn requested_gpu_device_returns_first_explicit_device_id() { + let gpu = GpuRequestSpec { + device_id: vec!["0000:2d:00.0".to_string()], + count: None, + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("0000:2d:00.0")); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 0122f9178..ac85f7269 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -19,10 +19,10 @@ use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, - ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, - compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, + GpuRequestSpec as DriverGpuRequestSpec, ListSandboxesRequest, ValidateSandboxCreateRequest, + WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto::{ PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, @@ -1267,8 +1267,10 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .template .as_ref() .map(driver_sandbox_template_from_public), - gpu: spec.gpu, - gpu_device: spec.gpu_device.clone(), + gpu: spec.gpu.as_ref().map(|gpu| DriverGpuRequestSpec { + device_id: gpu.device_id.clone(), + count: gpu.count, + }), sandbox_token: String::new(), } } @@ -1623,7 +1625,7 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu.is_some()); if !gpu_requested { return; } @@ -1781,6 +1783,7 @@ pub async fn new_test_runtime(store: Arc) -> ComputeRuntime { mod tests { use super::*; use futures::stream; + use openshell_core::proto::GpuRequestSpec; use openshell_core::proto::compute::v1::{ CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, @@ -1801,6 +1804,48 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_request_device_ids() { + let public = SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .gpu + .expect("driver GPU request should be present") + .device_id, + vec!["nvidia.com/gpu=0".to_string()] + ); + } + + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_count() { + let public = SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: Some(2), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .gpu + .expect("driver GPU request should be present") + .count, + Some(2) + ); + } + fn struct_value( fields: impl IntoIterator, prost_types::Value)>, ) -> prost_types::Value { @@ -2258,7 +2303,10 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: None, + }), ..Default::default() }), ); @@ -2289,7 +2337,7 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: false, + gpu: None, ..Default::default() }), ); @@ -2571,7 +2619,10 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: None, + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2594,7 +2645,7 @@ mod tests { SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu.is_some())); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 198d5f04c..0c837c537 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -99,7 +99,7 @@ fn emit_sandbox_create_telemetry( }; openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu, + spec.gpu.is_some(), spec.providers.len() as u64, spec.policy.is_some(), template_source, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 53f292053..a9f1e984f 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,6 +131,11 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } + // --- spec.gpu --- + if let Some(ref gpu) = spec.gpu { + validate_gpu_request(gpu)?; + } + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -144,6 +149,18 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_gpu_request(gpu: &openshell_core::proto::GpuRequestSpec) -> Result<(), Status> { + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + if gpu.count == Some(0) { + return Err(Status::invalid_argument("gpu.count must be greater than 0")); + } + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -661,7 +678,7 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::SandboxSpec; + use openshell_core::proto::{GpuRequestSpec, SandboxSpec}; use std::collections::HashMap; use tonic::Code; @@ -687,12 +704,59 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: None, + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: Some(2), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-count-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec![], + count: Some(0), + }), + ..Default::default() + }; + + let err = validate_sandbox_spec("gpu-count-sandbox", &spec).unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("count must be greater than 0")); + } + + #[test] + fn validate_sandbox_spec_rejects_gpu_count_with_device_id() { + let spec = SandboxSpec { + gpu: Some(GpuRequestSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: Some(1), + }), + ..Default::default() + }; + + let err = validate_sandbox_spec("gpu-count-sandbox", &spec).unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("mutually exclusive")); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 512abfd3d..0b8469612 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -51,10 +51,20 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +Request a specific number of GPUs with `--gpu-count`: + +```shell +openshell sandbox create --gpu-count 2 -- claude +``` + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. +Kubernetes gateways honor `--gpu-count` by setting the `nvidia.com/gpu` resource +limit. Docker and Podman support explicit CDI device IDs through `--gpu-device` +but do not support count-based selection yet. VM gateways accept only one GPU. + ### Custom Containers Use `--from` to create a sandbox from the base image, another pre-built sandbox name, a local directory, or a container image: diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 610d491c7..3ac04380c 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -77,18 +77,16 @@ message DriverSandbox { // Driver-owned provisioning inputs required to create a sandbox. message DriverSandboxSpec { + reserved 10; + // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Request GPU resources for this sandbox. Presence indicates a GPU request. + GpuRequestSpec gpu = 9; // Gateway-minted JWT identifying this sandbox to the gateway. Set by // the gateway on create; the driver materialises it via its native // secret mechanism (Docker/Podman/VM bind-mount a per-sandbox file; @@ -98,6 +96,15 @@ message DriverSandboxSpec { string sandbox_token = 11; } +// Driver-native GPU request details. +message GpuRequestSpec { + // Optional number of GPUs requested. Mutually exclusive with device_id. + optional uint32 count = 1; + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. + repeated string device_id = 2; +} + // Driver-owned runtime template consumed by the compute platform. // // This message describes the sandbox workload in backend-neutral terms. diff --git a/proto/openshell.proto b/proto/openshell.proto index f9b64618b..2dd51ba21 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -305,6 +305,9 @@ message Sandbox { // Desired sandbox configuration provided through the public API. message SandboxSpec { + reserved 10, 11; + reserved "gpu_device", "proposal_approval_mode"; + // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. @@ -315,18 +318,18 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; - // Field 11 was `proposal_approval_mode`. The approval mode is now a - // runtime setting (gateway or sandbox scope) read via UpdateConfig / - // GetSandboxConfig, so it can be flipped on a running sandbox and - // managed fleet-wide. - reserved 11; - reserved "proposal_approval_mode"; + // Request GPU resources for this sandbox. Presence indicates a GPU request. + GpuRequestSpec gpu = 9; +} + +// Public GPU request details. Device identifiers are interpreted by the +// selected compute driver. +message GpuRequestSpec { + // Optional number of GPUs requested. Mutually exclusive with device_id. + optional uint32 count = 1; + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. + repeated string device_id = 2; } // Public sandbox template mapped onto compute-driver template inputs. From 009a6ee320060857ee4a3fc52aa518bea0156409 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 3 Jun 2026 14:59:26 +0200 Subject: [PATCH 4/5] refactor(gpu): use resource requirements for GPU requests Signed-off-by: Evan Lezar --- architecture/compute-runtimes.md | 4 + crates/openshell-cli/src/run.rs | 95 +++++++++-------- .../sandbox_create_lifecycle_integration.rs | 5 +- crates/openshell-core/src/gpu.rs | 20 ++-- crates/openshell-driver-docker/README.md | 2 +- crates/openshell-driver-docker/src/lib.rs | 24 +++-- crates/openshell-driver-docker/src/tests.rs | 40 +++---- crates/openshell-driver-kubernetes/README.md | 8 +- .../openshell-driver-kubernetes/src/driver.rs | 63 ++++++----- crates/openshell-driver-podman/README.md | 2 +- .../openshell-driver-podman/src/container.rs | 30 ++++-- crates/openshell-driver-podman/src/driver.rs | 22 +++- crates/openshell-driver-vm/README.md | 5 +- crates/openshell-driver-vm/src/driver.rs | 87 ++++++++------- crates/openshell-server/src/compute/mod.rs | 100 ++++++++++++------ crates/openshell-server/src/grpc/sandbox.rs | 4 +- .../openshell-server/src/grpc/validation.rs | 61 +++++++---- docs/sandboxes/manage-sandboxes.mdx | 1 + proto/compute_driver.proto | 21 ++-- proto/openshell.proto | 23 ++-- 20 files changed, 377 insertions(+), 240 deletions(-) diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 58f46eeb3..f3accbfee 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -40,6 +40,10 @@ template resource limits. Docker and Podman apply them as runtime limits. Kubernetes mirrors each limit into the matching request. VM accepts the fields but currently ignores them. +GPU requests enter the driver layer through +`SandboxSpec.resource_requirements.gpu`. The compact interim shape supports a +default GPU request, GPU count, and driver-specific device IDs. + VM runtime state paths are derived only from driver-validated sandbox IDs matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a private `run/` directory plus Unix peer UID/PID checks. Standalone diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 0f7a84d49..038c821e3 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -39,18 +39,18 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuRequestSpec, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirement, HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, - RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, - SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, - SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, - UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, - tcp_forward_init, + RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, + SandboxResourceRequirements, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, + SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, + TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, + exec_sandbox_event, setting_value, tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -1766,7 +1766,11 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: gpu_request_from_cli(requested_gpu, gpu_device, gpu_count), + resource_requirements: resource_requirements_from_cli( + requested_gpu, + gpu_device, + gpu_count, + ), policy, providers: configured_providers, template, @@ -2191,17 +2195,19 @@ pub async fn sandbox_create( } } -fn gpu_request_from_cli( +fn resource_requirements_from_cli( requested_gpu: bool, gpu_device: Option<&str>, gpu_count: Option, -) -> Option { - requested_gpu.then(|| GpuRequestSpec { - device_id: gpu_device - .filter(|device_id| !device_id.is_empty()) - .map(|device_id| vec![device_id.to_string()]) - .unwrap_or_default(), - count: gpu_count, +) -> Option { + requested_gpu.then(|| SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: gpu_device + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default(), + count: gpu_count, + }), }) } @@ -7454,14 +7460,14 @@ mod tests { dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header, format_gateway_select_items, format_provider_attachment_table, gateway_add, gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, - git_sync_files, gpu_request_from_cli, http_health_check, image_requests_gpu, - import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, - parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs, - parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata, + git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, + inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, + parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs, + plaintext_gateway_is_remote, progress_step_from_metadata, provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, - sandbox_should_persist, sandbox_upload_plan, service_expose_status_error, - service_url_for_gateway, + resource_requirements_from_cli, sandbox_should_persist, sandbox_upload_plan, + service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -7941,43 +7947,48 @@ mod tests { } #[test] - fn gpu_request_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() { - let request = - gpu_request_from_cli(true, None, None).expect("gpu request should be present"); + fn resource_requirements_from_cli_uses_presence_for_default_gpu() { + let requirements = resource_requirements_from_cli(true, None, None) + .expect("resource requirements should be present"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); - assert!(request.device_id.is_empty()); - assert_eq!(request.count, None); + assert!(gpu.device_ids.is_empty()); + assert_eq!(gpu.count, None); } #[test] - fn gpu_request_from_cli_maps_gpu_device_to_one_device_id() { - let request = gpu_request_from_cli(true, Some("0000:2d:00.0"), None) - .expect("gpu request should be present"); + fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id() { + let requirements = resource_requirements_from_cli(true, Some("0000:2d:00.0"), None) + .expect("resource requirements should be present"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); - assert_eq!(request.device_id, vec!["0000:2d:00.0"]); - assert_eq!(request.count, None); + assert_eq!(gpu.device_ids, vec!["0000:2d:00.0"]); + assert_eq!(gpu.count, None); } #[test] - fn gpu_request_from_cli_maps_gpu_count() { - let request = gpu_request_from_cli(true, None, Some(2)).expect("gpu request should exist"); + fn resource_requirements_from_cli_maps_gpu_count() { + let requirements = + resource_requirements_from_cli(true, None, Some(2)).expect("requirements should exist"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); - assert!(request.device_id.is_empty()); - assert_eq!(request.count, Some(2)); + assert!(gpu.device_ids.is_empty()); + assert_eq!(gpu.count, Some(2)); } #[test] - fn gpu_request_from_cli_preserves_device_and_gpu_count_for_gateway_validation() { - let request = gpu_request_from_cli(true, Some("nvidia.com/gpu=0"), Some(2)) - .expect("gpu request should exist"); + fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation() { + let requirements = resource_requirements_from_cli(true, Some("nvidia.com/gpu=0"), Some(2)) + .expect("requirements should exist"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); - assert_eq!(request.device_id, vec!["nvidia.com/gpu=0"]); - assert_eq!(request.count, Some(2)); + assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]); + assert_eq!(gpu.count, Some(2)); } #[test] - fn gpu_request_from_cli_omits_gpu_request_when_not_requested() { - assert!(gpu_request_from_cli(false, Some("0"), None).is_none()); + fn resource_requirements_from_cli_omits_gpu_request_when_not_requested() { + assert!(resource_requirements_from_cli(false, Some("0"), None).is_none()); } #[test] diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 2372100e3..0829a2512 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -925,10 +925,11 @@ async fn sandbox_create_sends_gpu_count_request() { let gpu = requests[0] .spec .as_ref() - .and_then(|spec| spec.gpu.as_ref()) + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) .expect("GPU request should be sent"); - assert!(gpu.device_id.is_empty()); + assert!(gpu.device_ids.is_empty()); assert_eq!(gpu.count, Some(2)); } diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index b79f3f59d..07e0fcc78 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,17 +4,17 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; -use crate::proto::compute::v1::GpuRequestSpec; +use crate::proto::compute::v1::DriverGpuResourceRequirement; /// Resolve a driver GPU request into CDI device identifiers. /// /// `None` means no GPU was requested. Presence with no explicit device IDs /// uses the CDI all-GPU request; otherwise the driver-native IDs pass through. #[must_use] -pub fn cdi_gpu_device_ids(gpu: Option<&GpuRequestSpec>) -> Option> { +pub fn cdi_gpu_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> Option> { match gpu { - Some(gpu) if gpu.device_id.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), - Some(gpu) => Some(gpu.device_id.clone()), + Some(gpu) if gpu.device_ids.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(gpu) => Some(gpu.device_ids.clone()), None => None, } } @@ -30,8 +30,8 @@ mod tests { #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { - let request = GpuRequestSpec { - device_id: vec![], + let request = DriverGpuResourceRequirement { + device_ids: vec![], count: None, }; @@ -43,8 +43,8 @@ mod tests { #[test] fn cdi_gpu_device_ids_passes_single_device_id_through() { - let request = GpuRequestSpec { - device_id: vec!["nvidia.com/gpu=0".to_string()], + let request = DriverGpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], count: None, }; @@ -56,8 +56,8 @@ mod tests { #[test] fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { - let request = GpuRequestSpec { - device_id: vec![ + let request = DriverGpuResourceRequirement { + device_ids: vec![ "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string(), ], diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index df4069059..b44c7056f 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses explicit GPU request device IDs when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | +| CDI GPU request | Uses explicit `resource_requirements.gpu.device_ids` when set; otherwise requests all NVIDIA GPUs when `resource_requirements.gpu` is present and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index e18671067..e9b51c790 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -32,9 +32,9 @@ use openshell_core::progress::{ }; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, GpuRequestSpec, ListSandboxesRequest, ListSandboxesResponse, + DriverCondition, DriverGpuResourceRequirement, DriverPlatformEvent, DriverSandbox, + DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, @@ -375,7 +375,7 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu.as_ref(), config.supports_gpu)?; + Self::validate_gpu_request(driver_gpu_requirement(spec), config.supports_gpu)?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -410,7 +410,7 @@ impl DockerComputeDriver { } fn validate_gpu_request( - gpu: Option<&GpuRequestSpec>, + gpu: Option<&DriverGpuResourceRequirement>, supports_gpu: bool, ) -> Result<(), Status> { if gpu.is_some_and(|gpu| gpu.count.is_some()) { @@ -1721,7 +1721,17 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: Option<&GpuRequestSpec>) -> Option> { +fn driver_gpu_requirement( + spec: &openshell_core::proto::compute::v1::DriverSandboxSpec, +) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + +fn docker_gpu_device_requests( + gpu: Option<&DriverGpuResourceRequirement>, +) -> Option> { cdi_gpu_device_ids(gpu).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), @@ -1773,7 +1783,7 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, pids_limit: docker_pids_limit(config.sandbox_pids_limit)?, - device_requests: docker_gpu_device_requests(spec.gpu.as_ref()), + device_requests: docker_gpu_device_requests(driver_gpu_requirement(spec)), binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 07a68177e..308605fae 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -13,7 +13,8 @@ use openshell_core::progress::{ PROGRESS_STEP_STARTING_SANDBOX, }; use openshell_core::proto::compute::v1::{ - DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, GpuRequestSpec, + DriverGpuResourceRequirement, DriverResourceRequirements, DriverSandboxResourceRequirements, + DriverSandboxSpec, DriverSandboxTemplate, }; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -23,6 +24,15 @@ use tempfile::TempDir; const TLS_MOUNT_DIR: &str = "/etc/openshell/tls/client"; static ENV_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); +fn gpu_resource_requirements( + device_ids: Vec, + count: Option, +) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { device_ids, count }), + } +} + fn test_sandbox() -> DriverSandbox { // Mirrors the gateway-supplied request: the public `Sandbox` API no // longer carries `namespace`, so the gateway elides the field and the @@ -42,8 +52,8 @@ fn test_sandbox() -> DriverSandbox { resources: None, platform_config: None, }), - gpu: None, sandbox_token: String::new(), + resource_requirements: None, }), status: None, } @@ -604,10 +614,8 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { - device_id: vec![], - count: None, - }); + sandbox.spec.as_mut().unwrap().resource_requirements = + Some(gpu_resource_requirements(vec![], None)); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -620,10 +628,8 @@ fn validate_sandbox_rejects_gpu_count() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { - device_id: vec![], - count: Some(2), - }); + sandbox.spec.as_mut().unwrap().resource_requirements = + Some(gpu_resource_requirements(vec![], Some(2))); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -658,10 +664,8 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { - device_id: vec![], - count: None, - }); + sandbox.spec.as_mut().unwrap().resource_requirements = + Some(gpu_resource_requirements(vec![], None)); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -683,13 +687,13 @@ fn build_container_create_body_passes_explicit_cdi_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = Some(GpuRequestSpec { - device_id: vec![ + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements( + vec![ "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string(), ], - count: None, - }); + None, + )); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 329cde120..0ddbc9e2d 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -49,7 +49,7 @@ pods do not need direct external ingress for SSH. ## GPU Support -When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and sets the workload's `nvidia.com/gpu` resource limit. -Requests without an explicit count use one GPU. The sandbox image must provide -the user-space libraries needed by the agent workload. +When `resource_requirements.gpu` is present, the driver checks node allocatable +capacity for `nvidia.com/gpu` and sets the workload's `nvidia.com/gpu` resource +limit. Requests without an explicit count use one GPU. The sandbox image must +provide the user-space libraries needed by the agent workload. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index f4eb8fd96..c1e2c2a74 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -22,11 +22,12 @@ use openshell_core::progress::{ format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ - DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, - DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, GpuRequestSpec, WatchSandboxesDeletedEvent, WatchSandboxesEvent, - WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, + DriverCondition as SandboxCondition, DriverGpuResourceRequirement, + DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, + DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, + DriverSandboxTemplate as SandboxTemplate, GetCapabilitiesResponse, WatchSandboxesDeletedEvent, + WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, + watch_sandboxes_event, }; use std::collections::BTreeMap; use std::pin::Pin; @@ -79,8 +80,14 @@ pub const SANDBOX_KIND: &str = "Sandbox"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; const DEFAULT_GPU_COUNT: u32 = 1; -fn gpu_has_explicit_device_ids(gpu: Option<&GpuRequestSpec>) -> bool { - gpu.is_some_and(|gpu| !gpu.device_id.is_empty()) +fn driver_gpu_requirement(spec: &SandboxSpec) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + +fn gpu_has_explicit_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> bool { + gpu.is_some_and(|gpu| !gpu.device_ids.is_empty()) } // --------------------------------------------------------------------------- @@ -207,13 +214,13 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { - let gpu = sandbox.spec.as_ref().and_then(|spec| spec.gpu.as_ref()); + let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); self.validate_gpu_request(gpu).await } async fn validate_gpu_request( &self, - gpu: Option<&GpuRequestSpec>, + gpu: Option<&DriverGpuResourceRequirement>, ) -> Result<(), tonic::Status> { if gpu_has_explicit_device_ids(gpu) { return Err(tonic::Status::invalid_argument( @@ -312,7 +319,7 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { - if let Some(gpu) = sandbox.spec.as_ref().and_then(|spec| spec.gpu.as_ref()) + if let Some(gpu) = sandbox.spec.as_ref().and_then(driver_gpu_requirement) && gpu_has_explicit_device_ids(Some(gpu)) { return Err(KubernetesDriverError::Precondition( @@ -1128,7 +1135,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( template, - spec.gpu.as_ref(), + driver_gpu_requirement(spec), &pod_env, inject_workspace, params, @@ -1164,7 +1171,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.and_then(|s| s.gpu.as_ref()), + spec.and_then(driver_gpu_requirement), &pod_env, inject_workspace, params, @@ -1179,7 +1186,7 @@ fn sandbox_to_k8s_spec( fn sandbox_template_to_k8s( template: &SandboxTemplate, - gpu: Option<&GpuRequestSpec>, + gpu: Option<&DriverGpuResourceRequirement>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1416,7 +1423,7 @@ fn image_pull_secret_refs(secrets: &[String]) -> Vec { fn container_resources( template: &SandboxTemplate, - gpu: Option<&GpuRequestSpec>, + gpu: Option<&DriverGpuResourceRequirement>, ) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API @@ -1716,9 +1723,9 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); - fn gpu_request(count: Option) -> GpuRequestSpec { - GpuRequestSpec { - device_id: vec![], + fn gpu_request(count: Option) -> DriverGpuResourceRequirement { + DriverGpuResourceRequirement { + device_ids: vec![], count, } } @@ -2072,17 +2079,19 @@ mod tests { #[test] fn gpu_has_explicit_device_ids_only_when_ids_are_present() { - use openshell_core::proto::compute::v1::GpuRequestSpec; - assert!(!gpu_has_explicit_device_ids(None)); - assert!(!gpu_has_explicit_device_ids(Some(&GpuRequestSpec { - device_id: vec![], - count: None, - }))); - assert!(gpu_has_explicit_device_ids(Some(&GpuRequestSpec { - device_id: vec!["nvidia.com/gpu=0".to_string()], - count: None, - }))); + assert!(!gpu_has_explicit_device_ids(Some( + &DriverGpuResourceRequirement { + device_ids: vec![], + count: None, + } + ))); + assert!(gpu_has_explicit_device_ids(Some( + &DriverGpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + } + ))); } #[test] diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index e4183f75b..7bca6e653 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | Explicit GPU request device IDs when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-based GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | +| CDI GPU devices | Explicit `resource_requirements.gpu.device_ids` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-based GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index c31f72022..1f5691872 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -379,7 +379,11 @@ fn podman_pids_limit(value: i64) -> Option { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let gpu = sandbox.spec.as_ref().and_then(|spec| spec.gpu.as_ref()); + let gpu = sandbox.spec.as_ref().and_then(|spec| { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) + }); cdi_gpu_device_ids(gpu).map(|device_ids| { device_ids .into_iter() @@ -808,13 +812,17 @@ mod tests { #[test] fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { use openshell_core::config::CDI_GPU_DEVICE_ALL; - use openshell_core::proto::compute::v1::{DriverSandboxSpec, GpuRequestSpec}; + use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec, + }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: None, + resource_requirements: Some(DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { + device_ids: vec![], + count: None, + }), }), ..Default::default() }); @@ -829,13 +837,17 @@ mod tests { #[test] fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::{DriverSandboxSpec, GpuRequestSpec}; + use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec, + }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec!["nvidia.com/gpu=0".to_string()], - count: None, + resource_requirements: Some(DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), }), ..Default::default() }); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index f3f65747c..6f93bbff7 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,7 +10,9 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse, GpuRequestSpec}; +use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandbox, GetCapabilitiesResponse, +}; use std::path::PathBuf; use std::time::Duration; use tracing::{info, warn}; @@ -280,11 +282,13 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu = sandbox.spec.as_ref().and_then(|s| s.gpu.as_ref()); + let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); Self::validate_gpu_request(gpu) } - fn validate_gpu_request(gpu: Option<&GpuRequestSpec>) -> Result<(), ComputeDriverError> { + fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + ) -> Result<(), ComputeDriverError> { if gpu.is_some_and(|gpu| gpu.count.is_some()) { return Err(ComputeDriverError::Precondition( "podman compute driver does not support GPU count requests".to_string(), @@ -578,6 +582,14 @@ impl PodmanComputeDriver { } } +fn driver_gpu_requirement( + spec: &openshell_core::proto::compute::v1::DriverSandboxSpec, +) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + #[cfg(test)] impl PodmanComputeDriver { pub(crate) fn for_tests(config: PodmanComputeConfig) -> Self { @@ -675,8 +687,8 @@ mod tests { #[test] fn validate_gpu_request_rejects_count() { - let err = PodmanComputeDriver::validate_gpu_request(Some(&GpuRequestSpec { - device_id: vec![], + let err = PodmanComputeDriver::validate_gpu_request(Some(&DriverGpuResourceRequirement { + device_ids: vec![], count: Some(2), })) .expect_err("GPU count should be rejected"); diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index 724bde06c..c5860f9cd 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -52,8 +52,9 @@ sudo -E env "PATH=$PATH" mise run gateway:vm -- --gpu ``` GPU passthrough uses VFIO and requires host support for IOMMU, root privileges -for bind/unbind operations, and a compatible sandbox image. The public GPU -overview lives in the repository `README.md`. +for bind/unbind operations, and a compatible sandbox image. Sandbox GPU requests +arrive as `resource_requirements.gpu`; the VM driver accepts the default request, +one explicit device ID, or a count of one. Point the CLI at the gateway with one of: diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 35c5fef50..c60e525ed 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -30,13 +30,14 @@ use openshell_core::progress::{ }; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, GpuRequestSpec, - ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, - ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, - WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, - WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, + DriverCondition as SandboxCondition, DriverGpuResourceRequirement, + DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, + DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_vfio::SysfsRoot; use prost::Message; @@ -618,7 +619,8 @@ impl VmDriver { let gpu_device = sandbox .spec .as_ref() - .and_then(|spec| requested_gpu_device(spec.gpu.as_ref())); + .and_then(driver_gpu_requirement) + .and_then(|gpu| requested_gpu_device(Some(gpu))); let gpu_bdf = if let Some(gpu_device) = gpu_device { Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) } else { @@ -2578,7 +2580,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - validate_gpu_request(spec.gpu.as_ref(), gpu_enabled)?; + validate_gpu_request(driver_gpu_requirement(spec), gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -2621,13 +2623,24 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } -fn requested_gpu_device(gpu: Option<&GpuRequestSpec>) -> Option<&str> { +fn driver_gpu_requirement( + spec: &openshell_core::proto::compute::v1::DriverSandboxSpec, +) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + +fn requested_gpu_device(gpu: Option<&DriverGpuResourceRequirement>) -> Option<&str> { let gpu = gpu?; - Some(gpu.device_id.first().map_or("", String::as_str)) + Some(gpu.device_ids.first().map_or("", String::as_str)) } #[allow(clippy::result_large_err)] -fn validate_gpu_request(gpu: Option<&GpuRequestSpec>, gpu_enabled: bool) -> Result<(), Status> { +fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_enabled: bool, +) -> Result<(), Status> { if gpu.is_some() && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", @@ -2640,7 +2653,7 @@ fn validate_gpu_request(gpu: Option<&GpuRequestSpec>, gpu_enabled: bool) -> Resu )); } - if gpu.is_some_and(|gpu| gpu.device_id.len() > 1) { + if gpu.is_some_and(|gpu| gpu.device_ids.len() > 1) { return Err(Status::invalid_argument( "vm compute driver supports at most one GPU device ID", )); @@ -4432,7 +4445,8 @@ mod tests { PROGRESS_COMPLETE_STEP_KEY, }; use openshell_core::proto::compute::v1::{ - DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, GpuRequestSpec, + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, + DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; @@ -4444,6 +4458,15 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_resource_requirements( + device_ids: Vec, + count: Option, + ) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { device_ids, count }), + } + } + #[test] fn vm_pulling_layer_event_adds_progress_detail_metadata() { let mut event = platform_event( @@ -4511,10 +4534,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: None, - }), + resource_requirements: Some(gpu_resource_requirements(vec![], None)), ..Default::default() }), ..Default::default() @@ -4530,10 +4550,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: None, - }), + resource_requirements: Some(gpu_resource_requirements(vec![], None)), ..Default::default() }), ..Default::default() @@ -4546,10 +4563,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: Some(1), - }), + resource_requirements: Some(gpu_resource_requirements(vec![], Some(1))), ..Default::default() }), ..Default::default() @@ -4562,10 +4576,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: Some(2), - }), + resource_requirements: Some(gpu_resource_requirements(vec![], Some(2))), ..Default::default() }), ..Default::default() @@ -4581,10 +4592,10 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], - count: None, - }), + resource_requirements: Some(gpu_resource_requirements( + vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], + None, + )), ..Default::default() }), ..Default::default() @@ -4602,8 +4613,8 @@ mod tests { #[test] fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { - let gpu = GpuRequestSpec { - device_id: vec![], + let gpu = DriverGpuResourceRequirement { + device_ids: vec![], count: None, }; @@ -4612,8 +4623,8 @@ mod tests { #[test] fn requested_gpu_device_returns_first_explicit_device_id() { - let gpu = GpuRequestSpec { - device_id: vec!["0000:2d:00.0".to_string()], + let gpu = DriverGpuResourceRequirement { + device_ids: vec!["0000:2d:00.0".to_string()], count: None, }; diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index ac85f7269..666a7174a 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -17,12 +17,13 @@ use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ - CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, - DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, - GpuRequestSpec as DriverGpuRequestSpec, ListSandboxesRequest, ValidateSandboxCreateRequest, - WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, - compute_driver_server::ComputeDriver, watch_sandboxes_event, + CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverGpuResourceRequirement, + DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, + DriverSandboxResourceRequirements, DriverSandboxSpec, DriverSandboxStatus, + DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, + ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, + compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, + watch_sandboxes_event, }; use openshell_core::proto::{ PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, @@ -1267,14 +1268,28 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .template .as_ref() .map(driver_sandbox_template_from_public), - gpu: spec.gpu.as_ref().map(|gpu| DriverGpuRequestSpec { - device_id: gpu.device_id.clone(), - count: gpu.count, - }), + resource_requirements: spec + .resource_requirements + .as_ref() + .map(driver_resource_requirements_from_public), sandbox_token: String::new(), } } +fn driver_resource_requirements_from_public( + requirements: &openshell_core::proto::SandboxResourceRequirements, +) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: requirements + .gpu + .as_ref() + .map(|gpu| DriverGpuResourceRequirement { + device_ids: gpu.device_ids.clone(), + count: gpu.count, + }), + } +} + fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { DriverSandboxTemplate { image: template.image.clone(), @@ -1625,7 +1640,12 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu.is_some()); + let gpu_requested = spec.is_some_and(|sandbox_spec| { + sandbox_spec + .resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()) + }); if !gpu_requested { return; } @@ -1783,11 +1803,11 @@ pub async fn new_test_runtime(store: Arc) -> ComputeRuntime { mod tests { use super::*; use futures::stream; - use openshell_core::proto::GpuRequestSpec; use openshell_core::proto::compute::v1::{ CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; + use openshell_core::proto::{GpuResourceRequirement, SandboxResourceRequirements}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; @@ -1807,9 +1827,11 @@ mod tests { #[test] fn driver_sandbox_spec_from_public_preserves_gpu_request_device_ids() { let public = SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec!["nvidia.com/gpu=0".to_string()], - count: None, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), }), ..Default::default() }; @@ -1818,9 +1840,11 @@ mod tests { assert_eq!( driver + .resource_requirements + .expect("driver resource requirements should be present") .gpu - .expect("driver GPU request should be present") - .device_id, + .expect("driver GPU requirement should be present") + .device_ids, vec!["nvidia.com/gpu=0".to_string()] ); } @@ -1828,9 +1852,11 @@ mod tests { #[test] fn driver_sandbox_spec_from_public_preserves_gpu_count() { let public = SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: Some(2), + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: Some(2), + }), }), ..Default::default() }; @@ -1839,8 +1865,10 @@ mod tests { assert_eq!( driver + .resource_requirements + .expect("driver resource requirements should be present") .gpu - .expect("driver GPU request should be present") + .expect("driver GPU requirement should be present") .count, Some(2) ); @@ -2303,9 +2331,11 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: None, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: None, + }), }), ..Default::default() }), @@ -2334,13 +2364,7 @@ mod tests { ..Default::default() }); - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: None, - ..Default::default() - }), - ); + rewrite_user_facing_conditions(&mut status, Some(&SandboxSpec::default())); assert_eq!(status.unwrap().conditions[0].message, original); } @@ -2619,9 +2643,11 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: None, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: None, + }), }), ..Default::default() }), @@ -2645,7 +2671,11 @@ mod tests { SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu.is_some())); + assert!(stored.spec.as_ref().is_some_and(|spec| { + spec.resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()) + })); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 0c837c537..dec84c4e9 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -99,7 +99,9 @@ fn emit_sandbox_create_telemetry( }; openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu.is_some(), + spec.resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()), spec.providers.len() as u64, spec.policy.is_some(), template_source, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index a9f1e984f..0f3b3fd7c 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,9 +131,9 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } - // --- spec.gpu --- - if let Some(ref gpu) = spec.gpu { - validate_gpu_request(gpu)?; + // --- spec.resource_requirements --- + if let Some(ref requirements) = spec.resource_requirements { + validate_resource_requirements(requirements)?; } // --- spec.policy serialized size --- @@ -149,14 +149,27 @@ pub(super) fn validate_sandbox_spec( Ok(()) } -fn validate_gpu_request(gpu: &openshell_core::proto::GpuRequestSpec) -> Result<(), Status> { - if gpu.count.is_some() && !gpu.device_id.is_empty() { +fn validate_resource_requirements( + requirements: &openshell_core::proto::SandboxResourceRequirements, +) -> Result<(), Status> { + if let Some(ref gpu) = requirements.gpu { + validate_gpu_requirement(gpu)?; + } + Ok(()) +} + +fn validate_gpu_requirement( + gpu: &openshell_core::proto::GpuResourceRequirement, +) -> Result<(), Status> { + if gpu.count.is_some() && !gpu.device_ids.is_empty() { return Err(Status::invalid_argument( - "gpu.count is mutually exclusive with gpu.device_id", + "resource_requirements.gpu.count is mutually exclusive with resource_requirements.gpu.device_ids", )); } if gpu.count == Some(0) { - return Err(Status::invalid_argument("gpu.count must be greater than 0")); + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); } Ok(()) } @@ -678,7 +691,7 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::{GpuRequestSpec, SandboxSpec}; + use openshell_core::proto::{GpuResourceRequirement, SandboxResourceRequirements, SandboxSpec}; use std::collections::HashMap; use tonic::Code; @@ -704,9 +717,11 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: None, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: None, + }), }), ..Default::default() }; @@ -716,9 +731,11 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_count() { let spec = SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: Some(2), + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: Some(2), + }), }), ..Default::default() }; @@ -728,9 +745,11 @@ mod tests { #[test] fn validate_sandbox_spec_rejects_zero_gpu_count() { let spec = SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec![], - count: Some(0), + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: Some(0), + }), }), ..Default::default() }; @@ -744,9 +763,11 @@ mod tests { #[test] fn validate_sandbox_spec_rejects_gpu_count_with_device_id() { let spec = SandboxSpec { - gpu: Some(GpuRequestSpec { - device_id: vec!["nvidia.com/gpu=0".to_string()], - count: Some(1), + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: Some(1), + }), }), ..Default::default() }; diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 0b8469612..212f43f1d 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -64,6 +64,7 @@ updated Docker daemon capability. Kubernetes gateways honor `--gpu-count` by setting the `nvidia.com/gpu` resource limit. Docker and Podman support explicit CDI device IDs through `--gpu-device` but do not support count-based selection yet. VM gateways accept only one GPU. +In the API, these flags populate `SandboxSpec.resource_requirements.gpu`. ### Custom Containers diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 3ac04380c..79bff06e2 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -77,16 +77,14 @@ message DriverSandbox { // Driver-owned provisioning inputs required to create a sandbox. message DriverSandboxSpec { - reserved 10; - // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request GPU resources for this sandbox. Presence indicates a GPU request. - GpuRequestSpec gpu = 9; + // Portable resource requirements for this sandbox. + DriverSandboxResourceRequirements resource_requirements = 9; // Gateway-minted JWT identifying this sandbox to the gateway. Set by // the gateway on create; the driver materialises it via its native // secret mechanism (Docker/Podman/VM bind-mount a per-sandbox file; @@ -96,13 +94,20 @@ message DriverSandboxSpec { string sandbox_token = 11; } -// Driver-native GPU request details. -message GpuRequestSpec { - // Optional number of GPUs requested. Mutually exclusive with device_id. +// Driver-owned resource requirements for the sandbox workload. +message DriverSandboxResourceRequirements { + // GPU requirement for the sandbox. Presence indicates a GPU request. + DriverGpuResourceRequirement gpu = 1; +} + +// Driver-owned GPU resource requirement. Device identifiers are interpreted by +// the selected compute driver and are an interim compatibility surface. +message DriverGpuResourceRequirement { + // Optional number of GPUs requested. Mutually exclusive with device_ids. optional uint32 count = 1; // Optional driver-native device identifiers. Mutually exclusive with count. // Empty means the driver chooses its default GPU assignment behavior. - repeated string device_id = 2; + repeated string device_ids = 2; } // Driver-owned runtime template consumed by the compute platform. diff --git a/proto/openshell.proto b/proto/openshell.proto index 2dd51ba21..4731baa9e 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -305,9 +305,6 @@ message Sandbox { // Desired sandbox configuration provided through the public API. message SandboxSpec { - reserved 10, 11; - reserved "gpu_device", "proposal_approval_mode"; - // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. @@ -318,18 +315,24 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request GPU resources for this sandbox. Presence indicates a GPU request. - GpuRequestSpec gpu = 9; + // Portable resource requirements for this sandbox. + SandboxResourceRequirements resource_requirements = 9; +} + +// Public resource requirements for the sandbox workload. +message SandboxResourceRequirements { + // GPU requirement for the sandbox. Presence indicates a GPU request. + GpuResourceRequirement gpu = 1; } -// Public GPU request details. Device identifiers are interpreted by the -// selected compute driver. -message GpuRequestSpec { - // Optional number of GPUs requested. Mutually exclusive with device_id. +// Public GPU resource requirement. Device identifiers are interpreted by the +// selected compute driver and are an interim compatibility surface. +message GpuResourceRequirement { + // Optional number of GPUs requested. Mutually exclusive with device_ids. optional uint32 count = 1; // Optional driver-native device identifiers. Mutually exclusive with count. // Empty means the driver chooses its default GPU assignment behavior. - repeated string device_id = 2; + repeated string device_ids = 2; } // Public sandbox template mapped onto compute-driver template inputs. From e106f026a73f9ee4adfcc8ed5230cde02061508c Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 3 Jun 2026 15:07:17 +0200 Subject: [PATCH 5/5] feat(gpu): route device selection through driver config Signed-off-by: Evan Lezar --- Cargo.lock | 2 + architecture/compute-runtimes.md | 10 +- crates/openshell-cli/src/main.rs | 30 +- crates/openshell-cli/src/run.rs | 136 ++++++--- .../sandbox_create_lifecycle_integration.rs | 14 +- crates/openshell-core/src/gpu.rs | 156 +++++++++-- crates/openshell-driver-docker/Cargo.toml | 1 + crates/openshell-driver-docker/README.md | 2 +- crates/openshell-driver-docker/src/lib.rs | 71 ++++- crates/openshell-driver-docker/src/tests.rs | 154 ++++++++-- .../openshell-driver-kubernetes/src/driver.rs | 39 +-- crates/openshell-driver-podman/Cargo.toml | 1 + crates/openshell-driver-podman/README.md | 2 +- .../openshell-driver-podman/src/container.rs | 113 +++++++- crates/openshell-driver-podman/src/driver.rs | 86 +++++- crates/openshell-driver-vm/README.md | 3 +- crates/openshell-driver-vm/src/driver.rs | 265 +++++++++++++++--- crates/openshell-server/src/compute/mod.rs | 164 +++++++---- .../openshell-server/src/grpc/validation.rs | 62 ++-- docs/sandboxes/manage-sandboxes.mdx | 15 +- proto/compute_driver.proto | 15 +- proto/openshell.proto | 16 +- 22 files changed, 1081 insertions(+), 276 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4bc657be3..eee476aaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3481,6 +3481,7 @@ dependencies = [ "bytes", "futures", "openshell-core", + "prost-types", "serde", "tar", "temp-env", @@ -3528,6 +3529,7 @@ dependencies = [ "miette", "nix", "openshell-core", + "prost-types", "rustix 1.1.4", "serde", "serde_json", diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index f3accbfee..d1f91156d 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -42,7 +42,11 @@ but currently ignores them. GPU requests enter the driver layer through `SandboxSpec.resource_requirements.gpu`. The compact interim shape supports a -default GPU request, GPU count, and driver-specific device IDs. +default GPU request and GPU count. Exact driver-native device selection is +passed through the selected runtime's `driver_config` block; the gateway +selects that block but does not interpret the nested driver schema. Drivers +that support exact selection validate that the unique `gpu_device_ids` entry +count matches the portable GPU count. VM runtime state paths are derived only from driver-validated sandbox IDs matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a @@ -81,9 +85,7 @@ users. Custom sandbox images must include the agent runtime and any system dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The -runtime still owns GPU device injection. GPU requests can include explicit -driver-native device IDs or a requested count; the gateway validates the public -request shape and each runtime enforces the GPU allocation modes it supports. +runtime still owns GPU device injection. ## Deployment Shape diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 686103971..490a4cd2c 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1215,8 +1215,8 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. - /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu", conflicts_with = "gpu_count")] + /// When omitted with --gpu, the driver uses its default GPU selection. + #[arg(long, conflicts_with = "gpu_count")] gpu_device: Option, /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. @@ -4320,6 +4320,32 @@ mod tests { ); } + #[test] + fn sandbox_create_gpu_device_parses_without_gpu_flag() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu-device", + "nvidia.com/gpu=0", + ]) + .expect("sandbox create --gpu-device should parse without --gpu"); + + match cli.command { + Some(Commands::Sandbox { + command: + Some(SandboxCommands::Create { + gpu, gpu_device, .. + }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_device.as_deref(), Some("nvidia.com/gpu=0")); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + #[test] fn sandbox_create_gpu_count_conflicts_with_gpu_device() { let result = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 038c821e3..3e3ce1e08 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1734,8 +1734,10 @@ pub async fn sandbox_create( } None => None, }; + let gpu_device_ids = gpu_device_ids_from_cli(gpu_device); + let effective_gpu_count = gpu_count_from_cli(gpu_count, &gpu_device_ids); let requested_gpu = - gpu || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); + gpu || effective_gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { @@ -1753,11 +1755,13 @@ pub async fn sandbox_create( let policy = load_sandbox_policy(policy)?; let resource_limits = build_sandbox_resource_limits(cpu, memory)?; + let driver_config = gpu_driver_config_from_cli(&gpu_device_ids); - let template = if image.is_some() || resource_limits.is_some() { + let template = if image.is_some() || resource_limits.is_some() || driver_config.is_some() { Some(SandboxTemplate { image: image.unwrap_or_default(), resources: resource_limits, + driver_config, ..SandboxTemplate::default() }) } else { @@ -1768,8 +1772,7 @@ pub async fn sandbox_create( spec: Some(SandboxSpec { resource_requirements: resource_requirements_from_cli( requested_gpu, - gpu_device, - gpu_count, + effective_gpu_count, ), policy, providers: configured_providers, @@ -2197,17 +2200,69 @@ pub async fn sandbox_create( fn resource_requirements_from_cli( requested_gpu: bool, - gpu_device: Option<&str>, gpu_count: Option, ) -> Option { - requested_gpu.then(|| SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: gpu_device - .filter(|device_id| !device_id.is_empty()) - .map(|device_id| vec![device_id.to_string()]) - .unwrap_or_default(), - count: gpu_count, - }), + requested_gpu.then_some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: gpu_count }), + }) +} + +fn gpu_device_ids_from_cli(gpu_device: Option<&str>) -> Vec { + gpu_device + .map(str::trim) + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default() +} + +fn gpu_count_from_cli(gpu_count: Option, gpu_device_ids: &[String]) -> Option { + if gpu_device_ids.is_empty() { + gpu_count + } else { + u32::try_from(gpu_device_ids.len()).ok() + } +} + +fn gpu_driver_config_from_cli(gpu_device_ids: &[String]) -> Option { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + fn string_value(value: &str) -> Value { + Value { + kind: Some(Kind::StringValue(value.to_string())), + } + } + + fn driver_block(gpu_device_ids: &[String]) -> Value { + Value { + kind: Some(Kind::StructValue(Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: gpu_device_ids + .iter() + .map(|device_id| string_value(device_id)) + .collect(), + })), + }, + )) + .collect(), + })), + } + } + + if gpu_device_ids.is_empty() { + return None; + } + + Some(Struct { + fields: [ + ("docker".to_string(), driver_block(gpu_device_ids)), + ("podman".to_string(), driver_block(gpu_device_ids)), + ("vm".to_string(), driver_block(gpu_device_ids)), + ] + .into_iter() + .collect(), }) } @@ -7460,7 +7515,8 @@ mod tests { dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header, format_gateway_select_items, format_provider_attachment_table, gateway_add, gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, - git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, + git_sync_files, gpu_count_from_cli, gpu_device_ids_from_cli, gpu_driver_config_from_cli, + http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata, @@ -7946,49 +8002,65 @@ mod tests { } } + #[test] + fn gpu_device_ids_from_cli_trims_gpu_device() { + assert_eq!( + gpu_device_ids_from_cli(Some(" nvidia.com/gpu=0 ")), + vec!["nvidia.com/gpu=0".to_string()] + ); + } + + #[test] + fn gpu_device_ids_from_cli_omits_empty_device() { + assert!(gpu_device_ids_from_cli(Some(" ")).is_empty()); + assert!(gpu_device_ids_from_cli(None).is_empty()); + } + + #[test] + fn gpu_count_from_cli_uses_gpu_device_id_count() { + let device_ids = gpu_device_ids_from_cli(Some("nvidia.com/gpu=0")); + + assert_eq!(gpu_count_from_cli(None, &device_ids), Some(1)); + assert_eq!(gpu_count_from_cli(Some(2), &device_ids), Some(1)); + } + #[test] fn resource_requirements_from_cli_uses_presence_for_default_gpu() { - let requirements = resource_requirements_from_cli(true, None, None) + let requirements = resource_requirements_from_cli(true, None) .expect("resource requirements should be present"); let gpu = requirements.gpu.expect("GPU requirement should be present"); - assert!(gpu.device_ids.is_empty()); assert_eq!(gpu.count, None); } #[test] - fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id() { - let requirements = resource_requirements_from_cli(true, Some("0000:2d:00.0"), None) - .expect("resource requirements should be present"); - let gpu = requirements.gpu.expect("GPU requirement should be present"); + fn gpu_driver_config_from_cli_maps_gpu_device_to_driver_blocks() { + let device_ids = gpu_device_ids_from_cli(Some("nvidia.com/gpu=0")); + let config = + gpu_driver_config_from_cli(&device_ids).expect("driver config should be present"); - assert_eq!(gpu.device_ids, vec!["0000:2d:00.0"]); - assert_eq!(gpu.count, None); + assert!(config.fields.contains_key("docker")); + assert!(config.fields.contains_key("podman")); + assert!(config.fields.contains_key("vm")); } #[test] fn resource_requirements_from_cli_maps_gpu_count() { let requirements = - resource_requirements_from_cli(true, None, Some(2)).expect("requirements should exist"); + resource_requirements_from_cli(true, Some(2)).expect("requirements should exist"); let gpu = requirements.gpu.expect("GPU requirement should be present"); - assert!(gpu.device_ids.is_empty()); assert_eq!(gpu.count, Some(2)); } #[test] - fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation() { - let requirements = resource_requirements_from_cli(true, Some("nvidia.com/gpu=0"), Some(2)) - .expect("requirements should exist"); - let gpu = requirements.gpu.expect("GPU requirement should be present"); - - assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]); - assert_eq!(gpu.count, Some(2)); + fn gpu_driver_config_from_cli_omits_empty_device() { + assert!(gpu_driver_config_from_cli(&[]).is_none()); } #[test] fn resource_requirements_from_cli_omits_gpu_request_when_not_requested() { - assert!(resource_requirements_from_cli(false, Some("0"), None).is_none()); + assert!(resource_requirements_from_cli(false, None).is_none()); } #[test] diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 0829a2512..37a5a682c 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -922,15 +922,23 @@ async fn sandbox_create_sends_gpu_count_request() { .expect("sandbox create should succeed"); let requests = create_requests(&server).await; - let gpu = requests[0] + let spec = requests[0] .spec .as_ref() - .and_then(|spec| spec.resource_requirements.as_ref()) + .expect("sandbox spec should be sent"); + let gpu = spec + .resource_requirements + .as_ref() .and_then(|requirements| requirements.gpu.as_ref()) .expect("GPU request should be sent"); - assert!(gpu.device_ids.is_empty()); assert_eq!(gpu.count, Some(2)); + assert!( + spec.template + .as_ref() + .and_then(|template| template.driver_config.as_ref()) + .is_none() + ); } #[tokio::test] diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 07e0fcc78..9f5e24adf 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -5,71 +5,179 @@ use crate::config::CDI_GPU_DEVICE_ALL; use crate::proto::compute::v1::DriverGpuResourceRequirement; +use std::collections::HashSet; /// Resolve a driver GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. Presence with no explicit device IDs -/// uses the CDI all-GPU request; otherwise the driver-native IDs pass through. +/// `None` means no GPU was requested. Presence with a positive count and +/// explicit device IDs passes those IDs through. Other present GPU requests use +/// the CDI all-GPU request. #[must_use] -pub fn cdi_gpu_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> Option> { +pub fn cdi_gpu_device_ids( + gpu: Option<&DriverGpuResourceRequirement>, + driver_config_device_ids: &[String], +) -> Option> { match gpu { - Some(gpu) if gpu.device_ids.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), - Some(gpu) => Some(gpu.device_ids.clone()), + Some(gpu) + if gpu.count.is_some_and(|count| count > 0) && !driver_config_device_ids.is_empty() => + { + Some(driver_config_device_ids.to_vec()) + } + Some(_) if driver_config_device_ids.is_empty() => { + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + } + Some(_) => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), None => None, } } +/// Validate that explicit driver GPU device IDs line up with the portable GPU count. +pub fn validate_gpu_device_ids_count( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], +) -> Result<(), String> { + if gpu_device_ids.is_empty() { + return Ok(()); + } + + let Some(count) = gpu.and_then(|gpu| gpu.count) else { + return Err( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count" + .to_string(), + ); + }; + if count == 0 { + return Err("resource_requirements.gpu.count must be greater than 0".to_string()); + } + + let unique = gpu_device_ids.iter().collect::>().len(); + if unique != gpu_device_ids.len() { + return Err( + "template.driver_config.gpu_device_ids must not contain duplicates".to_string(), + ); + } + if unique != count as usize { + return Err( + "template.driver_config.gpu_device_ids unique entry count must equal resource_requirements.gpu.count" + .to_string(), + ); + } + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(None), None); + assert_eq!(cdi_gpu_device_ids(None, &[]), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { - let request = DriverGpuResourceRequirement { - device_ids: vec![], - count: None, - }; + let request = DriverGpuResourceRequirement { count: None }; assert_eq!( - cdi_gpu_device_ids(Some(&request)), + cdi_gpu_device_ids(Some(&request), &[]), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] fn cdi_gpu_device_ids_passes_single_device_id_through() { - let request = DriverGpuResourceRequirement { - device_ids: vec!["nvidia.com/gpu=0".to_string()], - count: None, - }; + let request = DriverGpuResourceRequirement { count: Some(1) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; assert_eq!( - cdi_gpu_device_ids(Some(&request)), + cdi_gpu_device_ids(Some(&request), &device_ids), Some(vec!["nvidia.com/gpu=0".to_string()]) ); } #[test] fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { - let request = DriverGpuResourceRequirement { - device_ids: vec![ - "nvidia.com/gpu=0".to_string(), - "nvidia.com/gpu=1".to_string(), - ], - count: None, - }; + let request = DriverGpuResourceRequirement { count: Some(2) }; + let device_ids = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; assert_eq!( - cdi_gpu_device_ids(Some(&request)), + cdi_gpu_device_ids(Some(&request), &device_ids), Some(vec![ "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string() ]) ); } + + #[test] + fn cdi_gpu_device_ids_ignores_device_ids_without_count() { + let request = DriverGpuResourceRequirement { count: None }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert_eq!( + cdi_gpu_device_ids(Some(&request), &device_ids), + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + ); + } + + #[test] + fn cdi_gpu_device_ids_ignores_device_ids_with_zero_count() { + let request = DriverGpuResourceRequirement { count: Some(0) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert_eq!( + cdi_gpu_device_ids(Some(&request), &device_ids), + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + ); + } + + #[test] + fn validate_gpu_device_ids_count_requires_gpu_count() { + let request = DriverGpuResourceRequirement { count: None }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } + + #[test] + fn validate_gpu_device_ids_count_rejects_zero_count() { + let request = DriverGpuResourceRequirement { count: Some(0) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } + + #[test] + fn validate_gpu_device_ids_count_accepts_matching_unique_ids() { + let request = DriverGpuResourceRequirement { count: Some(2) }; + let device_ids = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + validate_gpu_device_ids_count(Some(&request), &device_ids).unwrap(); + } + + #[test] + fn validate_gpu_device_ids_count_rejects_duplicate_ids() { + let request = DriverGpuResourceRequirement { count: Some(1) }; + let device_ids = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=0".to_string(), + ]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } + + #[test] + fn validate_gpu_device_ids_count_rejects_count_mismatch() { + let request = DriverGpuResourceRequirement { count: Some(2) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } } diff --git a/crates/openshell-driver-docker/Cargo.toml b/crates/openshell-driver-docker/Cargo.toml index 0cdc205ed..1f5eabaca 100644 --- a/crates/openshell-driver-docker/Cargo.toml +++ b/crates/openshell-driver-docker/Cargo.toml @@ -19,6 +19,7 @@ futures = { workspace = true } tokio-stream = { workspace = true } tracing = { workspace = true } bytes = { workspace = true } +prost-types = { workspace = true } serde = { workspace = true } bollard = { version = "0.20" } tar = "0.4" diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index b44c7056f..c20658d07 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses explicit `resource_requirements.gpu.device_ids` when set; otherwise requests all NVIDIA GPUs when `resource_requirements.gpu` is present and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | +| CDI GPU request | Uses explicit `template.driver_config.gpu_device_ids` when set and its unique entry count equals `resource_requirements.gpu.count`; otherwise requests all NVIDIA GPUs when `resource_requirements.gpu` is present and daemon CDI support is detected. Count-only GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index e9b51c790..ff5159491 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -25,7 +25,7 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, }; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{cdi_gpu_device_ids, validate_gpu_device_ids_count}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -375,7 +375,14 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(driver_gpu_requirement(spec), config.supports_gpu)?; + let gpu_device_ids = + docker_gpu_device_ids_from_driver_config(template.driver_config.as_ref()) + .map_err(Status::invalid_argument)?; + Self::validate_gpu_request( + driver_gpu_requirement(spec), + &gpu_device_ids, + config.supports_gpu, + )?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -411,9 +418,24 @@ impl DockerComputeDriver { fn validate_gpu_request( gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], supports_gpu: bool, ) -> Result<(), Status> { - if gpu.is_some_and(|gpu| gpu.count.is_some()) { + if gpu.is_none() && !gpu_device_ids.is_empty() { + return Err(Status::invalid_argument( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count", + )); + } + if let Some(gpu) = gpu + && gpu.count == Some(0) + { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); + } + if !gpu_device_ids.is_empty() { + validate_gpu_device_ids_count(gpu, gpu_device_ids).map_err(Status::invalid_argument)?; + } else if gpu.is_some_and(|gpu| gpu.count.is_some()) { return Err(Status::invalid_argument( "docker compute driver does not support GPU count requests", )); @@ -1729,10 +1751,44 @@ fn driver_gpu_requirement( .and_then(|requirements| requirements.gpu.as_ref()) } +fn docker_gpu_device_ids_from_driver_config( + driver_config: Option<&prost_types::Struct>, +) -> Result, String> { + use prost_types::value::Kind; + + let Some(config) = driver_config else { + return Ok(Vec::new()); + }; + if config.fields.is_empty() { + return Ok(Vec::new()); + } + + let Some(value) = config.fields.get("gpu_device_ids") else { + return Ok(Vec::new()); + }; + let Some(Kind::ListValue(list)) = value.kind.as_ref() else { + return Err("driver_config.gpu_device_ids must be a list of strings".to_string()); + }; + + list.values + .iter() + .enumerate() + .map(|(idx, value)| match value.kind.as_ref() { + Some(Kind::StringValue(device_id)) if !device_id.trim().is_empty() => { + Ok(device_id.clone()) + } + _ => Err(format!( + "driver_config.gpu_device_ids[{idx}] must be a non-empty string" + )), + }) + .collect() +} + fn docker_gpu_device_requests( gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], ) -> Option> { - cdi_gpu_device_ids(gpu).map(|device_ids| { + cdi_gpu_device_ids(gpu, gpu_device_ids).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), device_ids: Some(device_ids), @@ -1754,6 +1810,8 @@ fn build_container_create_body( .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox.spec.template is required"))?; let resource_limits = docker_resource_limits(template)?; + let gpu_device_ids = docker_gpu_device_ids_from_driver_config(template.driver_config.as_ref()) + .map_err(Status::invalid_argument)?; let mut labels = template.labels.clone(); labels.insert( LABEL_MANAGED_BY.to_string(), @@ -1783,7 +1841,10 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, pids_limit: docker_pids_limit(config.sandbox_pids_limit)?, - device_requests: docker_gpu_device_requests(driver_gpu_requirement(spec)), + device_requests: docker_gpu_device_requests( + driver_gpu_requirement(spec), + &gpu_device_ids, + ), binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 308605fae..449b9476f 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -24,15 +24,66 @@ use tempfile::TempDir; const TLS_MOUNT_DIR: &str = "/etc/openshell/tls/client"; static ENV_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); -fn gpu_resource_requirements( - device_ids: Vec, - count: Option, -) -> DriverSandboxResourceRequirements { +fn gpu_resource_requirements(count: Option) -> DriverSandboxResourceRequirements { DriverSandboxResourceRequirements { - gpu: Some(DriverGpuResourceRequirement { device_ids, count }), + gpu: Some(DriverGpuResourceRequirement { count }), } } +fn gpu_device_ids_driver_config(device_ids: &[&str]) -> prost_types::Struct { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: device_ids + .iter() + .map(|device_id| Value { + kind: Some(Kind::StringValue((*device_id).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), + } +} + +#[test] +fn docker_gpu_device_ids_from_driver_config_ignores_unrelated_fields() { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + let config = Struct { + fields: [ + ( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: vec![Value { + kind: Some(Kind::StringValue("nvidia.com/gpu=0".to_string())), + }], + })), + }, + ), + ( + "future_field".to_string(), + Value { + kind: Some(Kind::StringValue("ignored".to_string())), + }, + ), + ] + .into_iter() + .collect(), + }; + + assert_eq!( + docker_gpu_device_ids_from_driver_config(Some(&config)).unwrap(), + vec!["nvidia.com/gpu=0".to_string()] + ); +} + fn test_sandbox() -> DriverSandbox { // Mirrors the gateway-supplied request: the public `Sandbox` API no // longer carries `namespace`, so the gateway elides the field and the @@ -51,6 +102,7 @@ fn test_sandbox() -> DriverSandbox { environment: HashMap::from([("TEMPLATE_ENV".to_string(), "template".to_string())]), resources: None, platform_config: None, + driver_config: None, }), sandbox_token: String::new(), resource_requirements: None, @@ -401,6 +453,7 @@ fn docker_resource_limits_rejects_requests() { memory_limit: String::new(), }), platform_config: None, + driver_config: None, }; let err = docker_resource_limits(&template).unwrap_err(); @@ -421,6 +474,7 @@ fn docker_resource_limits_applies_cpu_and_memory_limits() { ..Default::default() }), platform_config: None, + driver_config: None, }; let limits = docker_resource_limits(&template).unwrap(); @@ -614,8 +668,7 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().resource_requirements = - Some(gpu_resource_requirements(vec![], None)); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements(None)); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -628,8 +681,7 @@ fn validate_sandbox_rejects_gpu_count() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().resource_requirements = - Some(gpu_resource_requirements(vec![], Some(2))); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements(Some(2))); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -637,6 +689,72 @@ fn validate_sandbox_rejects_gpu_count() { assert!(err.message().contains("does not support GPU count")); } +#[test] +fn validate_sandbox_accepts_gpu_count_with_matching_device_ids() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(gpu_device_ids_driver_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])); + + DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap(); +} + +#[test] +fn validate_sandbox_rejects_gpu_device_ids_without_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(None)); + spec.template.as_mut().unwrap().driver_config = + Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("requires resource_requirements.gpu.count") + ); +} + +#[test] +fn validate_sandbox_rejects_gpu_device_ids_with_zero_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(0))); + spec.template.as_mut().unwrap().driver_config = + Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("must be greater than 0")); +} + +#[test] +fn validate_sandbox_rejects_gpu_device_ids_count_mismatch() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(2))); + spec.template.as_mut().unwrap().driver_config = + Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("unique entry count")); +} + #[test] fn validate_sandbox_auth_requires_gateway_token() { let mut sandbox = test_sandbox(); @@ -664,8 +782,7 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().resource_requirements = - Some(gpu_resource_requirements(vec![], None)); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements(None)); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -683,17 +800,16 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { } #[test] -fn build_container_create_body_passes_explicit_cdi_device_ids_through() { +fn build_container_create_body_passes_explicit_gpu_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements( - vec![ - "nvidia.com/gpu=0".to_string(), - "nvidia.com/gpu=1".to_string(), - ], - None, - )); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(gpu_device_ids_driver_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index c1e2c2a74..87e658fe8 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -86,10 +86,6 @@ fn driver_gpu_requirement(spec: &SandboxSpec) -> Option<&DriverGpuResourceRequir .and_then(|requirements| requirements.gpu.as_ref()) } -fn gpu_has_explicit_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> bool { - gpu.is_some_and(|gpu| !gpu.device_ids.is_empty()) -} - // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) // --------------------------------------------------------------------------- @@ -222,11 +218,6 @@ impl KubernetesComputeDriver { &self, gpu: Option<&DriverGpuResourceRequirement>, ) -> Result<(), tonic::Status> { - if gpu_has_explicit_device_ids(gpu) { - return Err(tonic::Status::invalid_argument( - "kubernetes compute driver does not support explicit GPU device IDs", - )); - } if gpu.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) @@ -319,14 +310,6 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { - if let Some(gpu) = sandbox.spec.as_ref().and_then(driver_gpu_requirement) - && gpu_has_explicit_device_ids(Some(gpu)) - { - return Err(KubernetesDriverError::Precondition( - "kubernetes compute driver does not support explicit GPU device IDs".to_string(), - )); - } - let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -1724,10 +1707,7 @@ mod tests { std::sync::LazyLock::new(|| std::sync::Mutex::new(())); fn gpu_request(count: Option) -> DriverGpuResourceRequirement { - DriverGpuResourceRequirement { - device_ids: vec![], - count, - } + DriverGpuResourceRequirement { count } } #[test] @@ -2077,23 +2057,6 @@ mod tests { ); } - #[test] - fn gpu_has_explicit_device_ids_only_when_ids_are_present() { - assert!(!gpu_has_explicit_device_ids(None)); - assert!(!gpu_has_explicit_device_ids(Some( - &DriverGpuResourceRequirement { - device_ids: vec![], - count: None, - } - ))); - assert!(gpu_has_explicit_device_ids(Some( - &DriverGpuResourceRequirement { - device_ids: vec!["nvidia.com/gpu=0".to_string()], - count: None, - } - ))); - } - #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { diff --git a/crates/openshell-driver-podman/Cargo.toml b/crates/openshell-driver-podman/Cargo.toml index 6f2963d92..0ccff99f6 100644 --- a/crates/openshell-driver-podman/Cargo.toml +++ b/crates/openshell-driver-podman/Cargo.toml @@ -24,6 +24,7 @@ tokio-stream = { workspace = true } hyper = { workspace = true } hyper-util = { workspace = true } http-body-util = { workspace = true } +prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } clap = { workspace = true } diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 7bca6e653..7bd8e58e2 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | Explicit `resource_requirements.gpu.device_ids` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-based GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | +| CDI GPU devices | Explicit `template.driver_config.gpu_device_ids` when set and its unique entry count equals `resource_requirements.gpu.count`; otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-only GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 1f5691872..af47082af 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -379,12 +379,19 @@ fn podman_pids_limit(value: i64) -> Option { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let gpu = sandbox.spec.as_ref().and_then(|spec| { + let spec = sandbox.spec.as_ref(); + let gpu = spec.and_then(|spec| { spec.resource_requirements .as_ref() .and_then(|requirements| requirements.gpu.as_ref()) }); - cdi_gpu_device_ids(gpu).map(|device_ids| { + let gpu_device_ids = spec + .and_then(|spec| spec.template.as_ref()) + .and_then(|template| { + gpu_device_ids_from_driver_config(template.driver_config.as_ref()).ok() + }) + .unwrap_or_default(); + cdi_gpu_device_ids(gpu, &gpu_device_ids).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -392,6 +399,39 @@ fn build_devices(sandbox: &DriverSandbox) -> Option> { }) } +pub fn gpu_device_ids_from_driver_config( + driver_config: Option<&prost_types::Struct>, +) -> Result, String> { + use prost_types::value::Kind; + + let Some(config) = driver_config else { + return Ok(Vec::new()); + }; + if config.fields.is_empty() { + return Ok(Vec::new()); + } + + let Some(value) = config.fields.get("gpu_device_ids") else { + return Ok(Vec::new()); + }; + let Some(Kind::ListValue(list)) = value.kind.as_ref() else { + return Err("driver_config.gpu_device_ids must be a list of strings".to_string()); + }; + + list.values + .iter() + .enumerate() + .map(|(idx, value)| match value.kind.as_ref() { + Some(Kind::StringValue(device_id)) if !device_id.trim().is_empty() => { + Ok(device_id.clone()) + } + _ => Err(format!( + "driver_config.gpu_device_ids[{idx}] must be a non-empty string" + )), + }) + .collect() +} + /// Build the Podman container creation JSON spec. #[cfg(test)] #[must_use] @@ -703,6 +743,60 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_device_ids_driver_config(device_ids: &[&str]) -> prost_types::Struct { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: device_ids + .iter() + .map(|device_id| Value { + kind: Some(Kind::StringValue((*device_id).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), + } + } + + #[test] + fn gpu_device_ids_from_driver_config_ignores_unrelated_fields() { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + let config = Struct { + fields: [ + ( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: vec![Value { + kind: Some(Kind::StringValue("nvidia.com/gpu=0".to_string())), + }], + })), + }, + ), + ( + "future_field".to_string(), + Value { + kind: Some(Kind::StringValue("ignored".to_string())), + }, + ), + ] + .into_iter() + .collect(), + }; + + assert_eq!( + gpu_device_ids_from_driver_config(Some(&config)).unwrap(), + vec!["nvidia.com/gpu=0".to_string()] + ); + } + #[test] fn parse_cpu_millicore() { assert_eq!(parse_cpu_to_microseconds("500m"), Some(50_000)); @@ -819,10 +913,7 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { resource_requirements: Some(DriverSandboxResourceRequirements { - gpu: Some(DriverGpuResourceRequirement { - device_ids: vec![], - count: None, - }), + gpu: Some(DriverGpuResourceRequirement { count: None }), }), ..Default::default() }); @@ -839,15 +930,17 @@ mod tests { fn container_spec_passes_explicit_cdi_device_id_through() { use openshell_core::proto::compute::v1::{ DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec, + DriverSandboxTemplate, }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { + template: Some(DriverSandboxTemplate { + driver_config: Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), resource_requirements: Some(DriverSandboxResourceRequirements { - gpu: Some(DriverGpuResourceRequirement { - device_ids: vec!["nvidia.com/gpu=0".to_string()], - count: None, - }), + gpu: Some(DriverGpuResourceRequirement { count: Some(1) }), }), ..Default::default() }); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 6f93bbff7..1f6dcf3e5 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,6 +10,7 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; +use openshell_core::gpu::validate_gpu_device_ids_count; use openshell_core::proto::compute::v1::{ DriverGpuResourceRequirement, DriverSandbox, GetCapabilitiesResponse, }; @@ -282,14 +283,40 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); - Self::validate_gpu_request(gpu) + let spec = sandbox.spec.as_ref(); + let gpu = spec.and_then(driver_gpu_requirement); + let gpu_device_ids = spec + .and_then(|spec| spec.template.as_ref()) + .map(|template| { + container::gpu_device_ids_from_driver_config(template.driver_config.as_ref()) + .map_err(ComputeDriverError::Precondition) + }) + .transpose()? + .unwrap_or_default(); + Self::validate_gpu_request(gpu, &gpu_device_ids) } fn validate_gpu_request( gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], ) -> Result<(), ComputeDriverError> { - if gpu.is_some_and(|gpu| gpu.count.is_some()) { + if gpu.is_none() && !gpu_device_ids.is_empty() { + return Err(ComputeDriverError::Precondition( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count" + .to_string(), + )); + } + if let Some(gpu) = gpu + && gpu.count == Some(0) + { + return Err(ComputeDriverError::Precondition( + "resource_requirements.gpu.count must be greater than 0".to_string(), + )); + } + if !gpu_device_ids.is_empty() { + validate_gpu_device_ids_count(gpu, gpu_device_ids) + .map_err(ComputeDriverError::Precondition)?; + } else if gpu.is_some_and(|gpu| gpu.count.is_some()) { return Err(ComputeDriverError::Precondition( "podman compute driver does not support GPU count requests".to_string(), )); @@ -687,10 +714,10 @@ mod tests { #[test] fn validate_gpu_request_rejects_count() { - let err = PodmanComputeDriver::validate_gpu_request(Some(&DriverGpuResourceRequirement { - device_ids: vec![], - count: Some(2), - })) + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: Some(2) }), + &[], + ) .expect_err("GPU count should be rejected"); assert!( @@ -698,6 +725,51 @@ mod tests { ); } + #[test] + fn validate_gpu_request_rejects_device_ids_without_count() { + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: None }), + &device_ids, + ) + .expect_err("device IDs without count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("requires resource_requirements.gpu.count")) + ); + } + + #[test] + fn validate_gpu_request_rejects_device_ids_with_zero_count() { + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: Some(0) }), + &device_ids, + ) + .expect_err("device IDs with zero count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("must be greater than 0")) + ); + } + + #[test] + fn validate_gpu_request_rejects_device_id_count_mismatch() { + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: Some(2) }), + &device_ids, + ) + .expect_err("device ID count mismatch should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("unique entry count")) + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index c5860f9cd..d5b0982c3 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -54,7 +54,8 @@ sudo -E env "PATH=$PATH" mise run gateway:vm -- --gpu GPU passthrough uses VFIO and requires host support for IOMMU, root privileges for bind/unbind operations, and a compatible sandbox image. Sandbox GPU requests arrive as `resource_requirements.gpu`; the VM driver accepts the default request, -one explicit device ID, or a count of one. +one driver-configured `gpu_device_ids` entry with a matching count of one, or a +count of one. Point the CLI at the gateway with one of: diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index c60e525ed..ad52b7101 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -24,6 +24,7 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; +use openshell_core::gpu::validate_gpu_device_ids_count; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, @@ -619,10 +620,11 @@ impl VmDriver { let gpu_device = sandbox .spec .as_ref() - .and_then(driver_gpu_requirement) - .and_then(|gpu| requested_gpu_device(Some(gpu))); + .map(|spec| requested_gpu_device(driver_gpu_requirement(spec), spec.template.as_ref())) + .transpose()? + .flatten(); let gpu_bdf = if let Some(gpu_device) = gpu_device { - Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) + Some(self.assign_gpu_to_record(&sandbox.id, &gpu_device).await?) } else { None }; @@ -2580,7 +2582,11 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - validate_gpu_request(driver_gpu_requirement(spec), gpu_enabled)?; + validate_gpu_request( + driver_gpu_requirement(spec), + spec.template.as_ref(), + gpu_enabled, + )?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -2631,14 +2637,31 @@ fn driver_gpu_requirement( .and_then(|requirements| requirements.gpu.as_ref()) } -fn requested_gpu_device(gpu: Option<&DriverGpuResourceRequirement>) -> Option<&str> { - let gpu = gpu?; - Some(gpu.device_ids.first().map_or("", String::as_str)) +#[allow(clippy::result_large_err)] +fn requested_gpu_device( + gpu: Option<&DriverGpuResourceRequirement>, + template: Option<&openshell_core::proto::compute::v1::DriverSandboxTemplate>, +) -> Result, Status> { + let Some(_) = gpu else { + return Ok(None); + }; + let should_use_device_ids = gpu.is_some_and(|gpu| gpu.count.is_some_and(|count| count > 0)); + let configured = template + .and_then(|template| template.driver_config.as_ref()) + .map(vm_gpu_device_ids_from_driver_config) + .transpose()? + .unwrap_or_default(); + if should_use_device_ids { + Ok(Some(configured.first().cloned().unwrap_or_default())) + } else { + Ok(Some(String::new())) + } } #[allow(clippy::result_large_err)] fn validate_gpu_request( gpu: Option<&DriverGpuResourceRequirement>, + template: Option<&openshell_core::proto::compute::v1::DriverSandboxTemplate>, gpu_enabled: bool, ) -> Result<(), Status> { if gpu.is_some() && !gpu_enabled { @@ -2647,13 +2670,32 @@ fn validate_gpu_request( )); } + let gpu_device_ids = template + .and_then(|template| template.driver_config.as_ref()) + .map(vm_gpu_device_ids_from_driver_config) + .transpose()? + .unwrap_or_default(); + if gpu.is_none() && !gpu_device_ids.is_empty() { + return Err(Status::invalid_argument( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count", + )); + } + if let Some(gpu) = gpu + && gpu.count == Some(0) + { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); + } + if !gpu_device_ids.is_empty() { + validate_gpu_device_ids_count(gpu, &gpu_device_ids).map_err(Status::invalid_argument)?; + } if gpu.is_some_and(|gpu| gpu.count.is_some_and(|count| count > 1)) { return Err(Status::invalid_argument( "vm compute driver supports at most one GPU", )); } - - if gpu.is_some_and(|gpu| gpu.device_ids.len() > 1) { + if gpu_device_ids.len() > 1 { return Err(Status::invalid_argument( "vm compute driver supports at most one GPU device ID", )); @@ -2661,6 +2703,38 @@ fn validate_gpu_request( Ok(()) } +#[allow(clippy::result_large_err)] +fn vm_gpu_device_ids_from_driver_config( + driver_config: &prost_types::Struct, +) -> Result, Status> { + use prost_types::value::Kind; + + if driver_config.fields.is_empty() { + return Ok(Vec::new()); + } + let Some(value) = driver_config.fields.get("gpu_device_ids") else { + return Ok(Vec::new()); + }; + let Some(Kind::ListValue(list)) = value.kind.as_ref() else { + return Err(Status::invalid_argument( + "driver_config.gpu_device_ids must be a list of strings", + )); + }; + + list.values + .iter() + .enumerate() + .map(|(idx, value)| match value.kind.as_ref() { + Some(Kind::StringValue(gpu_device)) if !gpu_device.trim().is_empty() => { + Ok(gpu_device.clone()) + } + _ => Err(Status::invalid_argument(format!( + "driver_config.gpu_device_ids[{idx}] must be a non-empty string" + ))), + }) + .collect() +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { @@ -4458,12 +4532,28 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); - fn gpu_resource_requirements( - device_ids: Vec, - count: Option, - ) -> DriverSandboxResourceRequirements { + fn gpu_resource_requirements(count: Option) -> DriverSandboxResourceRequirements { DriverSandboxResourceRequirements { - gpu: Some(DriverGpuResourceRequirement { device_ids, count }), + gpu: Some(DriverGpuResourceRequirement { count }), + } + } + + fn vm_gpu_device_ids_config(gpu_devices: &[&str]) -> Struct { + Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(prost_types::ListValue { + values: gpu_devices + .iter() + .map(|gpu_device| Value { + kind: Some(Kind::StringValue((*gpu_device).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), } } @@ -4534,7 +4624,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - resource_requirements: Some(gpu_resource_requirements(vec![], None)), + resource_requirements: Some(gpu_resource_requirements(None)), ..Default::default() }), ..Default::default() @@ -4550,7 +4640,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - resource_requirements: Some(gpu_resource_requirements(vec![], None)), + resource_requirements: Some(gpu_resource_requirements(None)), ..Default::default() }), ..Default::default() @@ -4563,7 +4653,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - resource_requirements: Some(gpu_resource_requirements(vec![], Some(1))), + resource_requirements: Some(gpu_resource_requirements(Some(1))), ..Default::default() }), ..Default::default() @@ -4576,7 +4666,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - resource_requirements: Some(gpu_resource_requirements(vec![], Some(2))), + resource_requirements: Some(gpu_resource_requirements(Some(2))), ..Default::default() }), ..Default::default() @@ -4588,14 +4678,99 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { + fn validate_vm_sandbox_accepts_gpu_count_with_matching_driver_config_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(1))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + validate_vm_sandbox(&sandbox, true) + .expect("gpu count with matching explicit device should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_driver_config_gpu_device_ids_without_gpu_request() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("gpu device without gpu request should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message() + .contains("requires resource_requirements.gpu.count") + ); + } + + #[test] + fn validate_vm_sandbox_rejects_driver_config_gpu_device_id_count_mismatch() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - resource_requirements: Some(gpu_resource_requirements( - vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], - None, - )), + resource_requirements: Some(gpu_resource_requirements(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("GPU device ID count mismatch should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("unique entry count")); + } + + #[test] + fn validate_vm_sandbox_rejects_driver_config_gpu_device_ids_with_zero_count() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(0))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("GPU device IDs with zero count should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("must be greater than 0")); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_driver_config_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&[ + "0000:2d:00.0", + "0000:3d:00.0", + ])), + ..Default::default() + }), ..Default::default() }), ..Default::default() @@ -4603,32 +4778,52 @@ mod tests { let err = validate_vm_sandbox(&sandbox, true) .expect_err("multiple GPU device IDs should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("at most one GPU device ID")); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn vm_gpu_device_ids_from_driver_config_ignores_unrelated_fields() { + let mut config = vm_gpu_device_ids_config(&["0000:2d:00.0"]); + config.fields.insert( + "future_field".to_string(), + Value { + kind: Some(Kind::StringValue("ignored".to_string())), + }, + ); + + assert_eq!( + vm_gpu_device_ids_from_driver_config(&config).unwrap(), + vec!["0000:2d:00.0".to_string()] + ); } #[test] fn requested_gpu_device_returns_none_without_gpu_request() { - assert_eq!(requested_gpu_device(None), None); + assert_eq!(requested_gpu_device(None, None).unwrap(), None); } #[test] fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { - let gpu = DriverGpuResourceRequirement { - device_ids: vec![], - count: None, - }; + let gpu = DriverGpuResourceRequirement { count: None }; - assert_eq!(requested_gpu_device(Some(&gpu)), Some("")); + assert_eq!( + requested_gpu_device(Some(&gpu), None).unwrap(), + Some(String::new()) + ); } #[test] - fn requested_gpu_device_returns_first_explicit_device_id() { - let gpu = DriverGpuResourceRequirement { - device_ids: vec!["0000:2d:00.0".to_string()], - count: None, + fn requested_gpu_device_returns_driver_config_gpu_device_id() { + let gpu = DriverGpuResourceRequirement { count: Some(1) }; + let template = SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() }; - assert_eq!(requested_gpu_device(Some(&gpu)), Some("0000:2d:00.0")); + assert_eq!( + requested_gpu_device(Some(&gpu), Some(&template)).unwrap(), + Some("0000:2d:00.0".to_string()) + ); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 666a7174a..8413914bf 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -426,7 +426,7 @@ impl ComputeRuntime { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { - let driver_sandbox = driver_sandbox_from_public(sandbox); + let driver_sandbox = driver_sandbox_from_public(sandbox, self.driver_kind)?; self.driver .validate_sandbox_create(Request::new(ValidateSandboxCreateRequest { sandbox: Some(driver_sandbox), @@ -470,7 +470,7 @@ impl ComputeRuntime { } })?; - let mut driver_sandbox = driver_sandbox_from_public(&sandbox); + let mut driver_sandbox = driver_sandbox_from_public(&sandbox, self.driver_kind)?; if let Some(token) = sandbox_token && let Some(spec) = driver_sandbox.spec.as_mut() { @@ -552,12 +552,11 @@ impl ComputeRuntime { self.sandbox_watch_bus.notify(&id); self.cleanup_sandbox_owned_records(&sandbox).await; - let driver_sandbox = driver_sandbox_from_public(&sandbox); let deleted = self .driver .delete_sandbox(Request::new(DeleteSandboxRequest { - sandbox_id: driver_sandbox.id, - sandbox_name: driver_sandbox.name, + sandbox_id: sandbox.object_id().to_string(), + sandbox_name: sandbox.object_name().to_string(), })) .await .map(|response| response.into_inner().deleted) @@ -1250,54 +1249,92 @@ impl ComputeRuntime { } } -fn driver_sandbox_from_public(sandbox: &Sandbox) -> DriverSandbox { - DriverSandbox { +#[allow(clippy::result_large_err)] +fn driver_sandbox_from_public( + sandbox: &Sandbox, + driver_kind: Option, +) -> Result { + Ok(DriverSandbox { id: sandbox.object_id().to_string(), name: sandbox.object_name().to_string(), namespace: String::new(), // Namespace is set by the driver based on its config - spec: sandbox.spec.as_ref().map(driver_sandbox_spec_from_public), + spec: sandbox + .spec + .as_ref() + .map(|spec| driver_sandbox_spec_from_public(spec, driver_kind)) + .transpose()?, status: sandbox.status.as_ref().map(driver_status_from_public), - } + }) } -fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { - DriverSandboxSpec { +#[allow(clippy::result_large_err)] +fn driver_sandbox_spec_from_public( + spec: &SandboxSpec, + driver_kind: Option, +) -> Result { + Ok(DriverSandboxSpec { log_level: spec.log_level.clone(), environment: spec.environment.clone(), template: spec .template .as_ref() - .map(driver_sandbox_template_from_public), + .map(|template| driver_sandbox_template_from_public(template, driver_kind)) + .transpose()?, resource_requirements: spec .resource_requirements .as_ref() - .map(driver_resource_requirements_from_public), + .map(|requirements| driver_resource_requirements_from_public(*requirements)), sandbox_token: String::new(), - } + }) } fn driver_resource_requirements_from_public( - requirements: &openshell_core::proto::SandboxResourceRequirements, + requirements: openshell_core::proto::SandboxResourceRequirements, ) -> DriverSandboxResourceRequirements { DriverSandboxResourceRequirements { gpu: requirements .gpu .as_ref() - .map(|gpu| DriverGpuResourceRequirement { - device_ids: gpu.device_ids.clone(), - count: gpu.count, - }), + .map(|gpu| DriverGpuResourceRequirement { count: gpu.count }), } } -fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { - DriverSandboxTemplate { +#[allow(clippy::result_large_err)] +fn driver_sandbox_template_from_public( + template: &SandboxTemplate, + driver_kind: Option, +) -> Result { + Ok(DriverSandboxTemplate { image: template.image.clone(), agent_socket_path: template.agent_socket.clone(), labels: template.labels.clone(), environment: template.environment.clone(), resources: extract_typed_resources(&template.resources), platform_config: build_platform_config(template), + driver_config: select_driver_config(&template.driver_config, driver_kind)?, + }) +} + +#[allow(clippy::result_large_err)] +fn select_driver_config( + driver_config: &Option, + driver_kind: Option, +) -> Result, Status> { + let Some(driver_kind) = driver_kind else { + return Ok(None); + }; + let Some(config) = driver_config.as_ref() else { + return Ok(None); + }; + let Some(value) = config.fields.get(driver_kind.as_str()) else { + return Ok(None); + }; + match value.kind.as_ref() { + Some(prost_types::value::Kind::StructValue(inner)) => Ok(Some(inner.clone())), + _ => Err(Status::invalid_argument(format!( + "template.driver_config.{} must be an object", + driver_kind.as_str() + ))), } } @@ -1825,43 +1862,78 @@ mod tests { } #[test] - fn driver_sandbox_spec_from_public_preserves_gpu_request_device_ids() { + fn driver_sandbox_spec_from_public_selects_matching_driver_config_block() { let public = SandboxSpec { - resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec!["nvidia.com/gpu=0".to_string()], - count: None, + template: Some(SandboxTemplate { + driver_config: Some(prost_types::Struct { + fields: [ + ( + "docker".to_string(), + struct_value([( + "gpu_device_ids", + prost_types::Value { + kind: Some(prost_types::value::Kind::ListValue( + prost_types::ListValue { + values: vec![string_value("nvidia.com/gpu=0")], + }, + )), + }, + )]), + ), + ( + "vm".to_string(), + struct_value([( + "gpu_device_ids", + prost_types::Value { + kind: Some(prost_types::value::Kind::ListValue( + prost_types::ListValue { + values: vec![string_value("0")], + }, + )), + }, + )]), + ), + ] + .into_iter() + .collect(), }), + ..Default::default() }), ..Default::default() }; - let driver = driver_sandbox_spec_from_public(&public); + let driver = + driver_sandbox_spec_from_public(&public, Some(ComputeDriverKind::Docker)).unwrap(); - assert_eq!( - driver - .resource_requirements - .expect("driver resource requirements should be present") - .gpu - .expect("driver GPU requirement should be present") - .device_ids, - vec!["nvidia.com/gpu=0".to_string()] - ); + let config = driver + .template + .expect("driver template should be present") + .driver_config + .expect("driver config should be selected"); + let device_ids = config + .fields + .get("gpu_device_ids") + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::ListValue(list)) => list.values.first(), + _ => None, + }) + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), + _ => None, + }); + assert_eq!(device_ids, Some("nvidia.com/gpu=0")); } #[test] fn driver_sandbox_spec_from_public_preserves_gpu_count() { let public = SandboxSpec { resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec![], - count: Some(2), - }), + gpu: Some(GpuResourceRequirement { count: Some(2) }), }), ..Default::default() }; - let driver = driver_sandbox_spec_from_public(&public); + let driver = driver_sandbox_spec_from_public(&public, None).unwrap(); assert_eq!( driver @@ -2332,10 +2404,7 @@ mod tests { &mut status, Some(&SandboxSpec { resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec![], - count: None, - }), + gpu: Some(GpuResourceRequirement { count: None }), }), ..Default::default() }), @@ -2644,10 +2713,7 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec![], - count: None, - }), + gpu: Some(GpuResourceRequirement { count: None }), }), ..Default::default() }), diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 0f3b3fd7c..d6afec1b6 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -133,7 +133,7 @@ pub(super) fn validate_sandbox_spec( // --- spec.resource_requirements --- if let Some(ref requirements) = spec.resource_requirements { - validate_resource_requirements(requirements)?; + validate_resource_requirements(*requirements)?; } // --- spec.policy serialized size --- @@ -150,22 +150,17 @@ pub(super) fn validate_sandbox_spec( } fn validate_resource_requirements( - requirements: &openshell_core::proto::SandboxResourceRequirements, + requirements: openshell_core::proto::SandboxResourceRequirements, ) -> Result<(), Status> { - if let Some(ref gpu) = requirements.gpu { + if let Some(gpu) = requirements.gpu { validate_gpu_requirement(gpu)?; } Ok(()) } fn validate_gpu_requirement( - gpu: &openshell_core::proto::GpuResourceRequirement, + gpu: openshell_core::proto::GpuResourceRequirement, ) -> Result<(), Status> { - if gpu.count.is_some() && !gpu.device_ids.is_empty() { - return Err(Status::invalid_argument( - "resource_requirements.gpu.count is mutually exclusive with resource_requirements.gpu.device_ids", - )); - } if gpu.count == Some(0) { return Err(Status::invalid_argument( "resource_requirements.gpu.count must be greater than 0", @@ -230,6 +225,14 @@ fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { ))); } } + if let Some(ref s) = tmpl.driver_config { + let size = s.encoded_len(); + if size > MAX_TEMPLATE_STRUCT_SIZE { + return Err(Status::invalid_argument(format!( + "template.driver_config serialized size exceeds maximum ({size} > {MAX_TEMPLATE_STRUCT_SIZE})" + ))); + } + } Ok(()) } @@ -691,7 +694,10 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::{GpuResourceRequirement, SandboxResourceRequirements, SandboxSpec}; + use openshell_core::proto::{ + GpuResourceRequirement, SandboxResourceRequirements, SandboxSpec, SandboxTemplate, + }; + use prost_types::{Struct, Value, value::Kind}; use std::collections::HashMap; use tonic::Code; @@ -718,10 +724,7 @@ mod tests { fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec![], - count: None, - }), + gpu: Some(GpuResourceRequirement { count: None }), }), ..Default::default() }; @@ -732,10 +735,7 @@ mod tests { fn validate_sandbox_spec_accepts_gpu_count() { let spec = SandboxSpec { resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec![], - count: Some(2), - }), + gpu: Some(GpuResourceRequirement { count: Some(2) }), }), ..Default::default() }; @@ -746,10 +746,7 @@ mod tests { fn validate_sandbox_spec_rejects_zero_gpu_count() { let spec = SandboxSpec { resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec![], - count: Some(0), - }), + gpu: Some(GpuResourceRequirement { count: Some(0) }), }), ..Default::default() }; @@ -761,21 +758,24 @@ mod tests { } #[test] - fn validate_sandbox_spec_rejects_gpu_count_with_device_id() { + fn validate_sandbox_spec_accepts_driver_config() { let spec = SandboxSpec { - resource_requirements: Some(SandboxResourceRequirements { - gpu: Some(GpuResourceRequirement { - device_ids: vec!["nvidia.com/gpu=0".to_string()], - count: Some(1), + template: Some(SandboxTemplate { + driver_config: Some(Struct { + fields: std::iter::once(( + "docker".to_string(), + Value { + kind: Some(Kind::StructValue(Struct::default())), + }, + )) + .collect(), }), + ..Default::default() }), ..Default::default() }; - let err = validate_sandbox_spec("gpu-count-sandbox", &spec).unwrap_err(); - - assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("mutually exclusive")); + assert!(validate_sandbox_spec("driver-config-sandbox", &spec).is_ok()); } #[test] diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 212f43f1d..fb530ecfa 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -57,14 +57,25 @@ Request a specific number of GPUs with `--gpu-count`: openshell sandbox create --gpu-count 2 -- claude ``` +Request a specific driver-native device with `--gpu-device`: + +```shell +openshell sandbox create --gpu-device nvidia.com/gpu=0 -- claude +``` + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. Kubernetes gateways honor `--gpu-count` by setting the `nvidia.com/gpu` resource limit. Docker and Podman support explicit CDI device IDs through `--gpu-device` -but do not support count-based selection yet. VM gateways accept only one GPU. -In the API, these flags populate `SandboxSpec.resource_requirements.gpu`. +but do not support count-based selection yet. The CLI sets the portable GPU +count to match the requested device ID. VM gateways accept only one GPU. In the +API, portable GPU presence and count populate +`SandboxSpec.resource_requirements.gpu`. Exact device selection is passed as +driver-owned `template.driver_config.gpu_device_ids`. Drivers that support +exact selection require the number of unique `gpu_device_ids` entries to match +`resource_requirements.gpu.count`. ### Custom Containers diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 79bff06e2..c062d52a4 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -100,14 +100,14 @@ message DriverSandboxResourceRequirements { DriverGpuResourceRequirement gpu = 1; } -// Driver-owned GPU resource requirement. Device identifiers are interpreted by -// the selected compute driver and are an interim compatibility surface. +// Driver-owned GPU resource requirement. message DriverGpuResourceRequirement { - // Optional number of GPUs requested. Mutually exclusive with device_ids. + reserved 2; + reserved "device_ids"; + + // Optional number of GPUs requested. When unset, presence means the driver + // chooses its default GPU assignment behavior. optional uint32 count = 1; - // Optional driver-native device identifiers. Mutually exclusive with count. - // Empty means the driver chooses its default GPU assignment behavior. - repeated string device_ids = 2; } // Driver-owned runtime template consumed by the compute platform. @@ -133,6 +133,9 @@ message DriverSandboxTemplate { // For the Kubernetes driver this carries fields such as runtimeClassName, // annotations, and volumeClaimTemplates. google.protobuf.Struct platform_config = 11; + // Caller-provided config for the selected driver only. This is the inner + // block from public SandboxTemplate.driver_config after gateway selection. + google.protobuf.Struct driver_config = 12; } // Typed compute-resource requirements. diff --git a/proto/openshell.proto b/proto/openshell.proto index 4731baa9e..e11891774 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -325,14 +325,14 @@ message SandboxResourceRequirements { GpuResourceRequirement gpu = 1; } -// Public GPU resource requirement. Device identifiers are interpreted by the -// selected compute driver and are an interim compatibility surface. +// Public GPU resource requirement. message GpuResourceRequirement { - // Optional number of GPUs requested. Mutually exclusive with device_ids. + reserved 2; + reserved "device_ids"; + + // Optional number of GPUs requested. When unset, presence means the driver + // chooses its default GPU assignment behavior. optional uint32 count = 1; - // Optional driver-native device identifiers. Mutually exclusive with count. - // Empty means the driver chooses its default GPU assignment behavior. - repeated string device_ids = 2; } // Public sandbox template mapped onto compute-driver template inputs. @@ -359,6 +359,10 @@ message SandboxTemplate { // available (beta through 1.35, GA in 1.36+) and a supporting runtime. // When unset, the cluster-wide default is used. optional bool user_namespaces = 10; + // Opaque driver-specific configuration provided by the caller. The gateway + // selects the block matching the active driver name and forwards only that + // inner block to the selected compute driver. + google.protobuf.Struct driver_config = 11; } // User-facing sandbox status derived by the gateway from compute-driver observations.