[Add] browser-use and main.py
This commit is contained in:
parent
08e64bdf45
commit
96914d44ac
221 changed files with 30952 additions and 1 deletions
329
browser-use/browser_use/agent/message_manager/service.py
Normal file
329
browser-use/browser_use/agent/message_manager/service.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.agent.message_manager.views import MessageMetadata
|
||||
from browser_use.agent.prompts import AgentMessagePrompt
|
||||
from browser_use.agent.views import ActionResult, AgentOutput, AgentStepInfo, MessageManagerState
|
||||
from browser_use.browser.views import BrowserState
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageManagerSettings(BaseModel):
|
||||
max_input_tokens: int = 128000
|
||||
estimated_characters_per_token: int = 3
|
||||
image_tokens: int = 800
|
||||
include_attributes: list[str] = []
|
||||
message_context: str | None = None
|
||||
sensitive_data: dict[str, str] | None = None
|
||||
available_file_paths: list[str] | None = None
|
||||
|
||||
|
||||
class MessageManager:
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
system_message: SystemMessage,
|
||||
settings: MessageManagerSettings = MessageManagerSettings(),
|
||||
state: MessageManagerState = MessageManagerState(),
|
||||
):
|
||||
self.task = task
|
||||
self.settings = settings
|
||||
self.state = state
|
||||
self.system_prompt = system_message
|
||||
|
||||
# Only initialize messages if state is empty
|
||||
if len(self.state.history.messages) == 0:
|
||||
self._init_messages()
|
||||
|
||||
def _init_messages(self) -> None:
|
||||
"""Initialize the message history with system message, context, task, and other initial messages"""
|
||||
self._add_message_with_tokens(self.system_prompt, message_type='init')
|
||||
|
||||
if self.settings.message_context:
|
||||
context_message = HumanMessage(content='Context for the task' + self.settings.message_context)
|
||||
self._add_message_with_tokens(context_message, message_type='init')
|
||||
|
||||
task_message = HumanMessage(
|
||||
content=f'Your ultimate task is: """{self.task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.'
|
||||
)
|
||||
self._add_message_with_tokens(task_message, message_type='init')
|
||||
|
||||
if self.settings.sensitive_data:
|
||||
info = f'Here are placeholders for sensitive data: {list(self.settings.sensitive_data.keys())}'
|
||||
info += '\nTo use them, write <secret>the placeholder name</secret>'
|
||||
info_message = HumanMessage(content=info)
|
||||
self._add_message_with_tokens(info_message, message_type='init')
|
||||
|
||||
placeholder_message = HumanMessage(content='Example output:')
|
||||
self._add_message_with_tokens(placeholder_message, message_type='init')
|
||||
|
||||
example_tool_call = AIMessage(
|
||||
content='',
|
||||
tool_calls=[
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': {
|
||||
'current_state': {
|
||||
'evaluation_previous_goal': """
|
||||
Success - I successfully clicked on the 'Apple' link from the Google Search results page,
|
||||
which directed me to the 'Apple' company homepage. This is a good start toward finding
|
||||
the best place to buy a new iPhone as the Apple website often list iPhones for sale.
|
||||
""".strip(),
|
||||
'memory': """
|
||||
I searched for 'iPhone retailers' on Google. From the Google Search results page,
|
||||
I used the 'click_element_by_index' tool to click on element at index [45] labeled 'Best Buy' but calling
|
||||
the tool did not direct me to a new page. I then used the 'click_element_by_index' tool to click
|
||||
on element at index [82] labeled 'Apple' which redirected me to the 'Apple' company homepage.
|
||||
Currently at step 3/15.
|
||||
""".strip(),
|
||||
'next_goal': """
|
||||
Looking at reported structure of the current page, I can see the item '[127]<h3 iPhone/>'
|
||||
in the content. I think this button will lead to more information and potentially prices
|
||||
for iPhones. I'll click on the link at index [127] using the 'click_element_by_index'
|
||||
tool and hope to see prices on the next page.
|
||||
""".strip(),
|
||||
},
|
||||
'action': [{'click_element_by_index': {'index': 127}}],
|
||||
},
|
||||
'id': str(self.state.tool_id),
|
||||
'type': 'tool_call',
|
||||
},
|
||||
],
|
||||
)
|
||||
self._add_message_with_tokens(example_tool_call, message_type='init')
|
||||
self.add_tool_message(content='Browser started', message_type='init')
|
||||
|
||||
placeholder_message = HumanMessage(content='[Your task history memory starts here]')
|
||||
self._add_message_with_tokens(placeholder_message)
|
||||
|
||||
if self.settings.available_file_paths:
|
||||
filepaths_msg = HumanMessage(content=f'Here are file paths you can use: {self.settings.available_file_paths}')
|
||||
self._add_message_with_tokens(filepaths_msg, message_type='init')
|
||||
|
||||
def add_new_task(self, new_task: str) -> None:
|
||||
content = f'Your new ultimate task is: """{new_task}""". Take the previous context into account and finish your new ultimate task. '
|
||||
msg = HumanMessage(content=content)
|
||||
self._add_message_with_tokens(msg)
|
||||
self.task = new_task
|
||||
|
||||
@time_execution_sync('--add_state_message')
|
||||
def add_state_message(
|
||||
self,
|
||||
state: BrowserState,
|
||||
result: list[ActionResult] | None = None,
|
||||
step_info: AgentStepInfo | None = None,
|
||||
use_vision=True,
|
||||
) -> None:
|
||||
"""Add browser state as human message"""
|
||||
|
||||
# if keep in memory, add to directly to history and add state without result
|
||||
if result:
|
||||
for r in result:
|
||||
if r.include_in_memory:
|
||||
if r.extracted_content:
|
||||
msg = HumanMessage(content='Action result: ' + str(r.extracted_content))
|
||||
self._add_message_with_tokens(msg)
|
||||
if r.error:
|
||||
# if endswith \n, remove it
|
||||
if r.error.endswith('\n'):
|
||||
r.error = r.error[:-1]
|
||||
# get only last line of error
|
||||
last_line = r.error.split('\n')[-1]
|
||||
msg = HumanMessage(content='Action error: ' + last_line)
|
||||
self._add_message_with_tokens(msg)
|
||||
result = None # if result in history, we dont want to add it again
|
||||
|
||||
# otherwise add state message and result to next message (which will not stay in memory)
|
||||
state_message = AgentMessagePrompt(
|
||||
state,
|
||||
result,
|
||||
include_attributes=self.settings.include_attributes,
|
||||
step_info=step_info,
|
||||
).get_user_message(use_vision)
|
||||
self._add_message_with_tokens(state_message)
|
||||
|
||||
def add_model_output(self, model_output: AgentOutput) -> None:
|
||||
"""Add model output as AI message"""
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': model_output.model_dump(mode='json', exclude_unset=True),
|
||||
'id': str(self.state.tool_id),
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
|
||||
msg = AIMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
self._add_message_with_tokens(msg)
|
||||
# empty tool response
|
||||
self.add_tool_message(content='')
|
||||
|
||||
def add_plan(self, plan: str | None, position: int | None = None) -> None:
|
||||
if plan:
|
||||
msg = AIMessage(content=plan)
|
||||
self._add_message_with_tokens(msg, position)
|
||||
|
||||
@time_execution_sync('--get_messages')
|
||||
def get_messages(self) -> list[BaseMessage]:
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
|
||||
msg = [m.message for m in self.state.history.messages]
|
||||
# debug which messages are in history with token count # log
|
||||
total_input_tokens = 0
|
||||
logger.debug(f'Messages in history: {len(self.state.history.messages)}:')
|
||||
for m in self.state.history.messages:
|
||||
total_input_tokens += m.metadata.tokens
|
||||
logger.debug(f'{m.message.__class__.__name__} - Token count: {m.metadata.tokens}')
|
||||
logger.debug(f'Total input tokens: {total_input_tokens}')
|
||||
|
||||
return msg
|
||||
|
||||
def _add_message_with_tokens(
|
||||
self, message: BaseMessage, position: int | None = None, message_type: str | None = None
|
||||
) -> None:
|
||||
"""Add message with token count metadata
|
||||
position: None for last, -1 for second last, etc.
|
||||
"""
|
||||
|
||||
# filter out sensitive data from the message
|
||||
if self.settings.sensitive_data:
|
||||
message = self._filter_sensitive_data(message)
|
||||
|
||||
token_count = self._count_tokens(message)
|
||||
metadata = MessageMetadata(tokens=token_count, message_type=message_type)
|
||||
self.state.history.add_message(message, metadata, position)
|
||||
|
||||
@time_execution_sync('--filter_sensitive_data')
|
||||
def _filter_sensitive_data(self, message: BaseMessage) -> BaseMessage:
|
||||
"""Filter out sensitive data from the message"""
|
||||
|
||||
def replace_sensitive(value: str) -> str:
|
||||
if not self.settings.sensitive_data:
|
||||
return value
|
||||
|
||||
# Create a dictionary with all key-value pairs from sensitive_data where value is not None or empty
|
||||
valid_sensitive_data = {k: v for k, v in self.settings.sensitive_data.items() if v}
|
||||
|
||||
# If there are no valid sensitive data entries, just return the original value
|
||||
if not valid_sensitive_data:
|
||||
logger.warning('No valid entries found in sensitive_data dictionary')
|
||||
return value
|
||||
|
||||
# Replace all valid sensitive data values with their placeholder tags
|
||||
for key, val in valid_sensitive_data.items():
|
||||
value = value.replace(val, f'<secret>{key}</secret>')
|
||||
|
||||
return value
|
||||
|
||||
if isinstance(message.content, str):
|
||||
message.content = replace_sensitive(message.content)
|
||||
elif isinstance(message.content, list):
|
||||
for i, item in enumerate(message.content):
|
||||
if isinstance(item, dict) and 'text' in item:
|
||||
item['text'] = replace_sensitive(item['text'])
|
||||
message.content[i] = item
|
||||
return message
|
||||
|
||||
def _count_tokens(self, message: BaseMessage) -> int:
|
||||
"""Count tokens in a message using the model's tokenizer"""
|
||||
tokens = 0
|
||||
if isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if 'image_url' in item:
|
||||
tokens += self.settings.image_tokens
|
||||
elif isinstance(item, dict) and 'text' in item:
|
||||
tokens += self._count_text_tokens(item['text'])
|
||||
else:
|
||||
msg = message.content
|
||||
if hasattr(message, 'tool_calls'):
|
||||
msg += str(message.tool_calls) # type: ignore
|
||||
tokens += self._count_text_tokens(msg)
|
||||
return tokens
|
||||
|
||||
def _count_text_tokens(self, text: str) -> int:
|
||||
"""Count tokens in a text string"""
|
||||
tokens = len(text) // self.settings.estimated_characters_per_token # Rough estimate if no tokenizer available
|
||||
return tokens
|
||||
|
||||
def cut_messages(self):
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
diff = self.state.history.current_tokens - self.settings.max_input_tokens
|
||||
if diff <= 0:
|
||||
return None
|
||||
|
||||
msg = self.state.history.messages[-1]
|
||||
|
||||
# if list with image remove image
|
||||
if isinstance(msg.message.content, list):
|
||||
text = ''
|
||||
for item in msg.message.content:
|
||||
if 'image_url' in item:
|
||||
msg.message.content.remove(item)
|
||||
diff -= self.settings.image_tokens
|
||||
msg.metadata.tokens -= self.settings.image_tokens
|
||||
self.state.history.current_tokens -= self.settings.image_tokens
|
||||
logger.debug(
|
||||
f'Removed image with {self.settings.image_tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens}'
|
||||
)
|
||||
elif 'text' in item and isinstance(item, dict):
|
||||
text += item['text']
|
||||
msg.message.content = text
|
||||
self.state.history.messages[-1] = msg
|
||||
|
||||
if diff <= 0:
|
||||
return None
|
||||
|
||||
# if still over, remove text from state message proportionally to the number of tokens needed with buffer
|
||||
# Calculate the proportion of content to remove
|
||||
proportion_to_remove = diff / msg.metadata.tokens
|
||||
if proportion_to_remove > 0.99:
|
||||
raise ValueError(
|
||||
f'Max token limit reached - history is too long - reduce the system prompt or task. '
|
||||
f'proportion_to_remove: {proportion_to_remove}'
|
||||
)
|
||||
logger.debug(
|
||||
f'Removing {proportion_to_remove * 100:.2f}% of the last message {proportion_to_remove * msg.metadata.tokens:.2f} / {msg.metadata.tokens:.2f} tokens)'
|
||||
)
|
||||
|
||||
content = msg.message.content
|
||||
characters_to_remove = int(len(content) * proportion_to_remove)
|
||||
content = content[:-characters_to_remove]
|
||||
|
||||
# remove tokens and old long message
|
||||
self.state.history.remove_last_state_message()
|
||||
|
||||
# new message with updated content
|
||||
msg = HumanMessage(content=content)
|
||||
self._add_message_with_tokens(msg)
|
||||
|
||||
last_msg = self.state.history.messages[-1]
|
||||
|
||||
logger.debug(
|
||||
f'Added message with {last_msg.metadata.tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens} - total messages: {len(self.state.history.messages)}'
|
||||
)
|
||||
|
||||
def _remove_last_state_message(self) -> None:
|
||||
"""Remove last state message from history"""
|
||||
self.state.history.remove_last_state_message()
|
||||
|
||||
def add_tool_message(self, content: str, message_type: str | None = None) -> None:
|
||||
"""Add tool message to history"""
|
||||
msg = ToolMessage(content=content, tool_call_id=str(self.state.tool_id))
|
||||
self.state.tool_id += 1
|
||||
self._add_message_with_tokens(msg, message_type=message_type)
|
||||
237
browser-use/browser_use/agent/message_manager/tests.py
Normal file
237
browser-use/browser_use/agent/message_manager/tests.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import pytest
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
from browser_use.agent.message_manager.service import MessageManager, MessageManagerSettings
|
||||
from browser_use.agent.views import ActionResult
|
||||
from browser_use.browser.views import BrowserState, TabInfo
|
||||
from browser_use.dom.views import DOMElementNode, DOMTextNode
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
ChatOpenAI(model='gpt-4o-mini'),
|
||||
AzureChatOpenAI(model='gpt-4o', api_version='2024-02-15-preview'),
|
||||
ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=100, temperature=0.0, stop=None),
|
||||
],
|
||||
ids=['gpt-4o-mini', 'gpt-4o', 'claude-3-5-sonnet'],
|
||||
)
|
||||
def message_manager(request: pytest.FixtureRequest):
|
||||
task = 'Test task'
|
||||
action_descriptions = 'Test actions'
|
||||
return MessageManager(
|
||||
task=task,
|
||||
system_message=SystemMessage(content=action_descriptions),
|
||||
settings=MessageManagerSettings(
|
||||
max_input_tokens=1000,
|
||||
estimated_characters_per_token=3,
|
||||
image_tokens=800,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_initial_messages(message_manager: MessageManager):
|
||||
"""Test that message manager initializes with system and task messages"""
|
||||
messages = message_manager.get_messages()
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], SystemMessage)
|
||||
assert isinstance(messages[1], HumanMessage)
|
||||
assert 'Test task' in messages[1].content
|
||||
|
||||
|
||||
def test_add_state_message(message_manager: MessageManager):
|
||||
"""Test adding browser state message"""
|
||||
state = BrowserState(
|
||||
url='https://test.com',
|
||||
title='Test Page',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={},
|
||||
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
|
||||
)
|
||||
message_manager.add_state_message(state)
|
||||
|
||||
messages = message_manager.get_messages()
|
||||
assert len(messages) == 3
|
||||
assert isinstance(messages[2], HumanMessage)
|
||||
assert 'https://test.com' in messages[2].content
|
||||
|
||||
|
||||
def test_add_state_with_memory_result(message_manager: MessageManager):
|
||||
"""Test adding state with result that should be included in memory"""
|
||||
state = BrowserState(
|
||||
url='https://test.com',
|
||||
title='Test Page',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={},
|
||||
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
|
||||
)
|
||||
result = ActionResult(extracted_content='Important content', include_in_memory=True)
|
||||
|
||||
message_manager.add_state_message(state, [result])
|
||||
messages = message_manager.get_messages()
|
||||
|
||||
# Should have system, task, extracted content, and state messages
|
||||
assert len(messages) == 4
|
||||
assert 'Important content' in messages[2].content
|
||||
assert isinstance(messages[2], HumanMessage)
|
||||
assert isinstance(messages[3], HumanMessage)
|
||||
assert 'Important content' not in messages[3].content
|
||||
|
||||
|
||||
def test_add_state_with_non_memory_result(message_manager: MessageManager):
|
||||
"""Test adding state with result that should not be included in memory"""
|
||||
state = BrowserState(
|
||||
url='https://test.com',
|
||||
title='Test Page',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={},
|
||||
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
|
||||
)
|
||||
result = ActionResult(extracted_content='Temporary content', include_in_memory=False)
|
||||
|
||||
message_manager.add_state_message(state, [result])
|
||||
messages = message_manager.get_messages()
|
||||
|
||||
# Should have system, task, and combined state+result message
|
||||
assert len(messages) == 3
|
||||
assert 'Temporary content' in messages[2].content
|
||||
assert isinstance(messages[2], HumanMessage)
|
||||
|
||||
|
||||
@pytest.mark.skip('not sure how to fix this')
|
||||
@pytest.mark.parametrize('max_tokens', [100000, 10000, 5000])
|
||||
def test_token_overflow_handling_with_real_flow(message_manager: MessageManager, max_tokens):
|
||||
"""Test handling of token overflow in a realistic message flow"""
|
||||
# Set more realistic token limit
|
||||
message_manager.settings.max_input_tokens = max_tokens
|
||||
|
||||
# Create a long sequence of interactions
|
||||
for i in range(200): # Simulate 40 steps of interaction
|
||||
# Create state with varying content length
|
||||
state = BrowserState(
|
||||
url=f'https://test{i}.com',
|
||||
title=f'Test Page {i}',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[
|
||||
DOMTextNode(
|
||||
text=f'Content {j} ' * (10 + i), # Increasing content length
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
)
|
||||
for j in range(5) # Multiple DOM items
|
||||
],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={j: f'//div[{j}]' for j in range(5)},
|
||||
tabs=[TabInfo(page_id=1, url=f'https://test{i}.com', title=f'Test Page {i}')],
|
||||
)
|
||||
|
||||
# Alternate between different types of results
|
||||
result = None
|
||||
if i % 2 == 0: # Every other iteration
|
||||
result = ActionResult(
|
||||
extracted_content=f'Important content from step {i}' * 5,
|
||||
include_in_memory=i % 4 == 0, # Include in memory every 4th message
|
||||
)
|
||||
|
||||
# Add state message
|
||||
if result:
|
||||
message_manager.add_state_message(state, [result])
|
||||
else:
|
||||
message_manager.add_state_message(state)
|
||||
|
||||
try:
|
||||
messages = message_manager.get_messages()
|
||||
except ValueError as e:
|
||||
if 'Max token limit reached - history is too long' in str(e):
|
||||
return # If error occurs, end the test
|
||||
else:
|
||||
raise e
|
||||
|
||||
assert message_manager.state.history.current_tokens <= message_manager.settings.max_input_tokens + 100
|
||||
|
||||
last_msg = messages[-1]
|
||||
assert isinstance(last_msg, HumanMessage)
|
||||
|
||||
if i % 4 == 0:
|
||||
assert isinstance(message_manager.state.history.messages[-2].message, HumanMessage)
|
||||
if i % 2 == 0 and not i % 4 == 0:
|
||||
if isinstance(last_msg.content, list):
|
||||
assert 'Current url: https://test' in last_msg.content[0]['text']
|
||||
else:
|
||||
assert 'Current url: https://test' in last_msg.content
|
||||
|
||||
# Add model output every time
|
||||
from browser_use.agent.views import AgentBrain, AgentOutput
|
||||
from browser_use.controller.registry.views import ActionModel
|
||||
|
||||
output = AgentOutput(
|
||||
current_state=AgentBrain(
|
||||
evaluation_previous_goal=f'Success in step {i}',
|
||||
memory=f'Memory from step {i}',
|
||||
next_goal=f'Goal for step {i + 1}',
|
||||
),
|
||||
action=[ActionModel()],
|
||||
)
|
||||
message_manager._remove_last_state_message()
|
||||
message_manager.add_model_output(output)
|
||||
|
||||
# Get messages and verify after each addition
|
||||
messages = [m.message for m in message_manager.state.history.messages]
|
||||
|
||||
# Verify token limit is respected
|
||||
|
||||
# Verify essential messages are preserved
|
||||
assert isinstance(messages[0], SystemMessage) # System prompt always first
|
||||
assert isinstance(messages[1], HumanMessage) # Task always second
|
||||
assert 'Test task' in messages[1].content
|
||||
|
||||
# Verify structure of latest messages
|
||||
assert isinstance(messages[-1], AIMessage) # Last message should be model output
|
||||
assert f'step {i}' in messages[-1].content # Should contain current step info
|
||||
|
||||
# Log token usage for debugging
|
||||
token_usage = message_manager.state.history.current_tokens
|
||||
token_limit = message_manager.settings.max_input_tokens
|
||||
# print(f'Step {i}: Using {token_usage}/{token_limit} tokens')
|
||||
|
||||
# go through all messages and verify that the token count and total tokens is correct
|
||||
total_tokens = 0
|
||||
real_tokens = []
|
||||
stored_tokens = []
|
||||
for msg in message_manager.state.history.messages:
|
||||
total_tokens += msg.metadata.tokens
|
||||
stored_tokens.append(msg.metadata.tokens)
|
||||
real_tokens.append(message_manager._count_tokens(msg.message))
|
||||
assert total_tokens == sum(real_tokens)
|
||||
assert stored_tokens == real_tokens
|
||||
assert message_manager.state.history.current_tokens == total_tokens
|
||||
|
||||
|
||||
# pytest -s browser_use/agent/message_manager/tests.py
|
||||
147
browser-use/browser_use/agent/message_manager/utils.py
Normal file
147
browser-use/browser_use/agent/message_manager/utils.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_WITHOUT_TOOL_SUPPORT_PATTERNS = [
|
||||
'deepseek-reasoner',
|
||||
'deepseek-r1',
|
||||
'.*gemma.*-it',
|
||||
]
|
||||
|
||||
|
||||
def is_model_without_tool_support(model_name: str) -> bool:
|
||||
return any(re.match(pattern, model_name) for pattern in MODELS_WITHOUT_TOOL_SUPPORT_PATTERNS)
|
||||
|
||||
|
||||
def extract_json_from_model_output(content: str) -> dict:
|
||||
"""Extract JSON from model output, handling both plain JSON and code-block-wrapped JSON."""
|
||||
try:
|
||||
# If content is wrapped in code blocks, extract just the JSON part
|
||||
if '```' in content:
|
||||
# Find the JSON content between code blocks
|
||||
content = content.split('```')[1]
|
||||
# Remove language identifier if present (e.g., 'json\n')
|
||||
if '\n' in content:
|
||||
content = content.split('\n', 1)[1]
|
||||
# Parse the cleaned content
|
||||
result_dict = json.loads(content)
|
||||
|
||||
# some models occasionally respond with a list containing one dict: https://github.com/browser-use/browser-use/issues/1458
|
||||
if isinstance(result_dict, list) and len(result_dict) == 1 and isinstance(result_dict[0], dict):
|
||||
result_dict = result_dict[0]
|
||||
|
||||
assert isinstance(result_dict, dict), f'Expected JSON dictionary in response, got JSON {type(result_dict)} instead'
|
||||
return result_dict
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f'Failed to parse model output: {content} {str(e)}')
|
||||
raise ValueError('Could not parse response.')
|
||||
|
||||
|
||||
def convert_input_messages(input_messages: list[BaseMessage], model_name: str | None) -> list[BaseMessage]:
|
||||
"""Convert input messages to a format that is compatible with the planner model"""
|
||||
if model_name is None:
|
||||
return input_messages
|
||||
|
||||
if is_model_without_tool_support(model_name):
|
||||
converted_input_messages = _convert_messages_for_non_function_calling_models(input_messages)
|
||||
merged_input_messages = _merge_successive_messages(converted_input_messages, HumanMessage)
|
||||
merged_input_messages = _merge_successive_messages(merged_input_messages, AIMessage)
|
||||
return merged_input_messages
|
||||
return input_messages
|
||||
|
||||
|
||||
def _convert_messages_for_non_function_calling_models(input_messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Convert messages for non-function-calling models"""
|
||||
output_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
output_messages.append(message)
|
||||
elif isinstance(message, SystemMessage):
|
||||
output_messages.append(message)
|
||||
elif isinstance(message, ToolMessage):
|
||||
output_messages.append(HumanMessage(content=message.content))
|
||||
elif isinstance(message, AIMessage):
|
||||
# check if tool_calls is a valid JSON object
|
||||
if message.tool_calls:
|
||||
tool_calls = json.dumps(message.tool_calls)
|
||||
output_messages.append(AIMessage(content=tool_calls))
|
||||
else:
|
||||
output_messages.append(message)
|
||||
else:
|
||||
raise ValueError(f'Unknown message type: {type(message)}')
|
||||
return output_messages
|
||||
|
||||
|
||||
def _merge_successive_messages(messages: list[BaseMessage], class_to_merge: type[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."""
|
||||
merged_messages = []
|
||||
streak = 0
|
||||
for message in messages:
|
||||
if isinstance(message, class_to_merge):
|
||||
streak += 1
|
||||
if streak > 1:
|
||||
if isinstance(message.content, list):
|
||||
merged_messages[-1].content += message.content[0]['text'] # type:ignore
|
||||
else:
|
||||
merged_messages[-1].content += message.content
|
||||
else:
|
||||
merged_messages.append(message)
|
||||
else:
|
||||
merged_messages.append(message)
|
||||
streak = 0
|
||||
return merged_messages
|
||||
|
||||
|
||||
def save_conversation(input_messages: list[BaseMessage], response: Any, target: str, encoding: str | None = None) -> None:
|
||||
"""Save conversation history to file."""
|
||||
|
||||
# create folders if not exists
|
||||
if dirname := os.path.dirname(target):
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
with open(
|
||||
target,
|
||||
'w',
|
||||
encoding=encoding,
|
||||
) as f:
|
||||
_write_messages_to_file(f, input_messages)
|
||||
_write_response_to_file(f, response)
|
||||
|
||||
|
||||
def _write_messages_to_file(f: Any, messages: list[BaseMessage]) -> None:
|
||||
"""Write messages to conversation file"""
|
||||
for message in messages:
|
||||
f.write(f' {message.__class__.__name__} \n')
|
||||
|
||||
if isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
f.write(item['text'].strip() + '\n')
|
||||
elif isinstance(message.content, str):
|
||||
try:
|
||||
content = json.loads(message.content)
|
||||
f.write(json.dumps(content, indent=2) + '\n')
|
||||
except json.JSONDecodeError:
|
||||
f.write(message.content.strip() + '\n')
|
||||
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def _write_response_to_file(f: Any, response: Any) -> None:
|
||||
"""Write model response to conversation file"""
|
||||
f.write(' RESPONSE\n')
|
||||
f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))
|
||||
135
browser-use/browser_use/agent/message_manager/views.py
Normal file
135
browser-use/browser_use/agent/message_manager/views.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from warnings import filterwarnings
|
||||
|
||||
from langchain_core._api import LangChainBetaWarning
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
|
||||
|
||||
filterwarnings('ignore', category=LangChainBetaWarning)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from browser_use.agent.views import AgentOutput
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
"""Metadata for a message"""
|
||||
|
||||
tokens: int = 0
|
||||
message_type: str | None = None
|
||||
|
||||
|
||||
class ManagedMessage(BaseModel):
|
||||
"""A message with its metadata"""
|
||||
|
||||
message: BaseMessage
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# https://github.com/pydantic/pydantic/discussions/7558
|
||||
@model_serializer(mode='wrap')
|
||||
def to_json(self, original_dump):
|
||||
"""
|
||||
Returns the JSON representation of the model.
|
||||
|
||||
It uses langchain's `dumps` function to serialize the `message`
|
||||
property before encoding the overall dict with json.dumps.
|
||||
"""
|
||||
data = original_dump(self)
|
||||
|
||||
# NOTE: We override the message field to use langchain JSON serialization.
|
||||
data['message'] = dumpd(self.message)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
value: Any,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: Any | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Custom validator that uses langchain's `loads` function
|
||||
to parse the message if it is provided as a JSON string.
|
||||
"""
|
||||
if isinstance(value, dict) and 'message' in value:
|
||||
# NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object.
|
||||
filterwarnings('ignore', category=LangChainBetaWarning)
|
||||
value['message'] = load(value['message'])
|
||||
return value
|
||||
|
||||
|
||||
class MessageHistory(BaseModel):
|
||||
"""History of messages with metadata"""
|
||||
|
||||
messages: list[ManagedMessage] = Field(default_factory=list)
|
||||
current_tokens: int = 0
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None:
|
||||
"""Add message with metadata to history"""
|
||||
if position is None:
|
||||
self.messages.append(ManagedMessage(message=message, metadata=metadata))
|
||||
else:
|
||||
self.messages.insert(position, ManagedMessage(message=message, metadata=metadata))
|
||||
self.current_tokens += metadata.tokens
|
||||
|
||||
def add_model_output(self, output: AgentOutput) -> None:
|
||||
"""Add model output as AI message"""
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': output.model_dump(mode='json', exclude_unset=True),
|
||||
'id': '1',
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
|
||||
msg = AIMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls
|
||||
|
||||
# Empty tool response
|
||||
tool_message = ToolMessage(content='', tool_call_id='1')
|
||||
self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response
|
||||
|
||||
def get_messages(self) -> list[BaseMessage]:
|
||||
"""Get all messages"""
|
||||
return [m.message for m in self.messages]
|
||||
|
||||
def get_total_tokens(self) -> int:
|
||||
"""Get total tokens in history"""
|
||||
return self.current_tokens
|
||||
|
||||
def remove_oldest_message(self) -> None:
|
||||
"""Remove oldest non-system message"""
|
||||
for i, msg in enumerate(self.messages):
|
||||
if not isinstance(msg.message, SystemMessage):
|
||||
self.current_tokens -= msg.metadata.tokens
|
||||
self.messages.pop(i)
|
||||
break
|
||||
|
||||
def remove_last_state_message(self) -> None:
|
||||
"""Remove last state message from history"""
|
||||
if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage):
|
||||
self.current_tokens -= self.messages[-1].metadata.tokens
|
||||
self.messages.pop()
|
||||
|
||||
|
||||
class MessageManagerState(BaseModel):
|
||||
"""Holds the state for MessageManager"""
|
||||
|
||||
history: MessageHistory = Field(default_factory=MessageHistory)
|
||||
tool_id: int = 1
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
Loading…
Add table
Add a link
Reference in a new issue