[Add] browser-use and main.py
This commit is contained in:
parent
08e64bdf45
commit
96914d44ac
221 changed files with 30952 additions and 1 deletions
370
browser-use/browser_use/agent/gif.py
Normal file
370
browser-use/browser_use/agent/gif.py
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from browser_use.agent.views import AgentHistoryList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image, ImageFont
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decode_unicode_escapes_to_utf8(text: str) -> str:
|
||||
"""Handle decoding any unicode escape sequences embedded in a string (needed to render non-ASCII languages like chinese or arabic in the GIF overlay text)"""
|
||||
|
||||
if r'\u' not in text:
|
||||
# doesn't have any escape sequences that need to be decoded
|
||||
return text
|
||||
|
||||
try:
|
||||
# Try to decode Unicode escape sequences
|
||||
return text.encode('latin1').decode('unicode_escape')
|
||||
except (UnicodeEncodeError, UnicodeDecodeError):
|
||||
# logger.debug(f"Failed to decode unicode escape sequences while generating gif text: {text}")
|
||||
return text
|
||||
|
||||
|
||||
def create_history_gif(
|
||||
task: str,
|
||||
history: AgentHistoryList,
|
||||
#
|
||||
output_path: str = 'agent_history.gif',
|
||||
duration: int = 3000,
|
||||
show_goals: bool = True,
|
||||
show_task: bool = True,
|
||||
show_logo: bool = False,
|
||||
font_size: int = 40,
|
||||
title_font_size: int = 56,
|
||||
goal_font_size: int = 44,
|
||||
margin: int = 40,
|
||||
line_spacing: float = 1.5,
|
||||
) -> None:
|
||||
"""Create a GIF from the agent's history with overlaid task and goal text."""
|
||||
if not history.history:
|
||||
logger.warning('No history to create GIF from')
|
||||
return
|
||||
|
||||
from PIL import Image, ImageFont
|
||||
|
||||
images = []
|
||||
|
||||
# if history is empty or first screenshot is None, we can't create a gif
|
||||
if not history.history or not history.history[0].state.screenshot:
|
||||
logger.warning('No history or first screenshot to create GIF from')
|
||||
return
|
||||
|
||||
# Try to load nicer fonts
|
||||
try:
|
||||
# Try different font options in order of preference
|
||||
# ArialUni is a font that comes with Office and can render most non-alphabet characters
|
||||
font_options = [
|
||||
'Microsoft YaHei', # 微软雅黑
|
||||
'SimHei', # 黑体
|
||||
'SimSun', # 宋体
|
||||
'Noto Sans CJK SC', # 思源黑体
|
||||
'WenQuanYi Micro Hei', # 文泉驿微米黑
|
||||
'Helvetica',
|
||||
'Arial',
|
||||
'DejaVuSans',
|
||||
'Verdana',
|
||||
]
|
||||
font_loaded = False
|
||||
|
||||
for font_name in font_options:
|
||||
try:
|
||||
if platform.system() == 'Windows':
|
||||
# Need to specify the abs font path on Windows
|
||||
font_name = os.path.join(os.getenv('WIN_FONT_DIR', 'C:\\Windows\\Fonts'), font_name + '.ttf')
|
||||
regular_font = ImageFont.truetype(font_name, font_size)
|
||||
title_font = ImageFont.truetype(font_name, title_font_size)
|
||||
goal_font = ImageFont.truetype(font_name, goal_font_size)
|
||||
font_loaded = True
|
||||
break
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
if not font_loaded:
|
||||
raise OSError('No preferred fonts found')
|
||||
|
||||
except OSError:
|
||||
regular_font = ImageFont.load_default()
|
||||
title_font = ImageFont.load_default()
|
||||
|
||||
goal_font = regular_font
|
||||
|
||||
# Load logo if requested
|
||||
logo = None
|
||||
if show_logo:
|
||||
try:
|
||||
logo = Image.open('./static/browser-use.png')
|
||||
# Resize logo to be small (e.g., 40px height)
|
||||
logo_height = 150
|
||||
aspect_ratio = logo.width / logo.height
|
||||
logo_width = int(logo_height * aspect_ratio)
|
||||
logo = logo.resize((logo_width, logo_height), Image.Resampling.LANCZOS)
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not load logo: {e}')
|
||||
|
||||
# Create task frame if requested
|
||||
if show_task and task:
|
||||
task_frame = _create_task_frame(
|
||||
task,
|
||||
history.history[0].state.screenshot,
|
||||
title_font, # type: ignore
|
||||
regular_font, # type: ignore
|
||||
logo,
|
||||
line_spacing,
|
||||
)
|
||||
images.append(task_frame)
|
||||
|
||||
# Process each history item
|
||||
for i, item in enumerate(history.history, 1):
|
||||
if not item.state.screenshot:
|
||||
continue
|
||||
|
||||
# Convert base64 screenshot to PIL Image
|
||||
img_data = base64.b64decode(item.state.screenshot)
|
||||
image = Image.open(io.BytesIO(img_data))
|
||||
|
||||
if show_goals and item.model_output:
|
||||
image = _add_overlay_to_image(
|
||||
image=image,
|
||||
step_number=i,
|
||||
goal_text=item.model_output.current_state.next_goal,
|
||||
regular_font=regular_font, # type: ignore
|
||||
title_font=title_font, # type: ignore
|
||||
margin=margin,
|
||||
logo=logo,
|
||||
)
|
||||
|
||||
images.append(image)
|
||||
|
||||
if images:
|
||||
# Save the GIF
|
||||
images[0].save(
|
||||
output_path,
|
||||
save_all=True,
|
||||
append_images=images[1:],
|
||||
duration=duration,
|
||||
loop=0,
|
||||
optimize=False,
|
||||
)
|
||||
logger.info(f'Created GIF at {output_path}')
|
||||
else:
|
||||
logger.warning('No images found in history to create GIF')
|
||||
|
||||
|
||||
def _create_task_frame(
|
||||
task: str,
|
||||
first_screenshot: str,
|
||||
title_font: ImageFont.FreeTypeFont,
|
||||
regular_font: ImageFont.FreeTypeFont,
|
||||
logo: Image.Image | None = None,
|
||||
line_spacing: float = 1.5,
|
||||
) -> Image.Image:
|
||||
"""Create initial frame showing the task."""
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
img_data = base64.b64decode(first_screenshot)
|
||||
template = Image.open(io.BytesIO(img_data))
|
||||
image = Image.new('RGB', template.size, (0, 0, 0))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Calculate vertical center of image
|
||||
center_y = image.height // 2
|
||||
|
||||
# Draw task text with dynamic font size based on task length
|
||||
margin = 140 # Increased margin
|
||||
max_width = image.width - (2 * margin)
|
||||
|
||||
# Dynamic font size calculation based on task length
|
||||
# Start with base font size (regular + 16)
|
||||
base_font_size = regular_font.size + 16
|
||||
min_font_size = max(regular_font.size - 10, 16) # Don't go below 16pt
|
||||
max_font_size = base_font_size # Cap at the base font size
|
||||
|
||||
# Calculate dynamic font size based on text length and complexity
|
||||
# Longer texts get progressively smaller fonts
|
||||
text_length = len(task)
|
||||
if text_length > 200:
|
||||
# For very long text, reduce font size logarithmically
|
||||
font_size = max(base_font_size - int(10 * (text_length / 200)), min_font_size)
|
||||
else:
|
||||
font_size = base_font_size
|
||||
|
||||
larger_font = ImageFont.truetype(regular_font.path, font_size)
|
||||
|
||||
# Generate wrapped text with the calculated font size
|
||||
wrapped_text = _wrap_text(task, larger_font, max_width)
|
||||
|
||||
# Calculate line height with spacing
|
||||
line_height = larger_font.size * line_spacing
|
||||
|
||||
# Split text into lines and draw with custom spacing
|
||||
lines = wrapped_text.split('\n')
|
||||
total_height = line_height * len(lines)
|
||||
|
||||
# Start position for first line
|
||||
text_y = center_y - (total_height / 2) + 50 # Shifted down slightly
|
||||
|
||||
for line in lines:
|
||||
# Get line width for centering
|
||||
line_bbox = draw.textbbox((0, 0), line, font=larger_font)
|
||||
text_x = (image.width - (line_bbox[2] - line_bbox[0])) // 2
|
||||
|
||||
draw.text(
|
||||
(text_x, text_y),
|
||||
line,
|
||||
font=larger_font,
|
||||
fill=(255, 255, 255),
|
||||
)
|
||||
text_y += line_height
|
||||
|
||||
# Add logo if provided (top right corner)
|
||||
if logo:
|
||||
logo_margin = 20
|
||||
logo_x = image.width - logo.width - logo_margin
|
||||
image.paste(logo, (logo_x, logo_margin), logo if logo.mode == 'RGBA' else None)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _add_overlay_to_image(
|
||||
image: Image.Image,
|
||||
step_number: int,
|
||||
goal_text: str,
|
||||
regular_font: ImageFont.FreeTypeFont,
|
||||
title_font: ImageFont.FreeTypeFont,
|
||||
margin: int,
|
||||
logo: Image.Image | None = None,
|
||||
display_step: bool = True,
|
||||
text_color: tuple[int, int, int, int] = (255, 255, 255, 255),
|
||||
text_box_color: tuple[int, int, int, int] = (0, 0, 0, 255),
|
||||
) -> Image.Image:
|
||||
"""Add step number and goal overlay to an image."""
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
goal_text = decode_unicode_escapes_to_utf8(goal_text)
|
||||
image = image.convert('RGBA')
|
||||
txt_layer = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(txt_layer)
|
||||
if display_step:
|
||||
# Add step number (bottom left)
|
||||
step_text = str(step_number)
|
||||
step_bbox = draw.textbbox((0, 0), step_text, font=title_font)
|
||||
step_width = step_bbox[2] - step_bbox[0]
|
||||
step_height = step_bbox[3] - step_bbox[1]
|
||||
|
||||
# Position step number in bottom left
|
||||
x_step = margin + 10 # Slight additional offset from edge
|
||||
y_step = image.height - margin - step_height - 10 # Slight offset from bottom
|
||||
|
||||
# Draw rounded rectangle background for step number
|
||||
padding = 20 # Increased padding
|
||||
step_bg_bbox = (
|
||||
x_step - padding,
|
||||
y_step - padding,
|
||||
x_step + step_width + padding,
|
||||
y_step + step_height + padding,
|
||||
)
|
||||
draw.rounded_rectangle(
|
||||
step_bg_bbox,
|
||||
radius=15, # Add rounded corners
|
||||
fill=text_box_color,
|
||||
)
|
||||
|
||||
# Draw step number
|
||||
draw.text(
|
||||
(x_step, y_step),
|
||||
step_text,
|
||||
font=title_font,
|
||||
fill=text_color,
|
||||
)
|
||||
|
||||
# Draw goal text (centered, bottom)
|
||||
max_width = image.width - (4 * margin)
|
||||
wrapped_goal = _wrap_text(goal_text, title_font, max_width)
|
||||
goal_bbox = draw.multiline_textbbox((0, 0), wrapped_goal, font=title_font)
|
||||
goal_width = goal_bbox[2] - goal_bbox[0]
|
||||
goal_height = goal_bbox[3] - goal_bbox[1]
|
||||
|
||||
# Center goal text horizontally, place above step number
|
||||
x_goal = (image.width - goal_width) // 2
|
||||
y_goal = y_step - goal_height - padding * 4 # More space between step and goal
|
||||
|
||||
# Draw rounded rectangle background for goal
|
||||
padding_goal = 25 # Increased padding for goal
|
||||
goal_bg_bbox = (
|
||||
x_goal - padding_goal, # Remove extra space for logo
|
||||
y_goal - padding_goal,
|
||||
x_goal + goal_width + padding_goal,
|
||||
y_goal + goal_height + padding_goal,
|
||||
)
|
||||
draw.rounded_rectangle(
|
||||
goal_bg_bbox,
|
||||
radius=15, # Add rounded corners
|
||||
fill=text_box_color,
|
||||
)
|
||||
|
||||
# Draw goal text
|
||||
draw.multiline_text(
|
||||
(x_goal, y_goal),
|
||||
wrapped_goal,
|
||||
font=title_font,
|
||||
fill=text_color,
|
||||
align='center',
|
||||
)
|
||||
|
||||
# Add logo if provided (top right corner)
|
||||
if logo:
|
||||
logo_layer = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
||||
logo_margin = 20
|
||||
logo_x = image.width - logo.width - logo_margin
|
||||
logo_layer.paste(logo, (logo_x, logo_margin), logo if logo.mode == 'RGBA' else None)
|
||||
txt_layer = Image.alpha_composite(logo_layer, txt_layer)
|
||||
|
||||
# Composite and convert
|
||||
result = Image.alpha_composite(image, txt_layer)
|
||||
return result.convert('RGB')
|
||||
|
||||
|
||||
def _wrap_text(text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
|
||||
"""
|
||||
Wrap text to fit within a given width.
|
||||
|
||||
Args:
|
||||
text: Text to wrap
|
||||
font: Font to use for text
|
||||
max_width: Maximum width in pixels
|
||||
|
||||
Returns:
|
||||
Wrapped text with newlines
|
||||
"""
|
||||
text = decode_unicode_escapes_to_utf8(text)
|
||||
words = text.split()
|
||||
lines = []
|
||||
current_line = []
|
||||
|
||||
for word in words:
|
||||
current_line.append(word)
|
||||
line = ' '.join(current_line)
|
||||
bbox = font.getbbox(line)
|
||||
if bbox[2] > max_width:
|
||||
if len(current_line) == 1:
|
||||
lines.append(current_line.pop())
|
||||
else:
|
||||
current_line.pop()
|
||||
lines.append(' '.join(current_line))
|
||||
current_line = [word]
|
||||
|
||||
if current_line:
|
||||
lines.append(' '.join(current_line))
|
||||
|
||||
return '\n'.join(lines)
|
||||
4
browser-use/browser_use/agent/memory/__init__.py
Normal file
4
browser-use/browser_use/agent/memory/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from browser_use.agent.memory.service import Memory
|
||||
from browser_use.agent.memory.views import MemoryConfig
|
||||
|
||||
__all__ = ['Memory', 'MemoryConfig']
|
||||
151
browser-use/browser_use/agent/memory/service.py
Normal file
151
browser-use/browser_use/agent/memory/service.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import convert_to_openai_messages
|
||||
|
||||
from browser_use.agent.memory.views import MemoryConfig
|
||||
from browser_use.agent.message_manager.service import MessageManager
|
||||
from browser_use.agent.message_manager.views import ManagedMessage, MessageMetadata
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Manages procedural memory for agents.
|
||||
|
||||
This class implements a procedural memory management system using Mem0 that transforms agent interaction history
|
||||
into concise, structured representations at specified intervals. It serves to optimize context window
|
||||
utilization during extended task execution by converting verbose historical information into compact,
|
||||
yet comprehensive memory constructs that preserve essential operational knowledge.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_manager: MessageManager,
|
||||
llm: BaseChatModel,
|
||||
config: MemoryConfig | None = None,
|
||||
):
|
||||
self.message_manager = message_manager
|
||||
self.llm = llm
|
||||
|
||||
# Initialize configuration with defaults based on the LLM if not provided
|
||||
if config is None:
|
||||
self.config = MemoryConfig(llm_instance=llm, agent_id=f'agent_{id(self)}')
|
||||
|
||||
# Set appropriate embedder based on LLM type
|
||||
llm_class = llm.__class__.__name__
|
||||
if llm_class == 'ChatOpenAI':
|
||||
self.config.embedder_provider = 'openai'
|
||||
self.config.embedder_model = 'text-embedding-3-small'
|
||||
self.config.embedder_dims = 1536
|
||||
elif llm_class == 'ChatGoogleGenerativeAI':
|
||||
self.config.embedder_provider = 'gemini'
|
||||
self.config.embedder_model = 'models/text-embedding-004'
|
||||
self.config.embedder_dims = 768
|
||||
elif llm_class == 'ChatOllama':
|
||||
self.config.embedder_provider = 'ollama'
|
||||
self.config.embedder_model = 'nomic-embed-text'
|
||||
self.config.embedder_dims = 512
|
||||
else:
|
||||
# Ensure LLM instance is set in the config
|
||||
self.config = MemoryConfig(**dict(config)) # re-validate untrusted user-provided config
|
||||
self.config.llm_instance = llm
|
||||
|
||||
# Check for required packages
|
||||
try:
|
||||
# also disable mem0's telemetry when ANONYMIZED_TELEMETRY=False
|
||||
if os.getenv('ANONYMIZED_TELEMETRY', 'true').lower()[0] in 'fn0':
|
||||
os.environ['MEM0_TELEMETRY'] = 'False'
|
||||
from mem0 import Memory as Mem0Memory
|
||||
except ImportError:
|
||||
raise ImportError('mem0 is required when enable_memory=True. Please install it with `pip install mem0`.')
|
||||
|
||||
if self.config.embedder_provider == 'huggingface':
|
||||
try:
|
||||
# check that required package is installed if huggingface is used
|
||||
from sentence_transformers import SentenceTransformer # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'sentence_transformers is required when enable_memory=True and embedder_provider="huggingface". Please install it with `pip install sentence-transformers`.'
|
||||
)
|
||||
|
||||
# Initialize Mem0 with the configuration
|
||||
self.mem0 = Mem0Memory.from_config(config_dict=self.config.full_config_dict)
|
||||
|
||||
@time_execution_sync('--create_procedural_memory')
|
||||
def create_procedural_memory(self, current_step: int) -> None:
|
||||
"""
|
||||
Create a procedural memory if needed based on the current step.
|
||||
|
||||
Args:
|
||||
current_step: The current step number of the agent
|
||||
"""
|
||||
logger.info(f'Creating procedural memory at step {current_step}')
|
||||
|
||||
# Get all messages
|
||||
all_messages = self.message_manager.state.history.messages
|
||||
|
||||
# Separate messages into those to keep as-is and those to process for memory
|
||||
new_messages = []
|
||||
messages_to_process = []
|
||||
|
||||
for msg in all_messages:
|
||||
if isinstance(msg, ManagedMessage) and msg.metadata.message_type in {'init', 'memory'}:
|
||||
# Keep system and memory messages as they are
|
||||
new_messages.append(msg)
|
||||
else:
|
||||
if len(msg.message.content) > 0:
|
||||
messages_to_process.append(msg)
|
||||
|
||||
# Need at least 2 messages to create a meaningful summary
|
||||
if len(messages_to_process) <= 1:
|
||||
logger.info('Not enough non-memory messages to summarize')
|
||||
return
|
||||
# Create a procedural memory
|
||||
memory_content = self._create([m.message for m in messages_to_process], current_step)
|
||||
|
||||
if not memory_content:
|
||||
logger.warning('Failed to create procedural memory')
|
||||
return
|
||||
|
||||
# Replace the processed messages with the consolidated memory
|
||||
memory_message = HumanMessage(content=memory_content)
|
||||
memory_tokens = self.message_manager._count_tokens(memory_message)
|
||||
memory_metadata = MessageMetadata(tokens=memory_tokens, message_type='memory')
|
||||
|
||||
# Calculate the total tokens being removed
|
||||
removed_tokens = sum(m.metadata.tokens for m in messages_to_process)
|
||||
|
||||
# Add the memory message
|
||||
new_messages.append(ManagedMessage(message=memory_message, metadata=memory_metadata))
|
||||
|
||||
# Update the history
|
||||
self.message_manager.state.history.messages = new_messages
|
||||
self.message_manager.state.history.current_tokens -= removed_tokens
|
||||
self.message_manager.state.history.current_tokens += memory_tokens
|
||||
logger.info(f'Messages consolidated: {len(messages_to_process)} messages converted to procedural memory')
|
||||
|
||||
def _create(self, messages: list[BaseMessage], current_step: int) -> str | None:
|
||||
parsed_messages = convert_to_openai_messages(messages)
|
||||
try:
|
||||
results = self.mem0.add(
|
||||
messages=parsed_messages,
|
||||
agent_id=self.config.agent_id,
|
||||
memory_type='procedural_memory',
|
||||
metadata={'step': current_step},
|
||||
)
|
||||
if len(results.get('results', [])):
|
||||
return results.get('results', [])[0].get('memory')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating procedural memory: {e}')
|
||||
return None
|
||||
67
browser-use/browser_use/agent/memory/views.py
Normal file
67
browser-use/browser_use/agent/memory/views.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""Configuration for procedural memory."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, validate_default=True, revalidate_instances='always', validate_assignment=True
|
||||
)
|
||||
|
||||
# Memory settings
|
||||
agent_id: str = Field(default='browser_use_agent', min_length=1)
|
||||
memory_interval: int = Field(default=10, gt=1, lt=100)
|
||||
|
||||
# Embedder settings
|
||||
embedder_provider: Literal['openai', 'gemini', 'ollama', 'huggingface'] = 'huggingface'
|
||||
embedder_model: str = Field(min_length=2, default='all-MiniLM-L6-v2')
|
||||
embedder_dims: int = Field(default=384, gt=10, lt=10000)
|
||||
|
||||
# LLM settings - the LLM instance can be passed separately
|
||||
llm_provider: Literal['langchain'] = 'langchain'
|
||||
llm_instance: BaseChatModel | None = None
|
||||
|
||||
# Vector store settings
|
||||
vector_store_provider: Literal['faiss'] = 'faiss'
|
||||
vector_store_base_path: str = Field(default='/tmp/mem0')
|
||||
|
||||
@property
|
||||
def vector_store_path(self) -> str:
|
||||
"""Returns the full vector store path for the current configuration. e.g. /tmp/mem0_384_faiss"""
|
||||
return f'{self.vector_store_base_path}_{self.embedder_dims}_{self.vector_store_provider}'
|
||||
|
||||
@property
|
||||
def embedder_config_dict(self) -> dict[str, Any]:
|
||||
"""Returns the embedder configuration dictionary."""
|
||||
return {
|
||||
'provider': self.embedder_provider,
|
||||
'config': {'model': self.embedder_model, 'embedding_dims': self.embedder_dims},
|
||||
}
|
||||
|
||||
@property
|
||||
def llm_config_dict(self) -> dict[str, Any]:
|
||||
"""Returns the LLM configuration dictionary."""
|
||||
return {'provider': self.llm_provider, 'config': {'model': self.llm_instance}}
|
||||
|
||||
@property
|
||||
def vector_store_config_dict(self) -> dict[str, Any]:
|
||||
"""Returns the vector store configuration dictionary."""
|
||||
return {
|
||||
'provider': self.vector_store_provider,
|
||||
'config': {
|
||||
'embedding_model_dims': self.embedder_dims,
|
||||
'path': self.vector_store_path,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def full_config_dict(self) -> dict[str, dict[str, Any]]:
|
||||
"""Returns the complete configuration dictionary for Mem0."""
|
||||
return {
|
||||
'embedder': self.embedder_config_dict,
|
||||
'llm': self.llm_config_dict,
|
||||
'vector_store': self.vector_store_config_dict,
|
||||
}
|
||||
329
browser-use/browser_use/agent/message_manager/service.py
Normal file
329
browser-use/browser_use/agent/message_manager/service.py
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.agent.message_manager.views import MessageMetadata
|
||||
from browser_use.agent.prompts import AgentMessagePrompt
|
||||
from browser_use.agent.views import ActionResult, AgentOutput, AgentStepInfo, MessageManagerState
|
||||
from browser_use.browser.views import BrowserState
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageManagerSettings(BaseModel):
|
||||
max_input_tokens: int = 128000
|
||||
estimated_characters_per_token: int = 3
|
||||
image_tokens: int = 800
|
||||
include_attributes: list[str] = []
|
||||
message_context: str | None = None
|
||||
sensitive_data: dict[str, str] | None = None
|
||||
available_file_paths: list[str] | None = None
|
||||
|
||||
|
||||
class MessageManager:
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
system_message: SystemMessage,
|
||||
settings: MessageManagerSettings = MessageManagerSettings(),
|
||||
state: MessageManagerState = MessageManagerState(),
|
||||
):
|
||||
self.task = task
|
||||
self.settings = settings
|
||||
self.state = state
|
||||
self.system_prompt = system_message
|
||||
|
||||
# Only initialize messages if state is empty
|
||||
if len(self.state.history.messages) == 0:
|
||||
self._init_messages()
|
||||
|
||||
def _init_messages(self) -> None:
|
||||
"""Initialize the message history with system message, context, task, and other initial messages"""
|
||||
self._add_message_with_tokens(self.system_prompt, message_type='init')
|
||||
|
||||
if self.settings.message_context:
|
||||
context_message = HumanMessage(content='Context for the task' + self.settings.message_context)
|
||||
self._add_message_with_tokens(context_message, message_type='init')
|
||||
|
||||
task_message = HumanMessage(
|
||||
content=f'Your ultimate task is: """{self.task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.'
|
||||
)
|
||||
self._add_message_with_tokens(task_message, message_type='init')
|
||||
|
||||
if self.settings.sensitive_data:
|
||||
info = f'Here are placeholders for sensitive data: {list(self.settings.sensitive_data.keys())}'
|
||||
info += '\nTo use them, write <secret>the placeholder name</secret>'
|
||||
info_message = HumanMessage(content=info)
|
||||
self._add_message_with_tokens(info_message, message_type='init')
|
||||
|
||||
placeholder_message = HumanMessage(content='Example output:')
|
||||
self._add_message_with_tokens(placeholder_message, message_type='init')
|
||||
|
||||
example_tool_call = AIMessage(
|
||||
content='',
|
||||
tool_calls=[
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': {
|
||||
'current_state': {
|
||||
'evaluation_previous_goal': """
|
||||
Success - I successfully clicked on the 'Apple' link from the Google Search results page,
|
||||
which directed me to the 'Apple' company homepage. This is a good start toward finding
|
||||
the best place to buy a new iPhone as the Apple website often list iPhones for sale.
|
||||
""".strip(),
|
||||
'memory': """
|
||||
I searched for 'iPhone retailers' on Google. From the Google Search results page,
|
||||
I used the 'click_element_by_index' tool to click on element at index [45] labeled 'Best Buy' but calling
|
||||
the tool did not direct me to a new page. I then used the 'click_element_by_index' tool to click
|
||||
on element at index [82] labeled 'Apple' which redirected me to the 'Apple' company homepage.
|
||||
Currently at step 3/15.
|
||||
""".strip(),
|
||||
'next_goal': """
|
||||
Looking at reported structure of the current page, I can see the item '[127]<h3 iPhone/>'
|
||||
in the content. I think this button will lead to more information and potentially prices
|
||||
for iPhones. I'll click on the link at index [127] using the 'click_element_by_index'
|
||||
tool and hope to see prices on the next page.
|
||||
""".strip(),
|
||||
},
|
||||
'action': [{'click_element_by_index': {'index': 127}}],
|
||||
},
|
||||
'id': str(self.state.tool_id),
|
||||
'type': 'tool_call',
|
||||
},
|
||||
],
|
||||
)
|
||||
self._add_message_with_tokens(example_tool_call, message_type='init')
|
||||
self.add_tool_message(content='Browser started', message_type='init')
|
||||
|
||||
placeholder_message = HumanMessage(content='[Your task history memory starts here]')
|
||||
self._add_message_with_tokens(placeholder_message)
|
||||
|
||||
if self.settings.available_file_paths:
|
||||
filepaths_msg = HumanMessage(content=f'Here are file paths you can use: {self.settings.available_file_paths}')
|
||||
self._add_message_with_tokens(filepaths_msg, message_type='init')
|
||||
|
||||
def add_new_task(self, new_task: str) -> None:
|
||||
content = f'Your new ultimate task is: """{new_task}""". Take the previous context into account and finish your new ultimate task. '
|
||||
msg = HumanMessage(content=content)
|
||||
self._add_message_with_tokens(msg)
|
||||
self.task = new_task
|
||||
|
||||
@time_execution_sync('--add_state_message')
|
||||
def add_state_message(
|
||||
self,
|
||||
state: BrowserState,
|
||||
result: list[ActionResult] | None = None,
|
||||
step_info: AgentStepInfo | None = None,
|
||||
use_vision=True,
|
||||
) -> None:
|
||||
"""Add browser state as human message"""
|
||||
|
||||
# if keep in memory, add to directly to history and add state without result
|
||||
if result:
|
||||
for r in result:
|
||||
if r.include_in_memory:
|
||||
if r.extracted_content:
|
||||
msg = HumanMessage(content='Action result: ' + str(r.extracted_content))
|
||||
self._add_message_with_tokens(msg)
|
||||
if r.error:
|
||||
# if endswith \n, remove it
|
||||
if r.error.endswith('\n'):
|
||||
r.error = r.error[:-1]
|
||||
# get only last line of error
|
||||
last_line = r.error.split('\n')[-1]
|
||||
msg = HumanMessage(content='Action error: ' + last_line)
|
||||
self._add_message_with_tokens(msg)
|
||||
result = None # if result in history, we dont want to add it again
|
||||
|
||||
# otherwise add state message and result to next message (which will not stay in memory)
|
||||
state_message = AgentMessagePrompt(
|
||||
state,
|
||||
result,
|
||||
include_attributes=self.settings.include_attributes,
|
||||
step_info=step_info,
|
||||
).get_user_message(use_vision)
|
||||
self._add_message_with_tokens(state_message)
|
||||
|
||||
def add_model_output(self, model_output: AgentOutput) -> None:
|
||||
"""Add model output as AI message"""
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': model_output.model_dump(mode='json', exclude_unset=True),
|
||||
'id': str(self.state.tool_id),
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
|
||||
msg = AIMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
self._add_message_with_tokens(msg)
|
||||
# empty tool response
|
||||
self.add_tool_message(content='')
|
||||
|
||||
def add_plan(self, plan: str | None, position: int | None = None) -> None:
|
||||
if plan:
|
||||
msg = AIMessage(content=plan)
|
||||
self._add_message_with_tokens(msg, position)
|
||||
|
||||
@time_execution_sync('--get_messages')
|
||||
def get_messages(self) -> list[BaseMessage]:
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
|
||||
msg = [m.message for m in self.state.history.messages]
|
||||
# debug which messages are in history with token count # log
|
||||
total_input_tokens = 0
|
||||
logger.debug(f'Messages in history: {len(self.state.history.messages)}:')
|
||||
for m in self.state.history.messages:
|
||||
total_input_tokens += m.metadata.tokens
|
||||
logger.debug(f'{m.message.__class__.__name__} - Token count: {m.metadata.tokens}')
|
||||
logger.debug(f'Total input tokens: {total_input_tokens}')
|
||||
|
||||
return msg
|
||||
|
||||
def _add_message_with_tokens(
|
||||
self, message: BaseMessage, position: int | None = None, message_type: str | None = None
|
||||
) -> None:
|
||||
"""Add message with token count metadata
|
||||
position: None for last, -1 for second last, etc.
|
||||
"""
|
||||
|
||||
# filter out sensitive data from the message
|
||||
if self.settings.sensitive_data:
|
||||
message = self._filter_sensitive_data(message)
|
||||
|
||||
token_count = self._count_tokens(message)
|
||||
metadata = MessageMetadata(tokens=token_count, message_type=message_type)
|
||||
self.state.history.add_message(message, metadata, position)
|
||||
|
||||
@time_execution_sync('--filter_sensitive_data')
|
||||
def _filter_sensitive_data(self, message: BaseMessage) -> BaseMessage:
|
||||
"""Filter out sensitive data from the message"""
|
||||
|
||||
def replace_sensitive(value: str) -> str:
|
||||
if not self.settings.sensitive_data:
|
||||
return value
|
||||
|
||||
# Create a dictionary with all key-value pairs from sensitive_data where value is not None or empty
|
||||
valid_sensitive_data = {k: v for k, v in self.settings.sensitive_data.items() if v}
|
||||
|
||||
# If there are no valid sensitive data entries, just return the original value
|
||||
if not valid_sensitive_data:
|
||||
logger.warning('No valid entries found in sensitive_data dictionary')
|
||||
return value
|
||||
|
||||
# Replace all valid sensitive data values with their placeholder tags
|
||||
for key, val in valid_sensitive_data.items():
|
||||
value = value.replace(val, f'<secret>{key}</secret>')
|
||||
|
||||
return value
|
||||
|
||||
if isinstance(message.content, str):
|
||||
message.content = replace_sensitive(message.content)
|
||||
elif isinstance(message.content, list):
|
||||
for i, item in enumerate(message.content):
|
||||
if isinstance(item, dict) and 'text' in item:
|
||||
item['text'] = replace_sensitive(item['text'])
|
||||
message.content[i] = item
|
||||
return message
|
||||
|
||||
def _count_tokens(self, message: BaseMessage) -> int:
|
||||
"""Count tokens in a message using the model's tokenizer"""
|
||||
tokens = 0
|
||||
if isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if 'image_url' in item:
|
||||
tokens += self.settings.image_tokens
|
||||
elif isinstance(item, dict) and 'text' in item:
|
||||
tokens += self._count_text_tokens(item['text'])
|
||||
else:
|
||||
msg = message.content
|
||||
if hasattr(message, 'tool_calls'):
|
||||
msg += str(message.tool_calls) # type: ignore
|
||||
tokens += self._count_text_tokens(msg)
|
||||
return tokens
|
||||
|
||||
def _count_text_tokens(self, text: str) -> int:
|
||||
"""Count tokens in a text string"""
|
||||
tokens = len(text) // self.settings.estimated_characters_per_token # Rough estimate if no tokenizer available
|
||||
return tokens
|
||||
|
||||
def cut_messages(self):
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
diff = self.state.history.current_tokens - self.settings.max_input_tokens
|
||||
if diff <= 0:
|
||||
return None
|
||||
|
||||
msg = self.state.history.messages[-1]
|
||||
|
||||
# if list with image remove image
|
||||
if isinstance(msg.message.content, list):
|
||||
text = ''
|
||||
for item in msg.message.content:
|
||||
if 'image_url' in item:
|
||||
msg.message.content.remove(item)
|
||||
diff -= self.settings.image_tokens
|
||||
msg.metadata.tokens -= self.settings.image_tokens
|
||||
self.state.history.current_tokens -= self.settings.image_tokens
|
||||
logger.debug(
|
||||
f'Removed image with {self.settings.image_tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens}'
|
||||
)
|
||||
elif 'text' in item and isinstance(item, dict):
|
||||
text += item['text']
|
||||
msg.message.content = text
|
||||
self.state.history.messages[-1] = msg
|
||||
|
||||
if diff <= 0:
|
||||
return None
|
||||
|
||||
# if still over, remove text from state message proportionally to the number of tokens needed with buffer
|
||||
# Calculate the proportion of content to remove
|
||||
proportion_to_remove = diff / msg.metadata.tokens
|
||||
if proportion_to_remove > 0.99:
|
||||
raise ValueError(
|
||||
f'Max token limit reached - history is too long - reduce the system prompt or task. '
|
||||
f'proportion_to_remove: {proportion_to_remove}'
|
||||
)
|
||||
logger.debug(
|
||||
f'Removing {proportion_to_remove * 100:.2f}% of the last message {proportion_to_remove * msg.metadata.tokens:.2f} / {msg.metadata.tokens:.2f} tokens)'
|
||||
)
|
||||
|
||||
content = msg.message.content
|
||||
characters_to_remove = int(len(content) * proportion_to_remove)
|
||||
content = content[:-characters_to_remove]
|
||||
|
||||
# remove tokens and old long message
|
||||
self.state.history.remove_last_state_message()
|
||||
|
||||
# new message with updated content
|
||||
msg = HumanMessage(content=content)
|
||||
self._add_message_with_tokens(msg)
|
||||
|
||||
last_msg = self.state.history.messages[-1]
|
||||
|
||||
logger.debug(
|
||||
f'Added message with {last_msg.metadata.tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens} - total messages: {len(self.state.history.messages)}'
|
||||
)
|
||||
|
||||
def _remove_last_state_message(self) -> None:
|
||||
"""Remove last state message from history"""
|
||||
self.state.history.remove_last_state_message()
|
||||
|
||||
def add_tool_message(self, content: str, message_type: str | None = None) -> None:
|
||||
"""Add tool message to history"""
|
||||
msg = ToolMessage(content=content, tool_call_id=str(self.state.tool_id))
|
||||
self.state.tool_id += 1
|
||||
self._add_message_with_tokens(msg, message_type=message_type)
|
||||
237
browser-use/browser_use/agent/message_manager/tests.py
Normal file
237
browser-use/browser_use/agent/message_manager/tests.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import pytest
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
from browser_use.agent.message_manager.service import MessageManager, MessageManagerSettings
|
||||
from browser_use.agent.views import ActionResult
|
||||
from browser_use.browser.views import BrowserState, TabInfo
|
||||
from browser_use.dom.views import DOMElementNode, DOMTextNode
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
ChatOpenAI(model='gpt-4o-mini'),
|
||||
AzureChatOpenAI(model='gpt-4o', api_version='2024-02-15-preview'),
|
||||
ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=100, temperature=0.0, stop=None),
|
||||
],
|
||||
ids=['gpt-4o-mini', 'gpt-4o', 'claude-3-5-sonnet'],
|
||||
)
|
||||
def message_manager(request: pytest.FixtureRequest):
|
||||
task = 'Test task'
|
||||
action_descriptions = 'Test actions'
|
||||
return MessageManager(
|
||||
task=task,
|
||||
system_message=SystemMessage(content=action_descriptions),
|
||||
settings=MessageManagerSettings(
|
||||
max_input_tokens=1000,
|
||||
estimated_characters_per_token=3,
|
||||
image_tokens=800,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_initial_messages(message_manager: MessageManager):
|
||||
"""Test that message manager initializes with system and task messages"""
|
||||
messages = message_manager.get_messages()
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], SystemMessage)
|
||||
assert isinstance(messages[1], HumanMessage)
|
||||
assert 'Test task' in messages[1].content
|
||||
|
||||
|
||||
def test_add_state_message(message_manager: MessageManager):
|
||||
"""Test adding browser state message"""
|
||||
state = BrowserState(
|
||||
url='https://test.com',
|
||||
title='Test Page',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={},
|
||||
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
|
||||
)
|
||||
message_manager.add_state_message(state)
|
||||
|
||||
messages = message_manager.get_messages()
|
||||
assert len(messages) == 3
|
||||
assert isinstance(messages[2], HumanMessage)
|
||||
assert 'https://test.com' in messages[2].content
|
||||
|
||||
|
||||
def test_add_state_with_memory_result(message_manager: MessageManager):
|
||||
"""Test adding state with result that should be included in memory"""
|
||||
state = BrowserState(
|
||||
url='https://test.com',
|
||||
title='Test Page',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={},
|
||||
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
|
||||
)
|
||||
result = ActionResult(extracted_content='Important content', include_in_memory=True)
|
||||
|
||||
message_manager.add_state_message(state, [result])
|
||||
messages = message_manager.get_messages()
|
||||
|
||||
# Should have system, task, extracted content, and state messages
|
||||
assert len(messages) == 4
|
||||
assert 'Important content' in messages[2].content
|
||||
assert isinstance(messages[2], HumanMessage)
|
||||
assert isinstance(messages[3], HumanMessage)
|
||||
assert 'Important content' not in messages[3].content
|
||||
|
||||
|
||||
def test_add_state_with_non_memory_result(message_manager: MessageManager):
|
||||
"""Test adding state with result that should not be included in memory"""
|
||||
state = BrowserState(
|
||||
url='https://test.com',
|
||||
title='Test Page',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={},
|
||||
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
|
||||
)
|
||||
result = ActionResult(extracted_content='Temporary content', include_in_memory=False)
|
||||
|
||||
message_manager.add_state_message(state, [result])
|
||||
messages = message_manager.get_messages()
|
||||
|
||||
# Should have system, task, and combined state+result message
|
||||
assert len(messages) == 3
|
||||
assert 'Temporary content' in messages[2].content
|
||||
assert isinstance(messages[2], HumanMessage)
|
||||
|
||||
|
||||
@pytest.mark.skip('not sure how to fix this')
|
||||
@pytest.mark.parametrize('max_tokens', [100000, 10000, 5000])
|
||||
def test_token_overflow_handling_with_real_flow(message_manager: MessageManager, max_tokens):
|
||||
"""Test handling of token overflow in a realistic message flow"""
|
||||
# Set more realistic token limit
|
||||
message_manager.settings.max_input_tokens = max_tokens
|
||||
|
||||
# Create a long sequence of interactions
|
||||
for i in range(200): # Simulate 40 steps of interaction
|
||||
# Create state with varying content length
|
||||
state = BrowserState(
|
||||
url=f'https://test{i}.com',
|
||||
title=f'Test Page {i}',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='div',
|
||||
attributes={},
|
||||
children=[
|
||||
DOMTextNode(
|
||||
text=f'Content {j} ' * (10 + i), # Increasing content length
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
)
|
||||
for j in range(5) # Multiple DOM items
|
||||
],
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='//div',
|
||||
),
|
||||
selector_map={j: f'//div[{j}]' for j in range(5)},
|
||||
tabs=[TabInfo(page_id=1, url=f'https://test{i}.com', title=f'Test Page {i}')],
|
||||
)
|
||||
|
||||
# Alternate between different types of results
|
||||
result = None
|
||||
if i % 2 == 0: # Every other iteration
|
||||
result = ActionResult(
|
||||
extracted_content=f'Important content from step {i}' * 5,
|
||||
include_in_memory=i % 4 == 0, # Include in memory every 4th message
|
||||
)
|
||||
|
||||
# Add state message
|
||||
if result:
|
||||
message_manager.add_state_message(state, [result])
|
||||
else:
|
||||
message_manager.add_state_message(state)
|
||||
|
||||
try:
|
||||
messages = message_manager.get_messages()
|
||||
except ValueError as e:
|
||||
if 'Max token limit reached - history is too long' in str(e):
|
||||
return # If error occurs, end the test
|
||||
else:
|
||||
raise e
|
||||
|
||||
assert message_manager.state.history.current_tokens <= message_manager.settings.max_input_tokens + 100
|
||||
|
||||
last_msg = messages[-1]
|
||||
assert isinstance(last_msg, HumanMessage)
|
||||
|
||||
if i % 4 == 0:
|
||||
assert isinstance(message_manager.state.history.messages[-2].message, HumanMessage)
|
||||
if i % 2 == 0 and not i % 4 == 0:
|
||||
if isinstance(last_msg.content, list):
|
||||
assert 'Current url: https://test' in last_msg.content[0]['text']
|
||||
else:
|
||||
assert 'Current url: https://test' in last_msg.content
|
||||
|
||||
# Add model output every time
|
||||
from browser_use.agent.views import AgentBrain, AgentOutput
|
||||
from browser_use.controller.registry.views import ActionModel
|
||||
|
||||
output = AgentOutput(
|
||||
current_state=AgentBrain(
|
||||
evaluation_previous_goal=f'Success in step {i}',
|
||||
memory=f'Memory from step {i}',
|
||||
next_goal=f'Goal for step {i + 1}',
|
||||
),
|
||||
action=[ActionModel()],
|
||||
)
|
||||
message_manager._remove_last_state_message()
|
||||
message_manager.add_model_output(output)
|
||||
|
||||
# Get messages and verify after each addition
|
||||
messages = [m.message for m in message_manager.state.history.messages]
|
||||
|
||||
# Verify token limit is respected
|
||||
|
||||
# Verify essential messages are preserved
|
||||
assert isinstance(messages[0], SystemMessage) # System prompt always first
|
||||
assert isinstance(messages[1], HumanMessage) # Task always second
|
||||
assert 'Test task' in messages[1].content
|
||||
|
||||
# Verify structure of latest messages
|
||||
assert isinstance(messages[-1], AIMessage) # Last message should be model output
|
||||
assert f'step {i}' in messages[-1].content # Should contain current step info
|
||||
|
||||
# Log token usage for debugging
|
||||
token_usage = message_manager.state.history.current_tokens
|
||||
token_limit = message_manager.settings.max_input_tokens
|
||||
# print(f'Step {i}: Using {token_usage}/{token_limit} tokens')
|
||||
|
||||
# go through all messages and verify that the token count and total tokens is correct
|
||||
total_tokens = 0
|
||||
real_tokens = []
|
||||
stored_tokens = []
|
||||
for msg in message_manager.state.history.messages:
|
||||
total_tokens += msg.metadata.tokens
|
||||
stored_tokens.append(msg.metadata.tokens)
|
||||
real_tokens.append(message_manager._count_tokens(msg.message))
|
||||
assert total_tokens == sum(real_tokens)
|
||||
assert stored_tokens == real_tokens
|
||||
assert message_manager.state.history.current_tokens == total_tokens
|
||||
|
||||
|
||||
# pytest -s browser_use/agent/message_manager/tests.py
|
||||
147
browser-use/browser_use/agent/message_manager/utils.py
Normal file
147
browser-use/browser_use/agent/message_manager/utils.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODELS_WITHOUT_TOOL_SUPPORT_PATTERNS = [
|
||||
'deepseek-reasoner',
|
||||
'deepseek-r1',
|
||||
'.*gemma.*-it',
|
||||
]
|
||||
|
||||
|
||||
def is_model_without_tool_support(model_name: str) -> bool:
|
||||
return any(re.match(pattern, model_name) for pattern in MODELS_WITHOUT_TOOL_SUPPORT_PATTERNS)
|
||||
|
||||
|
||||
def extract_json_from_model_output(content: str) -> dict:
|
||||
"""Extract JSON from model output, handling both plain JSON and code-block-wrapped JSON."""
|
||||
try:
|
||||
# If content is wrapped in code blocks, extract just the JSON part
|
||||
if '```' in content:
|
||||
# Find the JSON content between code blocks
|
||||
content = content.split('```')[1]
|
||||
# Remove language identifier if present (e.g., 'json\n')
|
||||
if '\n' in content:
|
||||
content = content.split('\n', 1)[1]
|
||||
# Parse the cleaned content
|
||||
result_dict = json.loads(content)
|
||||
|
||||
# some models occasionally respond with a list containing one dict: https://github.com/browser-use/browser-use/issues/1458
|
||||
if isinstance(result_dict, list) and len(result_dict) == 1 and isinstance(result_dict[0], dict):
|
||||
result_dict = result_dict[0]
|
||||
|
||||
assert isinstance(result_dict, dict), f'Expected JSON dictionary in response, got JSON {type(result_dict)} instead'
|
||||
return result_dict
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f'Failed to parse model output: {content} {str(e)}')
|
||||
raise ValueError('Could not parse response.')
|
||||
|
||||
|
||||
def convert_input_messages(input_messages: list[BaseMessage], model_name: str | None) -> list[BaseMessage]:
|
||||
"""Convert input messages to a format that is compatible with the planner model"""
|
||||
if model_name is None:
|
||||
return input_messages
|
||||
|
||||
if is_model_without_tool_support(model_name):
|
||||
converted_input_messages = _convert_messages_for_non_function_calling_models(input_messages)
|
||||
merged_input_messages = _merge_successive_messages(converted_input_messages, HumanMessage)
|
||||
merged_input_messages = _merge_successive_messages(merged_input_messages, AIMessage)
|
||||
return merged_input_messages
|
||||
return input_messages
|
||||
|
||||
|
||||
def _convert_messages_for_non_function_calling_models(input_messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Convert messages for non-function-calling models"""
|
||||
output_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
output_messages.append(message)
|
||||
elif isinstance(message, SystemMessage):
|
||||
output_messages.append(message)
|
||||
elif isinstance(message, ToolMessage):
|
||||
output_messages.append(HumanMessage(content=message.content))
|
||||
elif isinstance(message, AIMessage):
|
||||
# check if tool_calls is a valid JSON object
|
||||
if message.tool_calls:
|
||||
tool_calls = json.dumps(message.tool_calls)
|
||||
output_messages.append(AIMessage(content=tool_calls))
|
||||
else:
|
||||
output_messages.append(message)
|
||||
else:
|
||||
raise ValueError(f'Unknown message type: {type(message)}')
|
||||
return output_messages
|
||||
|
||||
|
||||
def _merge_successive_messages(messages: list[BaseMessage], class_to_merge: type[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."""
|
||||
merged_messages = []
|
||||
streak = 0
|
||||
for message in messages:
|
||||
if isinstance(message, class_to_merge):
|
||||
streak += 1
|
||||
if streak > 1:
|
||||
if isinstance(message.content, list):
|
||||
merged_messages[-1].content += message.content[0]['text'] # type:ignore
|
||||
else:
|
||||
merged_messages[-1].content += message.content
|
||||
else:
|
||||
merged_messages.append(message)
|
||||
else:
|
||||
merged_messages.append(message)
|
||||
streak = 0
|
||||
return merged_messages
|
||||
|
||||
|
||||
def save_conversation(input_messages: list[BaseMessage], response: Any, target: str, encoding: str | None = None) -> None:
|
||||
"""Save conversation history to file."""
|
||||
|
||||
# create folders if not exists
|
||||
if dirname := os.path.dirname(target):
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
with open(
|
||||
target,
|
||||
'w',
|
||||
encoding=encoding,
|
||||
) as f:
|
||||
_write_messages_to_file(f, input_messages)
|
||||
_write_response_to_file(f, response)
|
||||
|
||||
|
||||
def _write_messages_to_file(f: Any, messages: list[BaseMessage]) -> None:
|
||||
"""Write messages to conversation file"""
|
||||
for message in messages:
|
||||
f.write(f' {message.__class__.__name__} \n')
|
||||
|
||||
if isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
f.write(item['text'].strip() + '\n')
|
||||
elif isinstance(message.content, str):
|
||||
try:
|
||||
content = json.loads(message.content)
|
||||
f.write(json.dumps(content, indent=2) + '\n')
|
||||
except json.JSONDecodeError:
|
||||
f.write(message.content.strip() + '\n')
|
||||
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def _write_response_to_file(f: Any, response: Any) -> None:
|
||||
"""Write model response to conversation file"""
|
||||
f.write(' RESPONSE\n')
|
||||
f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))
|
||||
135
browser-use/browser_use/agent/message_manager/views.py
Normal file
135
browser-use/browser_use/agent/message_manager/views.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from warnings import filterwarnings
|
||||
|
||||
from langchain_core._api import LangChainBetaWarning
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
|
||||
|
||||
filterwarnings('ignore', category=LangChainBetaWarning)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from browser_use.agent.views import AgentOutput
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
"""Metadata for a message"""
|
||||
|
||||
tokens: int = 0
|
||||
message_type: str | None = None
|
||||
|
||||
|
||||
class ManagedMessage(BaseModel):
|
||||
"""A message with its metadata"""
|
||||
|
||||
message: BaseMessage
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# https://github.com/pydantic/pydantic/discussions/7558
|
||||
@model_serializer(mode='wrap')
|
||||
def to_json(self, original_dump):
|
||||
"""
|
||||
Returns the JSON representation of the model.
|
||||
|
||||
It uses langchain's `dumps` function to serialize the `message`
|
||||
property before encoding the overall dict with json.dumps.
|
||||
"""
|
||||
data = original_dump(self)
|
||||
|
||||
# NOTE: We override the message field to use langchain JSON serialization.
|
||||
data['message'] = dumpd(self.message)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def validate(
|
||||
cls,
|
||||
value: Any,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: Any | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Custom validator that uses langchain's `loads` function
|
||||
to parse the message if it is provided as a JSON string.
|
||||
"""
|
||||
if isinstance(value, dict) and 'message' in value:
|
||||
# NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object.
|
||||
filterwarnings('ignore', category=LangChainBetaWarning)
|
||||
value['message'] = load(value['message'])
|
||||
return value
|
||||
|
||||
|
||||
class MessageHistory(BaseModel):
|
||||
"""History of messages with metadata"""
|
||||
|
||||
messages: list[ManagedMessage] = Field(default_factory=list)
|
||||
current_tokens: int = 0
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None:
|
||||
"""Add message with metadata to history"""
|
||||
if position is None:
|
||||
self.messages.append(ManagedMessage(message=message, metadata=metadata))
|
||||
else:
|
||||
self.messages.insert(position, ManagedMessage(message=message, metadata=metadata))
|
||||
self.current_tokens += metadata.tokens
|
||||
|
||||
def add_model_output(self, output: AgentOutput) -> None:
|
||||
"""Add model output as AI message"""
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': output.model_dump(mode='json', exclude_unset=True),
|
||||
'id': '1',
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
|
||||
msg = AIMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls
|
||||
|
||||
# Empty tool response
|
||||
tool_message = ToolMessage(content='', tool_call_id='1')
|
||||
self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response
|
||||
|
||||
def get_messages(self) -> list[BaseMessage]:
|
||||
"""Get all messages"""
|
||||
return [m.message for m in self.messages]
|
||||
|
||||
def get_total_tokens(self) -> int:
|
||||
"""Get total tokens in history"""
|
||||
return self.current_tokens
|
||||
|
||||
def remove_oldest_message(self) -> None:
|
||||
"""Remove oldest non-system message"""
|
||||
for i, msg in enumerate(self.messages):
|
||||
if not isinstance(msg.message, SystemMessage):
|
||||
self.current_tokens -= msg.metadata.tokens
|
||||
self.messages.pop(i)
|
||||
break
|
||||
|
||||
def remove_last_state_message(self) -> None:
|
||||
"""Remove last state message from history"""
|
||||
if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage):
|
||||
self.current_tokens -= self.messages[-1].metadata.tokens
|
||||
self.messages.pop()
|
||||
|
||||
|
||||
class MessageManagerState(BaseModel):
|
||||
"""Holds the state for MessageManager"""
|
||||
|
||||
history: MessageHistory = Field(default_factory=MessageHistory)
|
||||
tool_id: int = 1
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
722
browser-use/browser_use/agent/playwright_script_generator.py
Normal file
722
browser-use/browser_use/agent/playwright_script_generator.py
Normal file
|
|
@ -0,0 +1,722 @@
|
|||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from browser_use.browser.browser import BrowserConfig
|
||||
from browser_use.browser.context import BrowserContextConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlaywrightScriptGenerator:
|
||||
"""Generates a Playwright script from AgentHistoryList."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_list: list[dict[str, Any]],
|
||||
sensitive_data_keys: list[str] | None = None,
|
||||
browser_config: BrowserConfig | None = None,
|
||||
context_config: BrowserContextConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Initializes the script generator.
|
||||
|
||||
Args:
|
||||
history_list: A list of dictionaries, where each dictionary represents an AgentHistory item.
|
||||
Expected to be raw dictionaries from `AgentHistoryList.model_dump()`.
|
||||
sensitive_data_keys: A list of keys used as placeholders for sensitive data.
|
||||
browser_config: Configuration from the original Browser instance.
|
||||
context_config: Configuration from the original BrowserContext instance.
|
||||
"""
|
||||
self.history = history_list
|
||||
self.sensitive_data_keys = sensitive_data_keys or []
|
||||
self.browser_config = browser_config
|
||||
self.context_config = context_config
|
||||
self._imports_helpers_added = False
|
||||
self._page_counter = 0 # Track pages for tab management
|
||||
|
||||
# Dictionary mapping action types to handler methods
|
||||
self._action_handlers = {
|
||||
'go_to_url': self._map_go_to_url,
|
||||
'wait': self._map_wait,
|
||||
'input_text': self._map_input_text,
|
||||
'click_element': self._map_click_element,
|
||||
'click_element_by_index': self._map_click_element, # Map legacy action
|
||||
'scroll_down': self._map_scroll_down,
|
||||
'scroll_up': self._map_scroll_up,
|
||||
'send_keys': self._map_send_keys,
|
||||
'go_back': self._map_go_back,
|
||||
'open_tab': self._map_open_tab,
|
||||
'close_tab': self._map_close_tab,
|
||||
'switch_tab': self._map_switch_tab,
|
||||
'search_google': self._map_search_google,
|
||||
'drag_drop': self._map_drag_drop,
|
||||
'extract_content': self._map_extract_content,
|
||||
'click_download_button': self._map_click_download_button,
|
||||
'done': self._map_done,
|
||||
}
|
||||
|
||||
def _generate_browser_launch_args(self) -> str:
|
||||
"""Generates the arguments string for browser launch based on BrowserConfig."""
|
||||
if not self.browser_config:
|
||||
# Default launch if no config provided
|
||||
return 'headless=False'
|
||||
|
||||
args_dict = {
|
||||
'headless': self.browser_config.headless,
|
||||
# Add other relevant launch options here based on self.browser_config
|
||||
# Example: 'proxy': self.browser_config.proxy.model_dump() if self.browser_config.proxy else None
|
||||
# Example: 'args': self.browser_config.extra_browser_args # Be careful inheriting args
|
||||
}
|
||||
if self.browser_config.proxy:
|
||||
args_dict['proxy'] = self.browser_config.proxy.model_dump()
|
||||
|
||||
# Filter out None values
|
||||
args_dict = {k: v for k, v in args_dict.items() if v is not None}
|
||||
|
||||
# Format as keyword arguments string
|
||||
args_str = ', '.join(f'{key}={repr(value)}' for key, value in args_dict.items())
|
||||
return args_str
|
||||
|
||||
def _generate_context_options(self) -> str:
|
||||
"""Generates the options string for context creation based on BrowserContextConfig."""
|
||||
if not self.context_config:
|
||||
return '' # Default context
|
||||
|
||||
options_dict = {}
|
||||
|
||||
# Map relevant BrowserContextConfig fields to Playwright context options
|
||||
if self.context_config.user_agent:
|
||||
options_dict['user_agent'] = self.context_config.user_agent
|
||||
if self.context_config.locale:
|
||||
options_dict['locale'] = self.context_config.locale
|
||||
if self.context_config.permissions:
|
||||
options_dict['permissions'] = self.context_config.permissions
|
||||
if self.context_config.geolocation:
|
||||
options_dict['geolocation'] = self.context_config.geolocation
|
||||
if self.context_config.timezone_id:
|
||||
options_dict['timezone_id'] = self.context_config.timezone_id
|
||||
if self.context_config.http_credentials:
|
||||
options_dict['http_credentials'] = self.context_config.http_credentials
|
||||
if self.context_config.is_mobile is not None:
|
||||
options_dict['is_mobile'] = self.context_config.is_mobile
|
||||
if self.context_config.has_touch is not None:
|
||||
options_dict['has_touch'] = self.context_config.has_touch
|
||||
if self.context_config.save_recording_path:
|
||||
options_dict['record_video_dir'] = self.context_config.save_recording_path
|
||||
if self.context_config.save_har_path:
|
||||
options_dict['record_har_path'] = self.context_config.save_har_path
|
||||
|
||||
# Handle viewport/window size
|
||||
if self.context_config.no_viewport:
|
||||
options_dict['no_viewport'] = True
|
||||
elif hasattr(self.context_config, 'window_width') and hasattr(self.context_config, 'window_height'):
|
||||
options_dict['viewport'] = {
|
||||
'width': self.context_config.window_width,
|
||||
'height': self.context_config.window_height,
|
||||
}
|
||||
|
||||
# Note: cookies_file and save_downloads_path are handled separately
|
||||
|
||||
# Filter out None values
|
||||
options_dict = {k: v for k, v in options_dict.items() if v is not None}
|
||||
|
||||
# Format as keyword arguments string
|
||||
options_str = ', '.join(f'{key}={repr(value)}' for key, value in options_dict.items())
|
||||
return options_str
|
||||
|
||||
def _get_imports_and_helpers(self) -> list[str]:
|
||||
"""Generates necessary import statements (excluding helper functions)."""
|
||||
# Return only the standard imports needed by the main script body
|
||||
return [
|
||||
'import asyncio',
|
||||
'import json',
|
||||
'import os',
|
||||
'import sys',
|
||||
'from pathlib import Path', # Added Path import
|
||||
'import urllib.parse', # Needed for search_google
|
||||
'from playwright.async_api import async_playwright, Page, BrowserContext', # Added BrowserContext
|
||||
'from dotenv import load_dotenv',
|
||||
'',
|
||||
'# Load environment variables',
|
||||
'load_dotenv(override=True)',
|
||||
'',
|
||||
# Helper function definitions are no longer here
|
||||
]
|
||||
|
||||
def _get_sensitive_data_definitions(self) -> list[str]:
|
||||
"""Generates the SENSITIVE_DATA dictionary definition."""
|
||||
if not self.sensitive_data_keys:
|
||||
return ['SENSITIVE_DATA = {}', '']
|
||||
|
||||
lines = ['# Sensitive data placeholders mapped to environment variables']
|
||||
lines.append('SENSITIVE_DATA = {')
|
||||
for key in self.sensitive_data_keys:
|
||||
env_var_name = key.upper()
|
||||
default_value_placeholder = f'YOUR_{env_var_name}'
|
||||
lines.append(f' "{key}": os.getenv("{env_var_name}", {json.dumps(default_value_placeholder)}),')
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
return lines
|
||||
|
||||
def _get_selector_for_action(self, history_item: dict, action_index_in_step: int) -> str | None:
|
||||
"""
|
||||
Gets the selector (preferring XPath) for a given action index within a history step.
|
||||
Formats the XPath correctly for Playwright.
|
||||
"""
|
||||
state = history_item.get('state')
|
||||
if not isinstance(state, dict):
|
||||
return None
|
||||
interacted_elements = state.get('interacted_element')
|
||||
if not isinstance(interacted_elements, list):
|
||||
return None
|
||||
if action_index_in_step >= len(interacted_elements):
|
||||
return None
|
||||
element_data = interacted_elements[action_index_in_step]
|
||||
if not isinstance(element_data, dict):
|
||||
return None
|
||||
|
||||
# Prioritize XPath
|
||||
xpath = element_data.get('xpath')
|
||||
if isinstance(xpath, str) and xpath.strip():
|
||||
if not xpath.startswith('xpath=') and not xpath.startswith('/') and not xpath.startswith('//'):
|
||||
xpath_selector = f'xpath=//{xpath}' # Make relative if not already
|
||||
elif not xpath.startswith('xpath='):
|
||||
xpath_selector = f'xpath={xpath}' # Add prefix if missing
|
||||
else:
|
||||
xpath_selector = xpath
|
||||
return xpath_selector
|
||||
|
||||
# Fallback to CSS selector if XPath is missing
|
||||
css_selector = element_data.get('css_selector')
|
||||
if isinstance(css_selector, str) and css_selector.strip():
|
||||
return css_selector # Use CSS selector as is
|
||||
|
||||
logger.warning(
|
||||
f'Could not find a usable XPath or CSS selector for action index {action_index_in_step} (element index {element_data.get("highlight_index", "N/A")}).'
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_goto_timeout(self) -> int:
|
||||
"""Gets the page navigation timeout in milliseconds."""
|
||||
default_timeout = 90000 # Default 90 seconds
|
||||
if self.context_config and self.context_config.maximum_wait_page_load_time:
|
||||
# Convert seconds to milliseconds
|
||||
return int(self.context_config.maximum_wait_page_load_time * 1000)
|
||||
return default_timeout
|
||||
|
||||
# --- Action Mapping Methods ---
|
||||
def _map_go_to_url(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
url = params.get('url')
|
||||
goto_timeout = self._get_goto_timeout()
|
||||
script_lines = []
|
||||
if url and isinstance(url, str):
|
||||
escaped_url = json.dumps(url)
|
||||
script_lines.append(f' print(f"Navigating to: {url} ({step_info_str})")')
|
||||
script_lines.append(f' await page.goto({escaped_url}, timeout={goto_timeout})')
|
||||
script_lines.append(f" await page.wait_for_load_state('load', timeout={goto_timeout})")
|
||||
script_lines.append(' await page.wait_for_timeout(1000)') # Short pause
|
||||
else:
|
||||
script_lines.append(f' # Skipping go_to_url ({step_info_str}): missing or invalid url')
|
||||
return script_lines
|
||||
|
||||
def _map_wait(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
seconds = params.get('seconds', 3)
|
||||
try:
|
||||
wait_seconds = int(seconds)
|
||||
except (ValueError, TypeError):
|
||||
wait_seconds = 3
|
||||
return [
|
||||
f' print(f"Waiting for {wait_seconds} seconds... ({step_info_str})")',
|
||||
f' await asyncio.sleep({wait_seconds})',
|
||||
]
|
||||
|
||||
def _map_input_text(
|
||||
self, params: dict, history_item: dict, action_index_in_step: int, step_info_str: str, **kwargs
|
||||
) -> list[str]:
|
||||
index = params.get('index')
|
||||
text = params.get('text', '')
|
||||
selector = self._get_selector_for_action(history_item, action_index_in_step)
|
||||
script_lines = []
|
||||
if selector and index is not None:
|
||||
clean_text_expression = f'replace_sensitive_data({json.dumps(str(text))}, SENSITIVE_DATA)'
|
||||
escaped_selector = json.dumps(selector)
|
||||
escaped_step_info = json.dumps(step_info_str)
|
||||
script_lines.append(
|
||||
f' await _try_locate_and_act(page, {escaped_selector}, "fill", text={clean_text_expression}, step_info={escaped_step_info})'
|
||||
)
|
||||
else:
|
||||
script_lines.append(
|
||||
f' # Skipping input_text ({step_info_str}): missing index ({index}) or selector ({selector})'
|
||||
)
|
||||
return script_lines
|
||||
|
||||
def _map_click_element(
|
||||
self, params: dict, history_item: dict, action_index_in_step: int, step_info_str: str, action_type: str, **kwargs
|
||||
) -> list[str]:
|
||||
if action_type == 'click_element_by_index':
|
||||
logger.warning(f"Mapping legacy 'click_element_by_index' to 'click_element' ({step_info_str})")
|
||||
index = params.get('index')
|
||||
selector = self._get_selector_for_action(history_item, action_index_in_step)
|
||||
script_lines = []
|
||||
if selector and index is not None:
|
||||
escaped_selector = json.dumps(selector)
|
||||
escaped_step_info = json.dumps(step_info_str)
|
||||
script_lines.append(
|
||||
f' await _try_locate_and_act(page, {escaped_selector}, "click", step_info={escaped_step_info})'
|
||||
)
|
||||
else:
|
||||
script_lines.append(
|
||||
f' # Skipping {action_type} ({step_info_str}): missing index ({index}) or selector ({selector})'
|
||||
)
|
||||
return script_lines
|
||||
|
||||
def _map_scroll_down(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
amount = params.get('amount')
|
||||
script_lines = []
|
||||
if amount and isinstance(amount, int):
|
||||
script_lines.append(f' print(f"Scrolling down by {amount} pixels ({step_info_str})")')
|
||||
script_lines.append(f" await page.evaluate('window.scrollBy(0, {amount})')")
|
||||
else:
|
||||
script_lines.append(f' print(f"Scrolling down by one page height ({step_info_str})")')
|
||||
script_lines.append(" await page.evaluate('window.scrollBy(0, window.innerHeight)')")
|
||||
script_lines.append(' await page.wait_for_timeout(500)')
|
||||
return script_lines
|
||||
|
||||
def _map_scroll_up(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
amount = params.get('amount')
|
||||
script_lines = []
|
||||
if amount and isinstance(amount, int):
|
||||
script_lines.append(f' print(f"Scrolling up by {amount} pixels ({step_info_str})")')
|
||||
script_lines.append(f" await page.evaluate('window.scrollBy(0, -{amount})')")
|
||||
else:
|
||||
script_lines.append(f' print(f"Scrolling up by one page height ({step_info_str})")')
|
||||
script_lines.append(" await page.evaluate('window.scrollBy(0, -window.innerHeight)')")
|
||||
script_lines.append(' await page.wait_for_timeout(500)')
|
||||
return script_lines
|
||||
|
||||
def _map_send_keys(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
keys = params.get('keys')
|
||||
script_lines = []
|
||||
if keys and isinstance(keys, str):
|
||||
escaped_keys = json.dumps(keys)
|
||||
script_lines.append(f' print(f"Sending keys: {keys} ({step_info_str})")')
|
||||
script_lines.append(f' await page.keyboard.press({escaped_keys})')
|
||||
script_lines.append(' await page.wait_for_timeout(500)')
|
||||
else:
|
||||
script_lines.append(f' # Skipping send_keys ({step_info_str}): missing or invalid keys')
|
||||
return script_lines
|
||||
|
||||
def _map_go_back(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
goto_timeout = self._get_goto_timeout()
|
||||
return [
|
||||
' await asyncio.sleep(60) # Wait 1 minute (important) before going back',
|
||||
f' print(f"Navigating back using browser history ({step_info_str})")',
|
||||
f' await page.go_back(timeout={goto_timeout})',
|
||||
f" await page.wait_for_load_state('load', timeout={goto_timeout})",
|
||||
' await page.wait_for_timeout(1000)',
|
||||
]
|
||||
|
||||
def _map_open_tab(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
url = params.get('url')
|
||||
goto_timeout = self._get_goto_timeout()
|
||||
script_lines = []
|
||||
if url and isinstance(url, str):
|
||||
escaped_url = json.dumps(url)
|
||||
script_lines.append(f' print(f"Opening new tab and navigating to: {url} ({step_info_str})")')
|
||||
script_lines.append(' page = await context.new_page()')
|
||||
script_lines.append(f' await page.goto({escaped_url}, timeout={goto_timeout})')
|
||||
script_lines.append(f" await page.wait_for_load_state('load', timeout={goto_timeout})")
|
||||
script_lines.append(' await page.wait_for_timeout(1000)')
|
||||
self._page_counter += 1 # Increment page counter
|
||||
else:
|
||||
script_lines.append(f' # Skipping open_tab ({step_info_str}): missing or invalid url')
|
||||
return script_lines
|
||||
|
||||
def _map_close_tab(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
page_id = params.get('page_id')
|
||||
script_lines = []
|
||||
if page_id is not None:
|
||||
script_lines.extend(
|
||||
[
|
||||
f' print(f"Attempting to close tab with page_id {page_id} ({step_info_str})")',
|
||||
f' if {page_id} < len(context.pages):',
|
||||
f' target_page = context.pages[{page_id}]',
|
||||
' await target_page.close()',
|
||||
' await page.wait_for_timeout(500)',
|
||||
' if context.pages: page = context.pages[-1]', # Switch to last page
|
||||
' else:',
|
||||
" print(' Warning: No pages left after closing tab. Cannot switch.', file=sys.stderr)",
|
||||
' # Optionally, create a new page here if needed: page = await context.new_page()',
|
||||
' if page: await page.bring_to_front()', # Bring to front if page exists
|
||||
' else:',
|
||||
f' print(f" Warning: Tab with page_id {page_id} not found to close ({step_info_str})", file=sys.stderr)',
|
||||
]
|
||||
)
|
||||
else:
|
||||
script_lines.append(f' # Skipping close_tab ({step_info_str}): missing page_id')
|
||||
return script_lines
|
||||
|
||||
def _map_switch_tab(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
page_id = params.get('page_id')
|
||||
script_lines = []
|
||||
if page_id is not None:
|
||||
script_lines.extend(
|
||||
[
|
||||
f' print(f"Switching to tab with page_id {page_id} ({step_info_str})")',
|
||||
f' if {page_id} < len(context.pages):',
|
||||
f' page = context.pages[{page_id}]',
|
||||
' await page.bring_to_front()',
|
||||
" await page.wait_for_load_state('load', timeout=15000)",
|
||||
' await page.wait_for_timeout(500)',
|
||||
' else:',
|
||||
f' print(f" Warning: Tab with page_id {page_id} not found to switch ({step_info_str})", file=sys.stderr)',
|
||||
]
|
||||
)
|
||||
else:
|
||||
script_lines.append(f' # Skipping switch_tab ({step_info_str}): missing page_id')
|
||||
return script_lines
|
||||
|
||||
def _map_search_google(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
query = params.get('query')
|
||||
goto_timeout = self._get_goto_timeout()
|
||||
script_lines = []
|
||||
if query and isinstance(query, str):
|
||||
clean_query = f'replace_sensitive_data({json.dumps(query)}, SENSITIVE_DATA)'
|
||||
search_url_expression = f'f"https://www.google.com/search?q={{ urllib.parse.quote_plus({clean_query}) }}&udm=14"'
|
||||
script_lines.extend(
|
||||
[
|
||||
f' search_url = {search_url_expression}',
|
||||
f' print(f"Searching Google for query related to: {{ {clean_query} }} ({step_info_str})")',
|
||||
f' await page.goto(search_url, timeout={goto_timeout})',
|
||||
f" await page.wait_for_load_state('load', timeout={goto_timeout})",
|
||||
' await page.wait_for_timeout(1000)',
|
||||
]
|
||||
)
|
||||
else:
|
||||
script_lines.append(f' # Skipping search_google ({step_info_str}): missing or invalid query')
|
||||
return script_lines
|
||||
|
||||
def _map_drag_drop(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
source_sel = params.get('element_source')
|
||||
target_sel = params.get('element_target')
|
||||
source_coords = (params.get('coord_source_x'), params.get('coord_source_y'))
|
||||
target_coords = (params.get('coord_target_x'), params.get('coord_target_y'))
|
||||
script_lines = [f' print(f"Attempting drag and drop ({step_info_str})")']
|
||||
if source_sel and target_sel:
|
||||
escaped_source = json.dumps(source_sel)
|
||||
escaped_target = json.dumps(target_sel)
|
||||
script_lines.append(f' await page.drag_and_drop({escaped_source}, {escaped_target})')
|
||||
script_lines.append(f" print(f' Dragged element {escaped_source} to {escaped_target}')")
|
||||
elif all(c is not None for c in source_coords) and all(c is not None for c in target_coords):
|
||||
sx, sy = source_coords
|
||||
tx, ty = target_coords
|
||||
script_lines.extend(
|
||||
[
|
||||
f' await page.mouse.move({sx}, {sy})',
|
||||
' await page.mouse.down()',
|
||||
f' await page.mouse.move({tx}, {ty})',
|
||||
' await page.mouse.up()',
|
||||
f" print(f' Dragged from ({sx},{sy}) to ({tx},{ty})')",
|
||||
]
|
||||
)
|
||||
else:
|
||||
script_lines.append(
|
||||
f' # Skipping drag_drop ({step_info_str}): requires either element selectors or full coordinates'
|
||||
)
|
||||
script_lines.append(' await page.wait_for_timeout(500)')
|
||||
return script_lines
|
||||
|
||||
def _map_extract_content(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
goal = params.get('goal', 'content')
|
||||
logger.warning(f"Action 'extract_content' ({step_info_str}) cannot be directly translated to Playwright script.")
|
||||
return [f' # Action: extract_content (Goal: {goal}) - Skipped in Playwright script ({step_info_str})']
|
||||
|
||||
def _map_click_download_button(
|
||||
self, params: dict, history_item: dict, action_index_in_step: int, step_info_str: str, **kwargs
|
||||
) -> list[str]:
|
||||
index = params.get('index')
|
||||
selector = self._get_selector_for_action(history_item, action_index_in_step)
|
||||
download_dir_in_script = "'./files'" # Default
|
||||
if self.context_config and self.context_config.save_downloads_path:
|
||||
download_dir_in_script = repr(self.context_config.save_downloads_path)
|
||||
|
||||
script_lines = []
|
||||
if selector and index is not None:
|
||||
script_lines.append(
|
||||
f' print(f"Attempting to download file by clicking element ({selector}) ({step_info_str})")'
|
||||
)
|
||||
script_lines.append(' try:')
|
||||
script_lines.append(
|
||||
' async with page.expect_download(timeout=120000) as download_info:'
|
||||
) # 2 min timeout
|
||||
step_info_for_download = f'{step_info_str} (triggering download)'
|
||||
script_lines.append(
|
||||
f' await _try_locate_and_act(page, {json.dumps(selector)}, "click", step_info={json.dumps(step_info_for_download)})'
|
||||
)
|
||||
script_lines.append(' download = await download_info.value')
|
||||
script_lines.append(f' configured_download_dir = {download_dir_in_script}')
|
||||
script_lines.append(' download_dir_path = Path(configured_download_dir).resolve()')
|
||||
script_lines.append(' download_dir_path.mkdir(parents=True, exist_ok=True)')
|
||||
script_lines.append(
|
||||
" base, ext = os.path.splitext(download.suggested_filename or f'download_{{len(list(download_dir_path.iterdir())) + 1}}.tmp')"
|
||||
)
|
||||
script_lines.append(' counter = 1')
|
||||
script_lines.append(" download_path_obj = download_dir_path / f'{base}{ext}'")
|
||||
script_lines.append(' while download_path_obj.exists():')
|
||||
script_lines.append(" download_path_obj = download_dir_path / f'{base}({{counter}}){ext}'")
|
||||
script_lines.append(' counter += 1')
|
||||
script_lines.append(' await download.save_as(str(download_path_obj))')
|
||||
script_lines.append(" print(f' File downloaded successfully to: {str(download_path_obj)}')")
|
||||
script_lines.append(' except PlaywrightActionError as pae:')
|
||||
script_lines.append(' raise pae') # Re-raise to stop script
|
||||
script_lines.append(' except Exception as download_err:')
|
||||
script_lines.append(
|
||||
f" raise PlaywrightActionError(f'Download failed for {step_info_str}: {{download_err}}') from download_err"
|
||||
)
|
||||
else:
|
||||
script_lines.append(
|
||||
f' # Skipping click_download_button ({step_info_str}): missing index ({index}) or selector ({selector})'
|
||||
)
|
||||
return script_lines
|
||||
|
||||
def _map_done(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
|
||||
script_lines = []
|
||||
if isinstance(params, dict):
|
||||
final_text = params.get('text', '')
|
||||
success_status = params.get('success', False)
|
||||
escaped_final_text_with_placeholders = json.dumps(str(final_text))
|
||||
script_lines.append(f' print("\\n--- Task marked as Done by agent ({step_info_str}) ---")')
|
||||
script_lines.append(f' print(f"Agent reported success: {success_status}")')
|
||||
script_lines.append(' # Final Message from agent (may contain placeholders):')
|
||||
script_lines.append(
|
||||
f' final_message = replace_sensitive_data({escaped_final_text_with_placeholders}, SENSITIVE_DATA)'
|
||||
)
|
||||
script_lines.append(' print(final_message)')
|
||||
else:
|
||||
script_lines.append(f' print("\\n--- Task marked as Done by agent ({step_info_str}) ---")')
|
||||
script_lines.append(' print("Success: N/A (invalid params)")')
|
||||
script_lines.append(' print("Final Message: N/A (invalid params)")')
|
||||
return script_lines
|
||||
|
||||
def _map_action_to_playwright(
|
||||
self,
|
||||
action_dict: dict,
|
||||
history_item: dict,
|
||||
previous_history_item: dict | None,
|
||||
action_index_in_step: int,
|
||||
step_info_str: str,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Translates a single action dictionary into Playwright script lines using dictionary dispatch.
|
||||
"""
|
||||
if not isinstance(action_dict, dict) or not action_dict:
|
||||
return [f' # Invalid action format: {action_dict} ({step_info_str})']
|
||||
|
||||
action_type = next(iter(action_dict.keys()), None)
|
||||
params = action_dict.get(action_type)
|
||||
|
||||
if not action_type or params is None:
|
||||
if action_dict == {}:
|
||||
return [f' # Empty action dictionary found ({step_info_str})']
|
||||
return [f' # Could not determine action type or params: {action_dict} ({step_info_str})']
|
||||
|
||||
# Get the handler function from the dictionary
|
||||
handler = self._action_handlers.get(action_type)
|
||||
|
||||
if handler:
|
||||
# Call the specific handler method
|
||||
return handler(
|
||||
params=params,
|
||||
history_item=history_item,
|
||||
action_index_in_step=action_index_in_step,
|
||||
step_info_str=step_info_str,
|
||||
action_type=action_type, # Pass action_type for legacy handling etc.
|
||||
previous_history_item=previous_history_item,
|
||||
)
|
||||
else:
|
||||
# Handle unsupported actions
|
||||
logger.warning(f'Unsupported action type encountered: {action_type} ({step_info_str})')
|
||||
return [f' # Unsupported action type: {action_type} ({step_info_str})']
|
||||
|
||||
def generate_script_content(self) -> str:
|
||||
"""Generates the full Playwright script content as a string."""
|
||||
script_lines = []
|
||||
self._page_counter = 0 # Reset page counter for new script generation
|
||||
|
||||
if not self._imports_helpers_added:
|
||||
script_lines.extend(self._get_imports_and_helpers())
|
||||
self._imports_helpers_added = True
|
||||
|
||||
# Read helper script content
|
||||
helper_script_path = Path(__file__).parent / 'playwright_script_helpers.py'
|
||||
try:
|
||||
with open(helper_script_path, encoding='utf-8') as f_helper:
|
||||
helper_script_content = f_helper.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f'Helper script not found at {helper_script_path}. Cannot generate script.')
|
||||
return '# Error: Helper script file missing.'
|
||||
except Exception as e:
|
||||
logger.error(f'Error reading helper script {helper_script_path}: {e}')
|
||||
return f'# Error: Could not read helper script: {e}'
|
||||
|
||||
script_lines.extend(self._get_sensitive_data_definitions())
|
||||
|
||||
# Add the helper script content after imports and sensitive data
|
||||
script_lines.append('\n# --- Helper Functions (from playwright_script_helpers.py) ---')
|
||||
script_lines.append(helper_script_content)
|
||||
script_lines.append('# --- End Helper Functions ---')
|
||||
|
||||
# Generate browser launch and context creation code
|
||||
browser_launch_args = self._generate_browser_launch_args()
|
||||
context_options = self._generate_context_options()
|
||||
# Determine browser type (defaulting to chromium)
|
||||
browser_type = 'chromium'
|
||||
if self.browser_config and self.browser_config.browser_class in ['firefox', 'webkit']:
|
||||
browser_type = self.browser_config.browser_class
|
||||
|
||||
script_lines.extend(
|
||||
[
|
||||
'async def run_generated_script():',
|
||||
' global SENSITIVE_DATA', # Ensure sensitive data is accessible
|
||||
' async with async_playwright() as p:',
|
||||
' browser = None',
|
||||
' context = None',
|
||||
' page = None',
|
||||
' exit_code = 0 # Default success exit code',
|
||||
' try:',
|
||||
f" print('Launching {browser_type} browser...')",
|
||||
# Use generated launch args, remove slow_mo
|
||||
f' browser = await p.{browser_type}.launch({browser_launch_args})',
|
||||
# Use generated context options
|
||||
f' context = await browser.new_context({context_options})',
|
||||
" print('Browser context created.')",
|
||||
]
|
||||
)
|
||||
|
||||
# Add cookie loading logic if cookies_file is specified
|
||||
if self.context_config and self.context_config.cookies_file:
|
||||
cookies_file_path = repr(self.context_config.cookies_file)
|
||||
script_lines.extend(
|
||||
[
|
||||
' # Load cookies if specified',
|
||||
f' cookies_path = {cookies_file_path}',
|
||||
' if cookies_path and os.path.exists(cookies_path):',
|
||||
' try:',
|
||||
" with open(cookies_path, 'r', encoding='utf-8') as f_cookies:",
|
||||
' cookies = json.load(f_cookies)',
|
||||
' # Validate sameSite attribute',
|
||||
" valid_same_site = ['Strict', 'Lax', 'None']",
|
||||
' for cookie in cookies:',
|
||||
" if 'sameSite' in cookie and cookie['sameSite'] not in valid_same_site:",
|
||||
' print(f\' Warning: Fixing invalid sameSite value "{{cookie["sameSite"]}}" to None for cookie {{cookie.get("name")}}\', file=sys.stderr)',
|
||||
" cookie['sameSite'] = 'None'",
|
||||
' await context.add_cookies(cookies)',
|
||||
" print(f' Successfully loaded {{len(cookies)}} cookies from {{cookies_path}}')",
|
||||
' except Exception as cookie_err:',
|
||||
" print(f' Warning: Failed to load or add cookies from {{cookies_path}}: {{cookie_err}}', file=sys.stderr)",
|
||||
' else:',
|
||||
' if cookies_path:', # Only print if a path was specified but not found
|
||||
" print(f' Cookie file not found at: {cookies_path}')",
|
||||
'',
|
||||
]
|
||||
)
|
||||
|
||||
script_lines.extend(
|
||||
[
|
||||
' # Initial page handling',
|
||||
' if context.pages:',
|
||||
' page = context.pages[0]',
|
||||
" print('Using initial page provided by context.')",
|
||||
' else:',
|
||||
' page = await context.new_page()',
|
||||
" print('Created a new page as none existed.')",
|
||||
" print('\\n--- Starting Generated Script Execution ---')",
|
||||
]
|
||||
)
|
||||
|
||||
action_counter = 0
|
||||
stop_processing_steps = False
|
||||
previous_item_dict = None
|
||||
|
||||
for step_index, item_dict in enumerate(self.history):
|
||||
if stop_processing_steps:
|
||||
break
|
||||
|
||||
if not isinstance(item_dict, dict):
|
||||
logger.warning(f'Skipping step {step_index + 1}: Item is not a dictionary ({type(item_dict)})')
|
||||
script_lines.append(f'\n # --- Step {step_index + 1}: Skipped (Invalid Format) ---')
|
||||
previous_item_dict = item_dict
|
||||
continue
|
||||
|
||||
script_lines.append(f'\n # --- Step {step_index + 1} ---')
|
||||
model_output = item_dict.get('model_output')
|
||||
|
||||
if not isinstance(model_output, dict) or 'action' not in model_output:
|
||||
script_lines.append(' # No valid model_output or action found for this step')
|
||||
previous_item_dict = item_dict
|
||||
continue
|
||||
|
||||
actions = model_output.get('action')
|
||||
if not isinstance(actions, list):
|
||||
script_lines.append(f' # Actions format is not a list: {type(actions)}')
|
||||
previous_item_dict = item_dict
|
||||
continue
|
||||
|
||||
for action_index_in_step, action_detail in enumerate(actions):
|
||||
action_counter += 1
|
||||
script_lines.append(f' # Action {action_counter}')
|
||||
|
||||
step_info_str = f'Step {step_index + 1}, Action {action_index_in_step + 1}'
|
||||
action_lines = self._map_action_to_playwright(
|
||||
action_dict=action_detail,
|
||||
history_item=item_dict,
|
||||
previous_history_item=previous_item_dict,
|
||||
action_index_in_step=action_index_in_step,
|
||||
step_info_str=step_info_str,
|
||||
)
|
||||
script_lines.extend(action_lines)
|
||||
|
||||
action_type = next(iter(action_detail.keys()), None) if isinstance(action_detail, dict) else None
|
||||
if action_type == 'done':
|
||||
stop_processing_steps = True
|
||||
break
|
||||
|
||||
previous_item_dict = item_dict
|
||||
|
||||
# Updated final block to include sys.exit
|
||||
script_lines.extend(
|
||||
[
|
||||
' except PlaywrightActionError as pae:', # Catch specific action errors
|
||||
" print(f'\\n--- Playwright Action Error: {pae} ---', file=sys.stderr)",
|
||||
' exit_code = 1', # Set exit code to failure
|
||||
' except Exception as e:',
|
||||
" print(f'\\n--- An unexpected error occurred: {e} ---', file=sys.stderr)",
|
||||
' import traceback',
|
||||
' traceback.print_exc()',
|
||||
' exit_code = 1', # Set exit code to failure
|
||||
' finally:',
|
||||
" print('\\n--- Generated Script Execution Finished ---')",
|
||||
" print('Closing browser/context...')",
|
||||
' if context:',
|
||||
' try: await context.close()',
|
||||
" except Exception as ctx_close_err: print(f' Warning: could not close context: {ctx_close_err}', file=sys.stderr)",
|
||||
' if browser:',
|
||||
' try: await browser.close()',
|
||||
" except Exception as browser_close_err: print(f' Warning: could not close browser: {browser_close_err}', file=sys.stderr)",
|
||||
" print('Browser/context closed.')",
|
||||
' # Exit with the determined exit code',
|
||||
' if exit_code != 0:',
|
||||
" print(f'Script finished with errors (exit code {exit_code}).', file=sys.stderr)",
|
||||
' sys.exit(exit_code)', # Exit with non-zero code on error
|
||||
'',
|
||||
'# --- Script Entry Point ---',
|
||||
"if __name__ == '__main__':",
|
||||
" if os.name == 'nt':",
|
||||
' asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())',
|
||||
' asyncio.run(run_generated_script())',
|
||||
]
|
||||
)
|
||||
|
||||
return '\n'.join(script_lines)
|
||||
94
browser-use/browser_use/agent/playwright_script_helpers.py
Normal file
94
browser-use/browser_use/agent/playwright_script_helpers.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from playwright.async_api import Page
|
||||
|
||||
|
||||
# --- Helper Function for Replacing Sensitive Data ---
|
||||
def replace_sensitive_data(text: str, sensitive_map: dict) -> str:
|
||||
"""Replaces sensitive data placeholders in text."""
|
||||
if not isinstance(text, str):
|
||||
return text
|
||||
for placeholder, value in sensitive_map.items():
|
||||
replacement_value = str(value) if value is not None else ''
|
||||
text = text.replace(f'<secret>{placeholder}</secret>', replacement_value)
|
||||
return text
|
||||
|
||||
|
||||
# --- Helper Function for Robust Action Execution ---
|
||||
class PlaywrightActionError(Exception):
|
||||
"""Custom exception for errors during Playwright script action execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _try_locate_and_act(page: Page, selector: str, action_type: str, text: str | None = None, step_info: str = '') -> None:
|
||||
"""
|
||||
Attempts an action (click/fill) with XPath fallback by trimming prefixes.
|
||||
Raises PlaywrightActionError if the action fails after all fallbacks.
|
||||
"""
|
||||
print(f'Attempting {action_type} ({step_info}) using selector: {repr(selector)}')
|
||||
original_selector = selector
|
||||
MAX_FALLBACKS = 50 # Increased fallbacks
|
||||
# Increased timeouts for potentially slow pages
|
||||
INITIAL_TIMEOUT = 10000 # Milliseconds for the first attempt (10 seconds)
|
||||
FALLBACK_TIMEOUT = 1000 # Shorter timeout for fallback attempts (1 second)
|
||||
|
||||
try:
|
||||
locator = page.locator(selector).first
|
||||
if action_type == 'click':
|
||||
await locator.click(timeout=INITIAL_TIMEOUT)
|
||||
elif action_type == 'fill' and text is not None:
|
||||
await locator.fill(text, timeout=INITIAL_TIMEOUT)
|
||||
else:
|
||||
# This case should ideally not happen if called correctly
|
||||
raise PlaywrightActionError(f"Invalid action_type '{action_type}' or missing text for fill. ({step_info})")
|
||||
print(f" Action '{action_type}' successful with original selector.")
|
||||
await page.wait_for_timeout(500) # Wait after successful action
|
||||
return # Successful exit
|
||||
except Exception as e:
|
||||
print(f" Warning: Action '{action_type}' failed with original selector ({repr(selector)}): {e}. Starting fallback...")
|
||||
|
||||
# Fallback only works for XPath selectors
|
||||
if not selector.startswith('xpath='):
|
||||
# Raise error immediately if not XPath, as fallback won't work
|
||||
raise PlaywrightActionError(
|
||||
f"Action '{action_type}' failed. Fallback not possible for non-XPath selector: {repr(selector)}. ({step_info})"
|
||||
)
|
||||
|
||||
xpath_parts = selector.split('=', 1)
|
||||
if len(xpath_parts) < 2:
|
||||
raise PlaywrightActionError(
|
||||
f"Action '{action_type}' failed. Could not extract XPath string from selector: {repr(selector)}. ({step_info})"
|
||||
)
|
||||
xpath = xpath_parts[1] # Correctly get the XPath string
|
||||
|
||||
segments = [seg for seg in xpath.split('/') if seg]
|
||||
|
||||
for i in range(1, min(MAX_FALLBACKS + 1, len(segments))):
|
||||
trimmed_xpath_raw = '/'.join(segments[i:])
|
||||
fallback_xpath = f'xpath=//{trimmed_xpath_raw}'
|
||||
|
||||
print(f' Fallback attempt {i}/{MAX_FALLBACKS}: Trying selector: {repr(fallback_xpath)}')
|
||||
try:
|
||||
locator = page.locator(fallback_xpath).first
|
||||
if action_type == 'click':
|
||||
await locator.click(timeout=FALLBACK_TIMEOUT)
|
||||
elif action_type == 'fill' and text is not None:
|
||||
try:
|
||||
await locator.clear(timeout=FALLBACK_TIMEOUT)
|
||||
await page.wait_for_timeout(100)
|
||||
except Exception as clear_error:
|
||||
print(f' Warning: Failed to clear field during fallback ({step_info}): {clear_error}')
|
||||
await locator.fill(text, timeout=FALLBACK_TIMEOUT)
|
||||
|
||||
print(f" Action '{action_type}' successful with fallback selector: {repr(fallback_xpath)}")
|
||||
await page.wait_for_timeout(500)
|
||||
return # Successful exit after fallback
|
||||
except Exception as fallback_e:
|
||||
print(f' Fallback attempt {i} failed: {fallback_e}')
|
||||
if i == MAX_FALLBACKS:
|
||||
# Raise exception after exhausting fallbacks
|
||||
raise PlaywrightActionError(
|
||||
f"Action '{action_type}' failed after {MAX_FALLBACKS} fallback attempts. Original selector: {repr(original_selector)}. ({step_info})"
|
||||
)
|
||||
|
||||
# This part should not be reachable if logic is correct, but added as safeguard
|
||||
raise PlaywrightActionError(f"Action '{action_type}' failed unexpectedly for {repr(original_selector)}. ({step_info})")
|
||||
187
browser-use/browser_use/agent/prompts.py
Normal file
187
browser-use/browser_use/agent/prompts.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
import importlib.resources
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from browser_use.agent.views import ActionResult, AgentStepInfo
|
||||
from browser_use.browser.views import BrowserState
|
||||
|
||||
|
||||
class SystemPrompt:
|
||||
def __init__(
|
||||
self,
|
||||
action_description: str,
|
||||
max_actions_per_step: int = 10,
|
||||
override_system_message: str | None = None,
|
||||
extend_system_message: str | None = None,
|
||||
):
|
||||
self.default_action_description = action_description
|
||||
self.max_actions_per_step = max_actions_per_step
|
||||
prompt = ''
|
||||
if override_system_message:
|
||||
prompt = override_system_message
|
||||
else:
|
||||
self._load_prompt_template()
|
||||
prompt = self.prompt_template.format(max_actions=self.max_actions_per_step)
|
||||
|
||||
if extend_system_message:
|
||||
prompt += f'\n{extend_system_message}'
|
||||
|
||||
self.system_message = SystemMessage(content=prompt)
|
||||
|
||||
def _load_prompt_template(self) -> None:
|
||||
"""Load the prompt template from the markdown file."""
|
||||
try:
|
||||
# This works both in development and when installed as a package
|
||||
with importlib.resources.files('browser_use.agent').joinpath('system_prompt.md').open('r') as f:
|
||||
self.prompt_template = f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Failed to load system prompt template: {e}')
|
||||
|
||||
def get_system_message(self) -> SystemMessage:
|
||||
"""
|
||||
Get the system prompt for the agent.
|
||||
|
||||
Returns:
|
||||
SystemMessage: Formatted system prompt
|
||||
"""
|
||||
return self.system_message
|
||||
|
||||
|
||||
# Functions:
|
||||
# {self.default_action_description}
|
||||
|
||||
# Example:
|
||||
# {self.example_response()}
|
||||
# Your AVAILABLE ACTIONS:
|
||||
# {self.default_action_description}
|
||||
|
||||
|
||||
class AgentMessagePrompt:
|
||||
def __init__(
|
||||
self,
|
||||
state: 'BrowserState',
|
||||
result: list['ActionResult'] | None = None,
|
||||
include_attributes: list[str] | None = None,
|
||||
step_info: Optional['AgentStepInfo'] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.result = result
|
||||
self.include_attributes = include_attributes or []
|
||||
self.step_info = step_info
|
||||
|
||||
def get_user_message(self, use_vision: bool = True) -> HumanMessage:
|
||||
elements_text = self.state.element_tree.clickable_elements_to_string(include_attributes=self.include_attributes)
|
||||
|
||||
has_content_above = (self.state.pixels_above or 0) > 0
|
||||
has_content_below = (self.state.pixels_below or 0) > 0
|
||||
|
||||
if elements_text != '':
|
||||
if has_content_above:
|
||||
elements_text = (
|
||||
f'... {self.state.pixels_above} pixels above - scroll or extract content to see more ...\n{elements_text}'
|
||||
)
|
||||
else:
|
||||
elements_text = f'[Start of page]\n{elements_text}'
|
||||
if has_content_below:
|
||||
elements_text = (
|
||||
f'{elements_text}\n... {self.state.pixels_below} pixels below - scroll or extract content to see more ...'
|
||||
)
|
||||
else:
|
||||
elements_text = f'{elements_text}\n[End of page]'
|
||||
else:
|
||||
elements_text = 'empty page'
|
||||
|
||||
if self.step_info:
|
||||
step_info_description = f'Current step: {self.step_info.step_number + 1}/{self.step_info.max_steps}'
|
||||
else:
|
||||
step_info_description = ''
|
||||
time_str = datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
step_info_description += f'Current date and time: {time_str}'
|
||||
|
||||
state_description = f"""
|
||||
[Task history memory ends]
|
||||
[Current state starts here]
|
||||
The following is one-time information - if you need to remember it write it to memory:
|
||||
Current url: {self.state.url}
|
||||
Available tabs:
|
||||
{self.state.tabs}
|
||||
Interactive elements from top layer of the current page inside the viewport:
|
||||
{elements_text}
|
||||
{step_info_description}
|
||||
"""
|
||||
|
||||
if self.result:
|
||||
for i, result in enumerate(self.result):
|
||||
if result.extracted_content:
|
||||
state_description += f'\nAction result {i + 1}/{len(self.result)}: {result.extracted_content}'
|
||||
if result.error:
|
||||
# only use last line of error
|
||||
error = result.error.split('\n')[-1]
|
||||
state_description += f'\nAction error {i + 1}/{len(self.result)}: ...{error}'
|
||||
|
||||
if self.state.screenshot and use_vision is True:
|
||||
# Format message for vision model
|
||||
return HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': state_description},
|
||||
{
|
||||
'type': 'image_url',
|
||||
'image_url': {'url': f'data:image/png;base64,{self.state.screenshot}'}, # , 'detail': 'low'
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
return HumanMessage(content=state_description)
|
||||
|
||||
|
||||
class PlannerPrompt(SystemPrompt):
|
||||
def __init__(self, available_actions: str):
|
||||
self.available_actions = available_actions
|
||||
|
||||
def get_system_message(
|
||||
self, is_planner_reasoning: bool, extended_planner_system_prompt: str | None = None
|
||||
) -> SystemMessage | HumanMessage:
|
||||
"""Get the system message for the planner.
|
||||
|
||||
Args:
|
||||
is_planner_reasoning: If True, return as HumanMessage for chain-of-thought
|
||||
extended_planner_system_prompt: Optional text to append to the base prompt
|
||||
|
||||
Returns:
|
||||
SystemMessage or HumanMessage depending on is_planner_reasoning
|
||||
"""
|
||||
|
||||
planner_prompt_text = """
|
||||
You are a planning agent that helps break down tasks into smaller steps and reason about the current state.
|
||||
Your role is to:
|
||||
1. Analyze the current state and history
|
||||
2. Evaluate progress towards the ultimate goal
|
||||
3. Identify potential challenges or roadblocks
|
||||
4. Suggest the next high-level steps to take
|
||||
|
||||
Inside your messages, there will be AI messages from different agents with different formats.
|
||||
|
||||
Your output format should be always a JSON object with the following fields:
|
||||
{{
|
||||
"state_analysis": "Brief analysis of the current state and what has been done so far",
|
||||
"progress_evaluation": "Evaluation of progress towards the ultimate goal (as percentage and description)",
|
||||
"challenges": "List any potential challenges or roadblocks",
|
||||
"next_steps": "List 2-3 concrete next steps to take",
|
||||
"reasoning": "Explain your reasoning for the suggested next steps"
|
||||
}}
|
||||
|
||||
Ignore the other AI messages output structures.
|
||||
|
||||
Keep your responses concise and focused on actionable insights.
|
||||
"""
|
||||
|
||||
if extended_planner_system_prompt:
|
||||
planner_prompt_text += f'\n{extended_planner_system_prompt}'
|
||||
|
||||
if is_planner_reasoning:
|
||||
return HumanMessage(content=planner_prompt_text)
|
||||
else:
|
||||
return SystemMessage(content=planner_prompt_text)
|
||||
1458
browser-use/browser_use/agent/service.py
Normal file
1458
browser-use/browser_use/agent/service.py
Normal file
File diff suppressed because it is too large
Load diff
82
browser-use/browser_use/agent/system_prompt.md
Normal file
82
browser-use/browser_use/agent/system_prompt.md
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
You are an AI agent designed to automate browser tasks. Your goal is to accomplish the ultimate task following the rules.
|
||||
|
||||
# Input Format
|
||||
|
||||
Task
|
||||
Previous steps
|
||||
Current URL
|
||||
Open Tabs
|
||||
Interactive Elements
|
||||
[index]<type>text</type>
|
||||
|
||||
- index: Numeric identifier for interaction
|
||||
- type: HTML element type (button, input, etc.)
|
||||
- text: Element description
|
||||
Example:
|
||||
[33]<div>User form</div>
|
||||
\t*[35]*<button aria-label='Submit form'>Submit</button>
|
||||
|
||||
- Only elements with numeric indexes in [] are interactive
|
||||
- (stacked) indentation (with \t) is important and means that the element is a (html) child of the element above (with a lower index)
|
||||
- Elements with \* are new elements that were added after the previous step (if url has not changed)
|
||||
|
||||
# Response Rules
|
||||
|
||||
1. RESPONSE FORMAT: You must ALWAYS respond with valid JSON in this exact format:
|
||||
{{"current_state": {{"evaluation_previous_goal": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Mention if something unexpected happened. Shortly state why/why not",
|
||||
"memory": "Description of what has been done and what you need to remember. Be very specific. Count here ALWAYS how many times you have done something and how many remain. E.g. 0 out of 10 websites analyzed. Continue with abc and xyz",
|
||||
"next_goal": "What needs to be done with the next immediate action"}},
|
||||
"action":[{{"one_action_name": {{// action-specific parameter}}}}, // ... more actions in sequence]}}
|
||||
|
||||
2. ACTIONS: You can specify multiple actions in the list to be executed in sequence. But always specify only one action name per item. Use maximum {max_actions} actions per sequence.
|
||||
Common action sequences:
|
||||
|
||||
- Form filling: [{{"input_text": {{"index": 1, "text": "username"}}}}, {{"input_text": {{"index": 2, "text": "password"}}}}, {{"click_element": {{"index": 3}}}}]
|
||||
- Navigation and extraction: [{{"go_to_url": {{"url": "https://example.com"}}}}, {{"extract_content": {{"goal": "extract the names"}}}}]
|
||||
- Actions are executed in the given order
|
||||
- If the page changes after an action, the sequence is interrupted and you get the new state.
|
||||
- Only provide the action sequence until an action which changes the page state significantly.
|
||||
- Try to be efficient, e.g. fill forms at once, or chain actions where nothing changes on the page
|
||||
- only use multiple actions if it makes sense.
|
||||
|
||||
3. ELEMENT INTERACTION:
|
||||
|
||||
- Only use indexes of the interactive elements
|
||||
|
||||
4. NAVIGATION & ERROR HANDLING:
|
||||
|
||||
- If no suitable elements exist, use other functions to complete the task
|
||||
- If stuck, try alternative approaches - like going back to a previous page, new search, new tab etc.
|
||||
- Handle popups/cookies by accepting or closing them
|
||||
- Use scroll to find elements you are looking for
|
||||
- If you want to research something, open a new tab instead of using the current tab
|
||||
- If captcha pops up, try to solve it - else try a different approach
|
||||
- If the page is not fully loaded, use wait action
|
||||
|
||||
5. TASK COMPLETION:
|
||||
|
||||
- Use the done action as the last action as soon as the ultimate task is complete
|
||||
- Dont use "done" before you are done with everything the user asked you, except you reach the last step of max_steps.
|
||||
- If you reach your last step, use the done action even if the task is not fully finished. Provide all the information you have gathered so far. If the ultimate task is completely finished set success to true. If not everything the user asked for is completed set success in done to false!
|
||||
- If you have to do something repeatedly for example the task says for "each", or "for all", or "x times", count always inside "memory" how many times you have done it and how many remain. Don't stop until you have completed like the task asked you. Only call done after the last step.
|
||||
- Don't hallucinate actions
|
||||
- Make sure you include everything you found out for the ultimate task in the done text parameter. Do not just say you are done, but include the requested information of the task.
|
||||
|
||||
6. VISUAL CONTEXT:
|
||||
|
||||
- When an image is provided, use it to understand the page layout
|
||||
- Bounding boxes with labels on their top right corner correspond to element indexes
|
||||
|
||||
7. Form filling:
|
||||
|
||||
- If you fill an input field and your action sequence is interrupted, most often something changed e.g. suggestions popped up under the field.
|
||||
|
||||
8. Long tasks:
|
||||
|
||||
- Keep track of the status and subresults in the memory.
|
||||
- You are provided with procedural memory summaries that condense previous task history (every N steps). Use these summaries to maintain context about completed actions, current progress, and next steps. The summaries appear in chronological order and contain key information about navigation history, findings, errors encountered, and current state. Refer to these summaries to avoid repeating actions and to ensure consistent progress toward the task goal.
|
||||
|
||||
9. Extraction:
|
||||
|
||||
- If your task is to find information - call extract_content on the specific pages to get and store the information.
|
||||
Your responses must be always JSON with the specified format.
|
||||
197
browser-use/browser_use/agent/tests.py
Normal file
197
browser-use/browser_use/agent/tests.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
import pytest
|
||||
|
||||
from browser_use.agent.views import (
|
||||
ActionResult,
|
||||
AgentBrain,
|
||||
AgentHistory,
|
||||
AgentHistoryList,
|
||||
AgentOutput,
|
||||
)
|
||||
from browser_use.browser.views import BrowserState, BrowserStateHistory, TabInfo
|
||||
from browser_use.controller.registry.service import Registry
|
||||
from browser_use.controller.views import ClickElementAction, DoneAction, ExtractPageContentAction
|
||||
from browser_use.dom.views import DOMElementNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_browser_state():
|
||||
return BrowserState(
|
||||
url='https://example.com',
|
||||
title='Example Page',
|
||||
tabs=[TabInfo(url='https://example.com', title='Example Page', page_id=1)],
|
||||
screenshot='screenshot1.png',
|
||||
element_tree=DOMElementNode(
|
||||
tag_name='root',
|
||||
is_visible=True,
|
||||
parent=None,
|
||||
xpath='',
|
||||
attributes={},
|
||||
children=[],
|
||||
),
|
||||
selector_map={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_registry():
|
||||
registry = Registry()
|
||||
|
||||
# Register the actions we need for testing
|
||||
@registry.action(description='Click an element', param_model=ClickElementAction)
|
||||
def click_element(params: ClickElementAction, browser=None):
|
||||
pass
|
||||
|
||||
@registry.action(
|
||||
description='Extract page content',
|
||||
param_model=ExtractPageContentAction,
|
||||
)
|
||||
def extract_page_content(params: ExtractPageContentAction, browser=None):
|
||||
pass
|
||||
|
||||
@registry.action(description='Mark task as done', param_model=DoneAction)
|
||||
def done(params: DoneAction):
|
||||
pass
|
||||
|
||||
# Create the dynamic ActionModel with all registered actions
|
||||
return registry.create_action_model()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_history(action_registry):
|
||||
# Create actions with nested params structure
|
||||
click_action = action_registry(click_element={'index': 1})
|
||||
|
||||
extract_action = action_registry(extract_page_content={'value': 'text'})
|
||||
|
||||
done_action = action_registry(done={'text': 'Task completed'})
|
||||
|
||||
histories = [
|
||||
AgentHistory(
|
||||
model_output=AgentOutput(
|
||||
current_state=AgentBrain(
|
||||
evaluation_previous_goal='None',
|
||||
memory='Started task',
|
||||
next_goal='Click button',
|
||||
),
|
||||
action=[click_action],
|
||||
),
|
||||
result=[ActionResult(is_done=False)],
|
||||
state=BrowserStateHistory(
|
||||
url='https://example.com',
|
||||
title='Page 1',
|
||||
tabs=[TabInfo(url='https://example.com', title='Page 1', page_id=1)],
|
||||
screenshot='screenshot1.png',
|
||||
interacted_element=[{'xpath': '//button[1]'}],
|
||||
),
|
||||
),
|
||||
AgentHistory(
|
||||
model_output=AgentOutput(
|
||||
current_state=AgentBrain(
|
||||
evaluation_previous_goal='Clicked button',
|
||||
memory='Button clicked',
|
||||
next_goal='Extract content',
|
||||
),
|
||||
action=[extract_action],
|
||||
),
|
||||
result=[
|
||||
ActionResult(
|
||||
is_done=False,
|
||||
extracted_content='Extracted text',
|
||||
error='Failed to extract completely',
|
||||
)
|
||||
],
|
||||
state=BrowserStateHistory(
|
||||
url='https://example.com/page2',
|
||||
title='Page 2',
|
||||
tabs=[TabInfo(url='https://example.com/page2', title='Page 2', page_id=2)],
|
||||
screenshot='screenshot2.png',
|
||||
interacted_element=[{'xpath': '//div[1]'}],
|
||||
),
|
||||
),
|
||||
AgentHistory(
|
||||
model_output=AgentOutput(
|
||||
current_state=AgentBrain(
|
||||
evaluation_previous_goal='Extracted content',
|
||||
memory='Content extracted',
|
||||
next_goal='Finish task',
|
||||
),
|
||||
action=[done_action],
|
||||
),
|
||||
result=[ActionResult(is_done=True, extracted_content='Task completed', error=None)],
|
||||
state=BrowserStateHistory(
|
||||
url='https://example.com/page2',
|
||||
title='Page 2',
|
||||
tabs=[TabInfo(url='https://example.com/page2', title='Page 2', page_id=2)],
|
||||
screenshot='screenshot3.png',
|
||||
interacted_element=[{'xpath': '//div[1]'}],
|
||||
),
|
||||
),
|
||||
]
|
||||
return AgentHistoryList(history=histories)
|
||||
|
||||
|
||||
def test_last_model_output(sample_history: AgentHistoryList):
|
||||
last_output = sample_history.last_action()
|
||||
print(last_output)
|
||||
assert last_output == {'done': {'text': 'Task completed'}}
|
||||
|
||||
|
||||
def test_get_errors(sample_history: AgentHistoryList):
|
||||
errors = sample_history.errors()
|
||||
assert len(errors) == 1
|
||||
assert errors[0] == 'Failed to extract completely'
|
||||
|
||||
|
||||
def test_final_result(sample_history: AgentHistoryList):
|
||||
assert sample_history.final_result() == 'Task completed'
|
||||
|
||||
|
||||
def test_is_done(sample_history: AgentHistoryList):
|
||||
assert sample_history.is_done() is True
|
||||
|
||||
|
||||
def test_urls(sample_history: AgentHistoryList):
|
||||
urls = sample_history.urls()
|
||||
assert 'https://example.com' in urls
|
||||
assert 'https://example.com/page2' in urls
|
||||
|
||||
|
||||
def test_all_screenshots(sample_history: AgentHistoryList):
|
||||
screenshots = sample_history.screenshots()
|
||||
assert len(screenshots) == 3
|
||||
assert screenshots == ['screenshot1.png', 'screenshot2.png', 'screenshot3.png']
|
||||
|
||||
|
||||
def test_all_model_outputs(sample_history: AgentHistoryList):
|
||||
outputs = sample_history.model_actions()
|
||||
print(f'DEBUG: {outputs[0]}')
|
||||
assert len(outputs) == 3
|
||||
# get first key value pair
|
||||
assert dict([next(iter(outputs[0].items()))]) == {'click_element': {'index': 1}}
|
||||
assert dict([next(iter(outputs[1].items()))]) == {'extract_page_content': {'value': 'text'}}
|
||||
assert dict([next(iter(outputs[2].items()))]) == {'done': {'text': 'Task completed'}}
|
||||
|
||||
|
||||
def test_all_model_outputs_filtered(sample_history: AgentHistoryList):
|
||||
filtered = sample_history.model_actions_filtered(include=['click_element'])
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]['click_element']['index'] == 1
|
||||
|
||||
|
||||
def test_empty_history():
|
||||
empty_history = AgentHistoryList(history=[])
|
||||
assert empty_history.last_action() is None
|
||||
assert empty_history.final_result() is None
|
||||
assert empty_history.is_done() is False
|
||||
assert len(empty_history.urls()) == 0
|
||||
|
||||
|
||||
# Add a test to verify action creation
|
||||
def test_action_creation(action_registry):
|
||||
click_action = action_registry(click_element={'index': 1})
|
||||
|
||||
assert click_action.model_dump(exclude_none=True) == {'click_element': {'index': 1}}
|
||||
|
||||
|
||||
# run this with:
|
||||
# pytest browser_use/agent/tests.py
|
||||
440
browser-use/browser_use/agent/views.py
Normal file
440
browser-use/browser_use/agent/views.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from openai import RateLimitError
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model
|
||||
|
||||
from browser_use.agent.message_manager.views import MessageManagerState
|
||||
from browser_use.agent.playwright_script_generator import PlaywrightScriptGenerator
|
||||
from browser_use.browser.browser import BrowserConfig
|
||||
from browser_use.browser.context import BrowserContextConfig
|
||||
from browser_use.browser.views import BrowserStateHistory
|
||||
from browser_use.controller.registry.views import ActionModel
|
||||
from browser_use.dom.history_tree_processor.service import (
|
||||
DOMElementNode,
|
||||
DOMHistoryElement,
|
||||
HistoryTreeProcessor,
|
||||
)
|
||||
from browser_use.dom.views import SelectorMap
|
||||
|
||||
ToolCallingMethod = Literal['function_calling', 'json_mode', 'raw', 'auto', 'tools']
|
||||
REQUIRED_LLM_API_ENV_VARS = {
|
||||
'ChatOpenAI': ['OPENAI_API_KEY'],
|
||||
'AzureChatOpenAI': ['AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_KEY'],
|
||||
'ChatBedrockConverse': ['ANTHROPIC_API_KEY'],
|
||||
'ChatAnthropic': ['ANTHROPIC_API_KEY'],
|
||||
'ChatGoogleGenerativeAI': ['GOOGLE_API_KEY'],
|
||||
'ChatDeepSeek': ['DEEPSEEK_API_KEY'],
|
||||
'ChatOllama': [],
|
||||
'ChatGrok': ['GROK_API_KEY'],
|
||||
}
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
"""Options for the agent"""
|
||||
|
||||
use_vision: bool = True
|
||||
use_vision_for_planner: bool = False
|
||||
save_conversation_path: str | None = None
|
||||
save_conversation_path_encoding: str | None = 'utf-8'
|
||||
max_failures: int = 3
|
||||
retry_delay: int = 10
|
||||
max_input_tokens: int = 128000
|
||||
validate_output: bool = False
|
||||
message_context: str | None = None
|
||||
generate_gif: bool | str = False
|
||||
available_file_paths: list[str] | None = None
|
||||
override_system_message: str | None = None
|
||||
extend_system_message: str | None = None
|
||||
include_attributes: list[str] = [
|
||||
'title',
|
||||
'type',
|
||||
'name',
|
||||
'role',
|
||||
'tabindex',
|
||||
'aria-label',
|
||||
'placeholder',
|
||||
'value',
|
||||
'alt',
|
||||
'aria-expanded',
|
||||
]
|
||||
max_actions_per_step: int = 10
|
||||
|
||||
tool_calling_method: ToolCallingMethod | None = 'auto'
|
||||
page_extraction_llm: BaseChatModel | None = None
|
||||
planner_llm: BaseChatModel | None = None
|
||||
planner_interval: int = 1 # Run planner every N steps
|
||||
is_planner_reasoning: bool = False # type: ignore
|
||||
extend_planner_system_message: str | None = None
|
||||
|
||||
# Playwright script generation setting
|
||||
save_playwright_script_path: str | None = None # Path to save the generated Playwright script
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""Holds all state information for an Agent"""
|
||||
|
||||
agent_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
n_steps: int = 1
|
||||
consecutive_failures: int = 0
|
||||
last_result: list[ActionResult] | None = None
|
||||
history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[]))
|
||||
last_plan: str | None = None
|
||||
paused: bool = False
|
||||
stopped: bool = False
|
||||
|
||||
message_manager_state: MessageManagerState = Field(default_factory=MessageManagerState)
|
||||
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStepInfo:
|
||||
step_number: int
|
||||
max_steps: int
|
||||
|
||||
def is_last_step(self) -> bool:
|
||||
"""Check if this is the last step"""
|
||||
return self.step_number >= self.max_steps - 1
|
||||
|
||||
|
||||
class ActionResult(BaseModel):
|
||||
"""Result of executing an action"""
|
||||
|
||||
is_done: bool | None = False
|
||||
success: bool | None = None
|
||||
extracted_content: str | None = None
|
||||
error: str | None = None
|
||||
include_in_memory: bool = False # whether to include in past messages as context or not
|
||||
|
||||
|
||||
class StepMetadata(BaseModel):
|
||||
"""Metadata for a single step including timing and token information"""
|
||||
|
||||
step_start_time: float
|
||||
step_end_time: float
|
||||
input_tokens: int # Approximate tokens from message manager for this step
|
||||
step_number: int
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float:
|
||||
"""Calculate step duration in seconds"""
|
||||
return self.step_end_time - self.step_start_time
|
||||
|
||||
|
||||
class AgentBrain(BaseModel):
|
||||
"""Current state of the agent"""
|
||||
|
||||
evaluation_previous_goal: str
|
||||
memory: str
|
||||
next_goal: str
|
||||
|
||||
|
||||
class AgentOutput(BaseModel):
|
||||
"""Output model for agent
|
||||
|
||||
@dev note: this model is extended with custom actions in AgentService. You can also use some fields that are not in this model as provided by the linter, as long as they are registered in the DynamicActions model.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
current_state: AgentBrain
|
||||
action: list[ActionModel] = Field(
|
||||
...,
|
||||
description='List of actions to execute',
|
||||
json_schema_extra={'min_items': 1}, # Ensure at least one action is provided
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def type_with_custom_actions(custom_actions: type[ActionModel]) -> type[AgentOutput]:
|
||||
"""Extend actions with custom actions"""
|
||||
model_ = create_model(
|
||||
'AgentOutput',
|
||||
__base__=AgentOutput,
|
||||
action=(
|
||||
list[custom_actions],
|
||||
Field(..., description='List of actions to execute', json_schema_extra={'min_items': 1}),
|
||||
),
|
||||
__module__=AgentOutput.__module__,
|
||||
)
|
||||
model_.__doc__ = 'AgentOutput model with custom actions'
|
||||
return model_
|
||||
|
||||
|
||||
class AgentHistory(BaseModel):
|
||||
"""History item for agent actions"""
|
||||
|
||||
model_output: AgentOutput | None
|
||||
result: list[ActionResult]
|
||||
state: BrowserStateHistory
|
||||
metadata: StepMetadata | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
@staticmethod
|
||||
def get_interacted_element(model_output: AgentOutput, selector_map: SelectorMap) -> list[DOMHistoryElement | None]:
|
||||
elements = []
|
||||
for action in model_output.action:
|
||||
index = action.get_index()
|
||||
if index is not None and index in selector_map:
|
||||
el: DOMElementNode = selector_map[index]
|
||||
elements.append(HistoryTreeProcessor.convert_dom_element_to_history_element(el))
|
||||
else:
|
||||
elements.append(None)
|
||||
return elements
|
||||
|
||||
def model_dump(self, **kwargs) -> dict[str, Any]:
|
||||
"""Custom serialization handling circular references"""
|
||||
|
||||
# Handle action serialization
|
||||
model_output_dump = None
|
||||
if self.model_output:
|
||||
action_dump = [action.model_dump(exclude_none=True) for action in self.model_output.action]
|
||||
model_output_dump = {
|
||||
'current_state': self.model_output.current_state.model_dump(),
|
||||
'action': action_dump, # This preserves the actual action data
|
||||
}
|
||||
|
||||
return {
|
||||
'model_output': model_output_dump,
|
||||
'result': [r.model_dump(exclude_none=True) for r in self.result],
|
||||
'state': self.state.to_dict(),
|
||||
'metadata': self.metadata.model_dump() if self.metadata else None,
|
||||
}
|
||||
|
||||
|
||||
class AgentHistoryList(BaseModel):
|
||||
"""List of agent history items"""
|
||||
|
||||
history: list[AgentHistory]
|
||||
|
||||
def total_duration_seconds(self) -> float:
|
||||
"""Get total duration of all steps in seconds"""
|
||||
total = 0.0
|
||||
for h in self.history:
|
||||
if h.metadata:
|
||||
total += h.metadata.duration_seconds
|
||||
return total
|
||||
|
||||
def total_input_tokens(self) -> int:
|
||||
"""
|
||||
Get total tokens used across all steps.
|
||||
Note: These are from the approximate token counting of the message manager.
|
||||
For accurate token counting, use tools like LangChain Smith or OpenAI's token counters.
|
||||
"""
|
||||
total = 0
|
||||
for h in self.history:
|
||||
if h.metadata:
|
||||
total += h.metadata.input_tokens
|
||||
return total
|
||||
|
||||
def input_token_usage(self) -> list[int]:
|
||||
"""Get token usage for each step"""
|
||||
return [h.metadata.input_tokens for h in self.history if h.metadata]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Representation of the AgentHistoryList object"""
|
||||
return f'AgentHistoryList(all_results={self.action_results()}, all_model_outputs={self.model_actions()})'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Representation of the AgentHistoryList object"""
|
||||
return self.__str__()
|
||||
|
||||
def save_to_file(self, filepath: str | Path) -> None:
|
||||
"""Save history to JSON file with proper serialization"""
|
||||
try:
|
||||
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
||||
data = self.model_dump()
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def save_as_playwright_script(
|
||||
self,
|
||||
output_path: str | Path,
|
||||
sensitive_data_keys: list[str] | None = None,
|
||||
browser_config: BrowserConfig | None = None,
|
||||
context_config: BrowserContextConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generates a Playwright script based on the agent's history and saves it to a file.
|
||||
Args:
|
||||
output_path: The path where the generated Python script will be saved.
|
||||
sensitive_data_keys: A list of keys used as placeholders for sensitive data
|
||||
(e.g., ['username_placeholder', 'password_placeholder']).
|
||||
These will be loaded from environment variables in the
|
||||
generated script.
|
||||
browser_config: Configuration of the original Browser instance.
|
||||
context_config: Configuration of the original BrowserContext instance.
|
||||
"""
|
||||
try:
|
||||
serialized_history = self.model_dump()['history']
|
||||
generator = PlaywrightScriptGenerator(serialized_history, sensitive_data_keys, browser_config, context_config)
|
||||
script_content = generator.generate_script_content()
|
||||
path_obj = Path(output_path)
|
||||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path_obj, 'w', encoding='utf-8') as f:
|
||||
f.write(script_content)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def model_dump(self, **kwargs) -> dict[str, Any]:
|
||||
"""Custom serialization that properly uses AgentHistory's model_dump"""
|
||||
return {
|
||||
'history': [h.model_dump(**kwargs) for h in self.history],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, filepath: str | Path, output_model: type[AgentOutput]) -> AgentHistoryList:
|
||||
"""Load history from JSON file"""
|
||||
with open(filepath, encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
# loop through history and validate output_model actions to enrich with custom actions
|
||||
for h in data['history']:
|
||||
if h['model_output']:
|
||||
if isinstance(h['model_output'], dict):
|
||||
h['model_output'] = output_model.model_validate(h['model_output'])
|
||||
else:
|
||||
h['model_output'] = None
|
||||
if 'interacted_element' not in h['state']:
|
||||
h['state']['interacted_element'] = None
|
||||
history = cls.model_validate(data)
|
||||
return history
|
||||
|
||||
def last_action(self) -> None | dict:
|
||||
"""Last action in history"""
|
||||
if self.history and self.history[-1].model_output:
|
||||
return self.history[-1].model_output.action[-1].model_dump(exclude_none=True)
|
||||
return None
|
||||
|
||||
def errors(self) -> list[str | None]:
|
||||
"""Get all errors from history, with None for steps without errors"""
|
||||
errors = []
|
||||
for h in self.history:
|
||||
step_errors = [r.error for r in h.result if r.error]
|
||||
|
||||
# each step can have only one error
|
||||
errors.append(step_errors[0] if step_errors else None)
|
||||
return errors
|
||||
|
||||
def final_result(self) -> None | str:
|
||||
"""Final result from history"""
|
||||
if self.history and self.history[-1].result[-1].extracted_content:
|
||||
return self.history[-1].result[-1].extracted_content
|
||||
return None
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if the agent is done"""
|
||||
if self.history and len(self.history[-1].result) > 0:
|
||||
last_result = self.history[-1].result[-1]
|
||||
return last_result.is_done is True
|
||||
return False
|
||||
|
||||
def is_successful(self) -> bool | None:
|
||||
"""Check if the agent completed successfully - the agent decides in the last step if it was successful or not. None if not done yet."""
|
||||
if self.history and len(self.history[-1].result) > 0:
|
||||
last_result = self.history[-1].result[-1]
|
||||
if last_result.is_done is True:
|
||||
return last_result.success
|
||||
return None
|
||||
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if the agent has any non-None errors"""
|
||||
return any(error is not None for error in self.errors())
|
||||
|
||||
def urls(self) -> list[str | None]:
|
||||
"""Get all unique URLs from history"""
|
||||
return [h.state.url if h.state.url is not None else None for h in self.history]
|
||||
|
||||
def screenshots(self) -> list[str | None]:
|
||||
"""Get all screenshots from history"""
|
||||
return [h.state.screenshot if h.state.screenshot is not None else None for h in self.history]
|
||||
|
||||
def action_names(self) -> list[str]:
|
||||
"""Get all action names from history"""
|
||||
action_names = []
|
||||
for action in self.model_actions():
|
||||
actions = list(action.keys())
|
||||
if actions:
|
||||
action_names.append(actions[0])
|
||||
return action_names
|
||||
|
||||
def model_thoughts(self) -> list[AgentBrain]:
|
||||
"""Get all thoughts from history"""
|
||||
return [h.model_output.current_state for h in self.history if h.model_output]
|
||||
|
||||
def model_outputs(self) -> list[AgentOutput]:
|
||||
"""Get all model outputs from history"""
|
||||
return [h.model_output for h in self.history if h.model_output]
|
||||
|
||||
# get all actions with params
|
||||
def model_actions(self) -> list[dict]:
|
||||
"""Get all actions from history"""
|
||||
outputs = []
|
||||
|
||||
for h in self.history:
|
||||
if h.model_output:
|
||||
for action, interacted_element in zip(h.model_output.action, h.state.interacted_element):
|
||||
output = action.model_dump(exclude_none=True)
|
||||
output['interacted_element'] = interacted_element
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
def action_results(self) -> list[ActionResult]:
|
||||
"""Get all results from history"""
|
||||
results = []
|
||||
for h in self.history:
|
||||
results.extend([r for r in h.result if r])
|
||||
return results
|
||||
|
||||
def extracted_content(self) -> list[str]:
|
||||
"""Get all extracted content from history"""
|
||||
content = []
|
||||
for h in self.history:
|
||||
content.extend([r.extracted_content for r in h.result if r.extracted_content])
|
||||
return content
|
||||
|
||||
def model_actions_filtered(self, include: list[str] | None = None) -> list[dict]:
|
||||
"""Get all model actions from history as JSON"""
|
||||
if include is None:
|
||||
include = []
|
||||
outputs = self.model_actions()
|
||||
result = []
|
||||
for o in outputs:
|
||||
for i in include:
|
||||
if i == list(o.keys())[0]:
|
||||
result.append(o)
|
||||
return result
|
||||
|
||||
def number_of_steps(self) -> int:
|
||||
"""Get the number of steps in the history"""
|
||||
return len(self.history)
|
||||
|
||||
|
||||
class AgentError:
|
||||
"""Container for agent error handling"""
|
||||
|
||||
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
|
||||
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
|
||||
NO_VALID_ACTION = 'No valid action found'
|
||||
|
||||
@staticmethod
|
||||
def format_error(error: Exception, include_trace: bool = False) -> str:
|
||||
"""Format error message based on error type and optionally include trace"""
|
||||
message = ''
|
||||
if isinstance(error, ValidationError):
|
||||
return f'{AgentError.VALIDATION_ERROR}\nDetails: {str(error)}'
|
||||
if isinstance(error, RateLimitError):
|
||||
return AgentError.RATE_LIMIT_ERROR
|
||||
if include_trace:
|
||||
return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}'
|
||||
return f'{str(error)}'
|
||||
Loading…
Add table
Add a link
Reference in a new issue