class SandboxSession(BaseSandboxSession):
"""Wrap sandbox operations in audit events and SDK tracing spans when tracing is active."""
_inner: BaseSandboxSession
_instrumentation: Instrumentation
_seq: int
def __init__(
self,
inner: BaseSandboxSession,
*,
instrumentation: Instrumentation | None = None,
dependencies: Dependencies | None = None,
) -> None:
self._inner = inner
self._inner.set_dependencies(dependencies)
self._instrumentation = instrumentation or Instrumentation()
self._seq = 0
self._bind_session_to_sinks()
def _bind_session_to_sinks(self) -> None:
# Bind sinks to the *inner* session to avoid recursive instrumentation loops.
for sink in self._instrumentation.sinks:
sinks: list[object]
if isinstance(sink, ChainedSink):
sinks = list(sink.sinks)
else:
sinks = [sink]
for s in sinks:
if isinstance(s, SandboxSessionBoundSink):
s.bind(self._inner)
@property
def state(self) -> SandboxSessionState:
return self._inner.state
@state.setter
def state(self, value: SandboxSessionState) -> None: # pragma: no cover
self._inner.state = value
@property
def dependencies(self) -> Dependencies:
return self._inner.dependencies
def set_dependencies(self, dependencies: Dependencies | None) -> None:
self._inner.set_dependencies(dependencies)
async def _aclose_dependencies(self) -> None:
await self._inner._aclose_dependencies()
def _set_concurrency_limits(self, limits: SandboxConcurrencyLimits) -> None:
super()._set_concurrency_limits(limits)
self._inner._set_concurrency_limits(limits)
def normalize_path(self, path: Path | str) -> Path:
return self._inner.normalize_path(path)
def supports_pty(self) -> bool:
return self._inner.supports_pty()
async def aclose(self) -> None:
try:
await super().aclose()
finally:
await self._instrumentation.flush()
def _next_seq(self) -> int:
self._seq += 1
return self._seq
async def _emit_start_event(
self,
*,
op: OpName,
span_id: str,
parent_span_id: str | None,
trace_id: str | None,
data: dict[str, object] | None = None,
) -> None:
await self._instrumentation.emit(
SandboxSessionStartEvent(
session_id=self.state.session_id,
seq=self._next_seq(),
op=op,
span_id=span_id,
parent_span_id=parent_span_id,
trace_id=trace_id,
data=data or {},
)
)
def _trace_span_data(self, *, op: OpName) -> dict[str, object]:
return {
"sandbox.backend": type(self._inner).__module__.rsplit(".", 1)[-1],
"sandbox.operation": op,
"sandbox.session.id": str(self.state.session_id),
"session_id": str(self.state.session_id),
}
def _apply_trace_finish_data(
self,
*,
span: Span[Any] | None,
op: OpName,
ok: bool,
data: dict[str, object] | None,
exc: BaseException | None,
) -> None:
if span is None:
return
trace_data = span.span_data.data
trace_data.update(self._trace_span_data(op=op))
if data is not None:
if "alive" in data:
trace_data["alive"] = data["alive"]
if "exit_code" in data:
trace_data["exit_code"] = data["exit_code"]
if "process.exit.code" in data:
trace_data["process.exit.code"] = data["process.exit.code"]
if "server.port" in data:
trace_data["server.port"] = data["server.port"]
if "server.address" in data:
trace_data["server.address"] = data["server.address"]
if exc is not None:
trace_data["error.type"] = type(exc).__name__
trace_data["error_type"] = type(exc).__name__
error_data: dict[str, object] = {"operation": op}
if isinstance(exc, SandboxError):
trace_data["error_code"] = exc.error_code
error_data["error_code"] = exc.error_code
span.set_error({"message": type(exc).__name__, "data": error_data})
return
if not ok:
if op == "exec":
trace_data["error.type"] = "ExecNonZeroError"
error_data = {"operation": op}
if data is not None and "exit_code" in data:
error_data["exit_code"] = data["exit_code"]
span.set_error(
{
"message": "Sandbox operation returned an unsuccessful result.",
"data": error_data,
}
)
async def _annotate(
self,
*,
op: OpName,
start_data: dict[str, object] | None,
run: Callable[[], Coroutine[object, object, T]],
finish_data: Callable[[T], dict[str, object]] | None = None,
ok: Callable[[T], bool] | None = None,
outputs: Callable[[T], tuple[bytes | None, bytes | None]] | None = None,
) -> T:
span_cm = (
custom_span(
name=f"sandbox.{op}",
data=self._trace_span_data(op=op),
)
if _supports_trace_spans()
else nullcontext(None)
)
with span_cm as trace_span:
span_id, parent_span_id, trace_id = _audit_trace_ids(trace_span)
await self._emit_start_event(
op=op,
span_id=span_id,
parent_span_id=parent_span_id,
trace_id=trace_id,
data=start_data,
)
t0 = time.monotonic()
try:
value = await run()
except Exception as e:
duration_ms = (time.monotonic() - t0) * 1000.0
self._apply_trace_finish_data(
span=trace_span,
op=op,
ok=False,
data=start_data,
exc=e,
)
await self._emit_finish_event(
op=op,
span_id=span_id,
parent_span_id=parent_span_id,
trace_id=trace_id,
duration_ms=duration_ms,
ok=False,
exc=e,
data=start_data,
stdout=None,
stderr=None,
)
raise
data_finish = finish_data(value) if finish_data is not None else start_data
ok_value = ok(value) if ok is not None else True
stdout, stderr = outputs(value) if outputs is not None else (None, None)
duration_ms = (time.monotonic() - t0) * 1000.0
self._apply_trace_finish_data(
span=trace_span,
op=op,
ok=ok_value,
data=data_finish,
exc=None,
)
await self._emit_finish_event(
op=op,
span_id=span_id,
parent_span_id=parent_span_id,
trace_id=trace_id,
duration_ms=duration_ms,
ok=ok_value,
exc=None,
data=data_finish,
stdout=stdout,
stderr=stderr,
)
return value
async def _emit_finish_event(
self,
*,
op: OpName,
span_id: str,
parent_span_id: str | None,
trace_id: str | None,
duration_ms: float,
ok: bool,
exc: BaseException | None,
data: dict[str, object] | None,
stdout: bytes | None,
stderr: bytes | None,
) -> None:
event = SandboxSessionFinishEvent(
session_id=self.state.session_id,
seq=self._next_seq(),
op=op,
span_id=span_id,
parent_span_id=parent_span_id,
trace_id=trace_id,
data=data or {},
ok=ok,
duration_ms=duration_ms,
)
if exc is not None:
event.error_type = type(exc).__name__
event.error_message = str(exc)
if isinstance(exc, SandboxError):
event.error_code = exc.error_code
# Preserve raw bytes so Instrumentation can apply per-op/per-sink policies later.
# Decoding here would force one global formatting decision before sink-specific redaction
# and truncation rules have a chance to run.
event.stdout_bytes = stdout
event.stderr_bytes = stderr
await self._instrumentation.emit(event)
@instrumented_op("start")
async def start(self) -> None:
await self._inner.start()
@instrumented_op("stop")
async def stop(self) -> None:
await self._inner.stop()
@instrumented_op("shutdown")
async def shutdown(self) -> None:
await self._inner.shutdown()
@instrumented_op(
"exec",
data=_exec_start_data,
finish_data=_exec_finish_data,
ok=lambda result: cast(ExecResult, result).ok(),
outputs=lambda result: (
cast(ExecResult, result).stdout,
cast(ExecResult, result).stderr,
),
)
async def exec(
self,
*command: str | Path,
timeout: float | None = None,
shell: bool | list[str] = True,
user: str | User | None = None,
) -> ExecResult:
return await self._inner.exec(*command, timeout=timeout, shell=shell, user=user)
async def _exec_internal(
self,
*command: str | Path,
timeout: float | None = None,
) -> ExecResult:
raise NotImplementedError("this should never be invoked")
async def _resolve_exposed_port(self, port: int) -> ExposedPortEndpoint:
_ = port
raise NotImplementedError("this should never be invoked")
async def pty_exec_start(
self,
*command: str | Path,
timeout: float | None = None,
shell: bool | list[str] = True,
user: str | User | None = None,
tty: bool = False,
yield_time_s: float | None = None,
max_output_tokens: int | None = None,
) -> PtyExecUpdate:
return await self._inner.pty_exec_start(
*command,
timeout=timeout,
shell=shell,
user=user,
tty=tty,
yield_time_s=yield_time_s,
max_output_tokens=max_output_tokens,
)
async def pty_write_stdin(
self,
*,
session_id: int,
chars: str,
yield_time_s: float | None = None,
max_output_tokens: int | None = None,
) -> PtyExecUpdate:
return await self._inner.pty_write_stdin(
session_id=session_id,
chars=chars,
yield_time_s=yield_time_s,
max_output_tokens=max_output_tokens,
)
async def pty_terminate_all(self) -> None:
await self._inner.pty_terminate_all()
async def _normalize_path_for_io(self, path: Path | str) -> Path:
return await self._inner._normalize_path_for_io(path)
async def ls(
self,
path: Path | str,
*,
user: str | User | None = None,
) -> list[FileEntry]:
return await self._inner.ls(path, user=user)
async def rm(
self,
path: Path | str,
*,
recursive: bool = False,
user: str | User | None = None,
) -> None:
await self._inner.rm(path, recursive=recursive, user=user)
async def mkdir(
self,
path: Path | str,
*,
parents: bool = False,
user: str | User | None = None,
) -> None:
await self._inner.mkdir(path, parents=parents, user=user)
@instrumented_op("read", data=_read_start_data)
async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase:
return await self._inner.read(path, user=user)
@instrumented_op("write", data=_write_start_data)
async def write(
self,
path: Path,
data: io.IOBase,
*,
user: str | User | None = None,
) -> None:
await self._inner.write(path, data, user=user)
@instrumented_op(
"running",
finish_data=_running_finish_data,
ok=lambda _alive: True,
)
async def running(self) -> bool:
return await self._inner.running()
@instrumented_op(
"resolve_exposed_port",
data=_resolve_exposed_port_start_data,
finish_data=_resolve_exposed_port_finish_data,
ok=lambda _result: True,
)
async def resolve_exposed_port(self, port: int) -> ExposedPortEndpoint:
return await self._inner.resolve_exposed_port(port)
@instrumented_op(
"persist_workspace",
data=_persist_start_data,
finish_data=_persist_finish_data,
)
async def persist_workspace(self) -> io.IOBase:
return await self._inner.persist_workspace()
@instrumented_op(
"hydrate_workspace",
data=_hydrate_start_data,
)
async def hydrate_workspace(self, data: io.IOBase) -> None:
await self._inner.hydrate_workspace(data)