Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions codex-rs/core/src/session/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ use crate::stream_events_utils::handle_output_item_done;
use crate::stream_events_utils::last_assistant_message_from_item;
use crate::stream_events_utils::mark_thread_memory_mode_polluted_if_external_context;
use crate::stream_events_utils::raw_assistant_output_text_from_item;
use crate::stream_events_utils::record_completed_response_item;
use crate::stream_events_utils::record_completed_response_item_with_finalized_facts;
use crate::tasks::emit_compact_metric;
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::parallel::ToolCallRuntime;
use crate::tools::registry::ToolArgumentDiffConsumer;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouterParams;
use crate::tools::router::extension_tool_executors;
use crate::tools::spec_plan::search_tool_enabled;
Expand Down Expand Up @@ -1923,6 +1925,8 @@ async fn try_run_sampling_request(
.await??;
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>> =
FuturesOrdered::new();
let mut staged_tool_calls: Vec<(ResponseItem, ToolCall)> = Vec::new();
let mut discard_staged_tool_futures = false;
let mut needs_follow_up = false;
let mut last_agent_message: Option<String> = None;
let mut active_item: Option<TurnItem> = None;
Expand Down Expand Up @@ -2067,8 +2071,15 @@ async fn try_run_sampling_request(
Ok(output_result) => output_result,
Err(err) => break Err(err),
};
if let Some(tool_future) = output_result.tool_future {
in_flight.push_back(tool_future);
if output_result.should_discard_response_tool_futures() {
discard_staged_tool_futures = true;
staged_tool_calls.clear();
}
if let (Some(tool_item), Some(tool_call)) =
(output_result.tool_item, output_result.tool_future)
&& !discard_staged_tool_futures
{
staged_tool_calls.push((tool_item, tool_call));
}
if let Some(agent_message) = output_result.last_agent_message {
last_agent_message = Some(agent_message);
Expand Down Expand Up @@ -2212,6 +2223,22 @@ async fn try_run_sampling_request(
if let Some(false) = end_turn {
needs_follow_up = true;
}
if !discard_staged_tool_futures {
for (tool_item, tool_call) in staged_tool_calls.drain(..) {
record_completed_response_item(
sess.as_ref(),
turn_context.as_ref(),
&tool_item,
)
.await;
let tool_future = Box::pin(
tool_runtime
.clone()
.handle_tool_call(tool_call, cancellation_token.child_token()),
);
in_flight.push_back(tool_future);
}
}
break Ok(SamplingRequestResult {
needs_follow_up,
last_agent_message,
Expand Down
153 changes: 135 additions & 18 deletions codex-rs/core/src/stream_events_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::pin::Pin;
use std::sync::Arc;

use base64::Engine;
Expand All @@ -16,7 +15,10 @@ use crate::function_tool::FunctionCallError;
use crate::parse_turn_item;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::session::turn_context::TurnEnvironment;
use crate::tools::context::ToolPayload;
use crate::tools::parallel::ToolCallRuntime;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_memories_read::citations::parse_memory_citation;
use codex_memories_read::citations::thread_ids_from_memory_citation;
Expand All @@ -31,7 +33,6 @@ use codex_protocol::models::ResponseItem;
use codex_rollout::state_db;
use codex_utils_absolute_path::AbsolutePathBuf;
use codex_utils_stream_parser::strip_proposed_plan_blocks;
use futures::Future;
use tracing::debug;
use tracing::instrument;
use tracing::warn;
Expand Down Expand Up @@ -303,21 +304,28 @@ async fn record_stage1_output_usage_for_memory_citation(
/// Handle a completed output item from the model stream, recording it and
/// queuing any tool execution futures. This records items immediately so
/// history and rollout stay in sync even if the turn is later cancelled.
pub(crate) type InFlightFuture<'f> =
Pin<Box<dyn Future<Output = Result<ResponseInputItem>> + Send + 'f>>;

#[derive(Default)]
pub(crate) struct OutputItemResult {
pub last_agent_message: Option<String>,
pub needs_follow_up: bool,
pub tool_future: Option<InFlightFuture<'static>>,
pub tool_future: Option<ToolCall>,
pub tool_item: Option<ResponseItem>,
respond_to_model: bool,
}

impl OutputItemResult {
pub(crate) fn should_discard_response_tool_futures(&self) -> bool {
self.respond_to_model
}
}

pub(crate) struct HandleOutputCtx {
pub sess: Arc<Session>,
pub turn_context: Arc<TurnContext>,
pub turn_store: Arc<ExtensionData>,
#[allow(dead_code)]
pub tool_runtime: ToolCallRuntime,
#[allow(dead_code)]
pub cancellation_token: CancellationToken,
}

Expand Down Expand Up @@ -411,7 +419,7 @@ pub(crate) async fn handle_output_item_done(
let plan_mode = ctx.turn_context.collaboration_mode.mode == ModeKind::Plan;

match ToolRouter::build_tool_call(item.clone()) {
// The model emitted a tool call; log it, persist the item immediately, and queue the tool execution.
// The model emitted a tool call; log it and prepare execution after the full response validates.
Ok(Some(call)) => {
ctx.sess
.input_queue
Expand All @@ -429,18 +437,27 @@ pub(crate) async fn handle_output_item_done(
payload_preview
);

record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;

let cancellation_token = ctx.cancellation_token.child_token();
let tool_future: InFlightFuture<'static> = Box::pin(
ctx.tool_runtime
.clone()
.handle_tool_call(call, cancellation_token),
);
if let Some(response) =
preflight_direct_tool_call(ctx.turn_context.as_ref(), &call).await
{
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
if let Some(response_item) = response_input_to_response_item(&response) {
ctx.sess
.record_conversation_items(
&ctx.turn_context,
std::slice::from_ref(&response_item),
)
.await;
}

output.needs_follow_up = true;
output.tool_future = Some(tool_future);
output.needs_follow_up = true;
output.respond_to_model = true;
} else {
output.needs_follow_up = true;
output.tool_future = Some(call);
output.tool_item = Some(item);
}
}
// No tool call: convert messages/reasoning into turn items and mark them as complete.
Ok(None) => {
Expand Down Expand Up @@ -504,6 +521,7 @@ pub(crate) async fn handle_output_item_done(
}

output.needs_follow_up = true;
output.respond_to_model = true;
}
// A fatal error occurred; surface it back into history.
Err(FunctionCallError::Fatal(message)) => {
Expand All @@ -514,6 +532,105 @@ pub(crate) async fn handle_output_item_done(
Ok(output)
}

async fn preflight_direct_tool_call(
turn_context: &TurnContext,
call: &ToolCall,
) -> Option<ResponseInputItem> {
if call.tool_name.namespace.is_some() || call.tool_name.name != "apply_patch" {
return None;
}
let ToolPayload::Custom { input } = &call.payload else {
return None;
};

let args = match codex_apply_patch::parse_patch(input) {
Ok(args) => args,
Err(parse_error) => {
return Some(custom_tool_call_failure(
call,
format!("apply_patch verification failed: {parse_error}"),
));
}
};

let turn_environment =
match preflight_apply_patch_environment(turn_context, args.environment_id.as_deref()) {
Ok(Some(turn_environment)) => turn_environment,
Ok(None) => {
return Some(custom_tool_call_failure(
call,
"apply_patch is unavailable in this session".to_string(),
));
}
Err(message) => {
return Some(custom_tool_call_failure(call, message));
}
};
let cwd = turn_environment.cwd.clone();
let fs = turn_environment.environment.get_filesystem();
let sandbox =
turn_context.file_system_sandbox_context(/*additional_permissions*/ None, &cwd);

match codex_apply_patch::verify_apply_patch_args(args, &cwd, fs.as_ref(), Some(&sandbox)).await
{
codex_apply_patch::MaybeApplyPatchVerified::Body(_) => None,
codex_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
Some(custom_tool_call_failure(
call,
format!("apply_patch verification failed: {parse_error}"),
))
}
codex_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => {
tracing::trace!("Failed to parse apply_patch input, {error:?}");
Some(custom_tool_call_failure(
call,
"apply_patch handler received invalid patch input".to_string(),
))
}
codex_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => {
Some(custom_tool_call_failure(
call,
"apply_patch handler received non-apply_patch input".to_string(),
))
}
}
}

fn preflight_apply_patch_environment<'a>(
turn_context: &'a TurnContext,
environment_id: Option<&str>,
) -> std::result::Result<Option<&'a TurnEnvironment>, String> {
if environment_id.is_some() && turn_context.environments.turn_environments.len() <= 1 {
return Err("apply_patch environment selection is unavailable for this turn".to_string());
}

environment_id.map_or_else(
|| Ok(turn_context.environments.primary()),
|environment_id| {
turn_context
.environments
.turn_environments
.iter()
.find(|environment| environment.environment_id == environment_id)
.map(Some)
.ok_or_else(|| {
format!("unknown turn environment id `{environment_id}` for apply_patch")
})
},
)
}

fn custom_tool_call_failure(call: &ToolCall, message: String) -> ResponseInputItem {
ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
name: None,
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text(message),
success: Some(false),
},
}
}

pub(crate) async fn handle_non_tool_response_item(
sess: &Session,
turn_context: &TurnContext,
Expand Down
57 changes: 57 additions & 0 deletions codex-rs/core/tests/suite/apply_patch_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,63 @@ async fn apply_patch_cli_rejects_repeated_identical_payload_before_reapplying()
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apply_patch_cli_does_not_execute_valid_sibling_when_later_sibling_is_malformed()
-> Result<()> {
skip_if_no_network!(Ok(()));

let harness = apply_patch_harness().await?;

let valid_call_id = "apply-valid-sibling";
let invalid_call_id = "apply-invalid-sibling";
let valid_patch = "*** Begin Patch\n*** Add File: staged.txt\n+created\n*** End Patch";
let response_mock = mount_sse_sequence(
harness.server(),
vec![
sse(vec![
ev_response_created("resp-1"),
ev_apply_patch_custom_tool_call(valid_call_id, valid_patch),
ev_apply_patch_custom_tool_call(invalid_call_id, "not a patch"),
ev_completed("resp-1"),
]),
sse(vec![
ev_assistant_message("msg-1", "retrying"),
ev_completed("resp-2"),
]),
],
)
.await;

harness
.submit("apply a valid patch with a malformed sibling")
.await?;

let requests = response_mock.requests();
assert_eq!(requests.len(), 2);
assert!(
requests[1]
.custom_tool_call_output_content_and_success(invalid_call_id)
.is_some(),
"malformed sibling should still be reported back to the model"
);
assert!(
requests[1]
.inputs_of_type("custom_tool_call_output")
.iter()
.all(
|item| item.get("call_id").and_then(serde_json::Value::as_str)
!= Some(valid_call_id)
),
"valid sibling should not execute after the response batch is invalidated"
);
assert!(
fs::metadata(harness.cwd_abs().join("staged.txt")).is_err(),
"valid sibling must not create staged.txt"
);

Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn apply_patch_cli_skips_duplicate_streamed_call_id() -> Result<()> {
skip_if_no_network!(Ok(()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
source: core/tests/suite/pending_input.rs
assertion_line: 232
expression: snapshot
---
Scenario: /responses POST bodies (input only, redacted like other suite snapshots)

## First request
00:message/developer:<PERMISSIONS_INSTRUCTIONS>
01:message/user:<ENVIRONMENT_CONTEXT:cwd=<CWD>>
02:message/user:first prompt

## Second request
00:message/developer:<PERMISSIONS_INSTRUCTIONS>
01:message/user:<ENVIRONMENT_CONTEXT:cwd=<CWD>>
02:message/user:first prompt
03:reasoning:summary=thinking:encrypted=true
04:message/assistant:first answer
05:function_call/shell
06:function_call_output:unsupported call: shell
07:message/user:second prompt