Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of error handling #4575

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion evaluation/EDA/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'success': test_result,
'final_message': final_message,
Expand Down
2 changes: 1 addition & 1 deletion evaluation/agent_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'agent_answer': agent_answer,
'final_answer': final_ans,
Expand Down
2 changes: 1 addition & 1 deletion evaluation/aider_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result,
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/biocoder/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result,
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/bird/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def execute_sql(db_path, sql):
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result,
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/browsing_delegation/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'query': instance.instruction,
'action': last_delegate_action,
Expand Down
2 changes: 1 addition & 1 deletion evaluation/gaia/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result,
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/gorilla/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'text': model_answer_raw,
'correct': correct,
Expand Down
2 changes: 1 addition & 1 deletion evaluation/gpqa/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def process_instance(
metadata=metadata,
history=state.history.compatibility_for_eval_history_pairs(),
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'result': test_result,
'found_answers': found_answers,
Expand Down
2 changes: 1 addition & 1 deletion evaluation/humanevalfix/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result,
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/integration_tests/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result.model_dump(),
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/logic_reasoning/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result=test_result,
)
return output
Expand Down
2 changes: 1 addition & 1 deletion evaluation/miniwob/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'reward': reward,
},
Expand Down
2 changes: 1 addition & 1 deletion evaluation/mint/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'success': task_state.success if task_state else False,
},
Expand Down
8 changes: 4 additions & 4 deletions evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,10 @@ def process_instance(

# if fatal error, throw EvalError to trigger re-run
if (
state.last_error
and 'fatal error during agent execution' in state.last_error
state.get_last_error()
and 'fatal error during agent execution' in state.get_last_error()
):
raise EvalException('Fatal error detected: ' + state.last_error)
raise EvalException('Fatal error detected: ' + state.get_last_error())

# ======= THIS IS SWE-Bench specific =======
# Get git patch
Expand Down Expand Up @@ -442,7 +442,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
)
return output

Expand Down
2 changes: 1 addition & 1 deletion evaluation/toolqa/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
)
return output

Expand Down
2 changes: 1 addition & 1 deletion evaluation/webarena/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
error=state.get_last_error() if state and state.get_last_error() else None,
test_result={
'reward': reward,
},
Expand Down
67 changes: 31 additions & 36 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,19 @@ async def update_state_after_step(self):
# update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset()
self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics)

async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.

This method should be called for a particular type of errors, which have:
- a user-friendly message, which will be shown in the chat box. This should not be a raw exception message.
- an ErrorObservation that can be sent to the LLM by the user role, with the exception message, so it can self-correct next time.
"""
self.state.last_error = message
if exception:
self.state.last_error += f': {exception}'
async def _react_to_error(
self,
message: str,
exception: Exception | None = None,
new_state: AgentState | None = None,
):
if new_state is not None:
await self.set_agent_state_to(new_state)
detail = str(exception) if exception is not None else ''
if exception is not None and isinstance(exception, litellm.AuthenticationError):
detail = 'Please check your credentials. Is your API key correct?'
self.event_stream.add_event(
ErrorObservation(f'{message}:{detail}'), EventSource.USER
detail += '\nPlease check your credentials. Is your API key correct?'
await self.event_stream.async_add_event(
ErrorObservation(f'{message}:{detail}'), EventSource.AGENT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method was initially meant only for informing the user in the UI. But it had such a good name, that we felt like it should be used almost like logging errors. 😂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂

TBH I don't think an AgentController should know anything about a UI

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's not that it knows it, strictly speaking, just the ErrorObs created here has a message intended for the chat box / end user's eyes.

And by this time, it doesn't have anymore the info about the exception, except... it did, when it got it as param for some exception info intended for evals.

Copy link
Collaborator

@enyst enyst Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you're right, and your way of looking at it might just solve this little strange thing. How about:

  • we don't do at all these UI-intended lines of code here, in the controller
  • create an ErrorObservation
    • with all info potentially needed, e.g. for the backend / LLM
    • including exception info
  • put it in the stream
  • let the agent session interpret it and send it over to the UI with whatever info it makes of it
    • e.g. a dict with the right source, a certain message, perhaps remove stuff it doesn't want. 🤔

)

async def start_step_loop(self):
Expand All @@ -164,10 +162,11 @@ async def start_step_loop(self):
traceback.print_exc()
logger.error(f'Error while running the agent: {e}')
logger.error(traceback.format_exc())
await self.report_error(
'There was an unexpected error while running the agent', exception=e
await self._react_to_error(
'There was an unexpected error while running the agent',
exception=e,
new_state=AgentState.ERROR,
)
await self.set_agent_state_to(AgentState.ERROR)
break

await asyncio.sleep(0.1)
Expand Down Expand Up @@ -254,9 +253,6 @@ async def _handle_observation(self, observation: Observation):
if self.state.agent_state == AgentState.ERROR:
self.state.metrics.merge(self.state.local_metrics)
elif isinstance(observation, FatalErrorObservation):
await self.report_error(
'There was a fatal error during agent execution: ' + str(observation)
)
await self.set_agent_state_to(AgentState.ERROR)
self.state.metrics.merge(self.state.local_metrics)

Expand Down Expand Up @@ -330,7 +326,9 @@ async def set_agent_state_to(self, new_state: AgentState):
else:
confirmation_state = ActionConfirmationStatus.REJECTED
self._pending_action.confirmation_state = confirmation_state # type: ignore[attr-defined]
self.event_stream.add_event(self._pending_action, EventSource.AGENT)
await self.event_stream.async_add_event(
self._pending_action, EventSource.AGENT
)

self.state.agent_state = new_state
self.event_stream.add_event(
Expand Down Expand Up @@ -443,7 +441,7 @@ async def _step(self) -> None:
except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e:
# report to the user
# and send the underlying exception to the LLM for self-correction
await self.report_error(str(e))
await self._react_to_error(str(e))
return

if action.runnable:
Expand All @@ -462,15 +460,15 @@ async def _step(self) -> None:
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
self.event_stream.add_event(action, EventSource.AGENT)
await self.event_stream.async_add_event(action, EventSource.AGENT)

await self.update_state_after_step()
logger.info(action, extra={'msg_type': 'ACTION'})

if self._is_stuck():
# This need to go BEFORE report_error to sync metrics
await self.set_agent_state_to(AgentState.ERROR)
await self.report_error('Agent got stuck in a loop')
await self._react_to_error(
'Agent got stuck in a loop', new_state=AgentState.ERROR
)

async def _delegate_step(self):
"""Executes a single step of the delegate agent."""
Expand All @@ -489,7 +487,7 @@ async def _delegate_step(self):
self.delegate = None
self.delegateAction = None

await self.report_error('Delegator agent encountered an error')
await self._react_to_error('Delegator agent encountered an error')
elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
logger.info(
f'[Agent Controller {self.id}] Delegate agent has finished execution'
Expand Down Expand Up @@ -518,7 +516,7 @@ async def _delegate_step(self):
# clean up delegate status
self.delegate = None
self.delegateAction = None
self.event_stream.add_event(obs, EventSource.AGENT)
await self.event_stream.async_add_event(obs, EventSource.AGENT)
return

async def _handle_traffic_control(
Expand All @@ -538,20 +536,17 @@ async def _handle_traffic_control(
else:
self.state.traffic_control_state = TrafficControlState.THROTTLING
if self.headless_mode:
# This need to go BEFORE report_error to sync metrics
await self.set_agent_state_to(AgentState.ERROR)
# set to ERROR state if running in headless mode
# since user cannot resume on the web interface
await self.report_error(
await self._react_to_error(
f'Agent reached maximum {limit_type} in headless mode, task stopped. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}'
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}',
new_state=AgentState.ERROR,
)
else:
await self.set_agent_state_to(AgentState.PAUSED)
await self.report_error(
await self._react_to_error(
f'Agent reached maximum {limit_type}, task paused. '
f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. '
f'{TRAFFIC_CONTROL_REMINDER}'
f'{TRAFFIC_CONTROL_REMINDER}',
new_state=AgentState.PAUSED,
)
stop_step = True
return stop_step
Expand Down
13 changes: 9 additions & 4 deletions openhands/controller/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MessageAction,
)
from openhands.events.action.agent import AgentFinishAction
from openhands.events.observation import ErrorObservation, FatalErrorObservation
from openhands.llm.metrics import Metrics
from openhands.memory.history import ShortTermHistory
from openhands.storage.files import FileStore
Expand Down Expand Up @@ -80,7 +81,6 @@ class State:
history: ShortTermHistory = field(default_factory=ShortTermHistory)
inputs: dict = field(default_factory=dict)
outputs: dict = field(default_factory=dict)
last_error: str | None = None
agent_state: AgentState = AgentState.LOADING
resume_state: AgentState | None = None
traffic_control_state: TrafficControlState = TrafficControlState.NORMAL
Expand Down Expand Up @@ -124,9 +124,6 @@ def restore_from_session(sid: str, file_store: FileStore) -> 'State':
else:
state.resume_state = None

# don't carry last_error anymore after restore
state.last_error = None

# first state after restore
state.agent_state = AgentState.LOADING
return state
Expand Down Expand Up @@ -157,6 +154,14 @@ def __setstate__(self, state):

# remove the restored data from the state if any

def get_last_error(self) -> str:
for event in self.history.get_events(reverse=True):
if isinstance(event, ErrorObservation) or isinstance(
event, FatalErrorObservation
):
return event.content
return ''

def get_current_user_intent(self):
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
last_user_message = None
Expand Down
4 changes: 4 additions & 0 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def add_event(self, event: Event, source: EventSource):
asyncio.run(self.async_add_event(event, source))

async def async_add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
)
enyst marked this conversation as resolved.
Show resolved Hide resolved
with self._lock:
event._id = self._cur_id # type: ignore [attr-defined]
self._cur_id += 1
Expand Down
Loading