class OpenAIResponsesModel(Model):
"""
Implementation of `Model` that uses the OpenAI Responses API.
"""
def __init__(
self,
model: str | ChatModel,
openai_client: AsyncOpenAI,
*,
model_is_explicit: bool = True,
) -> None:
self.model = model
self._model_is_explicit = model_is_explicit
self._client = openai_client
def _non_null_or_omit(self, value: Any) -> Any:
return value if value is not None else omit
async def get_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None = None,
conversation_id: str | None = None,
prompt: ResponsePromptParam | None = None,
) -> ModelResponse:
with response_span(disabled=tracing.is_disabled()) as span_response:
try:
response = await self._fetch_response(
system_instructions,
input,
model_settings,
tools,
output_schema,
handoffs,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
stream=False,
prompt=prompt,
)
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("LLM responded")
else:
logger.debug(
"LLM resp:\n"
f"""{
json.dumps(
[x.model_dump() for x in response.output],
indent=2,
ensure_ascii=False,
)
}\n"""
)
usage = (
Usage(
requests=1,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
total_tokens=response.usage.total_tokens,
input_tokens_details=response.usage.input_tokens_details,
output_tokens_details=response.usage.output_tokens_details,
)
if response.usage
else Usage()
)
if tracing.include_data():
span_response.span_data.response = response
span_response.span_data.input = input
except Exception as e:
span_response.set_error(
SpanError(
message="Error getting response",
data={
"error": str(e) if tracing.include_data() else e.__class__.__name__,
},
)
)
request_id = e.request_id if isinstance(e, APIStatusError) else None
logger.error(f"Error getting response: {e}. (request_id: {request_id})")
raise
return ModelResponse(
output=response.output,
usage=usage,
response_id=response.id,
)
async def stream_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None = None,
conversation_id: str | None = None,
prompt: ResponsePromptParam | None = None,
) -> AsyncIterator[ResponseStreamEvent]:
"""
Yields a partial message as it is generated, as well as the usage information.
"""
with response_span(disabled=tracing.is_disabled()) as span_response:
try:
stream = await self._fetch_response(
system_instructions,
input,
model_settings,
tools,
output_schema,
handoffs,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
stream=True,
prompt=prompt,
)
final_response: Response | None = None
async for chunk in stream:
if isinstance(chunk, ResponseCompletedEvent):
final_response = chunk.response
yield chunk
if final_response and tracing.include_data():
span_response.span_data.response = final_response
span_response.span_data.input = input
except Exception as e:
span_response.set_error(
SpanError(
message="Error streaming response",
data={
"error": str(e) if tracing.include_data() else e.__class__.__name__,
},
)
)
logger.error(f"Error streaming response: {e}")
raise
@overload
async def _fetch_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
previous_response_id: str | None,
conversation_id: str | None,
stream: Literal[True],
prompt: ResponsePromptParam | None = None,
) -> AsyncStream[ResponseStreamEvent]: ...
@overload
async def _fetch_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
previous_response_id: str | None,
conversation_id: str | None,
stream: Literal[False],
prompt: ResponsePromptParam | None = None,
) -> Response: ...
async def _fetch_response(
self,
system_instructions: str | None,
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
previous_response_id: str | None = None,
conversation_id: str | None = None,
stream: Literal[True] | Literal[False] = False,
prompt: ResponsePromptParam | None = None,
) -> Response | AsyncStream[ResponseStreamEvent]:
list_input = ItemHelpers.input_to_new_input_list(input)
list_input = _to_dump_compatible(list_input)
list_input = self._remove_openai_responses_api_incompatible_fields(list_input)
if model_settings.parallel_tool_calls and tools:
parallel_tool_calls: bool | Omit = True
elif model_settings.parallel_tool_calls is False:
parallel_tool_calls = False
else:
parallel_tool_calls = omit
tool_choice = Converter.convert_tool_choice(model_settings.tool_choice)
converted_tools = Converter.convert_tools(tools, handoffs)
converted_tools_payload = _to_dump_compatible(converted_tools.tools)
response_format = Converter.get_response_format(output_schema)
should_omit_model = prompt is not None and not self._model_is_explicit
model_param: str | ChatModel | Omit = self.model if not should_omit_model else omit
should_omit_tools = prompt is not None and len(converted_tools_payload) == 0
tools_param: list[ToolParam] | Omit = (
converted_tools_payload if not should_omit_tools else omit
)
include_set: set[str] = set(converted_tools.includes)
if model_settings.response_include is not None:
include_set.update(model_settings.response_include)
if model_settings.top_logprobs is not None:
include_set.add("message.output_text.logprobs")
include = cast(list[ResponseIncludable], list(include_set))
if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Calling LLM")
else:
input_json = json.dumps(
list_input,
indent=2,
ensure_ascii=False,
)
tools_json = json.dumps(
converted_tools_payload,
indent=2,
ensure_ascii=False,
)
logger.debug(
f"Calling LLM {self.model} with input:\n"
f"{input_json}\n"
f"Tools:\n{tools_json}\n"
f"Stream: {stream}\n"
f"Tool choice: {tool_choice}\n"
f"Response format: {response_format}\n"
f"Previous response id: {previous_response_id}\n"
f"Conversation id: {conversation_id}\n"
)
extra_args = dict(model_settings.extra_args or {})
if model_settings.top_logprobs is not None:
extra_args["top_logprobs"] = model_settings.top_logprobs
if model_settings.verbosity is not None:
if response_format is not omit:
response_format["verbosity"] = model_settings.verbosity # type: ignore [index]
else:
response_format = {"verbosity": model_settings.verbosity}
stream_param: Literal[True] | Omit = True if stream else omit
response = await self._client.responses.create(
previous_response_id=self._non_null_or_omit(previous_response_id),
conversation=self._non_null_or_omit(conversation_id),
instructions=self._non_null_or_omit(system_instructions),
model=model_param,
input=list_input,
include=include,
tools=tools_param,
prompt=self._non_null_or_omit(prompt),
temperature=self._non_null_or_omit(model_settings.temperature),
top_p=self._non_null_or_omit(model_settings.top_p),
truncation=self._non_null_or_omit(model_settings.truncation),
max_output_tokens=self._non_null_or_omit(model_settings.max_tokens),
tool_choice=tool_choice,
parallel_tool_calls=parallel_tool_calls,
stream=cast(Any, stream_param),
extra_headers=self._merge_headers(model_settings),
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
text=response_format,
store=self._non_null_or_omit(model_settings.store),
prompt_cache_retention=self._non_null_or_omit(model_settings.prompt_cache_retention),
reasoning=self._non_null_or_omit(model_settings.reasoning),
metadata=self._non_null_or_omit(model_settings.metadata),
**extra_args,
)
return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response)
def _remove_openai_responses_api_incompatible_fields(self, list_input: list[Any]) -> list[Any]:
"""
Remove or transform input items that are incompatible with the OpenAI Responses API.
This data transformation does not always guarantee that items from other provider
interactions are accepted by the OpenAI Responses API.
Only items with truthy provider_data are processed.
This function handles the following incompatibilities:
- provider_data: Removes fields specific to other providers (e.g., Gemini, Claude).
- Fake IDs: Removes temporary IDs (FAKE_RESPONSES_ID) that should not be sent to OpenAI.
- Reasoning items: Filters out provider-specific reasoning items entirely.
"""
# Early return optimization: if no item has provider_data, return unchanged.
has_provider_data = any(
isinstance(item, dict) and item.get("provider_data") for item in list_input
)
if not has_provider_data:
return list_input
result = []
for item in list_input:
cleaned = self._clean_item_for_openai(item)
if cleaned is not None:
result.append(cleaned)
return result
def _clean_item_for_openai(self, item: Any) -> Any | None:
# Only process dict items
if not isinstance(item, dict):
return item
# Filter out reasoning items with provider_data (provider-specific reasoning).
if item.get("type") == "reasoning" and item.get("provider_data"):
return None
# Remove fake response ID.
if item.get("id") == FAKE_RESPONSES_ID:
del item["id"]
# Remove provider_data field.
if "provider_data" in item:
del item["provider_data"]
return item
def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = AsyncOpenAI()
return self._client
def _merge_headers(self, model_settings: ModelSettings):
return {
**_HEADERS,
**(model_settings.extra_headers or {}),
**(_HEADERS_OVERRIDE.get() or {}),
}