[Add] browser-use and main.py

This commit is contained in:
tv0924@icloud.com 2025-05-18 21:57:54 +09:00
commit 96914d44ac
221 changed files with 30952 additions and 1 deletions

View file

@ -0,0 +1,51 @@
# Codebase Structure
> The code structure inspired by https://github.com/Netflix/dispatch.
Very good structure on how to make a scalable codebase is also in [this repo](https://github.com/zhanymkanov/fastapi-best-practices).
Just a brief document about how we should structure our backend codebase.
## Code Structure
```markdown
src/
/<service name>/
models.py
services.py
prompts.py
views.py
utils.py
routers.py
/_<subservice name>/
```
### Service.py
Always a single file, except if it becomes too long - more than ~500 lines, split it into \_subservices
### Views.py
Always split the views into two parts
```python
# All
...
# Requests
...
# Responses
...
```
If too long → split into multiple files
### Prompts.py
Single file; if too long → split into multiple files (one prompt per file or so)
### Routers.py
Never split into more than one file

View file

@ -0,0 +1,35 @@
import warnings
# Suppress specific deprecation warnings from FAISS
warnings.filterwarnings('ignore', category=DeprecationWarning, module='faiss.loader')
warnings.filterwarnings('ignore', message='builtin type SwigPyPacked has no __module__ attribute')
warnings.filterwarnings('ignore', message='builtin type SwigPyObject has no __module__ attribute')
warnings.filterwarnings('ignore', message='builtin type swigvarlink has no __module__ attribute')
from browser_use.logging_config import setup_logging
setup_logging()
from browser_use.agent.prompts import SystemPrompt as SystemPrompt
from browser_use.agent.service import Agent as Agent
from browser_use.agent.views import ActionModel as ActionModel
from browser_use.agent.views import ActionResult as ActionResult
from browser_use.agent.views import AgentHistoryList as AgentHistoryList
from browser_use.browser.browser import Browser as Browser
from browser_use.browser.browser import BrowserConfig as BrowserConfig
from browser_use.browser.context import BrowserContextConfig
from browser_use.controller.service import Controller as Controller
from browser_use.dom.service import DomService as DomService
__all__ = [
'Agent',
'Browser',
'BrowserConfig',
'Controller',
'DomService',
'SystemPrompt',
'ActionResult',
'ActionModel',
'AgentHistoryList',
'BrowserContextConfig',
]

View file

@ -0,0 +1,370 @@
from __future__ import annotations
import base64
import io
import logging
import os
import platform
from typing import TYPE_CHECKING
from browser_use.agent.views import AgentHistoryList
if TYPE_CHECKING:
from PIL import Image, ImageFont
logger = logging.getLogger(__name__)
def decode_unicode_escapes_to_utf8(text: str) -> str:
"""Handle decoding any unicode escape sequences embedded in a string (needed to render non-ASCII languages like chinese or arabic in the GIF overlay text)"""
if r'\u' not in text:
# doesn't have any escape sequences that need to be decoded
return text
try:
# Try to decode Unicode escape sequences
return text.encode('latin1').decode('unicode_escape')
except (UnicodeEncodeError, UnicodeDecodeError):
# logger.debug(f"Failed to decode unicode escape sequences while generating gif text: {text}")
return text
def create_history_gif(
task: str,
history: AgentHistoryList,
#
output_path: str = 'agent_history.gif',
duration: int = 3000,
show_goals: bool = True,
show_task: bool = True,
show_logo: bool = False,
font_size: int = 40,
title_font_size: int = 56,
goal_font_size: int = 44,
margin: int = 40,
line_spacing: float = 1.5,
) -> None:
"""Create a GIF from the agent's history with overlaid task and goal text."""
if not history.history:
logger.warning('No history to create GIF from')
return
from PIL import Image, ImageFont
images = []
# if history is empty or first screenshot is None, we can't create a gif
if not history.history or not history.history[0].state.screenshot:
logger.warning('No history or first screenshot to create GIF from')
return
# Try to load nicer fonts
try:
# Try different font options in order of preference
# ArialUni is a font that comes with Office and can render most non-alphabet characters
font_options = [
'Microsoft YaHei', # 微软雅黑
'SimHei', # 黑体
'SimSun', # 宋体
'Noto Sans CJK SC', # 思源黑体
'WenQuanYi Micro Hei', # 文泉驿微米黑
'Helvetica',
'Arial',
'DejaVuSans',
'Verdana',
]
font_loaded = False
for font_name in font_options:
try:
if platform.system() == 'Windows':
# Need to specify the abs font path on Windows
font_name = os.path.join(os.getenv('WIN_FONT_DIR', 'C:\\Windows\\Fonts'), font_name + '.ttf')
regular_font = ImageFont.truetype(font_name, font_size)
title_font = ImageFont.truetype(font_name, title_font_size)
goal_font = ImageFont.truetype(font_name, goal_font_size)
font_loaded = True
break
except OSError:
continue
if not font_loaded:
raise OSError('No preferred fonts found')
except OSError:
regular_font = ImageFont.load_default()
title_font = ImageFont.load_default()
goal_font = regular_font
# Load logo if requested
logo = None
if show_logo:
try:
logo = Image.open('./static/browser-use.png')
# Resize logo to be small (e.g., 40px height)
logo_height = 150
aspect_ratio = logo.width / logo.height
logo_width = int(logo_height * aspect_ratio)
logo = logo.resize((logo_width, logo_height), Image.Resampling.LANCZOS)
except Exception as e:
logger.warning(f'Could not load logo: {e}')
# Create task frame if requested
if show_task and task:
task_frame = _create_task_frame(
task,
history.history[0].state.screenshot,
title_font, # type: ignore
regular_font, # type: ignore
logo,
line_spacing,
)
images.append(task_frame)
# Process each history item
for i, item in enumerate(history.history, 1):
if not item.state.screenshot:
continue
# Convert base64 screenshot to PIL Image
img_data = base64.b64decode(item.state.screenshot)
image = Image.open(io.BytesIO(img_data))
if show_goals and item.model_output:
image = _add_overlay_to_image(
image=image,
step_number=i,
goal_text=item.model_output.current_state.next_goal,
regular_font=regular_font, # type: ignore
title_font=title_font, # type: ignore
margin=margin,
logo=logo,
)
images.append(image)
if images:
# Save the GIF
images[0].save(
output_path,
save_all=True,
append_images=images[1:],
duration=duration,
loop=0,
optimize=False,
)
logger.info(f'Created GIF at {output_path}')
else:
logger.warning('No images found in history to create GIF')
def _create_task_frame(
task: str,
first_screenshot: str,
title_font: ImageFont.FreeTypeFont,
regular_font: ImageFont.FreeTypeFont,
logo: Image.Image | None = None,
line_spacing: float = 1.5,
) -> Image.Image:
"""Create initial frame showing the task."""
from PIL import Image, ImageDraw, ImageFont
img_data = base64.b64decode(first_screenshot)
template = Image.open(io.BytesIO(img_data))
image = Image.new('RGB', template.size, (0, 0, 0))
draw = ImageDraw.Draw(image)
# Calculate vertical center of image
center_y = image.height // 2
# Draw task text with dynamic font size based on task length
margin = 140 # Increased margin
max_width = image.width - (2 * margin)
# Dynamic font size calculation based on task length
# Start with base font size (regular + 16)
base_font_size = regular_font.size + 16
min_font_size = max(regular_font.size - 10, 16) # Don't go below 16pt
max_font_size = base_font_size # Cap at the base font size
# Calculate dynamic font size based on text length and complexity
# Longer texts get progressively smaller fonts
text_length = len(task)
if text_length > 200:
# For very long text, reduce font size logarithmically
font_size = max(base_font_size - int(10 * (text_length / 200)), min_font_size)
else:
font_size = base_font_size
larger_font = ImageFont.truetype(regular_font.path, font_size)
# Generate wrapped text with the calculated font size
wrapped_text = _wrap_text(task, larger_font, max_width)
# Calculate line height with spacing
line_height = larger_font.size * line_spacing
# Split text into lines and draw with custom spacing
lines = wrapped_text.split('\n')
total_height = line_height * len(lines)
# Start position for first line
text_y = center_y - (total_height / 2) + 50 # Shifted down slightly
for line in lines:
# Get line width for centering
line_bbox = draw.textbbox((0, 0), line, font=larger_font)
text_x = (image.width - (line_bbox[2] - line_bbox[0])) // 2
draw.text(
(text_x, text_y),
line,
font=larger_font,
fill=(255, 255, 255),
)
text_y += line_height
# Add logo if provided (top right corner)
if logo:
logo_margin = 20
logo_x = image.width - logo.width - logo_margin
image.paste(logo, (logo_x, logo_margin), logo if logo.mode == 'RGBA' else None)
return image
def _add_overlay_to_image(
image: Image.Image,
step_number: int,
goal_text: str,
regular_font: ImageFont.FreeTypeFont,
title_font: ImageFont.FreeTypeFont,
margin: int,
logo: Image.Image | None = None,
display_step: bool = True,
text_color: tuple[int, int, int, int] = (255, 255, 255, 255),
text_box_color: tuple[int, int, int, int] = (0, 0, 0, 255),
) -> Image.Image:
"""Add step number and goal overlay to an image."""
from PIL import Image, ImageDraw
goal_text = decode_unicode_escapes_to_utf8(goal_text)
image = image.convert('RGBA')
txt_layer = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(txt_layer)
if display_step:
# Add step number (bottom left)
step_text = str(step_number)
step_bbox = draw.textbbox((0, 0), step_text, font=title_font)
step_width = step_bbox[2] - step_bbox[0]
step_height = step_bbox[3] - step_bbox[1]
# Position step number in bottom left
x_step = margin + 10 # Slight additional offset from edge
y_step = image.height - margin - step_height - 10 # Slight offset from bottom
# Draw rounded rectangle background for step number
padding = 20 # Increased padding
step_bg_bbox = (
x_step - padding,
y_step - padding,
x_step + step_width + padding,
y_step + step_height + padding,
)
draw.rounded_rectangle(
step_bg_bbox,
radius=15, # Add rounded corners
fill=text_box_color,
)
# Draw step number
draw.text(
(x_step, y_step),
step_text,
font=title_font,
fill=text_color,
)
# Draw goal text (centered, bottom)
max_width = image.width - (4 * margin)
wrapped_goal = _wrap_text(goal_text, title_font, max_width)
goal_bbox = draw.multiline_textbbox((0, 0), wrapped_goal, font=title_font)
goal_width = goal_bbox[2] - goal_bbox[0]
goal_height = goal_bbox[3] - goal_bbox[1]
# Center goal text horizontally, place above step number
x_goal = (image.width - goal_width) // 2
y_goal = y_step - goal_height - padding * 4 # More space between step and goal
# Draw rounded rectangle background for goal
padding_goal = 25 # Increased padding for goal
goal_bg_bbox = (
x_goal - padding_goal, # Remove extra space for logo
y_goal - padding_goal,
x_goal + goal_width + padding_goal,
y_goal + goal_height + padding_goal,
)
draw.rounded_rectangle(
goal_bg_bbox,
radius=15, # Add rounded corners
fill=text_box_color,
)
# Draw goal text
draw.multiline_text(
(x_goal, y_goal),
wrapped_goal,
font=title_font,
fill=text_color,
align='center',
)
# Add logo if provided (top right corner)
if logo:
logo_layer = Image.new('RGBA', image.size, (0, 0, 0, 0))
logo_margin = 20
logo_x = image.width - logo.width - logo_margin
logo_layer.paste(logo, (logo_x, logo_margin), logo if logo.mode == 'RGBA' else None)
txt_layer = Image.alpha_composite(logo_layer, txt_layer)
# Composite and convert
result = Image.alpha_composite(image, txt_layer)
return result.convert('RGB')
def _wrap_text(text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
"""
Wrap text to fit within a given width.
Args:
text: Text to wrap
font: Font to use for text
max_width: Maximum width in pixels
Returns:
Wrapped text with newlines
"""
text = decode_unicode_escapes_to_utf8(text)
words = text.split()
lines = []
current_line = []
for word in words:
current_line.append(word)
line = ' '.join(current_line)
bbox = font.getbbox(line)
if bbox[2] > max_width:
if len(current_line) == 1:
lines.append(current_line.pop())
else:
current_line.pop()
lines.append(' '.join(current_line))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
return '\n'.join(lines)

View file

@ -0,0 +1,4 @@
from browser_use.agent.memory.service import Memory
from browser_use.agent.memory.views import MemoryConfig
__all__ = ['Memory', 'MemoryConfig']

View file

@ -0,0 +1,151 @@
from __future__ import annotations
import logging
import os
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
)
from langchain_core.messages.utils import convert_to_openai_messages
from browser_use.agent.memory.views import MemoryConfig
from browser_use.agent.message_manager.service import MessageManager
from browser_use.agent.message_manager.views import ManagedMessage, MessageMetadata
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
class Memory:
"""
Manages procedural memory for agents.
This class implements a procedural memory management system using Mem0 that transforms agent interaction history
into concise, structured representations at specified intervals. It serves to optimize context window
utilization during extended task execution by converting verbose historical information into compact,
yet comprehensive memory constructs that preserve essential operational knowledge.
"""
def __init__(
self,
message_manager: MessageManager,
llm: BaseChatModel,
config: MemoryConfig | None = None,
):
self.message_manager = message_manager
self.llm = llm
# Initialize configuration with defaults based on the LLM if not provided
if config is None:
self.config = MemoryConfig(llm_instance=llm, agent_id=f'agent_{id(self)}')
# Set appropriate embedder based on LLM type
llm_class = llm.__class__.__name__
if llm_class == 'ChatOpenAI':
self.config.embedder_provider = 'openai'
self.config.embedder_model = 'text-embedding-3-small'
self.config.embedder_dims = 1536
elif llm_class == 'ChatGoogleGenerativeAI':
self.config.embedder_provider = 'gemini'
self.config.embedder_model = 'models/text-embedding-004'
self.config.embedder_dims = 768
elif llm_class == 'ChatOllama':
self.config.embedder_provider = 'ollama'
self.config.embedder_model = 'nomic-embed-text'
self.config.embedder_dims = 512
else:
# Ensure LLM instance is set in the config
self.config = MemoryConfig(**dict(config)) # re-validate untrusted user-provided config
self.config.llm_instance = llm
# Check for required packages
try:
# also disable mem0's telemetry when ANONYMIZED_TELEMETRY=False
if os.getenv('ANONYMIZED_TELEMETRY', 'true').lower()[0] in 'fn0':
os.environ['MEM0_TELEMETRY'] = 'False'
from mem0 import Memory as Mem0Memory
except ImportError:
raise ImportError('mem0 is required when enable_memory=True. Please install it with `pip install mem0`.')
if self.config.embedder_provider == 'huggingface':
try:
# check that required package is installed if huggingface is used
from sentence_transformers import SentenceTransformer # noqa: F401
except ImportError:
raise ImportError(
'sentence_transformers is required when enable_memory=True and embedder_provider="huggingface". Please install it with `pip install sentence-transformers`.'
)
# Initialize Mem0 with the configuration
self.mem0 = Mem0Memory.from_config(config_dict=self.config.full_config_dict)
@time_execution_sync('--create_procedural_memory')
def create_procedural_memory(self, current_step: int) -> None:
"""
Create a procedural memory if needed based on the current step.
Args:
current_step: The current step number of the agent
"""
logger.info(f'Creating procedural memory at step {current_step}')
# Get all messages
all_messages = self.message_manager.state.history.messages
# Separate messages into those to keep as-is and those to process for memory
new_messages = []
messages_to_process = []
for msg in all_messages:
if isinstance(msg, ManagedMessage) and msg.metadata.message_type in {'init', 'memory'}:
# Keep system and memory messages as they are
new_messages.append(msg)
else:
if len(msg.message.content) > 0:
messages_to_process.append(msg)
# Need at least 2 messages to create a meaningful summary
if len(messages_to_process) <= 1:
logger.info('Not enough non-memory messages to summarize')
return
# Create a procedural memory
memory_content = self._create([m.message for m in messages_to_process], current_step)
if not memory_content:
logger.warning('Failed to create procedural memory')
return
# Replace the processed messages with the consolidated memory
memory_message = HumanMessage(content=memory_content)
memory_tokens = self.message_manager._count_tokens(memory_message)
memory_metadata = MessageMetadata(tokens=memory_tokens, message_type='memory')
# Calculate the total tokens being removed
removed_tokens = sum(m.metadata.tokens for m in messages_to_process)
# Add the memory message
new_messages.append(ManagedMessage(message=memory_message, metadata=memory_metadata))
# Update the history
self.message_manager.state.history.messages = new_messages
self.message_manager.state.history.current_tokens -= removed_tokens
self.message_manager.state.history.current_tokens += memory_tokens
logger.info(f'Messages consolidated: {len(messages_to_process)} messages converted to procedural memory')
def _create(self, messages: list[BaseMessage], current_step: int) -> str | None:
parsed_messages = convert_to_openai_messages(messages)
try:
results = self.mem0.add(
messages=parsed_messages,
agent_id=self.config.agent_id,
memory_type='procedural_memory',
metadata={'step': current_step},
)
if len(results.get('results', [])):
return results.get('results', [])[0].get('memory')
return None
except Exception as e:
logger.error(f'Error creating procedural memory: {e}')
return None

View file

@ -0,0 +1,67 @@
from typing import Any, Literal
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic import BaseModel, ConfigDict, Field
class MemoryConfig(BaseModel):
"""Configuration for procedural memory."""
model_config = ConfigDict(
from_attributes=True, validate_default=True, revalidate_instances='always', validate_assignment=True
)
# Memory settings
agent_id: str = Field(default='browser_use_agent', min_length=1)
memory_interval: int = Field(default=10, gt=1, lt=100)
# Embedder settings
embedder_provider: Literal['openai', 'gemini', 'ollama', 'huggingface'] = 'huggingface'
embedder_model: str = Field(min_length=2, default='all-MiniLM-L6-v2')
embedder_dims: int = Field(default=384, gt=10, lt=10000)
# LLM settings - the LLM instance can be passed separately
llm_provider: Literal['langchain'] = 'langchain'
llm_instance: BaseChatModel | None = None
# Vector store settings
vector_store_provider: Literal['faiss'] = 'faiss'
vector_store_base_path: str = Field(default='/tmp/mem0')
@property
def vector_store_path(self) -> str:
"""Returns the full vector store path for the current configuration. e.g. /tmp/mem0_384_faiss"""
return f'{self.vector_store_base_path}_{self.embedder_dims}_{self.vector_store_provider}'
@property
def embedder_config_dict(self) -> dict[str, Any]:
"""Returns the embedder configuration dictionary."""
return {
'provider': self.embedder_provider,
'config': {'model': self.embedder_model, 'embedding_dims': self.embedder_dims},
}
@property
def llm_config_dict(self) -> dict[str, Any]:
"""Returns the LLM configuration dictionary."""
return {'provider': self.llm_provider, 'config': {'model': self.llm_instance}}
@property
def vector_store_config_dict(self) -> dict[str, Any]:
"""Returns the vector store configuration dictionary."""
return {
'provider': self.vector_store_provider,
'config': {
'embedding_model_dims': self.embedder_dims,
'path': self.vector_store_path,
},
}
@property
def full_config_dict(self) -> dict[str, dict[str, Any]]:
"""Returns the complete configuration dictionary for Mem0."""
return {
'embedder': self.embedder_config_dict,
'llm': self.llm_config_dict,
'vector_store': self.vector_store_config_dict,
}

View file

@ -0,0 +1,329 @@
from __future__ import annotations
import logging
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from pydantic import BaseModel
from browser_use.agent.message_manager.views import MessageMetadata
from browser_use.agent.prompts import AgentMessagePrompt
from browser_use.agent.views import ActionResult, AgentOutput, AgentStepInfo, MessageManagerState
from browser_use.browser.views import BrowserState
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
class MessageManagerSettings(BaseModel):
max_input_tokens: int = 128000
estimated_characters_per_token: int = 3
image_tokens: int = 800
include_attributes: list[str] = []
message_context: str | None = None
sensitive_data: dict[str, str] | None = None
available_file_paths: list[str] | None = None
class MessageManager:
def __init__(
self,
task: str,
system_message: SystemMessage,
settings: MessageManagerSettings = MessageManagerSettings(),
state: MessageManagerState = MessageManagerState(),
):
self.task = task
self.settings = settings
self.state = state
self.system_prompt = system_message
# Only initialize messages if state is empty
if len(self.state.history.messages) == 0:
self._init_messages()
def _init_messages(self) -> None:
"""Initialize the message history with system message, context, task, and other initial messages"""
self._add_message_with_tokens(self.system_prompt, message_type='init')
if self.settings.message_context:
context_message = HumanMessage(content='Context for the task' + self.settings.message_context)
self._add_message_with_tokens(context_message, message_type='init')
task_message = HumanMessage(
content=f'Your ultimate task is: """{self.task}""". If you achieved your ultimate task, stop everything and use the done action in the next step to complete the task. If not, continue as usual.'
)
self._add_message_with_tokens(task_message, message_type='init')
if self.settings.sensitive_data:
info = f'Here are placeholders for sensitive data: {list(self.settings.sensitive_data.keys())}'
info += '\nTo use them, write <secret>the placeholder name</secret>'
info_message = HumanMessage(content=info)
self._add_message_with_tokens(info_message, message_type='init')
placeholder_message = HumanMessage(content='Example output:')
self._add_message_with_tokens(placeholder_message, message_type='init')
example_tool_call = AIMessage(
content='',
tool_calls=[
{
'name': 'AgentOutput',
'args': {
'current_state': {
'evaluation_previous_goal': """
Success - I successfully clicked on the 'Apple' link from the Google Search results page,
which directed me to the 'Apple' company homepage. This is a good start toward finding
the best place to buy a new iPhone as the Apple website often list iPhones for sale.
""".strip(),
'memory': """
I searched for 'iPhone retailers' on Google. From the Google Search results page,
I used the 'click_element_by_index' tool to click on element at index [45] labeled 'Best Buy' but calling
the tool did not direct me to a new page. I then used the 'click_element_by_index' tool to click
on element at index [82] labeled 'Apple' which redirected me to the 'Apple' company homepage.
Currently at step 3/15.
""".strip(),
'next_goal': """
Looking at reported structure of the current page, I can see the item '[127]<h3 iPhone/>'
in the content. I think this button will lead to more information and potentially prices
for iPhones. I'll click on the link at index [127] using the 'click_element_by_index'
tool and hope to see prices on the next page.
""".strip(),
},
'action': [{'click_element_by_index': {'index': 127}}],
},
'id': str(self.state.tool_id),
'type': 'tool_call',
},
],
)
self._add_message_with_tokens(example_tool_call, message_type='init')
self.add_tool_message(content='Browser started', message_type='init')
placeholder_message = HumanMessage(content='[Your task history memory starts here]')
self._add_message_with_tokens(placeholder_message)
if self.settings.available_file_paths:
filepaths_msg = HumanMessage(content=f'Here are file paths you can use: {self.settings.available_file_paths}')
self._add_message_with_tokens(filepaths_msg, message_type='init')
def add_new_task(self, new_task: str) -> None:
content = f'Your new ultimate task is: """{new_task}""". Take the previous context into account and finish your new ultimate task. '
msg = HumanMessage(content=content)
self._add_message_with_tokens(msg)
self.task = new_task
@time_execution_sync('--add_state_message')
def add_state_message(
self,
state: BrowserState,
result: list[ActionResult] | None = None,
step_info: AgentStepInfo | None = None,
use_vision=True,
) -> None:
"""Add browser state as human message"""
# if keep in memory, add to directly to history and add state without result
if result:
for r in result:
if r.include_in_memory:
if r.extracted_content:
msg = HumanMessage(content='Action result: ' + str(r.extracted_content))
self._add_message_with_tokens(msg)
if r.error:
# if endswith \n, remove it
if r.error.endswith('\n'):
r.error = r.error[:-1]
# get only last line of error
last_line = r.error.split('\n')[-1]
msg = HumanMessage(content='Action error: ' + last_line)
self._add_message_with_tokens(msg)
result = None # if result in history, we dont want to add it again
# otherwise add state message and result to next message (which will not stay in memory)
state_message = AgentMessagePrompt(
state,
result,
include_attributes=self.settings.include_attributes,
step_info=step_info,
).get_user_message(use_vision)
self._add_message_with_tokens(state_message)
def add_model_output(self, model_output: AgentOutput) -> None:
"""Add model output as AI message"""
tool_calls = [
{
'name': 'AgentOutput',
'args': model_output.model_dump(mode='json', exclude_unset=True),
'id': str(self.state.tool_id),
'type': 'tool_call',
}
]
msg = AIMessage(
content='',
tool_calls=tool_calls,
)
self._add_message_with_tokens(msg)
# empty tool response
self.add_tool_message(content='')
def add_plan(self, plan: str | None, position: int | None = None) -> None:
if plan:
msg = AIMessage(content=plan)
self._add_message_with_tokens(msg, position)
@time_execution_sync('--get_messages')
def get_messages(self) -> list[BaseMessage]:
"""Get current message list, potentially trimmed to max tokens"""
msg = [m.message for m in self.state.history.messages]
# debug which messages are in history with token count # log
total_input_tokens = 0
logger.debug(f'Messages in history: {len(self.state.history.messages)}:')
for m in self.state.history.messages:
total_input_tokens += m.metadata.tokens
logger.debug(f'{m.message.__class__.__name__} - Token count: {m.metadata.tokens}')
logger.debug(f'Total input tokens: {total_input_tokens}')
return msg
def _add_message_with_tokens(
self, message: BaseMessage, position: int | None = None, message_type: str | None = None
) -> None:
"""Add message with token count metadata
position: None for last, -1 for second last, etc.
"""
# filter out sensitive data from the message
if self.settings.sensitive_data:
message = self._filter_sensitive_data(message)
token_count = self._count_tokens(message)
metadata = MessageMetadata(tokens=token_count, message_type=message_type)
self.state.history.add_message(message, metadata, position)
@time_execution_sync('--filter_sensitive_data')
def _filter_sensitive_data(self, message: BaseMessage) -> BaseMessage:
"""Filter out sensitive data from the message"""
def replace_sensitive(value: str) -> str:
if not self.settings.sensitive_data:
return value
# Create a dictionary with all key-value pairs from sensitive_data where value is not None or empty
valid_sensitive_data = {k: v for k, v in self.settings.sensitive_data.items() if v}
# If there are no valid sensitive data entries, just return the original value
if not valid_sensitive_data:
logger.warning('No valid entries found in sensitive_data dictionary')
return value
# Replace all valid sensitive data values with their placeholder tags
for key, val in valid_sensitive_data.items():
value = value.replace(val, f'<secret>{key}</secret>')
return value
if isinstance(message.content, str):
message.content = replace_sensitive(message.content)
elif isinstance(message.content, list):
for i, item in enumerate(message.content):
if isinstance(item, dict) and 'text' in item:
item['text'] = replace_sensitive(item['text'])
message.content[i] = item
return message
def _count_tokens(self, message: BaseMessage) -> int:
"""Count tokens in a message using the model's tokenizer"""
tokens = 0
if isinstance(message.content, list):
for item in message.content:
if 'image_url' in item:
tokens += self.settings.image_tokens
elif isinstance(item, dict) and 'text' in item:
tokens += self._count_text_tokens(item['text'])
else:
msg = message.content
if hasattr(message, 'tool_calls'):
msg += str(message.tool_calls) # type: ignore
tokens += self._count_text_tokens(msg)
return tokens
def _count_text_tokens(self, text: str) -> int:
"""Count tokens in a text string"""
tokens = len(text) // self.settings.estimated_characters_per_token # Rough estimate if no tokenizer available
return tokens
def cut_messages(self):
"""Get current message list, potentially trimmed to max tokens"""
diff = self.state.history.current_tokens - self.settings.max_input_tokens
if diff <= 0:
return None
msg = self.state.history.messages[-1]
# if list with image remove image
if isinstance(msg.message.content, list):
text = ''
for item in msg.message.content:
if 'image_url' in item:
msg.message.content.remove(item)
diff -= self.settings.image_tokens
msg.metadata.tokens -= self.settings.image_tokens
self.state.history.current_tokens -= self.settings.image_tokens
logger.debug(
f'Removed image with {self.settings.image_tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens}'
)
elif 'text' in item and isinstance(item, dict):
text += item['text']
msg.message.content = text
self.state.history.messages[-1] = msg
if diff <= 0:
return None
# if still over, remove text from state message proportionally to the number of tokens needed with buffer
# Calculate the proportion of content to remove
proportion_to_remove = diff / msg.metadata.tokens
if proportion_to_remove > 0.99:
raise ValueError(
f'Max token limit reached - history is too long - reduce the system prompt or task. '
f'proportion_to_remove: {proportion_to_remove}'
)
logger.debug(
f'Removing {proportion_to_remove * 100:.2f}% of the last message {proportion_to_remove * msg.metadata.tokens:.2f} / {msg.metadata.tokens:.2f} tokens)'
)
content = msg.message.content
characters_to_remove = int(len(content) * proportion_to_remove)
content = content[:-characters_to_remove]
# remove tokens and old long message
self.state.history.remove_last_state_message()
# new message with updated content
msg = HumanMessage(content=content)
self._add_message_with_tokens(msg)
last_msg = self.state.history.messages[-1]
logger.debug(
f'Added message with {last_msg.metadata.tokens} tokens - total tokens now: {self.state.history.current_tokens}/{self.settings.max_input_tokens} - total messages: {len(self.state.history.messages)}'
)
def _remove_last_state_message(self) -> None:
"""Remove last state message from history"""
self.state.history.remove_last_state_message()
def add_tool_message(self, content: str, message_type: str | None = None) -> None:
"""Add tool message to history"""
msg = ToolMessage(content=content, tool_call_id=str(self.state.tool_id))
self.state.tool_id += 1
self._add_message_with_tokens(msg, message_type=message_type)

View file

@ -0,0 +1,237 @@
import pytest
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from browser_use.agent.message_manager.service import MessageManager, MessageManagerSettings
from browser_use.agent.views import ActionResult
from browser_use.browser.views import BrowserState, TabInfo
from browser_use.dom.views import DOMElementNode, DOMTextNode
@pytest.fixture(
params=[
ChatOpenAI(model='gpt-4o-mini'),
AzureChatOpenAI(model='gpt-4o', api_version='2024-02-15-preview'),
ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=100, temperature=0.0, stop=None),
],
ids=['gpt-4o-mini', 'gpt-4o', 'claude-3-5-sonnet'],
)
def message_manager(request: pytest.FixtureRequest):
task = 'Test task'
action_descriptions = 'Test actions'
return MessageManager(
task=task,
system_message=SystemMessage(content=action_descriptions),
settings=MessageManagerSettings(
max_input_tokens=1000,
estimated_characters_per_token=3,
image_tokens=800,
),
)
def test_initial_messages(message_manager: MessageManager):
"""Test that message manager initializes with system and task messages"""
messages = message_manager.get_messages()
assert len(messages) == 2
assert isinstance(messages[0], SystemMessage)
assert isinstance(messages[1], HumanMessage)
assert 'Test task' in messages[1].content
def test_add_state_message(message_manager: MessageManager):
"""Test adding browser state message"""
state = BrowserState(
url='https://test.com',
title='Test Page',
element_tree=DOMElementNode(
tag_name='div',
attributes={},
children=[],
is_visible=True,
parent=None,
xpath='//div',
),
selector_map={},
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
)
message_manager.add_state_message(state)
messages = message_manager.get_messages()
assert len(messages) == 3
assert isinstance(messages[2], HumanMessage)
assert 'https://test.com' in messages[2].content
def test_add_state_with_memory_result(message_manager: MessageManager):
"""Test adding state with result that should be included in memory"""
state = BrowserState(
url='https://test.com',
title='Test Page',
element_tree=DOMElementNode(
tag_name='div',
attributes={},
children=[],
is_visible=True,
parent=None,
xpath='//div',
),
selector_map={},
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
)
result = ActionResult(extracted_content='Important content', include_in_memory=True)
message_manager.add_state_message(state, [result])
messages = message_manager.get_messages()
# Should have system, task, extracted content, and state messages
assert len(messages) == 4
assert 'Important content' in messages[2].content
assert isinstance(messages[2], HumanMessage)
assert isinstance(messages[3], HumanMessage)
assert 'Important content' not in messages[3].content
def test_add_state_with_non_memory_result(message_manager: MessageManager):
"""Test adding state with result that should not be included in memory"""
state = BrowserState(
url='https://test.com',
title='Test Page',
element_tree=DOMElementNode(
tag_name='div',
attributes={},
children=[],
is_visible=True,
parent=None,
xpath='//div',
),
selector_map={},
tabs=[TabInfo(page_id=1, url='https://test.com', title='Test Page')],
)
result = ActionResult(extracted_content='Temporary content', include_in_memory=False)
message_manager.add_state_message(state, [result])
messages = message_manager.get_messages()
# Should have system, task, and combined state+result message
assert len(messages) == 3
assert 'Temporary content' in messages[2].content
assert isinstance(messages[2], HumanMessage)
@pytest.mark.skip('not sure how to fix this')
@pytest.mark.parametrize('max_tokens', [100000, 10000, 5000])
def test_token_overflow_handling_with_real_flow(message_manager: MessageManager, max_tokens):
"""Test handling of token overflow in a realistic message flow"""
# Set more realistic token limit
message_manager.settings.max_input_tokens = max_tokens
# Create a long sequence of interactions
for i in range(200): # Simulate 40 steps of interaction
# Create state with varying content length
state = BrowserState(
url=f'https://test{i}.com',
title=f'Test Page {i}',
element_tree=DOMElementNode(
tag_name='div',
attributes={},
children=[
DOMTextNode(
text=f'Content {j} ' * (10 + i), # Increasing content length
is_visible=True,
parent=None,
)
for j in range(5) # Multiple DOM items
],
is_visible=True,
parent=None,
xpath='//div',
),
selector_map={j: f'//div[{j}]' for j in range(5)},
tabs=[TabInfo(page_id=1, url=f'https://test{i}.com', title=f'Test Page {i}')],
)
# Alternate between different types of results
result = None
if i % 2 == 0: # Every other iteration
result = ActionResult(
extracted_content=f'Important content from step {i}' * 5,
include_in_memory=i % 4 == 0, # Include in memory every 4th message
)
# Add state message
if result:
message_manager.add_state_message(state, [result])
else:
message_manager.add_state_message(state)
try:
messages = message_manager.get_messages()
except ValueError as e:
if 'Max token limit reached - history is too long' in str(e):
return # If error occurs, end the test
else:
raise e
assert message_manager.state.history.current_tokens <= message_manager.settings.max_input_tokens + 100
last_msg = messages[-1]
assert isinstance(last_msg, HumanMessage)
if i % 4 == 0:
assert isinstance(message_manager.state.history.messages[-2].message, HumanMessage)
if i % 2 == 0 and not i % 4 == 0:
if isinstance(last_msg.content, list):
assert 'Current url: https://test' in last_msg.content[0]['text']
else:
assert 'Current url: https://test' in last_msg.content
# Add model output every time
from browser_use.agent.views import AgentBrain, AgentOutput
from browser_use.controller.registry.views import ActionModel
output = AgentOutput(
current_state=AgentBrain(
evaluation_previous_goal=f'Success in step {i}',
memory=f'Memory from step {i}',
next_goal=f'Goal for step {i + 1}',
),
action=[ActionModel()],
)
message_manager._remove_last_state_message()
message_manager.add_model_output(output)
# Get messages and verify after each addition
messages = [m.message for m in message_manager.state.history.messages]
# Verify token limit is respected
# Verify essential messages are preserved
assert isinstance(messages[0], SystemMessage) # System prompt always first
assert isinstance(messages[1], HumanMessage) # Task always second
assert 'Test task' in messages[1].content
# Verify structure of latest messages
assert isinstance(messages[-1], AIMessage) # Last message should be model output
assert f'step {i}' in messages[-1].content # Should contain current step info
# Log token usage for debugging
token_usage = message_manager.state.history.current_tokens
token_limit = message_manager.settings.max_input_tokens
# print(f'Step {i}: Using {token_usage}/{token_limit} tokens')
# go through all messages and verify that the token count and total tokens is correct
total_tokens = 0
real_tokens = []
stored_tokens = []
for msg in message_manager.state.history.messages:
total_tokens += msg.metadata.tokens
stored_tokens.append(msg.metadata.tokens)
real_tokens.append(message_manager._count_tokens(msg.message))
assert total_tokens == sum(real_tokens)
assert stored_tokens == real_tokens
assert message_manager.state.history.current_tokens == total_tokens
# pytest -s browser_use/agent/message_manager/tests.py

View file

@ -0,0 +1,147 @@
from __future__ import annotations
import json
import logging
import os
import re
from typing import Any
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
logger = logging.getLogger(__name__)
MODELS_WITHOUT_TOOL_SUPPORT_PATTERNS = [
'deepseek-reasoner',
'deepseek-r1',
'.*gemma.*-it',
]
def is_model_without_tool_support(model_name: str) -> bool:
return any(re.match(pattern, model_name) for pattern in MODELS_WITHOUT_TOOL_SUPPORT_PATTERNS)
def extract_json_from_model_output(content: str) -> dict:
"""Extract JSON from model output, handling both plain JSON and code-block-wrapped JSON."""
try:
# If content is wrapped in code blocks, extract just the JSON part
if '```' in content:
# Find the JSON content between code blocks
content = content.split('```')[1]
# Remove language identifier if present (e.g., 'json\n')
if '\n' in content:
content = content.split('\n', 1)[1]
# Parse the cleaned content
result_dict = json.loads(content)
# some models occasionally respond with a list containing one dict: https://github.com/browser-use/browser-use/issues/1458
if isinstance(result_dict, list) and len(result_dict) == 1 and isinstance(result_dict[0], dict):
result_dict = result_dict[0]
assert isinstance(result_dict, dict), f'Expected JSON dictionary in response, got JSON {type(result_dict)} instead'
return result_dict
except json.JSONDecodeError as e:
logger.warning(f'Failed to parse model output: {content} {str(e)}')
raise ValueError('Could not parse response.')
def convert_input_messages(input_messages: list[BaseMessage], model_name: str | None) -> list[BaseMessage]:
"""Convert input messages to a format that is compatible with the planner model"""
if model_name is None:
return input_messages
if is_model_without_tool_support(model_name):
converted_input_messages = _convert_messages_for_non_function_calling_models(input_messages)
merged_input_messages = _merge_successive_messages(converted_input_messages, HumanMessage)
merged_input_messages = _merge_successive_messages(merged_input_messages, AIMessage)
return merged_input_messages
return input_messages
def _convert_messages_for_non_function_calling_models(input_messages: list[BaseMessage]) -> list[BaseMessage]:
"""Convert messages for non-function-calling models"""
output_messages = []
for message in input_messages:
if isinstance(message, HumanMessage):
output_messages.append(message)
elif isinstance(message, SystemMessage):
output_messages.append(message)
elif isinstance(message, ToolMessage):
output_messages.append(HumanMessage(content=message.content))
elif isinstance(message, AIMessage):
# check if tool_calls is a valid JSON object
if message.tool_calls:
tool_calls = json.dumps(message.tool_calls)
output_messages.append(AIMessage(content=tool_calls))
else:
output_messages.append(message)
else:
raise ValueError(f'Unknown message type: {type(message)}')
return output_messages
def _merge_successive_messages(messages: list[BaseMessage], class_to_merge: type[BaseMessage]) -> list[BaseMessage]:
"""Some models like deepseek-reasoner dont allow multiple human messages in a row. This function merges them into one."""
merged_messages = []
streak = 0
for message in messages:
if isinstance(message, class_to_merge):
streak += 1
if streak > 1:
if isinstance(message.content, list):
merged_messages[-1].content += message.content[0]['text'] # type:ignore
else:
merged_messages[-1].content += message.content
else:
merged_messages.append(message)
else:
merged_messages.append(message)
streak = 0
return merged_messages
def save_conversation(input_messages: list[BaseMessage], response: Any, target: str, encoding: str | None = None) -> None:
"""Save conversation history to file."""
# create folders if not exists
if dirname := os.path.dirname(target):
os.makedirs(dirname, exist_ok=True)
with open(
target,
'w',
encoding=encoding,
) as f:
_write_messages_to_file(f, input_messages)
_write_response_to_file(f, response)
def _write_messages_to_file(f: Any, messages: list[BaseMessage]) -> None:
"""Write messages to conversation file"""
for message in messages:
f.write(f' {message.__class__.__name__} \n')
if isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and item.get('type') == 'text':
f.write(item['text'].strip() + '\n')
elif isinstance(message.content, str):
try:
content = json.loads(message.content)
f.write(json.dumps(content, indent=2) + '\n')
except json.JSONDecodeError:
f.write(message.content.strip() + '\n')
f.write('\n')
def _write_response_to_file(f: Any, response: Any) -> None:
"""Write model response to conversation file"""
f.write(' RESPONSE\n')
f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))

View file

@ -0,0 +1,135 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from warnings import filterwarnings
from langchain_core._api import LangChainBetaWarning
from langchain_core.load import dumpd, load
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
filterwarnings('ignore', category=LangChainBetaWarning)
if TYPE_CHECKING:
from browser_use.agent.views import AgentOutput
class MessageMetadata(BaseModel):
"""Metadata for a message"""
tokens: int = 0
message_type: str | None = None
class ManagedMessage(BaseModel):
"""A message with its metadata"""
message: BaseMessage
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
model_config = ConfigDict(arbitrary_types_allowed=True)
# https://github.com/pydantic/pydantic/discussions/7558
@model_serializer(mode='wrap')
def to_json(self, original_dump):
"""
Returns the JSON representation of the model.
It uses langchain's `dumps` function to serialize the `message`
property before encoding the overall dict with json.dumps.
"""
data = original_dump(self)
# NOTE: We override the message field to use langchain JSON serialization.
data['message'] = dumpd(self.message)
return data
@model_validator(mode='before')
@classmethod
def validate(
cls,
value: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: Any | None = None,
) -> Any:
"""
Custom validator that uses langchain's `loads` function
to parse the message if it is provided as a JSON string.
"""
if isinstance(value, dict) and 'message' in value:
# NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object.
filterwarnings('ignore', category=LangChainBetaWarning)
value['message'] = load(value['message'])
return value
class MessageHistory(BaseModel):
"""History of messages with metadata"""
messages: list[ManagedMessage] = Field(default_factory=list)
current_tokens: int = 0
model_config = ConfigDict(arbitrary_types_allowed=True)
def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None:
"""Add message with metadata to history"""
if position is None:
self.messages.append(ManagedMessage(message=message, metadata=metadata))
else:
self.messages.insert(position, ManagedMessage(message=message, metadata=metadata))
self.current_tokens += metadata.tokens
def add_model_output(self, output: AgentOutput) -> None:
"""Add model output as AI message"""
tool_calls = [
{
'name': 'AgentOutput',
'args': output.model_dump(mode='json', exclude_unset=True),
'id': '1',
'type': 'tool_call',
}
]
msg = AIMessage(
content='',
tool_calls=tool_calls,
)
self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls
# Empty tool response
tool_message = ToolMessage(content='', tool_call_id='1')
self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response
def get_messages(self) -> list[BaseMessage]:
"""Get all messages"""
return [m.message for m in self.messages]
def get_total_tokens(self) -> int:
"""Get total tokens in history"""
return self.current_tokens
def remove_oldest_message(self) -> None:
"""Remove oldest non-system message"""
for i, msg in enumerate(self.messages):
if not isinstance(msg.message, SystemMessage):
self.current_tokens -= msg.metadata.tokens
self.messages.pop(i)
break
def remove_last_state_message(self) -> None:
"""Remove last state message from history"""
if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage):
self.current_tokens -= self.messages[-1].metadata.tokens
self.messages.pop()
class MessageManagerState(BaseModel):
"""Holds the state for MessageManager"""
history: MessageHistory = Field(default_factory=MessageHistory)
tool_id: int = 1
model_config = ConfigDict(arbitrary_types_allowed=True)

View file

@ -0,0 +1,722 @@
import json
import logging
from pathlib import Path
from typing import Any
from browser_use.browser.browser import BrowserConfig
from browser_use.browser.context import BrowserContextConfig
logger = logging.getLogger(__name__)
class PlaywrightScriptGenerator:
"""Generates a Playwright script from AgentHistoryList."""
def __init__(
self,
history_list: list[dict[str, Any]],
sensitive_data_keys: list[str] | None = None,
browser_config: BrowserConfig | None = None,
context_config: BrowserContextConfig | None = None,
):
"""
Initializes the script generator.
Args:
history_list: A list of dictionaries, where each dictionary represents an AgentHistory item.
Expected to be raw dictionaries from `AgentHistoryList.model_dump()`.
sensitive_data_keys: A list of keys used as placeholders for sensitive data.
browser_config: Configuration from the original Browser instance.
context_config: Configuration from the original BrowserContext instance.
"""
self.history = history_list
self.sensitive_data_keys = sensitive_data_keys or []
self.browser_config = browser_config
self.context_config = context_config
self._imports_helpers_added = False
self._page_counter = 0 # Track pages for tab management
# Dictionary mapping action types to handler methods
self._action_handlers = {
'go_to_url': self._map_go_to_url,
'wait': self._map_wait,
'input_text': self._map_input_text,
'click_element': self._map_click_element,
'click_element_by_index': self._map_click_element, # Map legacy action
'scroll_down': self._map_scroll_down,
'scroll_up': self._map_scroll_up,
'send_keys': self._map_send_keys,
'go_back': self._map_go_back,
'open_tab': self._map_open_tab,
'close_tab': self._map_close_tab,
'switch_tab': self._map_switch_tab,
'search_google': self._map_search_google,
'drag_drop': self._map_drag_drop,
'extract_content': self._map_extract_content,
'click_download_button': self._map_click_download_button,
'done': self._map_done,
}
def _generate_browser_launch_args(self) -> str:
"""Generates the arguments string for browser launch based on BrowserConfig."""
if not self.browser_config:
# Default launch if no config provided
return 'headless=False'
args_dict = {
'headless': self.browser_config.headless,
# Add other relevant launch options here based on self.browser_config
# Example: 'proxy': self.browser_config.proxy.model_dump() if self.browser_config.proxy else None
# Example: 'args': self.browser_config.extra_browser_args # Be careful inheriting args
}
if self.browser_config.proxy:
args_dict['proxy'] = self.browser_config.proxy.model_dump()
# Filter out None values
args_dict = {k: v for k, v in args_dict.items() if v is not None}
# Format as keyword arguments string
args_str = ', '.join(f'{key}={repr(value)}' for key, value in args_dict.items())
return args_str
def _generate_context_options(self) -> str:
"""Generates the options string for context creation based on BrowserContextConfig."""
if not self.context_config:
return '' # Default context
options_dict = {}
# Map relevant BrowserContextConfig fields to Playwright context options
if self.context_config.user_agent:
options_dict['user_agent'] = self.context_config.user_agent
if self.context_config.locale:
options_dict['locale'] = self.context_config.locale
if self.context_config.permissions:
options_dict['permissions'] = self.context_config.permissions
if self.context_config.geolocation:
options_dict['geolocation'] = self.context_config.geolocation
if self.context_config.timezone_id:
options_dict['timezone_id'] = self.context_config.timezone_id
if self.context_config.http_credentials:
options_dict['http_credentials'] = self.context_config.http_credentials
if self.context_config.is_mobile is not None:
options_dict['is_mobile'] = self.context_config.is_mobile
if self.context_config.has_touch is not None:
options_dict['has_touch'] = self.context_config.has_touch
if self.context_config.save_recording_path:
options_dict['record_video_dir'] = self.context_config.save_recording_path
if self.context_config.save_har_path:
options_dict['record_har_path'] = self.context_config.save_har_path
# Handle viewport/window size
if self.context_config.no_viewport:
options_dict['no_viewport'] = True
elif hasattr(self.context_config, 'window_width') and hasattr(self.context_config, 'window_height'):
options_dict['viewport'] = {
'width': self.context_config.window_width,
'height': self.context_config.window_height,
}
# Note: cookies_file and save_downloads_path are handled separately
# Filter out None values
options_dict = {k: v for k, v in options_dict.items() if v is not None}
# Format as keyword arguments string
options_str = ', '.join(f'{key}={repr(value)}' for key, value in options_dict.items())
return options_str
def _get_imports_and_helpers(self) -> list[str]:
"""Generates necessary import statements (excluding helper functions)."""
# Return only the standard imports needed by the main script body
return [
'import asyncio',
'import json',
'import os',
'import sys',
'from pathlib import Path', # Added Path import
'import urllib.parse', # Needed for search_google
'from playwright.async_api import async_playwright, Page, BrowserContext', # Added BrowserContext
'from dotenv import load_dotenv',
'',
'# Load environment variables',
'load_dotenv(override=True)',
'',
# Helper function definitions are no longer here
]
def _get_sensitive_data_definitions(self) -> list[str]:
"""Generates the SENSITIVE_DATA dictionary definition."""
if not self.sensitive_data_keys:
return ['SENSITIVE_DATA = {}', '']
lines = ['# Sensitive data placeholders mapped to environment variables']
lines.append('SENSITIVE_DATA = {')
for key in self.sensitive_data_keys:
env_var_name = key.upper()
default_value_placeholder = f'YOUR_{env_var_name}'
lines.append(f' "{key}": os.getenv("{env_var_name}", {json.dumps(default_value_placeholder)}),')
lines.append('}')
lines.append('')
return lines
def _get_selector_for_action(self, history_item: dict, action_index_in_step: int) -> str | None:
"""
Gets the selector (preferring XPath) for a given action index within a history step.
Formats the XPath correctly for Playwright.
"""
state = history_item.get('state')
if not isinstance(state, dict):
return None
interacted_elements = state.get('interacted_element')
if not isinstance(interacted_elements, list):
return None
if action_index_in_step >= len(interacted_elements):
return None
element_data = interacted_elements[action_index_in_step]
if not isinstance(element_data, dict):
return None
# Prioritize XPath
xpath = element_data.get('xpath')
if isinstance(xpath, str) and xpath.strip():
if not xpath.startswith('xpath=') and not xpath.startswith('/') and not xpath.startswith('//'):
xpath_selector = f'xpath=//{xpath}' # Make relative if not already
elif not xpath.startswith('xpath='):
xpath_selector = f'xpath={xpath}' # Add prefix if missing
else:
xpath_selector = xpath
return xpath_selector
# Fallback to CSS selector if XPath is missing
css_selector = element_data.get('css_selector')
if isinstance(css_selector, str) and css_selector.strip():
return css_selector # Use CSS selector as is
logger.warning(
f'Could not find a usable XPath or CSS selector for action index {action_index_in_step} (element index {element_data.get("highlight_index", "N/A")}).'
)
return None
def _get_goto_timeout(self) -> int:
"""Gets the page navigation timeout in milliseconds."""
default_timeout = 90000 # Default 90 seconds
if self.context_config and self.context_config.maximum_wait_page_load_time:
# Convert seconds to milliseconds
return int(self.context_config.maximum_wait_page_load_time * 1000)
return default_timeout
# --- Action Mapping Methods ---
def _map_go_to_url(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
url = params.get('url')
goto_timeout = self._get_goto_timeout()
script_lines = []
if url and isinstance(url, str):
escaped_url = json.dumps(url)
script_lines.append(f' print(f"Navigating to: {url} ({step_info_str})")')
script_lines.append(f' await page.goto({escaped_url}, timeout={goto_timeout})')
script_lines.append(f" await page.wait_for_load_state('load', timeout={goto_timeout})")
script_lines.append(' await page.wait_for_timeout(1000)') # Short pause
else:
script_lines.append(f' # Skipping go_to_url ({step_info_str}): missing or invalid url')
return script_lines
def _map_wait(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
seconds = params.get('seconds', 3)
try:
wait_seconds = int(seconds)
except (ValueError, TypeError):
wait_seconds = 3
return [
f' print(f"Waiting for {wait_seconds} seconds... ({step_info_str})")',
f' await asyncio.sleep({wait_seconds})',
]
def _map_input_text(
self, params: dict, history_item: dict, action_index_in_step: int, step_info_str: str, **kwargs
) -> list[str]:
index = params.get('index')
text = params.get('text', '')
selector = self._get_selector_for_action(history_item, action_index_in_step)
script_lines = []
if selector and index is not None:
clean_text_expression = f'replace_sensitive_data({json.dumps(str(text))}, SENSITIVE_DATA)'
escaped_selector = json.dumps(selector)
escaped_step_info = json.dumps(step_info_str)
script_lines.append(
f' await _try_locate_and_act(page, {escaped_selector}, "fill", text={clean_text_expression}, step_info={escaped_step_info})'
)
else:
script_lines.append(
f' # Skipping input_text ({step_info_str}): missing index ({index}) or selector ({selector})'
)
return script_lines
def _map_click_element(
self, params: dict, history_item: dict, action_index_in_step: int, step_info_str: str, action_type: str, **kwargs
) -> list[str]:
if action_type == 'click_element_by_index':
logger.warning(f"Mapping legacy 'click_element_by_index' to 'click_element' ({step_info_str})")
index = params.get('index')
selector = self._get_selector_for_action(history_item, action_index_in_step)
script_lines = []
if selector and index is not None:
escaped_selector = json.dumps(selector)
escaped_step_info = json.dumps(step_info_str)
script_lines.append(
f' await _try_locate_and_act(page, {escaped_selector}, "click", step_info={escaped_step_info})'
)
else:
script_lines.append(
f' # Skipping {action_type} ({step_info_str}): missing index ({index}) or selector ({selector})'
)
return script_lines
def _map_scroll_down(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
amount = params.get('amount')
script_lines = []
if amount and isinstance(amount, int):
script_lines.append(f' print(f"Scrolling down by {amount} pixels ({step_info_str})")')
script_lines.append(f" await page.evaluate('window.scrollBy(0, {amount})')")
else:
script_lines.append(f' print(f"Scrolling down by one page height ({step_info_str})")')
script_lines.append(" await page.evaluate('window.scrollBy(0, window.innerHeight)')")
script_lines.append(' await page.wait_for_timeout(500)')
return script_lines
def _map_scroll_up(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
amount = params.get('amount')
script_lines = []
if amount and isinstance(amount, int):
script_lines.append(f' print(f"Scrolling up by {amount} pixels ({step_info_str})")')
script_lines.append(f" await page.evaluate('window.scrollBy(0, -{amount})')")
else:
script_lines.append(f' print(f"Scrolling up by one page height ({step_info_str})")')
script_lines.append(" await page.evaluate('window.scrollBy(0, -window.innerHeight)')")
script_lines.append(' await page.wait_for_timeout(500)')
return script_lines
def _map_send_keys(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
keys = params.get('keys')
script_lines = []
if keys and isinstance(keys, str):
escaped_keys = json.dumps(keys)
script_lines.append(f' print(f"Sending keys: {keys} ({step_info_str})")')
script_lines.append(f' await page.keyboard.press({escaped_keys})')
script_lines.append(' await page.wait_for_timeout(500)')
else:
script_lines.append(f' # Skipping send_keys ({step_info_str}): missing or invalid keys')
return script_lines
def _map_go_back(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
goto_timeout = self._get_goto_timeout()
return [
' await asyncio.sleep(60) # Wait 1 minute (important) before going back',
f' print(f"Navigating back using browser history ({step_info_str})")',
f' await page.go_back(timeout={goto_timeout})',
f" await page.wait_for_load_state('load', timeout={goto_timeout})",
' await page.wait_for_timeout(1000)',
]
def _map_open_tab(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
url = params.get('url')
goto_timeout = self._get_goto_timeout()
script_lines = []
if url and isinstance(url, str):
escaped_url = json.dumps(url)
script_lines.append(f' print(f"Opening new tab and navigating to: {url} ({step_info_str})")')
script_lines.append(' page = await context.new_page()')
script_lines.append(f' await page.goto({escaped_url}, timeout={goto_timeout})')
script_lines.append(f" await page.wait_for_load_state('load', timeout={goto_timeout})")
script_lines.append(' await page.wait_for_timeout(1000)')
self._page_counter += 1 # Increment page counter
else:
script_lines.append(f' # Skipping open_tab ({step_info_str}): missing or invalid url')
return script_lines
def _map_close_tab(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
page_id = params.get('page_id')
script_lines = []
if page_id is not None:
script_lines.extend(
[
f' print(f"Attempting to close tab with page_id {page_id} ({step_info_str})")',
f' if {page_id} < len(context.pages):',
f' target_page = context.pages[{page_id}]',
' await target_page.close()',
' await page.wait_for_timeout(500)',
' if context.pages: page = context.pages[-1]', # Switch to last page
' else:',
" print(' Warning: No pages left after closing tab. Cannot switch.', file=sys.stderr)",
' # Optionally, create a new page here if needed: page = await context.new_page()',
' if page: await page.bring_to_front()', # Bring to front if page exists
' else:',
f' print(f" Warning: Tab with page_id {page_id} not found to close ({step_info_str})", file=sys.stderr)',
]
)
else:
script_lines.append(f' # Skipping close_tab ({step_info_str}): missing page_id')
return script_lines
def _map_switch_tab(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
page_id = params.get('page_id')
script_lines = []
if page_id is not None:
script_lines.extend(
[
f' print(f"Switching to tab with page_id {page_id} ({step_info_str})")',
f' if {page_id} < len(context.pages):',
f' page = context.pages[{page_id}]',
' await page.bring_to_front()',
" await page.wait_for_load_state('load', timeout=15000)",
' await page.wait_for_timeout(500)',
' else:',
f' print(f" Warning: Tab with page_id {page_id} not found to switch ({step_info_str})", file=sys.stderr)',
]
)
else:
script_lines.append(f' # Skipping switch_tab ({step_info_str}): missing page_id')
return script_lines
def _map_search_google(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
query = params.get('query')
goto_timeout = self._get_goto_timeout()
script_lines = []
if query and isinstance(query, str):
clean_query = f'replace_sensitive_data({json.dumps(query)}, SENSITIVE_DATA)'
search_url_expression = f'f"https://www.google.com/search?q={{ urllib.parse.quote_plus({clean_query}) }}&udm=14"'
script_lines.extend(
[
f' search_url = {search_url_expression}',
f' print(f"Searching Google for query related to: {{ {clean_query} }} ({step_info_str})")',
f' await page.goto(search_url, timeout={goto_timeout})',
f" await page.wait_for_load_state('load', timeout={goto_timeout})",
' await page.wait_for_timeout(1000)',
]
)
else:
script_lines.append(f' # Skipping search_google ({step_info_str}): missing or invalid query')
return script_lines
def _map_drag_drop(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
source_sel = params.get('element_source')
target_sel = params.get('element_target')
source_coords = (params.get('coord_source_x'), params.get('coord_source_y'))
target_coords = (params.get('coord_target_x'), params.get('coord_target_y'))
script_lines = [f' print(f"Attempting drag and drop ({step_info_str})")']
if source_sel and target_sel:
escaped_source = json.dumps(source_sel)
escaped_target = json.dumps(target_sel)
script_lines.append(f' await page.drag_and_drop({escaped_source}, {escaped_target})')
script_lines.append(f" print(f' Dragged element {escaped_source} to {escaped_target}')")
elif all(c is not None for c in source_coords) and all(c is not None for c in target_coords):
sx, sy = source_coords
tx, ty = target_coords
script_lines.extend(
[
f' await page.mouse.move({sx}, {sy})',
' await page.mouse.down()',
f' await page.mouse.move({tx}, {ty})',
' await page.mouse.up()',
f" print(f' Dragged from ({sx},{sy}) to ({tx},{ty})')",
]
)
else:
script_lines.append(
f' # Skipping drag_drop ({step_info_str}): requires either element selectors or full coordinates'
)
script_lines.append(' await page.wait_for_timeout(500)')
return script_lines
def _map_extract_content(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
goal = params.get('goal', 'content')
logger.warning(f"Action 'extract_content' ({step_info_str}) cannot be directly translated to Playwright script.")
return [f' # Action: extract_content (Goal: {goal}) - Skipped in Playwright script ({step_info_str})']
def _map_click_download_button(
self, params: dict, history_item: dict, action_index_in_step: int, step_info_str: str, **kwargs
) -> list[str]:
index = params.get('index')
selector = self._get_selector_for_action(history_item, action_index_in_step)
download_dir_in_script = "'./files'" # Default
if self.context_config and self.context_config.save_downloads_path:
download_dir_in_script = repr(self.context_config.save_downloads_path)
script_lines = []
if selector and index is not None:
script_lines.append(
f' print(f"Attempting to download file by clicking element ({selector}) ({step_info_str})")'
)
script_lines.append(' try:')
script_lines.append(
' async with page.expect_download(timeout=120000) as download_info:'
) # 2 min timeout
step_info_for_download = f'{step_info_str} (triggering download)'
script_lines.append(
f' await _try_locate_and_act(page, {json.dumps(selector)}, "click", step_info={json.dumps(step_info_for_download)})'
)
script_lines.append(' download = await download_info.value')
script_lines.append(f' configured_download_dir = {download_dir_in_script}')
script_lines.append(' download_dir_path = Path(configured_download_dir).resolve()')
script_lines.append(' download_dir_path.mkdir(parents=True, exist_ok=True)')
script_lines.append(
" base, ext = os.path.splitext(download.suggested_filename or f'download_{{len(list(download_dir_path.iterdir())) + 1}}.tmp')"
)
script_lines.append(' counter = 1')
script_lines.append(" download_path_obj = download_dir_path / f'{base}{ext}'")
script_lines.append(' while download_path_obj.exists():')
script_lines.append(" download_path_obj = download_dir_path / f'{base}({{counter}}){ext}'")
script_lines.append(' counter += 1')
script_lines.append(' await download.save_as(str(download_path_obj))')
script_lines.append(" print(f' File downloaded successfully to: {str(download_path_obj)}')")
script_lines.append(' except PlaywrightActionError as pae:')
script_lines.append(' raise pae') # Re-raise to stop script
script_lines.append(' except Exception as download_err:')
script_lines.append(
f" raise PlaywrightActionError(f'Download failed for {step_info_str}: {{download_err}}') from download_err"
)
else:
script_lines.append(
f' # Skipping click_download_button ({step_info_str}): missing index ({index}) or selector ({selector})'
)
return script_lines
def _map_done(self, params: dict, step_info_str: str, **kwargs) -> list[str]:
script_lines = []
if isinstance(params, dict):
final_text = params.get('text', '')
success_status = params.get('success', False)
escaped_final_text_with_placeholders = json.dumps(str(final_text))
script_lines.append(f' print("\\n--- Task marked as Done by agent ({step_info_str}) ---")')
script_lines.append(f' print(f"Agent reported success: {success_status}")')
script_lines.append(' # Final Message from agent (may contain placeholders):')
script_lines.append(
f' final_message = replace_sensitive_data({escaped_final_text_with_placeholders}, SENSITIVE_DATA)'
)
script_lines.append(' print(final_message)')
else:
script_lines.append(f' print("\\n--- Task marked as Done by agent ({step_info_str}) ---")')
script_lines.append(' print("Success: N/A (invalid params)")')
script_lines.append(' print("Final Message: N/A (invalid params)")')
return script_lines
def _map_action_to_playwright(
self,
action_dict: dict,
history_item: dict,
previous_history_item: dict | None,
action_index_in_step: int,
step_info_str: str,
) -> list[str]:
"""
Translates a single action dictionary into Playwright script lines using dictionary dispatch.
"""
if not isinstance(action_dict, dict) or not action_dict:
return [f' # Invalid action format: {action_dict} ({step_info_str})']
action_type = next(iter(action_dict.keys()), None)
params = action_dict.get(action_type)
if not action_type or params is None:
if action_dict == {}:
return [f' # Empty action dictionary found ({step_info_str})']
return [f' # Could not determine action type or params: {action_dict} ({step_info_str})']
# Get the handler function from the dictionary
handler = self._action_handlers.get(action_type)
if handler:
# Call the specific handler method
return handler(
params=params,
history_item=history_item,
action_index_in_step=action_index_in_step,
step_info_str=step_info_str,
action_type=action_type, # Pass action_type for legacy handling etc.
previous_history_item=previous_history_item,
)
else:
# Handle unsupported actions
logger.warning(f'Unsupported action type encountered: {action_type} ({step_info_str})')
return [f' # Unsupported action type: {action_type} ({step_info_str})']
def generate_script_content(self) -> str:
"""Generates the full Playwright script content as a string."""
script_lines = []
self._page_counter = 0 # Reset page counter for new script generation
if not self._imports_helpers_added:
script_lines.extend(self._get_imports_and_helpers())
self._imports_helpers_added = True
# Read helper script content
helper_script_path = Path(__file__).parent / 'playwright_script_helpers.py'
try:
with open(helper_script_path, encoding='utf-8') as f_helper:
helper_script_content = f_helper.read()
except FileNotFoundError:
logger.error(f'Helper script not found at {helper_script_path}. Cannot generate script.')
return '# Error: Helper script file missing.'
except Exception as e:
logger.error(f'Error reading helper script {helper_script_path}: {e}')
return f'# Error: Could not read helper script: {e}'
script_lines.extend(self._get_sensitive_data_definitions())
# Add the helper script content after imports and sensitive data
script_lines.append('\n# --- Helper Functions (from playwright_script_helpers.py) ---')
script_lines.append(helper_script_content)
script_lines.append('# --- End Helper Functions ---')
# Generate browser launch and context creation code
browser_launch_args = self._generate_browser_launch_args()
context_options = self._generate_context_options()
# Determine browser type (defaulting to chromium)
browser_type = 'chromium'
if self.browser_config and self.browser_config.browser_class in ['firefox', 'webkit']:
browser_type = self.browser_config.browser_class
script_lines.extend(
[
'async def run_generated_script():',
' global SENSITIVE_DATA', # Ensure sensitive data is accessible
' async with async_playwright() as p:',
' browser = None',
' context = None',
' page = None',
' exit_code = 0 # Default success exit code',
' try:',
f" print('Launching {browser_type} browser...')",
# Use generated launch args, remove slow_mo
f' browser = await p.{browser_type}.launch({browser_launch_args})',
# Use generated context options
f' context = await browser.new_context({context_options})',
" print('Browser context created.')",
]
)
# Add cookie loading logic if cookies_file is specified
if self.context_config and self.context_config.cookies_file:
cookies_file_path = repr(self.context_config.cookies_file)
script_lines.extend(
[
' # Load cookies if specified',
f' cookies_path = {cookies_file_path}',
' if cookies_path and os.path.exists(cookies_path):',
' try:',
" with open(cookies_path, 'r', encoding='utf-8') as f_cookies:",
' cookies = json.load(f_cookies)',
' # Validate sameSite attribute',
" valid_same_site = ['Strict', 'Lax', 'None']",
' for cookie in cookies:',
" if 'sameSite' in cookie and cookie['sameSite'] not in valid_same_site:",
' print(f\' Warning: Fixing invalid sameSite value "{{cookie["sameSite"]}}" to None for cookie {{cookie.get("name")}}\', file=sys.stderr)',
" cookie['sameSite'] = 'None'",
' await context.add_cookies(cookies)',
" print(f' Successfully loaded {{len(cookies)}} cookies from {{cookies_path}}')",
' except Exception as cookie_err:',
" print(f' Warning: Failed to load or add cookies from {{cookies_path}}: {{cookie_err}}', file=sys.stderr)",
' else:',
' if cookies_path:', # Only print if a path was specified but not found
" print(f' Cookie file not found at: {cookies_path}')",
'',
]
)
script_lines.extend(
[
' # Initial page handling',
' if context.pages:',
' page = context.pages[0]',
" print('Using initial page provided by context.')",
' else:',
' page = await context.new_page()',
" print('Created a new page as none existed.')",
" print('\\n--- Starting Generated Script Execution ---')",
]
)
action_counter = 0
stop_processing_steps = False
previous_item_dict = None
for step_index, item_dict in enumerate(self.history):
if stop_processing_steps:
break
if not isinstance(item_dict, dict):
logger.warning(f'Skipping step {step_index + 1}: Item is not a dictionary ({type(item_dict)})')
script_lines.append(f'\n # --- Step {step_index + 1}: Skipped (Invalid Format) ---')
previous_item_dict = item_dict
continue
script_lines.append(f'\n # --- Step {step_index + 1} ---')
model_output = item_dict.get('model_output')
if not isinstance(model_output, dict) or 'action' not in model_output:
script_lines.append(' # No valid model_output or action found for this step')
previous_item_dict = item_dict
continue
actions = model_output.get('action')
if not isinstance(actions, list):
script_lines.append(f' # Actions format is not a list: {type(actions)}')
previous_item_dict = item_dict
continue
for action_index_in_step, action_detail in enumerate(actions):
action_counter += 1
script_lines.append(f' # Action {action_counter}')
step_info_str = f'Step {step_index + 1}, Action {action_index_in_step + 1}'
action_lines = self._map_action_to_playwright(
action_dict=action_detail,
history_item=item_dict,
previous_history_item=previous_item_dict,
action_index_in_step=action_index_in_step,
step_info_str=step_info_str,
)
script_lines.extend(action_lines)
action_type = next(iter(action_detail.keys()), None) if isinstance(action_detail, dict) else None
if action_type == 'done':
stop_processing_steps = True
break
previous_item_dict = item_dict
# Updated final block to include sys.exit
script_lines.extend(
[
' except PlaywrightActionError as pae:', # Catch specific action errors
" print(f'\\n--- Playwright Action Error: {pae} ---', file=sys.stderr)",
' exit_code = 1', # Set exit code to failure
' except Exception as e:',
" print(f'\\n--- An unexpected error occurred: {e} ---', file=sys.stderr)",
' import traceback',
' traceback.print_exc()',
' exit_code = 1', # Set exit code to failure
' finally:',
" print('\\n--- Generated Script Execution Finished ---')",
" print('Closing browser/context...')",
' if context:',
' try: await context.close()',
" except Exception as ctx_close_err: print(f' Warning: could not close context: {ctx_close_err}', file=sys.stderr)",
' if browser:',
' try: await browser.close()',
" except Exception as browser_close_err: print(f' Warning: could not close browser: {browser_close_err}', file=sys.stderr)",
" print('Browser/context closed.')",
' # Exit with the determined exit code',
' if exit_code != 0:',
" print(f'Script finished with errors (exit code {exit_code}).', file=sys.stderr)",
' sys.exit(exit_code)', # Exit with non-zero code on error
'',
'# --- Script Entry Point ---',
"if __name__ == '__main__':",
" if os.name == 'nt':",
' asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())',
' asyncio.run(run_generated_script())',
]
)
return '\n'.join(script_lines)

View file

@ -0,0 +1,94 @@
from playwright.async_api import Page
# --- Helper Function for Replacing Sensitive Data ---
def replace_sensitive_data(text: str, sensitive_map: dict) -> str:
"""Replaces sensitive data placeholders in text."""
if not isinstance(text, str):
return text
for placeholder, value in sensitive_map.items():
replacement_value = str(value) if value is not None else ''
text = text.replace(f'<secret>{placeholder}</secret>', replacement_value)
return text
# --- Helper Function for Robust Action Execution ---
class PlaywrightActionError(Exception):
"""Custom exception for errors during Playwright script action execution."""
pass
async def _try_locate_and_act(page: Page, selector: str, action_type: str, text: str | None = None, step_info: str = '') -> None:
"""
Attempts an action (click/fill) with XPath fallback by trimming prefixes.
Raises PlaywrightActionError if the action fails after all fallbacks.
"""
print(f'Attempting {action_type} ({step_info}) using selector: {repr(selector)}')
original_selector = selector
MAX_FALLBACKS = 50 # Increased fallbacks
# Increased timeouts for potentially slow pages
INITIAL_TIMEOUT = 10000 # Milliseconds for the first attempt (10 seconds)
FALLBACK_TIMEOUT = 1000 # Shorter timeout for fallback attempts (1 second)
try:
locator = page.locator(selector).first
if action_type == 'click':
await locator.click(timeout=INITIAL_TIMEOUT)
elif action_type == 'fill' and text is not None:
await locator.fill(text, timeout=INITIAL_TIMEOUT)
else:
# This case should ideally not happen if called correctly
raise PlaywrightActionError(f"Invalid action_type '{action_type}' or missing text for fill. ({step_info})")
print(f" Action '{action_type}' successful with original selector.")
await page.wait_for_timeout(500) # Wait after successful action
return # Successful exit
except Exception as e:
print(f" Warning: Action '{action_type}' failed with original selector ({repr(selector)}): {e}. Starting fallback...")
# Fallback only works for XPath selectors
if not selector.startswith('xpath='):
# Raise error immediately if not XPath, as fallback won't work
raise PlaywrightActionError(
f"Action '{action_type}' failed. Fallback not possible for non-XPath selector: {repr(selector)}. ({step_info})"
)
xpath_parts = selector.split('=', 1)
if len(xpath_parts) < 2:
raise PlaywrightActionError(
f"Action '{action_type}' failed. Could not extract XPath string from selector: {repr(selector)}. ({step_info})"
)
xpath = xpath_parts[1] # Correctly get the XPath string
segments = [seg for seg in xpath.split('/') if seg]
for i in range(1, min(MAX_FALLBACKS + 1, len(segments))):
trimmed_xpath_raw = '/'.join(segments[i:])
fallback_xpath = f'xpath=//{trimmed_xpath_raw}'
print(f' Fallback attempt {i}/{MAX_FALLBACKS}: Trying selector: {repr(fallback_xpath)}')
try:
locator = page.locator(fallback_xpath).first
if action_type == 'click':
await locator.click(timeout=FALLBACK_TIMEOUT)
elif action_type == 'fill' and text is not None:
try:
await locator.clear(timeout=FALLBACK_TIMEOUT)
await page.wait_for_timeout(100)
except Exception as clear_error:
print(f' Warning: Failed to clear field during fallback ({step_info}): {clear_error}')
await locator.fill(text, timeout=FALLBACK_TIMEOUT)
print(f" Action '{action_type}' successful with fallback selector: {repr(fallback_xpath)}")
await page.wait_for_timeout(500)
return # Successful exit after fallback
except Exception as fallback_e:
print(f' Fallback attempt {i} failed: {fallback_e}')
if i == MAX_FALLBACKS:
# Raise exception after exhausting fallbacks
raise PlaywrightActionError(
f"Action '{action_type}' failed after {MAX_FALLBACKS} fallback attempts. Original selector: {repr(original_selector)}. ({step_info})"
)
# This part should not be reachable if logic is correct, but added as safeguard
raise PlaywrightActionError(f"Action '{action_type}' failed unexpectedly for {repr(original_selector)}. ({step_info})")

View file

@ -0,0 +1,187 @@
import importlib.resources
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from langchain_core.messages import HumanMessage, SystemMessage
if TYPE_CHECKING:
from browser_use.agent.views import ActionResult, AgentStepInfo
from browser_use.browser.views import BrowserState
class SystemPrompt:
def __init__(
self,
action_description: str,
max_actions_per_step: int = 10,
override_system_message: str | None = None,
extend_system_message: str | None = None,
):
self.default_action_description = action_description
self.max_actions_per_step = max_actions_per_step
prompt = ''
if override_system_message:
prompt = override_system_message
else:
self._load_prompt_template()
prompt = self.prompt_template.format(max_actions=self.max_actions_per_step)
if extend_system_message:
prompt += f'\n{extend_system_message}'
self.system_message = SystemMessage(content=prompt)
def _load_prompt_template(self) -> None:
"""Load the prompt template from the markdown file."""
try:
# This works both in development and when installed as a package
with importlib.resources.files('browser_use.agent').joinpath('system_prompt.md').open('r') as f:
self.prompt_template = f.read()
except Exception as e:
raise RuntimeError(f'Failed to load system prompt template: {e}')
def get_system_message(self) -> SystemMessage:
"""
Get the system prompt for the agent.
Returns:
SystemMessage: Formatted system prompt
"""
return self.system_message
# Functions:
# {self.default_action_description}
# Example:
# {self.example_response()}
# Your AVAILABLE ACTIONS:
# {self.default_action_description}
class AgentMessagePrompt:
def __init__(
self,
state: 'BrowserState',
result: list['ActionResult'] | None = None,
include_attributes: list[str] | None = None,
step_info: Optional['AgentStepInfo'] = None,
):
self.state = state
self.result = result
self.include_attributes = include_attributes or []
self.step_info = step_info
def get_user_message(self, use_vision: bool = True) -> HumanMessage:
elements_text = self.state.element_tree.clickable_elements_to_string(include_attributes=self.include_attributes)
has_content_above = (self.state.pixels_above or 0) > 0
has_content_below = (self.state.pixels_below or 0) > 0
if elements_text != '':
if has_content_above:
elements_text = (
f'... {self.state.pixels_above} pixels above - scroll or extract content to see more ...\n{elements_text}'
)
else:
elements_text = f'[Start of page]\n{elements_text}'
if has_content_below:
elements_text = (
f'{elements_text}\n... {self.state.pixels_below} pixels below - scroll or extract content to see more ...'
)
else:
elements_text = f'{elements_text}\n[End of page]'
else:
elements_text = 'empty page'
if self.step_info:
step_info_description = f'Current step: {self.step_info.step_number + 1}/{self.step_info.max_steps}'
else:
step_info_description = ''
time_str = datetime.now().strftime('%Y-%m-%d %H:%M')
step_info_description += f'Current date and time: {time_str}'
state_description = f"""
[Task history memory ends]
[Current state starts here]
The following is one-time information - if you need to remember it write it to memory:
Current url: {self.state.url}
Available tabs:
{self.state.tabs}
Interactive elements from top layer of the current page inside the viewport:
{elements_text}
{step_info_description}
"""
if self.result:
for i, result in enumerate(self.result):
if result.extracted_content:
state_description += f'\nAction result {i + 1}/{len(self.result)}: {result.extracted_content}'
if result.error:
# only use last line of error
error = result.error.split('\n')[-1]
state_description += f'\nAction error {i + 1}/{len(self.result)}: ...{error}'
if self.state.screenshot and use_vision is True:
# Format message for vision model
return HumanMessage(
content=[
{'type': 'text', 'text': state_description},
{
'type': 'image_url',
'image_url': {'url': f'data:image/png;base64,{self.state.screenshot}'}, # , 'detail': 'low'
},
]
)
return HumanMessage(content=state_description)
class PlannerPrompt(SystemPrompt):
def __init__(self, available_actions: str):
self.available_actions = available_actions
def get_system_message(
self, is_planner_reasoning: bool, extended_planner_system_prompt: str | None = None
) -> SystemMessage | HumanMessage:
"""Get the system message for the planner.
Args:
is_planner_reasoning: If True, return as HumanMessage for chain-of-thought
extended_planner_system_prompt: Optional text to append to the base prompt
Returns:
SystemMessage or HumanMessage depending on is_planner_reasoning
"""
planner_prompt_text = """
You are a planning agent that helps break down tasks into smaller steps and reason about the current state.
Your role is to:
1. Analyze the current state and history
2. Evaluate progress towards the ultimate goal
3. Identify potential challenges or roadblocks
4. Suggest the next high-level steps to take
Inside your messages, there will be AI messages from different agents with different formats.
Your output format should be always a JSON object with the following fields:
{{
"state_analysis": "Brief analysis of the current state and what has been done so far",
"progress_evaluation": "Evaluation of progress towards the ultimate goal (as percentage and description)",
"challenges": "List any potential challenges or roadblocks",
"next_steps": "List 2-3 concrete next steps to take",
"reasoning": "Explain your reasoning for the suggested next steps"
}}
Ignore the other AI messages output structures.
Keep your responses concise and focused on actionable insights.
"""
if extended_planner_system_prompt:
planner_prompt_text += f'\n{extended_planner_system_prompt}'
if is_planner_reasoning:
return HumanMessage(content=planner_prompt_text)
else:
return SystemMessage(content=planner_prompt_text)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,82 @@
You are an AI agent designed to automate browser tasks. Your goal is to accomplish the ultimate task following the rules.
# Input Format
Task
Previous steps
Current URL
Open Tabs
Interactive Elements
[index]<type>text</type>
- index: Numeric identifier for interaction
- type: HTML element type (button, input, etc.)
- text: Element description
Example:
[33]<div>User form</div>
\t*[35]*<button aria-label='Submit form'>Submit</button>
- Only elements with numeric indexes in [] are interactive
- (stacked) indentation (with \t) is important and means that the element is a (html) child of the element above (with a lower index)
- Elements with \* are new elements that were added after the previous step (if url has not changed)
# Response Rules
1. RESPONSE FORMAT: You must ALWAYS respond with valid JSON in this exact format:
{{"current_state": {{"evaluation_previous_goal": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Mention if something unexpected happened. Shortly state why/why not",
"memory": "Description of what has been done and what you need to remember. Be very specific. Count here ALWAYS how many times you have done something and how many remain. E.g. 0 out of 10 websites analyzed. Continue with abc and xyz",
"next_goal": "What needs to be done with the next immediate action"}},
"action":[{{"one_action_name": {{// action-specific parameter}}}}, // ... more actions in sequence]}}
2. ACTIONS: You can specify multiple actions in the list to be executed in sequence. But always specify only one action name per item. Use maximum {max_actions} actions per sequence.
Common action sequences:
- Form filling: [{{"input_text": {{"index": 1, "text": "username"}}}}, {{"input_text": {{"index": 2, "text": "password"}}}}, {{"click_element": {{"index": 3}}}}]
- Navigation and extraction: [{{"go_to_url": {{"url": "https://example.com"}}}}, {{"extract_content": {{"goal": "extract the names"}}}}]
- Actions are executed in the given order
- If the page changes after an action, the sequence is interrupted and you get the new state.
- Only provide the action sequence until an action which changes the page state significantly.
- Try to be efficient, e.g. fill forms at once, or chain actions where nothing changes on the page
- only use multiple actions if it makes sense.
3. ELEMENT INTERACTION:
- Only use indexes of the interactive elements
4. NAVIGATION & ERROR HANDLING:
- If no suitable elements exist, use other functions to complete the task
- If stuck, try alternative approaches - like going back to a previous page, new search, new tab etc.
- Handle popups/cookies by accepting or closing them
- Use scroll to find elements you are looking for
- If you want to research something, open a new tab instead of using the current tab
- If captcha pops up, try to solve it - else try a different approach
- If the page is not fully loaded, use wait action
5. TASK COMPLETION:
- Use the done action as the last action as soon as the ultimate task is complete
- Dont use "done" before you are done with everything the user asked you, except you reach the last step of max_steps.
- If you reach your last step, use the done action even if the task is not fully finished. Provide all the information you have gathered so far. If the ultimate task is completely finished set success to true. If not everything the user asked for is completed set success in done to false!
- If you have to do something repeatedly for example the task says for "each", or "for all", or "x times", count always inside "memory" how many times you have done it and how many remain. Don't stop until you have completed like the task asked you. Only call done after the last step.
- Don't hallucinate actions
- Make sure you include everything you found out for the ultimate task in the done text parameter. Do not just say you are done, but include the requested information of the task.
6. VISUAL CONTEXT:
- When an image is provided, use it to understand the page layout
- Bounding boxes with labels on their top right corner correspond to element indexes
7. Form filling:
- If you fill an input field and your action sequence is interrupted, most often something changed e.g. suggestions popped up under the field.
8. Long tasks:
- Keep track of the status and subresults in the memory.
- You are provided with procedural memory summaries that condense previous task history (every N steps). Use these summaries to maintain context about completed actions, current progress, and next steps. The summaries appear in chronological order and contain key information about navigation history, findings, errors encountered, and current state. Refer to these summaries to avoid repeating actions and to ensure consistent progress toward the task goal.
9. Extraction:
- If your task is to find information - call extract_content on the specific pages to get and store the information.
Your responses must be always JSON with the specified format.

View file

@ -0,0 +1,197 @@
import pytest
from browser_use.agent.views import (
ActionResult,
AgentBrain,
AgentHistory,
AgentHistoryList,
AgentOutput,
)
from browser_use.browser.views import BrowserState, BrowserStateHistory, TabInfo
from browser_use.controller.registry.service import Registry
from browser_use.controller.views import ClickElementAction, DoneAction, ExtractPageContentAction
from browser_use.dom.views import DOMElementNode
@pytest.fixture
def sample_browser_state():
return BrowserState(
url='https://example.com',
title='Example Page',
tabs=[TabInfo(url='https://example.com', title='Example Page', page_id=1)],
screenshot='screenshot1.png',
element_tree=DOMElementNode(
tag_name='root',
is_visible=True,
parent=None,
xpath='',
attributes={},
children=[],
),
selector_map={},
)
@pytest.fixture
def action_registry():
registry = Registry()
# Register the actions we need for testing
@registry.action(description='Click an element', param_model=ClickElementAction)
def click_element(params: ClickElementAction, browser=None):
pass
@registry.action(
description='Extract page content',
param_model=ExtractPageContentAction,
)
def extract_page_content(params: ExtractPageContentAction, browser=None):
pass
@registry.action(description='Mark task as done', param_model=DoneAction)
def done(params: DoneAction):
pass
# Create the dynamic ActionModel with all registered actions
return registry.create_action_model()
@pytest.fixture
def sample_history(action_registry):
# Create actions with nested params structure
click_action = action_registry(click_element={'index': 1})
extract_action = action_registry(extract_page_content={'value': 'text'})
done_action = action_registry(done={'text': 'Task completed'})
histories = [
AgentHistory(
model_output=AgentOutput(
current_state=AgentBrain(
evaluation_previous_goal='None',
memory='Started task',
next_goal='Click button',
),
action=[click_action],
),
result=[ActionResult(is_done=False)],
state=BrowserStateHistory(
url='https://example.com',
title='Page 1',
tabs=[TabInfo(url='https://example.com', title='Page 1', page_id=1)],
screenshot='screenshot1.png',
interacted_element=[{'xpath': '//button[1]'}],
),
),
AgentHistory(
model_output=AgentOutput(
current_state=AgentBrain(
evaluation_previous_goal='Clicked button',
memory='Button clicked',
next_goal='Extract content',
),
action=[extract_action],
),
result=[
ActionResult(
is_done=False,
extracted_content='Extracted text',
error='Failed to extract completely',
)
],
state=BrowserStateHistory(
url='https://example.com/page2',
title='Page 2',
tabs=[TabInfo(url='https://example.com/page2', title='Page 2', page_id=2)],
screenshot='screenshot2.png',
interacted_element=[{'xpath': '//div[1]'}],
),
),
AgentHistory(
model_output=AgentOutput(
current_state=AgentBrain(
evaluation_previous_goal='Extracted content',
memory='Content extracted',
next_goal='Finish task',
),
action=[done_action],
),
result=[ActionResult(is_done=True, extracted_content='Task completed', error=None)],
state=BrowserStateHistory(
url='https://example.com/page2',
title='Page 2',
tabs=[TabInfo(url='https://example.com/page2', title='Page 2', page_id=2)],
screenshot='screenshot3.png',
interacted_element=[{'xpath': '//div[1]'}],
),
),
]
return AgentHistoryList(history=histories)
def test_last_model_output(sample_history: AgentHistoryList):
last_output = sample_history.last_action()
print(last_output)
assert last_output == {'done': {'text': 'Task completed'}}
def test_get_errors(sample_history: AgentHistoryList):
errors = sample_history.errors()
assert len(errors) == 1
assert errors[0] == 'Failed to extract completely'
def test_final_result(sample_history: AgentHistoryList):
assert sample_history.final_result() == 'Task completed'
def test_is_done(sample_history: AgentHistoryList):
assert sample_history.is_done() is True
def test_urls(sample_history: AgentHistoryList):
urls = sample_history.urls()
assert 'https://example.com' in urls
assert 'https://example.com/page2' in urls
def test_all_screenshots(sample_history: AgentHistoryList):
screenshots = sample_history.screenshots()
assert len(screenshots) == 3
assert screenshots == ['screenshot1.png', 'screenshot2.png', 'screenshot3.png']
def test_all_model_outputs(sample_history: AgentHistoryList):
outputs = sample_history.model_actions()
print(f'DEBUG: {outputs[0]}')
assert len(outputs) == 3
# get first key value pair
assert dict([next(iter(outputs[0].items()))]) == {'click_element': {'index': 1}}
assert dict([next(iter(outputs[1].items()))]) == {'extract_page_content': {'value': 'text'}}
assert dict([next(iter(outputs[2].items()))]) == {'done': {'text': 'Task completed'}}
def test_all_model_outputs_filtered(sample_history: AgentHistoryList):
filtered = sample_history.model_actions_filtered(include=['click_element'])
assert len(filtered) == 1
assert filtered[0]['click_element']['index'] == 1
def test_empty_history():
empty_history = AgentHistoryList(history=[])
assert empty_history.last_action() is None
assert empty_history.final_result() is None
assert empty_history.is_done() is False
assert len(empty_history.urls()) == 0
# Add a test to verify action creation
def test_action_creation(action_registry):
click_action = action_registry(click_element={'index': 1})
assert click_action.model_dump(exclude_none=True) == {'click_element': {'index': 1}}
# run this with:
# pytest browser_use/agent/tests.py

View file

@ -0,0 +1,440 @@
from __future__ import annotations
import json
import traceback
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from langchain_core.language_models.chat_models import BaseChatModel
from openai import RateLimitError
from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model
from browser_use.agent.message_manager.views import MessageManagerState
from browser_use.agent.playwright_script_generator import PlaywrightScriptGenerator
from browser_use.browser.browser import BrowserConfig
from browser_use.browser.context import BrowserContextConfig
from browser_use.browser.views import BrowserStateHistory
from browser_use.controller.registry.views import ActionModel
from browser_use.dom.history_tree_processor.service import (
DOMElementNode,
DOMHistoryElement,
HistoryTreeProcessor,
)
from browser_use.dom.views import SelectorMap
ToolCallingMethod = Literal['function_calling', 'json_mode', 'raw', 'auto', 'tools']
REQUIRED_LLM_API_ENV_VARS = {
'ChatOpenAI': ['OPENAI_API_KEY'],
'AzureChatOpenAI': ['AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_KEY'],
'ChatBedrockConverse': ['ANTHROPIC_API_KEY'],
'ChatAnthropic': ['ANTHROPIC_API_KEY'],
'ChatGoogleGenerativeAI': ['GOOGLE_API_KEY'],
'ChatDeepSeek': ['DEEPSEEK_API_KEY'],
'ChatOllama': [],
'ChatGrok': ['GROK_API_KEY'],
}
class AgentSettings(BaseModel):
"""Options for the agent"""
use_vision: bool = True
use_vision_for_planner: bool = False
save_conversation_path: str | None = None
save_conversation_path_encoding: str | None = 'utf-8'
max_failures: int = 3
retry_delay: int = 10
max_input_tokens: int = 128000
validate_output: bool = False
message_context: str | None = None
generate_gif: bool | str = False
available_file_paths: list[str] | None = None
override_system_message: str | None = None
extend_system_message: str | None = None
include_attributes: list[str] = [
'title',
'type',
'name',
'role',
'tabindex',
'aria-label',
'placeholder',
'value',
'alt',
'aria-expanded',
]
max_actions_per_step: int = 10
tool_calling_method: ToolCallingMethod | None = 'auto'
page_extraction_llm: BaseChatModel | None = None
planner_llm: BaseChatModel | None = None
planner_interval: int = 1 # Run planner every N steps
is_planner_reasoning: bool = False # type: ignore
extend_planner_system_message: str | None = None
# Playwright script generation setting
save_playwright_script_path: str | None = None # Path to save the generated Playwright script
class AgentState(BaseModel):
"""Holds all state information for an Agent"""
agent_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
n_steps: int = 1
consecutive_failures: int = 0
last_result: list[ActionResult] | None = None
history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[]))
last_plan: str | None = None
paused: bool = False
stopped: bool = False
message_manager_state: MessageManagerState = Field(default_factory=MessageManagerState)
# class Config:
# arbitrary_types_allowed = True
@dataclass
class AgentStepInfo:
step_number: int
max_steps: int
def is_last_step(self) -> bool:
"""Check if this is the last step"""
return self.step_number >= self.max_steps - 1
class ActionResult(BaseModel):
"""Result of executing an action"""
is_done: bool | None = False
success: bool | None = None
extracted_content: str | None = None
error: str | None = None
include_in_memory: bool = False # whether to include in past messages as context or not
class StepMetadata(BaseModel):
"""Metadata for a single step including timing and token information"""
step_start_time: float
step_end_time: float
input_tokens: int # Approximate tokens from message manager for this step
step_number: int
@property
def duration_seconds(self) -> float:
"""Calculate step duration in seconds"""
return self.step_end_time - self.step_start_time
class AgentBrain(BaseModel):
"""Current state of the agent"""
evaluation_previous_goal: str
memory: str
next_goal: str
class AgentOutput(BaseModel):
"""Output model for agent
@dev note: this model is extended with custom actions in AgentService. You can also use some fields that are not in this model as provided by the linter, as long as they are registered in the DynamicActions model.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
current_state: AgentBrain
action: list[ActionModel] = Field(
...,
description='List of actions to execute',
json_schema_extra={'min_items': 1}, # Ensure at least one action is provided
)
@staticmethod
def type_with_custom_actions(custom_actions: type[ActionModel]) -> type[AgentOutput]:
"""Extend actions with custom actions"""
model_ = create_model(
'AgentOutput',
__base__=AgentOutput,
action=(
list[custom_actions],
Field(..., description='List of actions to execute', json_schema_extra={'min_items': 1}),
),
__module__=AgentOutput.__module__,
)
model_.__doc__ = 'AgentOutput model with custom actions'
return model_
class AgentHistory(BaseModel):
"""History item for agent actions"""
model_output: AgentOutput | None
result: list[ActionResult]
state: BrowserStateHistory
metadata: StepMetadata | None = None
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
@staticmethod
def get_interacted_element(model_output: AgentOutput, selector_map: SelectorMap) -> list[DOMHistoryElement | None]:
elements = []
for action in model_output.action:
index = action.get_index()
if index is not None and index in selector_map:
el: DOMElementNode = selector_map[index]
elements.append(HistoryTreeProcessor.convert_dom_element_to_history_element(el))
else:
elements.append(None)
return elements
def model_dump(self, **kwargs) -> dict[str, Any]:
"""Custom serialization handling circular references"""
# Handle action serialization
model_output_dump = None
if self.model_output:
action_dump = [action.model_dump(exclude_none=True) for action in self.model_output.action]
model_output_dump = {
'current_state': self.model_output.current_state.model_dump(),
'action': action_dump, # This preserves the actual action data
}
return {
'model_output': model_output_dump,
'result': [r.model_dump(exclude_none=True) for r in self.result],
'state': self.state.to_dict(),
'metadata': self.metadata.model_dump() if self.metadata else None,
}
class AgentHistoryList(BaseModel):
"""List of agent history items"""
history: list[AgentHistory]
def total_duration_seconds(self) -> float:
"""Get total duration of all steps in seconds"""
total = 0.0
for h in self.history:
if h.metadata:
total += h.metadata.duration_seconds
return total
def total_input_tokens(self) -> int:
"""
Get total tokens used across all steps.
Note: These are from the approximate token counting of the message manager.
For accurate token counting, use tools like LangChain Smith or OpenAI's token counters.
"""
total = 0
for h in self.history:
if h.metadata:
total += h.metadata.input_tokens
return total
def input_token_usage(self) -> list[int]:
"""Get token usage for each step"""
return [h.metadata.input_tokens for h in self.history if h.metadata]
def __str__(self) -> str:
"""Representation of the AgentHistoryList object"""
return f'AgentHistoryList(all_results={self.action_results()}, all_model_outputs={self.model_actions()})'
def __repr__(self) -> str:
"""Representation of the AgentHistoryList object"""
return self.__str__()
def save_to_file(self, filepath: str | Path) -> None:
"""Save history to JSON file with proper serialization"""
try:
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
data = self.model_dump()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
except Exception as e:
raise e
def save_as_playwright_script(
self,
output_path: str | Path,
sensitive_data_keys: list[str] | None = None,
browser_config: BrowserConfig | None = None,
context_config: BrowserContextConfig | None = None,
) -> None:
"""
Generates a Playwright script based on the agent's history and saves it to a file.
Args:
output_path: The path where the generated Python script will be saved.
sensitive_data_keys: A list of keys used as placeholders for sensitive data
(e.g., ['username_placeholder', 'password_placeholder']).
These will be loaded from environment variables in the
generated script.
browser_config: Configuration of the original Browser instance.
context_config: Configuration of the original BrowserContext instance.
"""
try:
serialized_history = self.model_dump()['history']
generator = PlaywrightScriptGenerator(serialized_history, sensitive_data_keys, browser_config, context_config)
script_content = generator.generate_script_content()
path_obj = Path(output_path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(path_obj, 'w', encoding='utf-8') as f:
f.write(script_content)
except Exception as e:
raise e
def model_dump(self, **kwargs) -> dict[str, Any]:
"""Custom serialization that properly uses AgentHistory's model_dump"""
return {
'history': [h.model_dump(**kwargs) for h in self.history],
}
@classmethod
def load_from_file(cls, filepath: str | Path, output_model: type[AgentOutput]) -> AgentHistoryList:
"""Load history from JSON file"""
with open(filepath, encoding='utf-8') as f:
data = json.load(f)
# loop through history and validate output_model actions to enrich with custom actions
for h in data['history']:
if h['model_output']:
if isinstance(h['model_output'], dict):
h['model_output'] = output_model.model_validate(h['model_output'])
else:
h['model_output'] = None
if 'interacted_element' not in h['state']:
h['state']['interacted_element'] = None
history = cls.model_validate(data)
return history
def last_action(self) -> None | dict:
"""Last action in history"""
if self.history and self.history[-1].model_output:
return self.history[-1].model_output.action[-1].model_dump(exclude_none=True)
return None
def errors(self) -> list[str | None]:
"""Get all errors from history, with None for steps without errors"""
errors = []
for h in self.history:
step_errors = [r.error for r in h.result if r.error]
# each step can have only one error
errors.append(step_errors[0] if step_errors else None)
return errors
def final_result(self) -> None | str:
"""Final result from history"""
if self.history and self.history[-1].result[-1].extracted_content:
return self.history[-1].result[-1].extracted_content
return None
def is_done(self) -> bool:
"""Check if the agent is done"""
if self.history and len(self.history[-1].result) > 0:
last_result = self.history[-1].result[-1]
return last_result.is_done is True
return False
def is_successful(self) -> bool | None:
"""Check if the agent completed successfully - the agent decides in the last step if it was successful or not. None if not done yet."""
if self.history and len(self.history[-1].result) > 0:
last_result = self.history[-1].result[-1]
if last_result.is_done is True:
return last_result.success
return None
def has_errors(self) -> bool:
"""Check if the agent has any non-None errors"""
return any(error is not None for error in self.errors())
def urls(self) -> list[str | None]:
"""Get all unique URLs from history"""
return [h.state.url if h.state.url is not None else None for h in self.history]
def screenshots(self) -> list[str | None]:
"""Get all screenshots from history"""
return [h.state.screenshot if h.state.screenshot is not None else None for h in self.history]
def action_names(self) -> list[str]:
"""Get all action names from history"""
action_names = []
for action in self.model_actions():
actions = list(action.keys())
if actions:
action_names.append(actions[0])
return action_names
def model_thoughts(self) -> list[AgentBrain]:
"""Get all thoughts from history"""
return [h.model_output.current_state for h in self.history if h.model_output]
def model_outputs(self) -> list[AgentOutput]:
"""Get all model outputs from history"""
return [h.model_output for h in self.history if h.model_output]
# get all actions with params
def model_actions(self) -> list[dict]:
"""Get all actions from history"""
outputs = []
for h in self.history:
if h.model_output:
for action, interacted_element in zip(h.model_output.action, h.state.interacted_element):
output = action.model_dump(exclude_none=True)
output['interacted_element'] = interacted_element
outputs.append(output)
return outputs
def action_results(self) -> list[ActionResult]:
"""Get all results from history"""
results = []
for h in self.history:
results.extend([r for r in h.result if r])
return results
def extracted_content(self) -> list[str]:
"""Get all extracted content from history"""
content = []
for h in self.history:
content.extend([r.extracted_content for r in h.result if r.extracted_content])
return content
def model_actions_filtered(self, include: list[str] | None = None) -> list[dict]:
"""Get all model actions from history as JSON"""
if include is None:
include = []
outputs = self.model_actions()
result = []
for o in outputs:
for i in include:
if i == list(o.keys())[0]:
result.append(o)
return result
def number_of_steps(self) -> int:
"""Get the number of steps in the history"""
return len(self.history)
class AgentError:
"""Container for agent error handling"""
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
NO_VALID_ACTION = 'No valid action found'
@staticmethod
def format_error(error: Exception, include_trace: bool = False) -> str:
"""Format error message based on error type and optionally include trace"""
message = ''
if isinstance(error, ValidationError):
return f'{AgentError.VALIDATION_ERROR}\nDetails: {str(error)}'
if isinstance(error, RateLimitError):
return AgentError.RATE_LIMIT_ERROR
if include_trace:
return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}'
return f'{str(error)}'

View file

@ -0,0 +1,421 @@
"""
Playwright browser on steroids.
"""
import asyncio
import gc
import logging
import os
import socket
import subprocess
from pathlib import Path
from tempfile import gettempdir
from typing import Literal
import httpx
import psutil
from dotenv import load_dotenv
from playwright.async_api import Browser as PlaywrightBrowser
from playwright.async_api import Playwright, async_playwright
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
load_dotenv()
from browser_use.browser.chrome import (
CHROME_ARGS,
CHROME_DEBUG_PORT,
CHROME_DETERMINISTIC_RENDERING_ARGS,
CHROME_DISABLE_SECURITY_ARGS,
CHROME_DOCKER_ARGS,
CHROME_HEADLESS_ARGS,
)
from browser_use.browser.context import BrowserContext, BrowserContextConfig
from browser_use.browser.utils.screen_resolution import get_screen_resolution, get_window_adjustments
from browser_use.utils import time_execution_async
logger = logging.getLogger(__name__)
IN_DOCKER = os.environ.get('IN_DOCKER', 'false').lower()[0] in 'ty1'
class ProxySettings(BaseModel):
"""the same as playwright.sync_api.ProxySettings, but now as a Pydantic BaseModel so pydantic can validate it"""
server: str
bypass: str | None = None
username: str | None = None
password: str | None = None
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
# Support dict-like behavior for compatibility with Playwright's ProxySettings
def __getitem__(self, key):
return getattr(self, key)
def get(self, key, default=None):
return getattr(self, key, default)
class BrowserConfig(BaseModel):
r"""
Configuration for the Browser.
Default values:
headless: False
Whether to run browser in headless mode (not recommended)
disable_security: False
Disable browser security features (required for cross-origin iframe support)
extra_browser_args: []
Extra arguments to pass to the browser
wss_url: None
Connect to a browser instance via WebSocket
cdp_url: None
Connect to a browser instance via CDP
browser_binary_path: None
Path to a Browser instance to use to connect to your normal browser
e.g. '/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome'
chrome_remote_debugging_port: 9222
Chrome remote debugging port to use to when browser_binary_path is supplied.
This allows running multiple chrome browsers with same browser_binary_path but running on different ports.
Also, makes it possible to launch new user provided chrome browser without closing already opened chrome instances,
by providing non-default chrome debugging port.
keep_alive: False
Keep the browser alive after the agent has finished running
deterministic_rendering: False
Enable deterministic rendering (makes GPU/font rendering consistent across different OS's and docker)
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra='ignore',
populate_by_name=True,
from_attributes=True,
validate_assignment=True,
revalidate_instances='subclass-instances',
)
wss_url: str | None = None
cdp_url: str | None = None
browser_class: Literal['chromium', 'firefox', 'webkit'] = 'chromium'
browser_binary_path: str | None = Field(
default=None, validation_alias=AliasChoices('browser_instance_path', 'chrome_instance_path')
)
chrome_remote_debugging_port: int | None = CHROME_DEBUG_PORT
extra_browser_args: list[str] = Field(default_factory=list)
headless: bool = False
disable_security: bool = False # disable_security=True is dangerous as any malicious URL visited could embed an iframe for the user's bank, and use their cookies to steal money
deterministic_rendering: bool = False
keep_alive: bool = Field(default=False, alias='_force_keep_browser_alive') # used to be called _force_keep_browser_alive
proxy: ProxySettings | None = None
new_context_config: BrowserContextConfig = Field(default_factory=BrowserContextConfig)
# @singleton: TODO - think about id singleton makes sense here
# @dev By default this is a singleton, but you can create multiple instances if you need to.
class Browser:
"""
Playwright browser on steroids.
This is persistent browser factory that can spawn multiple browser contexts.
It is recommended to use only one instance of Browser per your application (RAM usage will grow otherwise).
"""
def __init__(
self,
config: BrowserConfig | None = None,
):
logger.debug('🌎 Initializing new browser')
self.config = config or BrowserConfig()
self.playwright: Playwright | None = None
self.playwright_browser: PlaywrightBrowser | None = None
async def new_context(self, config: BrowserContextConfig | None = None) -> BrowserContext:
"""Create a browser context"""
browser_config = self.config.model_dump() if self.config else {}
context_config = config.model_dump() if config else {}
merged_config = {**browser_config, **context_config}
return BrowserContext(config=BrowserContextConfig(**merged_config), browser=self)
async def get_playwright_browser(self) -> PlaywrightBrowser:
"""Get a browser context"""
if self.playwright_browser is None:
return await self._init()
return self.playwright_browser
@time_execution_async('--init (browser)')
async def _init(self):
"""Initialize the browser session"""
playwright = await async_playwright().start()
self.playwright = playwright
browser = await self._setup_browser(playwright)
self.playwright_browser = browser
return self.playwright_browser
async def _setup_remote_cdp_browser(self, playwright: Playwright) -> PlaywrightBrowser:
"""Sets up and returns a Playwright Browser instance with anti-detection measures. Firefox has no longer CDP support."""
if 'firefox' in (self.config.browser_binary_path or '').lower():
raise ValueError(
'CDP has been deprecated for firefox, check: https://fxdx.dev/deprecating-cdp-support-in-firefox-embracing-the-future-with-webdriver-bidi/'
)
if not self.config.cdp_url:
raise ValueError('CDP URL is required')
logger.info(f'🔌 Connecting to remote browser via CDP {self.config.cdp_url}')
browser_class = getattr(playwright, self.config.browser_class)
browser = await browser_class.connect_over_cdp(self.config.cdp_url)
return browser
async def _setup_remote_wss_browser(self, playwright: Playwright) -> PlaywrightBrowser:
"""Sets up and returns a Playwright Browser instance with anti-detection measures."""
if not self.config.wss_url:
raise ValueError('WSS URL is required')
logger.info(f'🔌 Connecting to remote browser via WSS {self.config.wss_url}')
browser_class = getattr(playwright, self.config.browser_class)
browser = await browser_class.connect(self.config.wss_url)
return browser
async def _setup_user_provided_browser(self, playwright: Playwright) -> PlaywrightBrowser:
"""Sets up and returns a Playwright Browser instance with anti-detection measures."""
if not self.config.browser_binary_path:
raise ValueError('A browser_binary_path is required')
assert self.config.browser_class == 'chromium', (
'browser_binary_path only supports chromium browsers (make sure browser_class=chromium)'
)
try:
# Check if browser is already running
async with httpx.AsyncClient() as client:
response = await client.get(
f'http://localhost:{self.config.chrome_remote_debugging_port}/json/version', timeout=2
)
if response.status_code == 200:
logger.info(
f'🔌 Reusing existing browser found running on http://localhost:{self.config.chrome_remote_debugging_port}'
)
browser_class = getattr(playwright, self.config.browser_class)
browser = await browser_class.connect_over_cdp(
endpoint_url=f'http://localhost:{self.config.chrome_remote_debugging_port}',
timeout=20000, # 20 second timeout for connection
)
return browser
except httpx.RequestError:
logger.debug('🌎 No existing Chrome instance found, starting a new one')
provided_user_data_dir = [arg for arg in self.config.extra_browser_args if '--user-data-dir=' in arg]
if provided_user_data_dir:
user_data_dir = Path(provided_user_data_dir[0].split('=')[-1])
else:
fallback_user_data_dir = Path(gettempdir()) / 'browseruse' / 'profiles' / 'default' # /tmp/browseruse
try:
# ~/.config/browseruse/profiles/default
user_data_dir = Path('~/.config') / 'browseruse' / 'profiles' / 'default'
user_data_dir = user_data_dir.expanduser()
user_data_dir.mkdir(parents=True, exist_ok=True)
except Exception as e:
logger.error(f'❌ Failed to create ~/.config/browseruse directory: {type(e).__name__}: {e}')
user_data_dir = fallback_user_data_dir
user_data_dir.mkdir(parents=True, exist_ok=True)
logger.info(f'🌐 Storing Browser Profile user data dir in: {user_data_dir}')
try:
# Remove any existing SingletonLock file to allow the browser to start
(user_data_dir / 'Default' / 'SingletonLock').unlink()
self.config.extra_browser_args.append('--no-first-run')
except (FileNotFoundError, PermissionError, OSError):
pass
# Start a new Chrome instance
chrome_launch_args = [
*{ # remove duplicates (usually preserves the order, but not guaranteed)
f'--remote-debugging-port={self.config.chrome_remote_debugging_port}',
*([f'--user-data-dir={user_data_dir.resolve()}'] if not provided_user_data_dir else []),
*CHROME_ARGS,
*(CHROME_DOCKER_ARGS if IN_DOCKER else []),
*(CHROME_HEADLESS_ARGS if self.config.headless else []),
*(CHROME_DISABLE_SECURITY_ARGS if self.config.disable_security else []),
*(CHROME_DETERMINISTIC_RENDERING_ARGS if self.config.deterministic_rendering else []),
*self.config.extra_browser_args,
},
]
chrome_sub_process = await asyncio.create_subprocess_exec(
self.config.browser_binary_path,
*chrome_launch_args,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
shell=False,
)
self._chrome_subprocess = psutil.Process(chrome_sub_process.pid)
# Attempt to connect again after starting a new instance
for _ in range(10):
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f'http://localhost:{self.config.chrome_remote_debugging_port}/json/version', timeout=2
)
if response.status_code == 200:
break
except httpx.RequestError:
pass
await asyncio.sleep(1)
# Attempt to connect again after starting a new instance
try:
browser_class = getattr(playwright, self.config.browser_class)
browser = await browser_class.connect_over_cdp(
endpoint_url=f'http://localhost:{self.config.chrome_remote_debugging_port}',
timeout=20000, # 20 second timeout for connection
)
return browser
except Exception as e:
logger.error(f'❌ Failed to start a new Chrome instance: {str(e)}')
raise RuntimeError(
'To start chrome in Debug mode, you need to close all existing Chrome instances and try again otherwise we can not connect to the instance.'
)
async def _setup_builtin_browser(self, playwright: Playwright) -> PlaywrightBrowser:
"""Sets up and returns a Playwright Browser instance with anti-detection measures."""
assert self.config.browser_binary_path is None, 'browser_binary_path should be None if trying to use the builtin browsers'
# Use the configured window size from new_context_config if available
if (
not self.config.headless
and hasattr(self.config, 'new_context_config')
and hasattr(self.config.new_context_config, 'window_width')
and hasattr(self.config.new_context_config, 'window_height')
and not self.config.new_context_config.no_viewport
):
screen_size = {
'width': self.config.new_context_config.window_width,
'height': self.config.new_context_config.window_height,
}
offset_x, offset_y = get_window_adjustments()
elif self.config.headless:
screen_size = {'width': 1920, 'height': 1080}
offset_x, offset_y = 0, 0
else:
screen_size = get_screen_resolution()
offset_x, offset_y = get_window_adjustments()
chrome_args = {
f'--remote-debugging-port={self.config.chrome_remote_debugging_port}',
*CHROME_ARGS,
*(CHROME_DOCKER_ARGS if IN_DOCKER else []),
*(CHROME_HEADLESS_ARGS if self.config.headless else []),
*(CHROME_DISABLE_SECURITY_ARGS if self.config.disable_security else []),
*(CHROME_DETERMINISTIC_RENDERING_ARGS if self.config.deterministic_rendering else []),
f'--window-position={offset_x},{offset_y}',
f'--window-size={screen_size["width"]},{screen_size["height"]}',
*self.config.extra_browser_args,
}
# check if chrome remote debugging port is already taken,
# if so remove the remote-debugging-port arg to prevent conflicts
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', self.config.chrome_remote_debugging_port)) == 0:
chrome_args.remove(f'--remote-debugging-port={self.config.chrome_remote_debugging_port}')
browser_class = getattr(playwright, self.config.browser_class)
args = {
'chromium': list(chrome_args),
'firefox': [
*{
'-no-remote',
*self.config.extra_browser_args,
}
],
'webkit': [
*{
'--no-startup-window',
*self.config.extra_browser_args,
}
],
}
browser = await browser_class.launch(
channel='chromium', # https://github.com/microsoft/playwright/issues/33566
headless=self.config.headless,
args=args[self.config.browser_class],
proxy=self.config.proxy.model_dump() if self.config.proxy else None,
handle_sigterm=False,
handle_sigint=False,
)
return browser
async def _setup_browser(self, playwright: Playwright) -> PlaywrightBrowser:
"""Sets up and returns a Playwright Browser instance with anti-detection measures."""
try:
if self.config.cdp_url:
return await self._setup_remote_cdp_browser(playwright)
if self.config.wss_url:
return await self._setup_remote_wss_browser(playwright)
if self.config.headless:
logger.warning('⚠️ Headless mode is not recommended. Many sites will detect and block all headless browsers.')
if self.config.browser_binary_path:
return await self._setup_user_provided_browser(playwright)
else:
return await self._setup_builtin_browser(playwright)
except Exception as e:
logger.error(f'Failed to initialize Playwright browser: {e}')
raise
async def close(self):
"""Close the browser instance"""
if self.config.keep_alive:
return
try:
if self.playwright_browser:
await self.playwright_browser.close()
del self.playwright_browser
if self.playwright:
await self.playwright.stop()
del self.playwright
if chrome_proc := getattr(self, '_chrome_subprocess', None):
try:
# always kill all children processes, otherwise chrome leaves a bunch of zombie processes
for proc in chrome_proc.children(recursive=True):
proc.kill()
chrome_proc.kill()
except Exception as e:
logger.debug(f'Failed to terminate chrome subprocess: {e}')
except Exception as e:
if 'OpenAI error' not in str(e):
logger.debug(f'Failed to close browser properly: {e}')
finally:
self.playwright_browser = None
self.playwright = None
self._chrome_subprocess = None
gc.collect()
def __del__(self):
"""Async cleanup when object is destroyed"""
try:
if self.playwright_browser or self.playwright:
loop = asyncio.get_running_loop()
if loop.is_running():
loop.create_task(self.close())
else:
asyncio.run(self.close())
except Exception as e:
logger.debug(f'Failed to cleanup browser in destructor: {e}')

View file

@ -0,0 +1,177 @@
CHROME_EXTENSIONS = {} # coming in a separate PR
CHROME_EXTENSIONS_PATH = 'chrome_extensions'
CHROME_PROFILE_PATH = 'chrome_profile'
CHROME_PROFILE_USER = 'Default'
CHROME_DEBUG_PORT = 9242
CHROME_DISABLED_COMPONENTS = [
'Translate',
'AcceptCHFrame',
'OptimizationHints',
'ProcessPerSiteUpToMainFrameThreshold',
'InterestFeedContentSuggestions',
# 'CalculateNativeWinOcclusion',
'BackForwardCache',
# 'HeavyAdPrivacyMitigations',
'LazyFrameLoading',
# 'ImprovedCookieControls',
'PrivacySandboxSettings4',
'AutofillServerCommunication',
'CertificateTransparencyComponentUpdater',
'DestroyProfileOnBrowserClose',
'CrashReporting',
'OverscrollHistoryNavigation',
'InfiniteSessionRestore',
#'LockProfileCookieDatabase', # disabling allows multiple chrome instances to concurrently modify profile, but might make chrome much slower https://github.com/yt-dlp/yt-dlp/issues/7271 https://issues.chromium.org/issues/40901624
] # it's always best to give each chrome instance its own exclusive copy of the user profile
CHROME_HEADLESS_ARGS = [
'--headless=new',
# '--test-type',
# '--test-type=gpu', # https://github.com/puppeteer/puppeteer/issues/10516
# '--enable-automation', # <- DONT USE THIS, it makes you easily detectable / blocked by cloudflare
]
CHROME_DOCKER_ARGS = [
# Docker-specific options
# https://github.com/GoogleChrome/lighthouse-ci/tree/main/docs/recipes/docker-client#--no-sandbox-issues-explained
'--no-sandbox', # rely on docker sandboxing in docker, otherwise we need cap_add: SYS_ADM to use host sandboxing
'--disable-gpu-sandbox',
'--disable-setuid-sandbox',
'--disable-dev-shm-usage', # docker 75mb default shm size is not big enough, disabling just uses /tmp instead
'--no-xshm',
# dont try to disable (or install) dbus in docker, its not needed, chrome can work without dbus despite the errors
]
CHROME_DISABLE_SECURITY_ARGS = [
# DANGER: JS isolation security features (to allow easier tampering with pages during automation)
# chrome://net-internals
'--disable-web-security', # <- WARNING, breaks some sites that expect/enforce strict CORS headers (try webflow.com)
'--disable-site-isolation-trials',
'--disable-features=IsolateOrigins,site-per-process',
# '--allow-file-access-from-files', # <- WARNING, dangerous, allows JS to read filesystem using file:// URLs
# DANGER: Disable HTTPS verification
'--allow-running-insecure-content', # Breaks CORS/CSRF/HSTS etc., useful sometimes but very easy to detect
'--ignore-certificate-errors',
'--ignore-ssl-errors',
'--ignore-certificate-errors-spki-list',
# '--allow-insecure-localhost',
]
# flags to make chrome behave more deterministically across different OS's
CHROME_DETERMINISTIC_RENDERING_ARGS = [
'--deterministic-mode',
'--js-flags=--random-seed=1157259159', # make all JS random numbers deterministic by providing a seed
'--force-device-scale-factor=1',
# GPU, canvas, text, and pdf rendering config
# chrome://gpu
'--enable-webgl', # enable web-gl graphics support
'--font-render-hinting=none', # make rendering more deterministic by ignoring OS font hints, may also need css override, try: * {text-rendering: geometricprecision !important; -webkit-font-smoothing: antialiased;}
'--force-color-profile=srgb', # make rendering more deterministic by using consistent color profile, if browser looks weird, try: generic-rgb
# '--disable-partial-raster', # make rendering more deterministic (TODO: verify if still needed)
'--disable-skia-runtime-opts', # make rendering more deterministic by avoiding Skia hot path runtime optimizations
'--disable-2d-canvas-clip-aa', # make rendering more deterministic by disabling antialiasing on 2d canvas clips
# '--disable-gpu', # falls back to more consistent software renderer across all OS's, especially helps linux text rendering look less weird
# // '--use-gl=swiftshader', <- DO NOT USE, breaks M1 ARM64. it makes rendering more deterministic by using simpler CPU renderer instead of OS GPU renderer bug: https://groups.google.com/a/chromium.org/g/chromium-dev/c/8eR2GctzGuw
# // '--disable-software-rasterizer', <- DO NOT USE, harmless, used in tandem with --disable-gpu
# // '--run-all-compositor-stages-before-draw', <- DO NOT USE, makes headful chrome hang on startup (tested v121 Google Chrome.app on macOS)
# // '--disable-gl-drawing-for-tests', <- DO NOT USE, disables gl output (makes tests run faster if you dont care about canvas)
# // '--blink-settings=imagesEnabled=false', <- DO NOT USE, disables images entirely (only sometimes useful to speed up loading)
]
CHROME_ARGS = [
# Process management & performance tuning
# chrome://process-internals
# '--disable-lazy-loading', # make rendering more deterministic by loading all content up-front instead of on-focus
# '--disable-renderer-backgrounding', # dont throttle tab rendering based on focus/visibility
# '--disable-background-networking', # dont throttle tab networking based on focus/visibility
# '--disable-background-timer-throttling', # dont throttle tab timers based on focus/visibility
# '--disable-backgrounding-occluded-windows', # dont throttle tab window based on focus/visibility
# '--disable-ipc-flooding-protection', # dont throttle ipc traffic or accessing big request/response/buffer/etc. objects will fail
# '--disable-extensions-http-throttling', # dont throttle http traffic based on runtime heuristics
# '--disable-field-trial-config', # disable shared field trial state between browser processes
# '--disable-back-forward-cache', # disable browsing navigation cache
# Profile data dir setup
# chrome://profile-internals
# f'--user-data-dir={CHROME_PROFILE_PATH}', # managed by playwright arg instead
# f'--profile-directory={CHROME_PROFILE_USER}',
# '--password-store=basic', # use mock keychain instead of OS-provided keychain (we manage auth.json instead)
# '--use-mock-keychain',
# '--disable-cookie-encryption', # we need to be able to write unencrypted cookies to save/load auth.json
'--disable-sync', # don't try to use Google account sync features while automation is active
# Extensions
# chrome://inspect/#extensions
# f'--load-extension={CHROME_EXTENSIONS.map(({unpacked_path}) => unpacked_path).join(',')}', # not needed when using existing profile that already has extensions installed
# f'--allowlisted-extension-id={",".join(CHROME_EXTENSIONS.keys())}',
'--allow-legacy-extension-manifests',
'--allow-pre-commit-input', # allow JS mutations before page rendering is complete
'--disable-blink-features=AutomationControlled', # hide the signatures that announce browser is being remote-controlled
# f'--proxy-server=https://43.159.28.126:2334:u7ce652b7568805c4-zone-custom-region-us-session-szGWq3FRU-sessTime-60:u7ce652b7568805c4', # send all network traffic through a proxy https://2captcha.com/proxy
# f'--proxy-bypass-list=127.0.0.1',
# Browser window and viewport setup
# chrome://version
# f'--user-agent="{DEFAULT_USER_AGENT}"',
# f'--window-size={DEFAULT_VIEWPORT.width},{DEFAULT_VIEWPORT.height}',
# '--window-position=0,0',
# '--start-maximized',
'--install-autogenerated-theme=0,0,0', # black border makes it easier to see which chrome window is browser-use's
'--hide-scrollbars', # stop scrollbars from affecting screenshot width/height
#'--virtual-time-budget=60000', # DONT USE THIS, makes chrome hang forever and doesn't work, used to fast-forward all animations & timers by 60s, dont use this it's unfortunately buggy and breaks screenshot and PDF capture sometimes
#'--autoplay-policy=no-user-gesture-required', # auto-start videos so they trigger network requests + show up in outputs
#'--disable-gesture-requirement-for-media-playback',
#'--lang=en-US,en;q=0.9',
# IO: stdin/stdout, debug port config
# chrome://inspect
'--log-level=2', # 1=DEBUG 2=WARNING 3=ERROR
'--enable-logging=stderr',
# '--remote-debugging-address=127.0.0.1', <- DONT USE THIS, no longer supported on chrome >100, never expose to non-localhost, would allow attacker to drive your browser from any machine
# '--enable-experimental-extension-apis', # add support for tab groups via chrome.tabs extension API
'--disable-focus-on-load', # prevent browser from hijacking focus
'--disable-window-activation',
# '--in-process-gpu', <- DONT USE THIS, makes headful startup time ~5-10s slower (tested v121 Google Chrome.app on macOS)
# '--disable-component-extensions-with-background-pages', # TODO: check this, disables chrome components that only run in background with no visible UI (could lower startup time)
# uncomment to disable hardware camera/mic/speaker access + present fake devices to websites
# (faster to disable, but disabling breaks recording browser audio in puppeteer-stream screenrecordings)
# '--use-fake-device-for-media-stream',
# '--use-fake-ui-for-media-stream',
# '--disable-features=GlobalMediaControls,MediaRouter,DialMediaRouteProvider',
# Output format options (PDF, screenshot, etc.)
'--export-tagged-pdf', # include table on contents and tags in printed PDFs
'--generate-pdf-document-outline',
# Suppress first-run features, popups, hints, updates, etc.
# chrome://system
'--no-pings',
'--no-default-browser-check',
'--no-startup-window',
'--ash-no-nudges',
'--disable-infobars',
'--disable-search-engine-choice-screen',
'--disable-session-crashed-bubble',
'--simulate-outdated-no-au="Tue, 31 Dec 2099 23:59:59 GMT"', # disable browser self-update while automation is active
'--hide-crash-restore-bubble',
'--suppress-message-center-popups',
'--disable-client-side-phishing-detection',
'--disable-domain-reliability',
'--disable-datasaver-prompt',
'--disable-hang-monitor',
'--disable-session-crashed-bubble',
'--disable-speech-synthesis-api',
'--disable-speech-api',
'--disable-print-preview',
'--safebrowsing-disable-auto-update',
# '--deny-permission-prompts',
'--disable-external-intent-requests',
# '--disable-notifications',
'--disable-desktop-notifications',
'--noerrdialogs',
'--disable-prompt-on-repost',
'--silent-debugger-extension-api',
# '--block-new-web-contents',
'--metrics-recording-only',
'--disable-breakpad',
# other feature flags
# chrome://flags chrome://components
f'--disable-features={",".join(CHROME_DISABLED_COMPONENTS)}',
'--enable-features=NetworkService',
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,348 @@
import logging
import os
import aiohttp
from playwright.async_api import Page, async_playwright
from browser_use.browser.service import Browser
from browser_use.browser.views import BrowserState, TabInfo
logger = logging.getLogger(__name__)
class DolphinBrowser(Browser):
"""A class for managing Dolphin Anty browser sessions using Playwright"""
def __init__(self, headless: bool = False, keep_open: bool = False):
"""
Initialize the DolphinBrowser instance.
Args:
headless (bool): Run browser in headless mode (default: False).
keep_open (bool): Keep browser open after finishing tasks (default: False).
"""
# Retrieve environment variables for API connection
self.api_token = os.getenv('DOLPHIN_API_TOKEN')
self.api_url = os.getenv('DOLPHIN_API_URL', 'http://localhost:3001/v1.0')
self.profile_id = os.getenv('DOLPHIN_PROFILE_ID')
# Initialize internal attributes
self.playwright = None
self.browser = None
self.context = None
self.page = None
self.headless = headless
self.keep_open = keep_open
self._pages: list[Page] = [] # List to store open pages
self.session = None
self.cached_state = None
async def get_current_page(self) -> Page:
"""
Get the currently active page.
Raises:
Exception: If no active page is available.
"""
if not self.page:
raise Exception('No active page. Browser might not be connected.')
return self.page
async def create_new_tab(self, url: str | None = None) -> None:
"""
Create a new tab and optionally navigate to a given URL.
Args:
url (str, optional): URL to navigate to after creating the tab. Defaults to None.
Raises:
Exception: If browser context is not initialized or navigation fails.
"""
if not self.context:
raise Exception('Browser context not initialized')
# Create new page (tab) in the current browser context
new_page = await self.context.new_page()
self._pages.append(new_page)
self.page = new_page # Set as current page
if url:
try:
# Navigate to the URL and wait for the page to load
await new_page.goto(url, wait_until='networkidle')
await self.wait_for_page_load()
except Exception as e:
logger.error(f'Failed to navigate to URL {url}: {str(e)}')
raise
async def switch_to_tab(self, page_id: int) -> None:
"""
Switch to a specific tab by its page ID.
Args:
page_id (int): The index of the tab to switch to.
Raises:
Exception: If the tab index is out of range or no tabs are available.
"""
if not self._pages:
raise Exception('No tabs available')
# Handle negative indices (e.g., -1 for last tab)
if page_id < 0:
page_id = len(self._pages) + page_id
if page_id >= len(self._pages) or page_id < 0:
raise Exception(f'Tab index {page_id} out of range')
# Set the current page to the selected tab
self.page = self._pages[page_id]
await self.page.bring_to_front() # Bring tab to the front
await self.wait_for_page_load()
async def get_tabs_info(self) -> list[TabInfo]:
"""
Get information about all open tabs.
Returns:
list: A list of TabInfo objects containing details about each tab.
"""
tabs_info = []
for idx, page in enumerate(self._pages):
tab_info = TabInfo(
page_id=idx,
url=page.url,
title=await page.title(), # Fetch the title of the page
)
tabs_info.append(tab_info)
return tabs_info
async def wait_for_page_load(self, timeout: int = 30000):
"""
Wait for the page to load completely.
Args:
timeout (int): Maximum time to wait for page load in milliseconds (default: 30000ms).
Raises:
Exception: If the page fails to load within the specified timeout.
"""
if self.page:
try:
await self.page.wait_for_load_state('networkidle', timeout=timeout)
except Exception as e:
logger.warning(f'Wait for page load timeout: {str(e)}')
async def get_session(self):
"""
Get the current session.
Returns:
DolphinBrowser: The current DolphinBrowser instance.
Raises:
Exception: If the browser is not connected.
"""
if not self.browser:
raise Exception('Browser not connected. Call connect() first.')
self.session = self
return self
async def authenticate(self):
"""
Authenticate with Dolphin Anty API using the API token.
Raises:
Exception: If authentication fails.
"""
async with aiohttp.ClientSession() as session:
auth_url = f'{self.api_url}/auth/login-with-token'
auth_data = {'token': self.api_token}
async with session.post(auth_url, json=auth_data) as response:
if not response.ok:
raise Exception(f'Failed to authenticate with Dolphin Anty: {await response.text()}')
return await response.json()
async def get_browser_profiles(self):
"""
Get a list of available browser profiles from Dolphin Anty.
Returns:
list: A list of browser profiles.
Raises:
Exception: If fetching the browser profiles fails.
"""
# Authenticate before fetching profiles
await self.authenticate()
async with aiohttp.ClientSession() as session:
headers = {'Authorization': f'Bearer {self.api_token}'}
async with session.get(f'{self.api_url}/browser_profiles', headers=headers) as response:
if not response.ok:
raise Exception(f'Failed to get browser profiles: {await response.text()}')
data = await response.json()
return data.get('data', []) # Return the profiles array from the response
async def start_profile(self, profile_id: str | None = None, headless: bool = False) -> dict:
"""
Start a browser profile on Dolphin Anty.
Args:
profile_id (str, optional): Profile ID to start (defaults to the one set in the environment).
headless (bool): Run browser in headless mode (default: False).
Returns:
dict: Information about the started profile.
Raises:
ValueError: If no profile ID is provided and no default is set.
Exception: If starting the profile fails.
"""
# Authenticate before starting the profile
await self.authenticate()
profile_id = profile_id or self.profile_id
if not profile_id:
raise ValueError('No profile ID provided')
url = f'{self.api_url}/browser_profiles/{profile_id}/start'
params = {'automation': 1}
if headless:
params['headless'] = 1
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if not response.ok:
raise Exception(f'Failed to start profile: {await response.text()}')
return await response.json()
async def stop_profile(self, profile_id: str | None = None):
"""
Stop a browser profile on Dolphin Anty.
Args:
profile_id (str, optional): Profile ID to stop (defaults to the one set in the environment).
Returns:
dict: Information about the stopped profile.
Raises:
ValueError: If no profile ID is provided and no default is set.
"""
# Authenticate before stopping the profile
await self.authenticate()
profile_id = profile_id or self.profile_id
if not profile_id:
raise ValueError('No profile ID provided')
url = f'{self.api_url}/browser_profiles/{profile_id}/stop'
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.json()
async def connect(self, profile_id: str | None = None):
"""
Connect to a running browser profile using Playwright.
Args:
profile_id (str, optional): Profile ID to connect to (defaults to the one set in the environment).
Returns:
PlaywrightBrowser: The connected browser instance.
Raises:
Exception: If authentication or profile connection fails.
"""
# Authenticate before connecting to the profile
await self.authenticate()
# Start the browser profile
profile_data = await self.start_profile(profile_id)
if not profile_data.get('success'):
raise Exception(f'Failed to start profile: {profile_data}')
automation = profile_data['automation']
port = automation['port']
ws_endpoint = automation['wsEndpoint']
ws_url = f'ws://127.0.0.1:{port}{ws_endpoint}'
# Use Playwright to connect to the browser's WebSocket endpoint
self.playwright = await async_playwright().start()
self.browser = await self.playwright.chromium.connect_over_cdp(ws_url)
# Get or create a browser context and page
contexts = self.browser.contexts
self.context = contexts[0] if contexts else await self.browser.new_context()
pages = self.context.pages
self.page = pages[0] if pages else await self.context.new_page()
self._pages = [self.page] # Initialize pages list with the first page
return self.browser
async def close(self, force: bool = False):
"""
Close the browser connection and clean up resources.
Args:
force (bool): If True, forcefully stop the associated profile (default: False).
"""
try:
# Close all open pages
if self._pages:
for page in self._pages:
try:
await page.close()
except BaseException:
pass
self._pages = []
# Close the browser and Playwright instance
if self.browser:
await self.browser.close()
if self.playwright:
await self.playwright.stop()
if force:
await self.stop_profile() # Force stop the profile
except Exception as e:
logger.error(f'Error during browser cleanup: {str(e)}')
async def get_current_state(self) -> BrowserState:
"""
Get the current state of the browser (URL, content, viewport size, tabs).
Returns:
BrowserState: The current state of the browser.
Raises:
Exception: If no active page is available.
"""
if not self.page:
raise Exception('No active page')
# Get page content and viewport size
content = await self.page.content()
viewport_size = await self.page.viewport_size()
# Create and return the current browser state
state = BrowserState(
url=self.page.url,
content=content,
viewport_height=viewport_size['height'] if viewport_size else 0,
viewport_width=viewport_size['width'] if viewport_size else 0,
tabs=await self.get_tabs_info(),
)
# Cache and return the state
self.cached_state = state
return state
def __del__(self):
"""Clean up resources when the DolphinBrowser instance is deleted."""
# No need to handle session cleanup as we're using self as session
pass

View file

@ -0,0 +1,39 @@
import httpx
import pytest
from browser_use.browser.browser import Browser, BrowserConfig
@pytest.mark.asyncio
async def test_browser_close_doesnt_affect_external_httpx_clients():
"""
Test that Browser.close() doesn't close HTTPX clients created outside the Browser instance.
This test demonstrates the issue where Browser.close() is closing all HTTPX clients.
"""
# Create an external HTTPX client that should remain open
external_client = httpx.AsyncClient()
# Create a Browser instance
browser = Browser(config=BrowserConfig(headless=True))
# Close the browser (which should trigger cleanup_httpx_clients)
await browser.close()
# Check if the external client is still usable
try:
# If the client is closed, this will raise RuntimeError
# Using a simple HEAD request to a reliable URL
await external_client.head('https://www.example.com', timeout=2.0)
client_is_closed = False
except RuntimeError as e:
# If we get "Cannot send a request, as the client has been closed"
client_is_closed = 'client has been closed' in str(e)
except Exception:
# Any other exception means the client is not closed but request failed
client_is_closed = False
finally:
# Always clean up our test client properly
await external_client.aclose()
# Our external client should not be closed by browser.close()
assert not client_is_closed, 'External HTTPX client was incorrectly closed by Browser.close()'

View file

@ -0,0 +1,36 @@
import asyncio
import base64
import pytest
from browser_use.browser.browser import Browser, BrowserConfig
async def test_take_full_page_screenshot():
browser = Browser(config=BrowserConfig(headless=False, disable_security=True))
try:
async with await browser.new_context() as context:
page = await context.get_current_page()
# Go to a test page
await page.goto('https://example.com')
await asyncio.sleep(3)
# Take full page screenshot
screenshot_b64 = await context.take_screenshot(full_page=True)
await asyncio.sleep(3)
# Verify screenshot is not empty and is valid base64
assert screenshot_b64 is not None
assert isinstance(screenshot_b64, str)
assert len(screenshot_b64) > 0
# Test we can decode the base64 string
try:
base64.b64decode(screenshot_b64)
except Exception as e:
pytest.fail(f'Failed to decode base64 screenshot: {str(e)}')
finally:
await browser.close()
if __name__ == '__main__':
asyncio.run(test_take_full_page_screenshot())

View file

@ -0,0 +1,96 @@
import asyncio
import json
import anyio
import pytest
from browser_use.browser.browser import Browser, BrowserConfig
from browser_use.dom.views import DOMBaseNode, DOMElementNode, DOMTextNode
from browser_use.utils import time_execution_sync
class ElementTreeSerializer:
@staticmethod
def dom_element_node_to_json(element_tree: DOMElementNode) -> dict:
def node_to_dict(node: DOMBaseNode) -> dict:
if isinstance(node, DOMTextNode):
return {'type': 'text', 'text': node.text}
elif isinstance(node, DOMElementNode):
return {
'type': 'element',
'tag_name': node.tag_name,
'attributes': node.attributes,
'highlight_index': node.highlight_index,
'children': [node_to_dict(child) for child in node.children],
}
return {}
return node_to_dict(element_tree)
# run with: pytest browser_use/browser/tests/test_clicks.py
@pytest.mark.asyncio
async def test_highlight_elements():
browser = Browser(config=BrowserConfig(headless=False, disable_security=True))
async with await browser.new_context() as context:
page = await context.get_current_page()
# await page.goto('https://immobilienscout24.de')
# await page.goto('https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/service-plans')
# await page.goto('https://google.com/search?q=elon+musk')
# await page.goto('https://kayak.com')
# await page.goto('https://www.w3schools.com/tags/tryit.asp?filename=tryhtml_iframe')
# await page.goto('https://dictionary.cambridge.org')
# await page.goto('https://github.com')
await page.goto('https://huggingface.co/')
await asyncio.sleep(1)
while True:
try:
# await asyncio.sleep(10)
state = await context.get_state(True)
async with await anyio.open_file('./tmp/page.json', 'w') as f:
await f.write(
json.dumps(
ElementTreeSerializer.dom_element_node_to_json(state.element_tree),
indent=1,
)
)
# await time_execution_sync('highlight_selector_map_elements')(
# browser.highlight_selector_map_elements
# )(state.selector_map)
# Find and print duplicate XPaths
xpath_counts = {}
if not state.selector_map:
continue
for selector in state.selector_map.values():
xpath = selector.xpath
if xpath in xpath_counts:
xpath_counts[xpath] += 1
else:
xpath_counts[xpath] = 1
print('\nDuplicate XPaths found:')
for xpath, count in xpath_counts.items():
if count > 1:
print(f'XPath: {xpath}')
print(f'Count: {count}\n')
print(list(state.selector_map.keys()), 'Selector map keys')
print(state.element_tree.clickable_elements_to_string())
action = input('Select next action: ')
await time_execution_sync('remove_highlight_elements')(context.remove_highlights)()
node_element = state.selector_map[int(action)]
# check if index of selector map are the same as index of items in dom_items
await context._click_element_node(node_element)
except Exception as e:
print(e)

View file

@ -0,0 +1,41 @@
import sys
def get_screen_resolution():
if sys.platform == 'darwin': # macOS
try:
from AppKit import NSScreen
screen = NSScreen.mainScreen().frame()
return {'width': int(screen.size.width), 'height': int(screen.size.height)}
except ImportError:
print('AppKit is not available. Make sure you are running this on macOS with pyobjc installed.')
except Exception as e:
print(f'Error retrieving macOS screen resolution: {e}')
return {'width': 2560, 'height': 1664}
else: # Windows & Linux
try:
from screeninfo import get_monitors
monitors = get_monitors()
if not monitors:
raise Exception('No monitors detected.')
monitor = monitors[0]
return {'width': monitor.width, 'height': monitor.height}
except ImportError:
print("screeninfo package not found. Install it using 'pip install screeninfo'.")
except Exception as e:
print(f'Error retrieving screen resolution: {e}')
return {'width': 1920, 'height': 1080}
def get_window_adjustments():
"""Returns recommended x, y offsets for window positioning"""
if sys.platform == 'darwin': # macOS
return -4, 24 # macOS has a small title bar, no border
elif sys.platform == 'win32': # Windows
return -8, 0 # Windows has a border on the left
else: # Linux
return 0, 0

View file

@ -0,0 +1,54 @@
from dataclasses import dataclass, field
from typing import Any
from pydantic import BaseModel
from browser_use.dom.history_tree_processor.service import DOMHistoryElement
from browser_use.dom.views import DOMState
# Pydantic
class TabInfo(BaseModel):
"""Represents information about a browser tab"""
page_id: int
url: str
title: str
parent_page_id: int | None = None # parent page that contains this popup or cross-origin iframe
@dataclass
class BrowserState(DOMState):
url: str
title: str
tabs: list[TabInfo]
screenshot: str | None = None
pixels_above: int = 0
pixels_below: int = 0
browser_errors: list[str] = field(default_factory=list)
@dataclass
class BrowserStateHistory:
url: str
title: str
tabs: list[TabInfo]
interacted_element: list[DOMHistoryElement | None] | list[None]
screenshot: str | None = None
def to_dict(self) -> dict[str, Any]:
data = {}
data['tabs'] = [tab.model_dump() for tab in self.tabs]
data['screenshot'] = self.screenshot
data['interacted_element'] = [el.to_dict() if el else None for el in self.interacted_element]
data['url'] = self.url
data['title'] = self.title
return data
class BrowserError(Exception):
"""Base class for all browser errors"""
class URLNotAllowedError(BrowserError):
"""Error raised when a URL is not allowed"""

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,246 @@
import asyncio
from collections.abc import Callable
from inspect import iscoroutinefunction, signature
from typing import Any, Generic, Optional, TypeVar
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic import BaseModel, Field, create_model
from browser_use.browser.context import BrowserContext
from browser_use.controller.registry.views import (
ActionModel,
ActionRegistry,
RegisteredAction,
)
from browser_use.telemetry.service import ProductTelemetry
from browser_use.telemetry.views import (
ControllerRegisteredFunctionsTelemetryEvent,
RegisteredFunction,
)
from browser_use.utils import time_execution_async
Context = TypeVar('Context')
class Registry(Generic[Context]):
"""Service for registering and managing actions"""
def __init__(self, exclude_actions: list[str] | None = None):
self.registry = ActionRegistry()
self.telemetry = ProductTelemetry()
self.exclude_actions = exclude_actions if exclude_actions is not None else []
# @time_execution_sync('--create_param_model')
def _create_param_model(self, function: Callable) -> type[BaseModel]:
"""Creates a Pydantic model from function signature"""
sig = signature(function)
params = {
name: (param.annotation, ... if param.default == param.empty else param.default)
for name, param in sig.parameters.items()
if name != 'browser' and name != 'page_extraction_llm' and name != 'available_file_paths'
}
# TODO: make the types here work
return create_model(
f'{function.__name__}_parameters',
__base__=ActionModel,
**params, # type: ignore
)
def action(
self,
description: str,
param_model: type[BaseModel] | None = None,
domains: list[str] | None = None,
page_filter: Callable[[Any], bool] | None = None,
):
"""Decorator for registering actions"""
def decorator(func: Callable):
# Skip registration if action is in exclude_actions
if func.__name__ in self.exclude_actions:
return func
# Create param model from function if not provided
actual_param_model = param_model or self._create_param_model(func)
# Wrap sync functions to make them async
if not iscoroutinefunction(func):
async def async_wrapper(*args, **kwargs):
return await asyncio.to_thread(func, *args, **kwargs)
# Copy the signature and other metadata from the original function
async_wrapper.__signature__ = signature(func)
async_wrapper.__name__ = func.__name__
async_wrapper.__annotations__ = func.__annotations__
wrapped_func = async_wrapper
else:
wrapped_func = func
action = RegisteredAction(
name=func.__name__,
description=description,
function=wrapped_func,
param_model=actual_param_model,
domains=domains,
page_filter=page_filter,
)
self.registry.actions[func.__name__] = action
return func
return decorator
@time_execution_async('--execute_action')
async def execute_action(
self,
action_name: str,
params: dict,
browser: BrowserContext | None = None,
page_extraction_llm: BaseChatModel | None = None,
sensitive_data: dict[str, str] | None = None,
available_file_paths: list[str] | None = None,
#
context: Context | None = None,
) -> Any:
"""Execute a registered action"""
if action_name not in self.registry.actions:
raise ValueError(f'Action {action_name} not found')
action = self.registry.actions[action_name]
try:
# Create the validated Pydantic model
validated_params = action.param_model(**params)
# Check if the first parameter is a Pydantic model
sig = signature(action.function)
parameters = list(sig.parameters.values())
is_pydantic = parameters and issubclass(parameters[0].annotation, BaseModel)
parameter_names = [param.name for param in parameters]
if sensitive_data:
validated_params = self._replace_sensitive_data(validated_params, sensitive_data)
# Check if the action requires browser
if 'browser' in parameter_names and not browser:
raise ValueError(f'Action {action_name} requires browser but none provided.')
if 'page_extraction_llm' in parameter_names and not page_extraction_llm:
raise ValueError(f'Action {action_name} requires page_extraction_llm but none provided.')
if 'available_file_paths' in parameter_names and not available_file_paths:
raise ValueError(f'Action {action_name} requires available_file_paths but none provided.')
if 'context' in parameter_names and not context:
raise ValueError(f'Action {action_name} requires context but none provided.')
# Prepare arguments based on parameter type
extra_args = {}
if 'context' in parameter_names:
extra_args['context'] = context
if 'browser' in parameter_names:
extra_args['browser'] = browser
if 'page_extraction_llm' in parameter_names:
extra_args['page_extraction_llm'] = page_extraction_llm
if 'available_file_paths' in parameter_names:
extra_args['available_file_paths'] = available_file_paths
if action_name == 'input_text' and sensitive_data:
extra_args['has_sensitive_data'] = True
if is_pydantic:
return await action.function(validated_params, **extra_args)
return await action.function(**validated_params.model_dump(), **extra_args)
except Exception as e:
raise RuntimeError(f'Error executing action {action_name}: {str(e)}') from e
def _replace_sensitive_data(self, params: BaseModel, sensitive_data: dict[str, str]) -> BaseModel:
"""Replaces the sensitive data in the params"""
# if there are any str with <secret>placeholder</secret> in the params, replace them with the actual value from sensitive_data
import logging
import re
logger = logging.getLogger(__name__)
secret_pattern = re.compile(r'<secret>(.*?)</secret>')
# Set to track all missing placeholders across the full object
all_missing_placeholders = set()
def replace_secrets(value):
if isinstance(value, str):
matches = secret_pattern.findall(value)
for placeholder in matches:
if placeholder in sensitive_data and sensitive_data[placeholder]:
value = value.replace(f'<secret>{placeholder}</secret>', sensitive_data[placeholder])
else:
# Keep track of missing placeholders
all_missing_placeholders.add(placeholder)
# Don't replace the tag, keep it as is
return value
elif isinstance(value, dict):
return {k: replace_secrets(v) for k, v in value.items()}
elif isinstance(value, list):
return [replace_secrets(v) for v in value]
return value
params_dump = params.model_dump()
processed_params = replace_secrets(params_dump)
# Log a warning if any placeholders are missing
if all_missing_placeholders:
logger.warning(f'Missing or empty keys in sensitive_data dictionary: {", ".join(all_missing_placeholders)}')
return type(params).model_validate(processed_params)
# @time_execution_sync('--create_action_model')
def create_action_model(self, include_actions: list[str] | None = None, page=None) -> type[ActionModel]:
"""Creates a Pydantic model from registered actions, used by LLM APIs that support tool calling & enforce a schema"""
# Filter actions based on page if provided:
# if page is None, only include actions with no filters
# if page is provided, only include actions that match the page
available_actions = {}
for name, action in self.registry.actions.items():
if include_actions is not None and name not in include_actions:
continue
# If no page provided, only include actions with no filters
if page is None:
if action.page_filter is None and action.domains is None:
available_actions[name] = action
continue
# Check page_filter if present
domain_is_allowed = self.registry._match_domains(action.domains, page.url)
page_is_allowed = self.registry._match_page_filter(action.page_filter, page)
# Include action if both filters match (or if either is not present)
if domain_is_allowed and page_is_allowed:
available_actions[name] = action
fields = {
name: (
Optional[action.param_model],
Field(default=None, description=action.description),
)
for name, action in available_actions.items()
}
self.telemetry.capture(
ControllerRegisteredFunctionsTelemetryEvent(
registered_functions=[
RegisteredFunction(name=name, params=action.param_model.model_json_schema())
for name, action in available_actions.items()
]
)
)
return create_model('ActionModel', __base__=ActionModel, **fields) # type:ignore
def get_prompt_description(self, page=None) -> str:
"""Get a description of all actions for the prompt
If page is provided, only include actions that are available for that page
based on their filter_func
"""
return self.registry.get_prompt_description(page=page)

View file

@ -0,0 +1,149 @@
from collections.abc import Callable
from playwright.async_api import Page
from pydantic import BaseModel, ConfigDict
class RegisteredAction(BaseModel):
"""Model for a registered action"""
name: str
description: str
function: Callable
param_model: type[BaseModel]
# filters: provide specific domains or a function to determine whether the action should be available on the given page or not
domains: list[str] | None = None # e.g. ['*.google.com', 'www.bing.com', 'yahoo.*]
page_filter: Callable[[Page], bool] | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def prompt_description(self) -> str:
"""Get a description of the action for the prompt"""
skip_keys = ['title']
s = f'{self.description}: \n'
s += '{' + str(self.name) + ': '
s += str(
{
k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k not in skip_keys}
for k, v in self.param_model.model_json_schema()['properties'].items()
}
)
s += '}'
return s
class ActionModel(BaseModel):
"""Base model for dynamically created action models"""
# this will have all the registered actions, e.g.
# click_element = param_model = ClickElementParams
# done = param_model = None
#
model_config = ConfigDict(arbitrary_types_allowed=True)
def get_index(self) -> int | None:
"""Get the index of the action"""
# {'clicked_element': {'index':5}}
params = self.model_dump(exclude_unset=True).values()
if not params:
return None
for param in params:
if param is not None and 'index' in param:
return param['index']
return None
def set_index(self, index: int):
"""Overwrite the index of the action"""
# Get the action name and params
action_data = self.model_dump(exclude_unset=True)
action_name = next(iter(action_data.keys()))
action_params = getattr(self, action_name)
# Update the index directly on the model
if hasattr(action_params, 'index'):
action_params.index = index
class ActionRegistry(BaseModel):
"""Model representing the action registry"""
actions: dict[str, RegisteredAction] = {}
@staticmethod
def _match_domains(domains: list[str] | None, url: str) -> bool:
"""
Match a list of domain glob patterns against a URL.
Args:
domain_patterns: A list of domain patterns that can include glob patterns (* wildcard)
url: The URL to match against
Returns:
True if the URL's domain matches the pattern, False otherwise
"""
if domains is None or not url:
return True
import fnmatch
from urllib.parse import urlparse
# Parse the URL to get the domain
try:
parsed_url = urlparse(url)
if not parsed_url.netloc:
return False
domain = parsed_url.netloc
# Remove port if present
if ':' in domain:
domain = domain.split(':')[0]
for domain_pattern in domains:
if fnmatch.fnmatch(domain, domain_pattern): # Perform glob *.matching.*
return True
return False
except Exception:
return False
@staticmethod
def _match_page_filter(page_filter: Callable[[Page], bool] | None, page: Page) -> bool:
"""Match a page filter against a page"""
if page_filter is None:
return True
return page_filter(page)
def get_prompt_description(self, page: Page | None = None) -> str:
"""Get a description of all actions for the prompt
Args:
page: If provided, filter actions by page using page_filter and domains.
Returns:
A string description of available actions.
- If page is None: return only actions with no page_filter and no domains (for system prompt)
- If page is provided: return only filtered actions that match the current page (excluding unfiltered actions)
"""
if page is None:
# For system prompt (no page provided), include only actions with no filters
return '\n'.join(
action.prompt_description()
for action in self.actions.values()
if action.page_filter is None and action.domains is None
)
# only include filtered actions for the current page
filtered_actions = []
for action in self.actions.values():
if not (action.domains or action.page_filter):
# skip actions with no filters, they are already included in the system prompt
continue
domain_is_allowed = self._match_domains(action.domains, page.url)
page_is_allowed = self._match_page_filter(action.page_filter, page)
if domain_is_allowed and page_is_allowed:
filtered_actions.append(action)
return '\n'.join(action.prompt_description() for action in filtered_actions)

View file

@ -0,0 +1,874 @@
import asyncio
import enum
import json
import logging
import re
from typing import Generic, TypeVar, cast
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import PromptTemplate
from playwright.async_api import ElementHandle, Page
# from lmnr.sdk.laminar import Laminar
from pydantic import BaseModel
from browser_use.agent.views import ActionModel, ActionResult
from browser_use.browser.context import BrowserContext
from browser_use.controller.registry.service import Registry
from browser_use.controller.views import (
ClickElementAction,
CloseTabAction,
DoneAction,
DragDropAction,
GoToUrlAction,
InputTextAction,
NoParamsAction,
OpenTabAction,
Position,
ScrollAction,
SearchGoogleAction,
SendKeysAction,
SwitchTabAction,
)
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
Context = TypeVar('Context')
class Controller(Generic[Context]):
def __init__(
self,
exclude_actions: list[str] = [],
output_model: type[BaseModel] | None = None,
):
self.registry = Registry[Context](exclude_actions)
"""Register all default browser actions"""
if output_model is not None:
# Create a new model that extends the output model with success parameter
class ExtendedOutputModel(BaseModel): # type: ignore
success: bool = True
data: output_model # type: ignore
@self.registry.action(
'Complete task - with return text and if the task is finished (success=True) or not yet completely finished (success=False), because last step is reached',
param_model=ExtendedOutputModel,
)
async def done(params: ExtendedOutputModel):
# Exclude success from the output JSON since it's an internal parameter
output_dict = params.data.model_dump()
# Enums are not serializable, convert to string
for key, value in output_dict.items():
if isinstance(value, enum.Enum):
output_dict[key] = value.value
return ActionResult(is_done=True, success=params.success, extracted_content=json.dumps(output_dict))
else:
@self.registry.action(
'Complete task - with return text and if the task is finished (success=True) or not yet completely finished (success=False), because last step is reached',
param_model=DoneAction,
)
async def done(params: DoneAction):
return ActionResult(is_done=True, success=params.success, extracted_content=params.text)
# Basic Navigation Actions
@self.registry.action(
'Search the query in Google in the current tab, the query should be a search query like humans search in Google, concrete and not vague or super long. More the single most important items. ',
param_model=SearchGoogleAction,
)
async def search_google(params: SearchGoogleAction, browser: BrowserContext):
page = await browser.get_current_page()
await page.goto(f'https://www.google.com/search?q={params.query}&udm=14')
await page.wait_for_load_state()
msg = f'🔍 Searched for "{params.query}" in Google'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
@self.registry.action('Navigate to URL in the current tab', param_model=GoToUrlAction)
async def go_to_url(params: GoToUrlAction, browser: BrowserContext):
page = await browser.get_current_page()
await page.goto(params.url)
await page.wait_for_load_state()
msg = f'🔗 Navigated to {params.url}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
@self.registry.action('Go back', param_model=NoParamsAction)
async def go_back(_: NoParamsAction, browser: BrowserContext):
await browser.go_back()
msg = '🔙 Navigated back'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
# wait for x seconds
@self.registry.action('Wait for x seconds default 3')
async def wait(seconds: int = 3):
msg = f'🕒 Waiting for {seconds} seconds'
logger.info(msg)
await asyncio.sleep(seconds)
return ActionResult(extracted_content=msg, include_in_memory=True)
# Element Interaction Actions
@self.registry.action('Click element by index', param_model=ClickElementAction)
async def click_element_by_index(params: ClickElementAction, browser: BrowserContext):
session = await browser.get_session()
if params.index not in await browser.get_selector_map():
raise Exception(f'Element with index {params.index} does not exist - retry or use alternative actions')
element_node = await browser.get_dom_element_by_index(params.index)
initial_pages = len(session.context.pages)
# if element has file uploader then dont click
if await browser.is_file_uploader(element_node):
msg = f'Index {params.index} - has an element which opens file upload dialog. To upload files please use a specific function to upload files '
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
msg = None
try:
download_path = await browser._click_element_node(element_node)
if download_path:
msg = f'💾 Downloaded file to {download_path}'
else:
msg = f'🖱️ Clicked button with index {params.index}: {element_node.get_all_text_till_next_clickable_element(max_depth=2)}'
logger.info(msg)
logger.debug(f'Element xpath: {element_node.xpath}')
if len(session.context.pages) > initial_pages:
new_tab_msg = 'New tab opened - switching to it'
msg += f' - {new_tab_msg}'
logger.info(new_tab_msg)
await browser.switch_to_tab(-1)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
logger.warning(f'Element not clickable with index {params.index} - most likely the page changed')
return ActionResult(error=str(e))
@self.registry.action(
'Input text into a input interactive element',
param_model=InputTextAction,
)
async def input_text(params: InputTextAction, browser: BrowserContext, has_sensitive_data: bool = False):
if params.index not in await browser.get_selector_map():
raise Exception(f'Element index {params.index} does not exist - retry or use alternative actions')
element_node = await browser.get_dom_element_by_index(params.index)
await browser._input_text_element_node(element_node, params.text)
if not has_sensitive_data:
msg = f'⌨️ Input {params.text} into index {params.index}'
else:
msg = f'⌨️ Input sensitive data into index {params.index}'
logger.info(msg)
logger.debug(f'Element xpath: {element_node.xpath}')
return ActionResult(extracted_content=msg, include_in_memory=True)
# Save PDF
@self.registry.action(
'Save the current page as a PDF file',
)
async def save_pdf(browser: BrowserContext):
page = await browser.get_current_page()
short_url = re.sub(r'^https?://(?:www\.)?|/$', '', page.url)
slug = re.sub(r'[^a-zA-Z0-9]+', '-', short_url).strip('-').lower()
sanitized_filename = f'{slug}.pdf'
await page.emulate_media(media='screen')
await page.pdf(path=sanitized_filename, format='A4', print_background=False)
msg = f'Saving page with URL {page.url} as PDF to ./{sanitized_filename}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
# Tab Management Actions
@self.registry.action('Switch tab', param_model=SwitchTabAction)
async def switch_tab(params: SwitchTabAction, browser: BrowserContext):
await browser.switch_to_tab(params.page_id)
# Wait for tab to be ready and ensure references are synchronized
page = await browser.get_agent_current_page()
await page.wait_for_load_state()
msg = f'🔄 Switched to tab {params.page_id}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
@self.registry.action('Open url in new tab', param_model=OpenTabAction)
async def open_tab(params: OpenTabAction, browser: BrowserContext):
await browser.create_new_tab(params.url)
# Ensure tab references are properly synchronized
await browser.get_agent_current_page() # this has side-effects (even though it looks like a getter)
msg = f'🔗 Opened new tab with {params.url}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
@self.registry.action('Close an existing tab', param_model=CloseTabAction)
async def close_tab(params: CloseTabAction, browser: BrowserContext):
await browser.switch_to_tab(params.page_id)
page = await browser.get_current_page()
url = page.url
await page.close()
msg = f'❌ Closed tab #{params.page_id} with url {url}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
# Content Actions
@self.registry.action(
'Extract page content to retrieve specific information from the page, e.g. all company names, a specific description, all information about, links with companies in structured format or simply links',
)
async def extract_content(
goal: str, should_strip_link_urls: bool, browser: BrowserContext, page_extraction_llm: BaseChatModel
):
page = await browser.get_current_page()
import markdownify
strip = []
if should_strip_link_urls:
strip = ['a', 'img']
content = markdownify.markdownify(await page.content(), strip=strip)
# manually append iframe text into the content so it's readable by the LLM (includes cross-origin iframes)
for iframe in page.frames:
if iframe.url != page.url and not iframe.url.startswith('data:'):
content += f'\n\nIFRAME {iframe.url}:\n'
content += markdownify.markdownify(await iframe.content())
prompt = 'Your task is to extract the content of the page. You will be given a page and a goal and you should extract all relevant information around this goal from the page. If the goal is vague, summarize the page. Respond in json format. Extraction goal: {goal}, Page: {page}'
template = PromptTemplate(input_variables=['goal', 'page'], template=prompt)
try:
output = await page_extraction_llm.ainvoke(template.format(goal=goal, page=content))
msg = f'📄 Extracted from page\n: {output.content}\n'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
logger.debug(f'Error extracting content: {e}')
msg = f'📄 Extracted from page\n: {content}\n'
logger.info(msg)
return ActionResult(extracted_content=msg)
@self.registry.action(
'Scroll down the page by pixel amount - if no amount is specified, scroll down one page',
param_model=ScrollAction,
)
async def scroll_down(params: ScrollAction, browser: BrowserContext):
page = await browser.get_current_page()
if params.amount is not None:
await page.evaluate(f'window.scrollBy(0, {params.amount});')
else:
await page.evaluate('window.scrollBy(0, window.innerHeight);')
amount = f'{params.amount} pixels' if params.amount is not None else 'one page'
msg = f'🔍 Scrolled down the page by {amount}'
logger.info(msg)
return ActionResult(
extracted_content=msg,
include_in_memory=True,
)
# scroll up
@self.registry.action(
'Scroll up the page by pixel amount - if no amount is specified, scroll up one page',
param_model=ScrollAction,
)
async def scroll_up(params: ScrollAction, browser: BrowserContext):
page = await browser.get_current_page()
if params.amount is not None:
await page.evaluate(f'window.scrollBy(0, -{params.amount});')
else:
await page.evaluate('window.scrollBy(0, -window.innerHeight);')
amount = f'{params.amount} pixels' if params.amount is not None else 'one page'
msg = f'🔍 Scrolled up the page by {amount}'
logger.info(msg)
return ActionResult(
extracted_content=msg,
include_in_memory=True,
)
# send keys
@self.registry.action(
'Send strings of special keys like Escape,Backspace, Insert, PageDown, Delete, Enter, Shortcuts such as `Control+o`, `Control+Shift+T` are supported as well. This gets used in keyboard.press. ',
param_model=SendKeysAction,
)
async def send_keys(params: SendKeysAction, browser: BrowserContext):
page = await browser.get_current_page()
try:
await page.keyboard.press(params.keys)
except Exception as e:
if 'Unknown key' in str(e):
# loop over the keys and try to send each one
for key in params.keys:
try:
await page.keyboard.press(key)
except Exception as e:
logger.debug(f'Error sending key {key}: {str(e)}')
raise e
else:
raise e
msg = f'⌨️ Sent keys: {params.keys}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
@self.registry.action(
description='If you dont find something which you want to interact with, scroll to it',
)
async def scroll_to_text(text: str, browser: BrowserContext): # type: ignore
page = await browser.get_current_page()
try:
# Try different locator strategies
locators = [
page.get_by_text(text, exact=False),
page.locator(f'text={text}'),
page.locator(f"//*[contains(text(), '{text}')]"),
]
for locator in locators:
try:
if await locator.count() == 0:
continue
element = await locator.first
is_visible = await element.is_visible()
bbox = await element.bounding_box()
if is_visible and bbox is not None and bbox['width'] > 0 and bbox['height'] > 0:
await element.scroll_into_view_if_needed()
await asyncio.sleep(0.5) # Wait for scroll to complete
msg = f'🔍 Scrolled to text: {text}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
logger.debug(f'Locator attempt failed: {str(e)}')
continue
msg = f"Text '{text}' not found or not visible on page"
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
msg = f"Failed to scroll to text '{text}': {str(e)}"
logger.error(msg)
return ActionResult(error=msg, include_in_memory=True)
@self.registry.action(
description='Get all options from a native dropdown',
)
async def get_dropdown_options(index: int, browser: BrowserContext) -> ActionResult:
"""Get all options from a native dropdown"""
page = await browser.get_current_page()
selector_map = await browser.get_selector_map()
dom_element = selector_map[index]
try:
# Frame-aware approach since we know it works
all_options = []
frame_index = 0
for frame in page.frames:
try:
options = await frame.evaluate(
"""
(xpath) => {
const select = document.evaluate(xpath, document, null,
XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!select) return null;
return {
options: Array.from(select.options).map(opt => ({
text: opt.text, //do not trim, because we are doing exact match in select_dropdown_option
value: opt.value,
index: opt.index
})),
id: select.id,
name: select.name
};
}
""",
dom_element.xpath,
)
if options:
logger.debug(f'Found dropdown in frame {frame_index}')
logger.debug(f'Dropdown ID: {options["id"]}, Name: {options["name"]}')
formatted_options = []
for opt in options['options']:
# encoding ensures AI uses the exact string in select_dropdown_option
encoded_text = json.dumps(opt['text'])
formatted_options.append(f'{opt["index"]}: text={encoded_text}')
all_options.extend(formatted_options)
except Exception as frame_e:
logger.debug(f'Frame {frame_index} evaluation failed: {str(frame_e)}')
frame_index += 1
if all_options:
msg = '\n'.join(all_options)
msg += '\nUse the exact text string in select_dropdown_option'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
else:
msg = 'No options found in any frame for dropdown'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
logger.error(f'Failed to get dropdown options: {str(e)}')
msg = f'Error getting options: {str(e)}'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
@self.registry.action(
description='Select dropdown option for interactive element index by the text of the option you want to select',
)
async def select_dropdown_option(
index: int,
text: str,
browser: BrowserContext,
) -> ActionResult:
"""Select dropdown option by the text of the option you want to select"""
page = await browser.get_current_page()
selector_map = await browser.get_selector_map()
dom_element = selector_map[index]
# Validate that we're working with a select element
if dom_element.tag_name != 'select':
logger.error(f'Element is not a select! Tag: {dom_element.tag_name}, Attributes: {dom_element.attributes}')
msg = f'Cannot select option: Element with index {index} is a {dom_element.tag_name}, not a select'
return ActionResult(extracted_content=msg, include_in_memory=True)
logger.debug(f"Attempting to select '{text}' using xpath: {dom_element.xpath}")
logger.debug(f'Element attributes: {dom_element.attributes}')
logger.debug(f'Element tag: {dom_element.tag_name}')
xpath = '//' + dom_element.xpath
try:
frame_index = 0
for frame in page.frames:
try:
logger.debug(f'Trying frame {frame_index} URL: {frame.url}')
# First verify we can find the dropdown in this frame
find_dropdown_js = """
(xpath) => {
try {
const select = document.evaluate(xpath, document, null,
XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!select) return null;
if (select.tagName.toLowerCase() !== 'select') {
return {
error: `Found element but it's a ${select.tagName}, not a SELECT`,
found: false
};
}
return {
id: select.id,
name: select.name,
found: true,
tagName: select.tagName,
optionCount: select.options.length,
currentValue: select.value,
availableOptions: Array.from(select.options).map(o => o.text.trim())
};
} catch (e) {
return {error: e.toString(), found: false};
}
}
"""
dropdown_info = await frame.evaluate(find_dropdown_js, dom_element.xpath)
if dropdown_info:
if not dropdown_info.get('found'):
logger.error(f'Frame {frame_index} error: {dropdown_info.get("error")}')
continue
logger.debug(f'Found dropdown in frame {frame_index}: {dropdown_info}')
# "label" because we are selecting by text
# nth(0) to disable error thrown by strict mode
# timeout=1000 because we are already waiting for all network events, therefore ideally we don't need to wait a lot here (default 30s)
selected_option_values = (
await frame.locator('//' + dom_element.xpath).nth(0).select_option(label=text, timeout=1000)
)
msg = f'selected option {text} with value {selected_option_values}'
logger.info(msg + f' in frame {frame_index}')
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as frame_e:
logger.error(f'Frame {frame_index} attempt failed: {str(frame_e)}')
logger.error(f'Frame type: {type(frame)}')
logger.error(f'Frame URL: {frame.url}')
frame_index += 1
msg = f"Could not select option '{text}' in any frame"
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
msg = f'Selection failed: {str(e)}'
logger.error(msg)
return ActionResult(error=msg, include_in_memory=True)
@self.registry.action(
'Drag and drop elements or between coordinates on the page - useful for canvas drawing, sortable lists, sliders, file uploads, and UI rearrangement',
param_model=DragDropAction,
)
async def drag_drop(params: DragDropAction, browser: BrowserContext) -> ActionResult:
"""
Performs a precise drag and drop operation between elements or coordinates.
"""
async def get_drag_elements(
page: Page,
source_selector: str,
target_selector: str,
) -> tuple[ElementHandle | None, ElementHandle | None]:
"""Get source and target elements with appropriate error handling."""
source_element = None
target_element = None
try:
# page.locator() auto-detects CSS and XPath
source_locator = page.locator(source_selector)
target_locator = page.locator(target_selector)
# Check if elements exist
source_count = await source_locator.count()
target_count = await target_locator.count()
if source_count > 0:
source_element = await source_locator.first.element_handle()
logger.debug(f'Found source element with selector: {source_selector}')
else:
logger.warning(f'Source element not found: {source_selector}')
if target_count > 0:
target_element = await target_locator.first.element_handle()
logger.debug(f'Found target element with selector: {target_selector}')
else:
logger.warning(f'Target element not found: {target_selector}')
except Exception as e:
logger.error(f'Error finding elements: {str(e)}')
return source_element, target_element
async def get_element_coordinates(
source_element: ElementHandle,
target_element: ElementHandle,
source_position: Position | None,
target_position: Position | None,
) -> tuple[tuple[int, int] | None, tuple[int, int] | None]:
"""Get coordinates from elements with appropriate error handling."""
source_coords = None
target_coords = None
try:
# Get source coordinates
if source_position:
source_coords = (source_position.x, source_position.y)
else:
source_box = await source_element.bounding_box()
if source_box:
source_coords = (
int(source_box['x'] + source_box['width'] / 2),
int(source_box['y'] + source_box['height'] / 2),
)
# Get target coordinates
if target_position:
target_coords = (target_position.x, target_position.y)
else:
target_box = await target_element.bounding_box()
if target_box:
target_coords = (
int(target_box['x'] + target_box['width'] / 2),
int(target_box['y'] + target_box['height'] / 2),
)
except Exception as e:
logger.error(f'Error getting element coordinates: {str(e)}')
return source_coords, target_coords
async def execute_drag_operation(
page: Page,
source_x: int,
source_y: int,
target_x: int,
target_y: int,
steps: int,
delay_ms: int,
) -> tuple[bool, str]:
"""Execute the drag operation with comprehensive error handling."""
try:
# Try to move to source position
try:
await page.mouse.move(source_x, source_y)
logger.debug(f'Moved to source position ({source_x}, {source_y})')
except Exception as e:
logger.error(f'Failed to move to source position: {str(e)}')
return False, f'Failed to move to source position: {str(e)}'
# Press mouse button down
await page.mouse.down()
# Move to target position with intermediate steps
for i in range(1, steps + 1):
ratio = i / steps
intermediate_x = int(source_x + (target_x - source_x) * ratio)
intermediate_y = int(source_y + (target_y - source_y) * ratio)
await page.mouse.move(intermediate_x, intermediate_y)
if delay_ms > 0:
await asyncio.sleep(delay_ms / 1000)
# Move to final target position
await page.mouse.move(target_x, target_y)
# Move again to ensure dragover events are properly triggered
await page.mouse.move(target_x, target_y)
# Release mouse button
await page.mouse.up()
return True, 'Drag operation completed successfully'
except Exception as e:
return False, f'Error during drag operation: {str(e)}'
page = await browser.get_current_page()
try:
# Initialize variables
source_x: int | None = None
source_y: int | None = None
target_x: int | None = None
target_y: int | None = None
# Normalize parameters
steps = max(1, params.steps or 10)
delay_ms = max(0, params.delay_ms or 5)
# Case 1: Element selectors provided
if params.element_source and params.element_target:
logger.debug('Using element-based approach with selectors')
source_element, target_element = await get_drag_elements(
page,
params.element_source,
params.element_target,
)
if not source_element or not target_element:
error_msg = f'Failed to find {"source" if not source_element else "target"} element'
return ActionResult(error=error_msg, include_in_memory=True)
source_coords, target_coords = await get_element_coordinates(
source_element, target_element, params.element_source_offset, params.element_target_offset
)
if not source_coords or not target_coords:
error_msg = f'Failed to determine {"source" if not source_coords else "target"} coordinates'
return ActionResult(error=error_msg, include_in_memory=True)
source_x, source_y = source_coords
target_x, target_y = target_coords
# Case 2: Coordinates provided directly
elif all(
coord is not None
for coord in [params.coord_source_x, params.coord_source_y, params.coord_target_x, params.coord_target_y]
):
logger.debug('Using coordinate-based approach')
source_x = params.coord_source_x
source_y = params.coord_source_y
target_x = params.coord_target_x
target_y = params.coord_target_y
else:
error_msg = 'Must provide either source/target selectors or source/target coordinates'
return ActionResult(error=error_msg, include_in_memory=True)
# Validate coordinates
if any(coord is None for coord in [source_x, source_y, target_x, target_y]):
error_msg = 'Failed to determine source or target coordinates'
return ActionResult(error=error_msg, include_in_memory=True)
# Perform the drag operation
success, message = await execute_drag_operation(
page,
cast(int, source_x),
cast(int, source_y),
cast(int, target_x),
cast(int, target_y),
steps,
delay_ms,
)
if not success:
logger.error(f'Drag operation failed: {message}')
return ActionResult(error=message, include_in_memory=True)
# Create descriptive message
if params.element_source and params.element_target:
msg = f"🖱️ Dragged element '{params.element_source}' to '{params.element_target}'"
else:
msg = f'🖱️ Dragged from ({source_x}, {source_y}) to ({target_x}, {target_y})'
logger.info(msg)
return ActionResult(extracted_content=msg, include_in_memory=True)
except Exception as e:
error_msg = f'Failed to perform drag and drop: {str(e)}'
logger.error(error_msg)
return ActionResult(error=error_msg, include_in_memory=True)
@self.registry.action('Google Sheets: Get the contents of the entire sheet', domains=['sheets.google.com'])
async def get_sheet_contents(browser: BrowserContext):
page = await browser.get_current_page()
# select all cells
await page.keyboard.press('Enter')
await page.keyboard.press('Escape')
await page.keyboard.press('ControlOrMeta+A')
await page.keyboard.press('ControlOrMeta+C')
extracted_tsv = await page.evaluate('() => navigator.clipboard.readText()')
return ActionResult(extracted_content=extracted_tsv, include_in_memory=True)
@self.registry.action('Google Sheets: Select a specific cell or range of cells', domains=['sheets.google.com'])
async def select_cell_or_range(browser: BrowserContext, cell_or_range: str):
page = await browser.get_current_page()
await page.keyboard.press('Enter') # make sure we dont delete current cell contents if we were last editing
await page.keyboard.press('Escape') # to clear current focus (otherwise select range popup is additive)
await asyncio.sleep(0.1)
await page.keyboard.press('Home') # move cursor to the top left of the sheet first
await page.keyboard.press('ArrowUp')
await asyncio.sleep(0.1)
await page.keyboard.press('Control+G') # open the goto range popup
await asyncio.sleep(0.2)
await page.keyboard.type(cell_or_range, delay=0.05)
await asyncio.sleep(0.2)
await page.keyboard.press('Enter')
await asyncio.sleep(0.2)
await page.keyboard.press('Escape') # to make sure the popup still closes in the case where the jump failed
return ActionResult(extracted_content=f'Selected cell {cell_or_range}', include_in_memory=False)
@self.registry.action(
'Google Sheets: Get the contents of a specific cell or range of cells', domains=['sheets.google.com']
)
async def get_range_contents(browser: BrowserContext, cell_or_range: str):
page = await browser.get_current_page()
await select_cell_or_range(browser, cell_or_range)
await page.keyboard.press('ControlOrMeta+C')
await asyncio.sleep(0.1)
extracted_tsv = await page.evaluate('() => navigator.clipboard.readText()')
return ActionResult(extracted_content=extracted_tsv, include_in_memory=True)
@self.registry.action('Google Sheets: Clear the currently selected cells', domains=['sheets.google.com'])
async def clear_selected_range(browser: BrowserContext):
page = await browser.get_current_page()
await page.keyboard.press('Backspace')
return ActionResult(extracted_content='Cleared selected range', include_in_memory=False)
@self.registry.action('Google Sheets: Input text into the currently selected cell', domains=['sheets.google.com'])
async def input_selected_cell_text(browser: BrowserContext, text: str):
page = await browser.get_current_page()
await page.keyboard.type(text, delay=0.1)
await page.keyboard.press('Enter') # make sure to commit the input so it doesn't get overwritten by the next action
await page.keyboard.press('ArrowUp')
return ActionResult(extracted_content=f'Inputted text {text}', include_in_memory=False)
@self.registry.action('Google Sheets: Batch update a range of cells', domains=['sheets.google.com'])
async def update_range_contents(browser: BrowserContext, range: str, new_contents_tsv: str):
page = await browser.get_current_page()
await select_cell_or_range(browser, range)
# simulate paste event from clipboard with TSV content
await page.evaluate(f"""
const clipboardData = new DataTransfer();
clipboardData.setData('text/plain', `{new_contents_tsv}`);
document.activeElement.dispatchEvent(new ClipboardEvent('paste', {{clipboardData}}));
""")
return ActionResult(extracted_content=f'Updated cell {range} with {new_contents_tsv}', include_in_memory=False)
# Register ---------------------------------------------------------------
def action(self, description: str, **kwargs):
"""Decorator for registering custom actions
@param description: Describe the LLM what the function does (better description == better function calling)
"""
return self.registry.action(description, **kwargs)
# Act --------------------------------------------------------------------
@time_execution_sync('--act')
async def act(
self,
action: ActionModel,
browser_context: BrowserContext,
#
page_extraction_llm: BaseChatModel | None = None,
sensitive_data: dict[str, str] | None = None,
available_file_paths: list[str] | None = None,
#
context: Context | None = None,
) -> ActionResult:
"""Execute an action"""
try:
for action_name, params in action.model_dump(exclude_unset=True).items():
if params is not None:
# with Laminar.start_as_current_span(
# name=action_name,
# input={
# 'action': action_name,
# 'params': params,
# },
# span_type='TOOL',
# ):
result = await self.registry.execute_action(
action_name,
params,
browser=browser_context,
page_extraction_llm=page_extraction_llm,
sensitive_data=sensitive_data,
available_file_paths=available_file_paths,
context=context,
)
# Laminar.set_span_output(result)
if isinstance(result, str):
return ActionResult(extracted_content=result)
elif isinstance(result, ActionResult):
return result
elif result is None:
return ActionResult()
else:
raise ValueError(f'Invalid action result type: {type(result)} of {result}')
return ActionResult()
except Exception as e:
raise e

View file

@ -0,0 +1,91 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator
# Action Input Models
class SearchGoogleAction(BaseModel):
query: str
class GoToUrlAction(BaseModel):
url: str
class ClickElementAction(BaseModel):
index: int
xpath: str | None = None
class InputTextAction(BaseModel):
index: int
text: str
xpath: str | None = None
class DoneAction(BaseModel):
text: str
success: bool
class SwitchTabAction(BaseModel):
page_id: int
class OpenTabAction(BaseModel):
url: str
class CloseTabAction(BaseModel):
page_id: int
class ScrollAction(BaseModel):
amount: int | None = None # The number of pixels to scroll. If None, scroll down/up one page
class SendKeysAction(BaseModel):
keys: str
class ExtractPageContentAction(BaseModel):
value: str
class NoParamsAction(BaseModel):
"""
Accepts absolutely anything in the incoming data
and discards it, so the final parsed model is empty.
"""
model_config = ConfigDict(extra='allow')
@model_validator(mode='before')
def ignore_all_inputs(cls, values):
# No matter what the user sends, discard it and return empty.
return {}
class Position(BaseModel):
x: int
y: int
class DragDropAction(BaseModel):
# Element-based approach
element_source: str | None = Field(None, description='CSS selector or XPath of the element to drag from')
element_target: str | None = Field(None, description='CSS selector or XPath of the element to drop onto')
element_source_offset: Position | None = Field(
None, description='Precise position within the source element to start drag (in pixels from top-left corner)'
)
element_target_offset: Position | None = Field(
None, description='Precise position within the target element to drop (in pixels from top-left corner)'
)
# Coordinate-based approach (used if selectors not provided)
coord_source_x: int | None = Field(None, description='Absolute X coordinate on page to start drag from (in pixels)')
coord_source_y: int | None = Field(None, description='Absolute Y coordinate on page to start drag from (in pixels)')
coord_target_x: int | None = Field(None, description='Absolute X coordinate on page to drop at (in pixels)')
coord_target_y: int | None = Field(None, description='Absolute Y coordinate on page to drop at (in pixels)')
# Common options
steps: int | None = Field(10, description='Number of intermediate points for smoother movement (5-20 recommended)')
delay_ms: int | None = Field(5, description='Delay in milliseconds between steps (0 for fastest, 10-20 for more natural)')

View file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
import hashlib
from browser_use.dom.views import DOMElementNode
class ClickableElementProcessor:
@staticmethod
def get_clickable_elements_hashes(dom_element: DOMElementNode) -> set[str]:
"""Get all clickable elements in the DOM tree"""
clickable_elements = ClickableElementProcessor.get_clickable_elements(dom_element)
return {ClickableElementProcessor.hash_dom_element(element) for element in clickable_elements}
@staticmethod
def get_clickable_elements(dom_element: DOMElementNode) -> list[DOMElementNode]:
"""Get all clickable elements in the DOM tree"""
clickable_elements = list()
for child in dom_element.children:
if isinstance(child, DOMElementNode):
if child.highlight_index:
clickable_elements.append(child)
clickable_elements.extend(ClickableElementProcessor.get_clickable_elements(child))
return list(clickable_elements)
@staticmethod
def hash_dom_element(dom_element: DOMElementNode) -> str:
parent_branch_path = ClickableElementProcessor._get_parent_branch_path(dom_element)
branch_path_hash = ClickableElementProcessor._parent_branch_path_hash(parent_branch_path)
attributes_hash = ClickableElementProcessor._attributes_hash(dom_element.attributes)
xpath_hash = ClickableElementProcessor._xpath_hash(dom_element.xpath)
# text_hash = DomTreeProcessor._text_hash(dom_element)
return ClickableElementProcessor._hash_string(f'{branch_path_hash}-{attributes_hash}-{xpath_hash}')
@staticmethod
def _get_parent_branch_path(dom_element: DOMElementNode) -> list[str]:
parents: list[DOMElementNode] = []
current_element: DOMElementNode = dom_element
while current_element.parent is not None:
parents.append(current_element)
current_element = current_element.parent
parents.reverse()
return [parent.tag_name for parent in parents]
@staticmethod
def _parent_branch_path_hash(parent_branch_path: list[str]) -> str:
parent_branch_path_string = '/'.join(parent_branch_path)
return hashlib.sha256(parent_branch_path_string.encode()).hexdigest()
@staticmethod
def _attributes_hash(attributes: dict[str, str]) -> str:
attributes_string = ''.join(f'{key}={value}' for key, value in attributes.items())
return ClickableElementProcessor._hash_string(attributes_string)
@staticmethod
def _xpath_hash(xpath: str) -> str:
return ClickableElementProcessor._hash_string(xpath)
@staticmethod
def _text_hash(dom_element: DOMElementNode) -> str:
""" """
text_string = dom_element.get_all_text_till_next_clickable_element()
return ClickableElementProcessor._hash_string(text_string)
@staticmethod
def _hash_string(string: str) -> str:
return hashlib.sha256(string.encode()).hexdigest()

View file

@ -0,0 +1,106 @@
import hashlib
from browser_use.dom.history_tree_processor.view import DOMHistoryElement, HashedDomElement
from browser_use.dom.views import DOMElementNode
class HistoryTreeProcessor:
""" "
Operations on the DOM elements
@dev be careful - text nodes can change even if elements stay the same
"""
@staticmethod
def convert_dom_element_to_history_element(dom_element: DOMElementNode) -> DOMHistoryElement:
from browser_use.browser.context import BrowserContext
parent_branch_path = HistoryTreeProcessor._get_parent_branch_path(dom_element)
css_selector = BrowserContext._enhanced_css_selector_for_element(dom_element)
return DOMHistoryElement(
dom_element.tag_name,
dom_element.xpath,
dom_element.highlight_index,
parent_branch_path,
dom_element.attributes,
dom_element.shadow_root,
css_selector=css_selector,
page_coordinates=dom_element.page_coordinates,
viewport_coordinates=dom_element.viewport_coordinates,
viewport_info=dom_element.viewport_info,
)
@staticmethod
def find_history_element_in_tree(dom_history_element: DOMHistoryElement, tree: DOMElementNode) -> DOMElementNode | None:
hashed_dom_history_element = HistoryTreeProcessor._hash_dom_history_element(dom_history_element)
def process_node(node: DOMElementNode):
if node.highlight_index is not None:
hashed_node = HistoryTreeProcessor._hash_dom_element(node)
if hashed_node == hashed_dom_history_element:
return node
for child in node.children:
if isinstance(child, DOMElementNode):
result = process_node(child)
if result is not None:
return result
return None
return process_node(tree)
@staticmethod
def compare_history_element_and_dom_element(dom_history_element: DOMHistoryElement, dom_element: DOMElementNode) -> bool:
hashed_dom_history_element = HistoryTreeProcessor._hash_dom_history_element(dom_history_element)
hashed_dom_element = HistoryTreeProcessor._hash_dom_element(dom_element)
return hashed_dom_history_element == hashed_dom_element
@staticmethod
def _hash_dom_history_element(dom_history_element: DOMHistoryElement) -> HashedDomElement:
branch_path_hash = HistoryTreeProcessor._parent_branch_path_hash(dom_history_element.entire_parent_branch_path)
attributes_hash = HistoryTreeProcessor._attributes_hash(dom_history_element.attributes)
xpath_hash = HistoryTreeProcessor._xpath_hash(dom_history_element.xpath)
return HashedDomElement(branch_path_hash, attributes_hash, xpath_hash)
@staticmethod
def _hash_dom_element(dom_element: DOMElementNode) -> HashedDomElement:
parent_branch_path = HistoryTreeProcessor._get_parent_branch_path(dom_element)
branch_path_hash = HistoryTreeProcessor._parent_branch_path_hash(parent_branch_path)
attributes_hash = HistoryTreeProcessor._attributes_hash(dom_element.attributes)
xpath_hash = HistoryTreeProcessor._xpath_hash(dom_element.xpath)
# text_hash = DomTreeProcessor._text_hash(dom_element)
return HashedDomElement(branch_path_hash, attributes_hash, xpath_hash)
@staticmethod
def _get_parent_branch_path(dom_element: DOMElementNode) -> list[str]:
parents: list[DOMElementNode] = []
current_element: DOMElementNode = dom_element
while current_element.parent is not None:
parents.append(current_element)
current_element = current_element.parent
parents.reverse()
return [parent.tag_name for parent in parents]
@staticmethod
def _parent_branch_path_hash(parent_branch_path: list[str]) -> str:
parent_branch_path_string = '/'.join(parent_branch_path)
return hashlib.sha256(parent_branch_path_string.encode()).hexdigest()
@staticmethod
def _attributes_hash(attributes: dict[str, str]) -> str:
attributes_string = ''.join(f'{key}={value}' for key, value in attributes.items())
return hashlib.sha256(attributes_string.encode()).hexdigest()
@staticmethod
def _xpath_hash(xpath: str) -> str:
return hashlib.sha256(xpath.encode()).hexdigest()
@staticmethod
def _text_hash(dom_element: DOMElementNode) -> str:
""" """
text_string = dom_element.get_all_text_till_next_clickable_element()
return hashlib.sha256(text_string.encode()).hexdigest()

View file

@ -0,0 +1,69 @@
from dataclasses import dataclass
from pydantic import BaseModel
@dataclass
class HashedDomElement:
"""
Hash of the dom element to be used as a unique identifier
"""
branch_path_hash: str
attributes_hash: str
xpath_hash: str
# text_hash: str
class Coordinates(BaseModel):
x: int
y: int
class CoordinateSet(BaseModel):
top_left: Coordinates
top_right: Coordinates
bottom_left: Coordinates
bottom_right: Coordinates
center: Coordinates
width: int
height: int
class ViewportInfo(BaseModel):
scroll_x: int
scroll_y: int
width: int
height: int
@dataclass
class DOMHistoryElement:
tag_name: str
xpath: str
highlight_index: int | None
entire_parent_branch_path: list[str]
attributes: dict[str, str]
shadow_root: bool = False
css_selector: str | None = None
page_coordinates: CoordinateSet | None = None
viewport_coordinates: CoordinateSet | None = None
viewport_info: ViewportInfo | None = None
def to_dict(self) -> dict:
page_coordinates = self.page_coordinates.model_dump() if self.page_coordinates else None
viewport_coordinates = self.viewport_coordinates.model_dump() if self.viewport_coordinates else None
viewport_info = self.viewport_info.model_dump() if self.viewport_info else None
return {
'tag_name': self.tag_name,
'xpath': self.xpath,
'highlight_index': self.highlight_index,
'entire_parent_branch_path': self.entire_parent_branch_path,
'attributes': self.attributes,
'shadow_root': self.shadow_root,
'css_selector': self.css_selector,
'page_coordinates': page_coordinates,
'viewport_coordinates': viewport_coordinates,
'viewport_info': viewport_info,
}

View file

@ -0,0 +1,203 @@
import json
import logging
from dataclasses import dataclass
from importlib import resources
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from playwright.async_api import Page
from browser_use.dom.views import (
DOMBaseNode,
DOMElementNode,
DOMState,
DOMTextNode,
SelectorMap,
)
from browser_use.utils import time_execution_async
logger = logging.getLogger(__name__)
@dataclass
class ViewportInfo:
width: int
height: int
class DomService:
def __init__(self, page: 'Page'):
self.page = page
self.xpath_cache = {}
self.js_code = resources.files('browser_use.dom').joinpath('buildDomTree.js').read_text()
# region - Clickable elements
@time_execution_async('--get_clickable_elements')
async def get_clickable_elements(
self,
highlight_elements: bool = True,
focus_element: int = -1,
viewport_expansion: int = 0,
) -> DOMState:
element_tree, selector_map = await self._build_dom_tree(highlight_elements, focus_element, viewport_expansion)
return DOMState(element_tree=element_tree, selector_map=selector_map)
@time_execution_async('--get_cross_origin_iframes')
async def get_cross_origin_iframes(self) -> list[str]:
# invisible cross-origin iframes are used for ads and tracking, dont open those
hidden_frame_urls = await self.page.locator('iframe').filter(visible=False).evaluate_all('e => e.map(e => e.src)')
is_ad_url = lambda url: any(
domain in urlparse(url).netloc for domain in ('doubleclick.net', 'adroll.com', 'googletagmanager.com')
)
return [
frame.url
for frame in self.page.frames
if urlparse(frame.url).netloc # exclude data:urls and about:blank
and urlparse(frame.url).netloc != urlparse(self.page.url).netloc # exclude same-origin iframes
and frame.url not in hidden_frame_urls # exclude hidden frames
and not is_ad_url(frame.url) # exclude most common ad network tracker frame URLs
]
@time_execution_async('--build_dom_tree')
async def _build_dom_tree(
self,
highlight_elements: bool,
focus_element: int,
viewport_expansion: int,
) -> tuple[DOMElementNode, SelectorMap]:
if await self.page.evaluate('1+1') != 2:
raise ValueError('The page cannot evaluate javascript code properly')
if self.page.url == 'about:blank':
# short-circuit if the page is a new empty tab for speed, no need to inject buildDomTree.js
return (
DOMElementNode(
tag_name='body',
xpath='',
attributes={},
children=[],
is_visible=False,
parent=None,
),
{},
)
# NOTE: We execute JS code in the browser to extract important DOM information.
# The returned hash map contains information about the DOM tree and the
# relationship between the DOM elements.
debug_mode = logger.getEffectiveLevel() == logging.DEBUG
args = {
'doHighlightElements': highlight_elements,
'focusHighlightIndex': focus_element,
'viewportExpansion': viewport_expansion,
'debugMode': debug_mode,
}
try:
eval_page: dict = await self.page.evaluate(self.js_code, args)
except Exception as e:
logger.error('Error evaluating JavaScript: %s', e)
raise
# Only log performance metrics in debug mode
if debug_mode and 'perfMetrics' in eval_page:
logger.debug(
'DOM Tree Building Performance Metrics for: %s\n%s',
self.page.url,
json.dumps(eval_page['perfMetrics'], indent=2),
)
return await self._construct_dom_tree(eval_page)
@time_execution_async('--construct_dom_tree')
async def _construct_dom_tree(
self,
eval_page: dict,
) -> tuple[DOMElementNode, SelectorMap]:
js_node_map = eval_page['map']
js_root_id = eval_page['rootId']
selector_map = {}
node_map = {}
for id, node_data in js_node_map.items():
node, children_ids = self._parse_node(node_data)
if node is None:
continue
node_map[id] = node
if isinstance(node, DOMElementNode) and node.highlight_index is not None:
selector_map[node.highlight_index] = node
# NOTE: We know that we are building the tree bottom up
# and all children are already processed.
if isinstance(node, DOMElementNode):
for child_id in children_ids:
if child_id not in node_map:
continue
child_node = node_map[child_id]
child_node.parent = node
node.children.append(child_node)
html_to_dict = node_map[str(js_root_id)]
del node_map
del js_node_map
del js_root_id
if html_to_dict is None or not isinstance(html_to_dict, DOMElementNode):
raise ValueError('Failed to parse HTML to dictionary')
return html_to_dict, selector_map
def _parse_node(
self,
node_data: dict,
) -> tuple[DOMBaseNode | None, list[int]]:
if not node_data:
return None, []
# Process text nodes immediately
if node_data.get('type') == 'TEXT_NODE':
text_node = DOMTextNode(
text=node_data['text'],
is_visible=node_data['isVisible'],
parent=None,
)
return text_node, []
# Process coordinates if they exist for element nodes
viewport_info = None
if 'viewport' in node_data:
viewport_info = ViewportInfo(
width=node_data['viewport']['width'],
height=node_data['viewport']['height'],
)
element_node = DOMElementNode(
tag_name=node_data['tagName'],
xpath=node_data['xpath'],
attributes=node_data.get('attributes', {}),
children=[],
is_visible=node_data.get('isVisible', False),
is_interactive=node_data.get('isInteractive', False),
is_top_element=node_data.get('isTopElement', False),
is_in_viewport=node_data.get('isInViewport', False),
highlight_index=node_data.get('highlightIndex'),
shadow_root=node_data.get('shadowRoot', False),
parent=None,
viewport_info=viewport_info,
)
children_ids = node_data.get('children', [])
return element_node, children_ids

View file

@ -0,0 +1,123 @@
import asyncio
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from browser_use.browser.browser import Browser, BrowserConfig
from browser_use.browser.context import BrowserContext
async def analyze_page_structure(url: str):
"""Analyze and print the structure of a webpage with enhanced debugging"""
browser = Browser(
config=BrowserConfig(
headless=False, # Set to True if you don't need to see the browser
)
)
context = BrowserContext(browser=browser)
try:
async with context as ctx:
# Navigate to the URL
page = await ctx.get_current_page()
await page.goto(url)
await page.wait_for_load_state('networkidle')
# Get viewport dimensions
viewport_info = await page.evaluate("""() => {
return {
viewport: {
width: window.innerWidth,
height: window.innerHeight,
scrollX: window.scrollX,
scrollY: window.scrollY
}
}
}""")
print('\nViewport Information:')
print(f'Width: {viewport_info["viewport"]["width"]}')
print(f'Height: {viewport_info["viewport"]["height"]}')
print(f'ScrollX: {viewport_info["viewport"]["scrollX"]}')
print(f'ScrollY: {viewport_info["viewport"]["scrollY"]}')
# Enhanced debug information for cookie consent and fixed position elements
debug_info = await page.evaluate("""() => {
function getElementInfo(element) {
const rect = element.getBoundingClientRect();
const style = window.getComputedStyle(element);
return {
tag: element.tagName.toLowerCase(),
id: element.id,
className: element.className,
position: style.position,
rect: {
top: rect.top,
right: rect.right,
bottom: rect.bottom,
left: rect.left,
width: rect.width,
height: rect.height
},
isFixed: style.position === 'fixed',
isSticky: style.position === 'sticky',
zIndex: style.zIndex,
visibility: style.visibility,
display: style.display,
opacity: style.opacity
};
}
// Find cookie-related elements
const cookieElements = Array.from(document.querySelectorAll('[id*="cookie"], [id*="consent"], [class*="cookie"], [class*="consent"]'));
const fixedElements = Array.from(document.querySelectorAll('*')).filter(el => {
const style = window.getComputedStyle(el);
return style.position === 'fixed' || style.position === 'sticky';
});
return {
cookieElements: cookieElements.map(getElementInfo),
fixedElements: fixedElements.map(getElementInfo)
};
}""")
print('\nCookie-related Elements:')
for elem in debug_info['cookieElements']:
print(f'\nElement: {elem["tag"]}#{elem["id"]} .{elem["className"]}')
print(f'Position: {elem["position"]}')
print(f'Rect: {elem["rect"]}')
print(f'Z-Index: {elem["zIndex"]}')
print(f'Visibility: {elem["visibility"]}')
print(f'Display: {elem["display"]}')
print(f'Opacity: {elem["opacity"]}')
print('\nFixed/Sticky Position Elements:')
for elem in debug_info['fixedElements']:
print(f'\nElement: {elem["tag"]}#{elem["id"]} .{elem["className"]}')
print(f'Position: {elem["position"]}')
print(f'Rect: {elem["rect"]}')
print(f'Z-Index: {elem["zIndex"]}')
print(f'\nPage Structure for {url}:\n')
structure = await ctx.get_page_structure()
print(structure)
input('Press Enter to close the browser...')
finally:
await browser.close()
if __name__ == '__main__':
# You can modify this URL to analyze different pages
urls = [
'https://www.mlb.com/yankees/stats/',
'https://immobilienscout24.de',
'https://www.zeiss.com/career/en/job-search.html?page=1',
'https://www.zeiss.com/career/en/job-search.html?page=1',
'https://reddit.com',
]
for url in urls:
asyncio.run(analyze_page_structure(url))

View file

@ -0,0 +1,182 @@
import asyncio
import os
import anyio
from langchain_openai import ChatOpenAI
from browser_use.agent.prompts import AgentMessagePrompt
from browser_use.browser.browser import Browser, BrowserConfig
from browser_use.browser.context import BrowserContext, BrowserContextConfig
from browser_use.dom.service import DomService
def count_string_tokens(string: str, model: str) -> tuple[int, float]:
"""Count the number of tokens in a string using a specified model."""
def get_price_per_token(model: str) -> float:
"""Get the price per token for a specified model.
@todo: move to utils, use a package or sth
"""
prices = {
'gpt-4o': 2.5 / 1e6,
'gpt-4o-mini': 0.15 / 1e6,
}
return prices[model]
llm = ChatOpenAI(model=model)
token_count = llm.get_num_tokens(string)
price = token_count * get_price_per_token(model)
return token_count, price
TIMEOUT = 60
DEFAULT_INCLUDE_ATTRIBUTES = [
'id',
'title',
'type',
'name',
'role',
'aria-label',
'placeholder',
'value',
'alt',
'aria-expanded',
'data-date-format',
]
async def test_focus_vs_all_elements():
config = BrowserContextConfig(
# cookies_file='cookies3.json',
disable_security=True,
wait_for_network_idle_page_load_time=1,
)
browser = Browser(
config=BrowserConfig(
# browser_binary_path='/Applications/Google Chrome.app/Contents/MacOS/Google Chrome',
)
)
context = BrowserContext(browser=browser, config=config)
websites = [
'https://demos.telerik.com/kendo-react-ui/treeview/overview/basic/func?theme=default-ocean-blue-a11y',
'https://www.ycombinator.com/companies',
'https://kayak.com/flights',
# 'https://en.wikipedia.org/wiki/Humanist_Party_of_Ontario',
# 'https://www.google.com/travel/flights?tfs=CBwQARoJagcIARIDTEpVGglyBwgBEgNMSlVAAUgBcAGCAQsI____________AZgBAQ&tfu=KgIIAw&hl=en-US&gl=US',
# # 'https://www.concur.com/?&cookie_preferences=cpra',
# 'https://immobilienscout24.de',
'https://docs.google.com/spreadsheets/d/1INaIcfpYXlMRWO__de61SHFCaqt1lfHlcvtXZPItlpI/edit',
'https://www.zeiss.com/career/en/job-search.html?page=1',
'https://www.mlb.com/yankees/stats/',
'https://www.amazon.com/s?k=laptop&s=review-rank&crid=1RZCEJ289EUSI&qid=1740202453&sprefix=laptop%2Caps%2C166&ref=sr_st_review-rank&ds=v1%3A4EnYKXVQA7DIE41qCvRZoNB4qN92Jlztd3BPsTFXmxU',
'https://reddit.com',
'https://codepen.io/geheimschriftstift/pen/mPLvQz',
'https://www.google.com/search?q=google+hi&oq=google+hi&gs_lcrp=EgZjaHJvbWUyBggAEEUYOTIGCAEQRRhA0gEIMjI2NmowajSoAgCwAgE&sourceid=chrome&ie=UTF-8',
'https://google.com',
'https://amazon.com',
'https://github.com',
]
async with context as context:
page = await context.get_current_page()
dom_service = DomService(page)
for website in websites:
# sleep 2
await page.goto(website)
asyncio.sleep(1)
last_clicked_index = None # Track the index for text input
while True:
try:
print(f'\n{"=" * 50}\nTesting {website}\n{"=" * 50}')
# Get/refresh the state (includes removing old highlights)
print('\nGetting page state...')
all_elements_state = await context.get_state(True)
selector_map = all_elements_state.selector_map
total_elements = len(selector_map.keys())
print(f'Total number of elements: {total_elements}')
# print(all_elements_state.element_tree.clickable_elements_to_string())
prompt = AgentMessagePrompt(
state=all_elements_state,
result=None,
include_attributes=DEFAULT_INCLUDE_ATTRIBUTES,
step_info=None,
)
# print(prompt.get_user_message(use_vision=False).content)
# Write the user message to a file for analysis
user_message = prompt.get_user_message(use_vision=False).content
os.makedirs('./tmp', exist_ok=True)
async with await anyio.open_file('./tmp/user_message.txt', 'w', encoding='utf-8') as f:
await f.write(user_message)
token_count, price = count_string_tokens(user_message, model='gpt-4o')
print(f'Prompt token count: {token_count}, price: {round(price, 4)} USD')
print('User message written to ./tmp/user_message.txt')
# also save all_elements_state.element_tree.clickable_elements_to_string() to a file
# with open('./tmp/clickable_elements.json', 'w', encoding='utf-8') as f:
# f.write(json.dumps(all_elements_state.element_tree.__json__(), indent=2))
# print('Clickable elements written to ./tmp/clickable_elements.json')
answer = input("Enter element index to click, 'index,text' to input, or 'q' to quit: ")
if answer.lower() == 'q':
break
try:
if ',' in answer:
# Input text format: index,text
parts = answer.split(',', 1)
if len(parts) == 2:
try:
target_index = int(parts[0].strip())
text_to_input = parts[1]
if target_index in selector_map:
element_node = selector_map[target_index]
print(
f"Inputting text '{text_to_input}' into element {target_index}: {element_node.tag_name}"
)
await context._input_text_element_node(element_node, text_to_input)
print('Input successful.')
else:
print(f'Invalid index: {target_index}')
except ValueError:
print(f'Invalid index format: {parts[0]}')
else:
print("Invalid input format. Use 'index,text'.")
else:
# Click element format: index
try:
clicked_index = int(answer)
if clicked_index in selector_map:
element_node = selector_map[clicked_index]
print(f'Clicking element {clicked_index}: {element_node.tag_name}')
await context._click_element_node(element_node)
print('Click successful.')
else:
print(f'Invalid index: {clicked_index}')
except ValueError:
print(f"Invalid input: '{answer}'. Enter an index, 'index,text', or 'q'.")
except Exception as action_e:
print(f'Action failed: {action_e}')
# No explicit highlight removal here, get_state handles it at the start of the loop
except Exception as e:
print(f'Error in loop: {e}')
# Optionally add a small delay before retrying
await asyncio.sleep(1)
if __name__ == '__main__':
asyncio.run(test_focus_vs_all_elements())
# asyncio.run(test_process_html_file()) # Commented out the other test

View file

@ -0,0 +1,43 @@
import asyncio
import json
import os
import time
import anyio
from browser_use.browser.browser import Browser, BrowserConfig
async def test_process_dom():
browser = Browser(config=BrowserConfig(headless=False))
async with await browser.new_context() as context:
page = await context.get_current_page()
await page.goto('https://kayak.com/flights')
# await page.goto('https://google.com/flights')
# await page.goto('https://immobilienscout24.de')
# await page.goto('https://seleniumbase.io/w3schools/iframes')
await asyncio.sleep(3)
async with await anyio.open_file('browser_use/dom/buildDomTree.js', 'r') as f:
js_code = await f.read()
start = time.time()
dom_tree = await page.evaluate(js_code)
end = time.time()
# print(dom_tree)
print(f'Time: {end - start:.2f}s')
os.makedirs('./tmp', exist_ok=True)
async with await anyio.open_file('./tmp/dom.json', 'w') as f:
await f.write(json.dumps(dom_tree, indent=1))
# both of these work for immobilienscout24.de
# await page.click('.sc-dcJsrY.ezjNCe')
# await page.click(
# 'div > div:nth-of-type(2) > div > div:nth-of-type(2) > div > div:nth-of-type(2) > div > div > div > button:nth-of-type(2)'
# )
input('Press Enter to continue...')

View file

@ -0,0 +1,265 @@
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Optional
from browser_use.dom.history_tree_processor.view import CoordinateSet, HashedDomElement, ViewportInfo
from browser_use.utils import time_execution_sync
# Avoid circular import issues
if TYPE_CHECKING:
from .views import DOMElementNode
@dataclass(frozen=False)
class DOMBaseNode:
is_visible: bool
# Use None as default and set parent later to avoid circular reference issues
parent: Optional['DOMElementNode']
def __json__(self) -> dict:
raise NotImplementedError('DOMBaseNode is an abstract class')
@dataclass(frozen=False)
class DOMTextNode(DOMBaseNode):
text: str
type: str = 'TEXT_NODE'
def has_parent_with_highlight_index(self) -> bool:
current = self.parent
while current is not None:
# stop if the element has a highlight index (will be handled separately)
if current.highlight_index is not None:
return True
current = current.parent
return False
def is_parent_in_viewport(self) -> bool:
if self.parent is None:
return False
return self.parent.is_in_viewport
def is_parent_top_element(self) -> bool:
if self.parent is None:
return False
return self.parent.is_top_element
def __json__(self) -> dict:
return {
'text': self.text,
'type': self.type,
}
@dataclass(frozen=False)
class DOMElementNode(DOMBaseNode):
"""
xpath: the xpath of the element from the last root node (shadow root or iframe OR document if no shadow root or iframe).
To properly reference the element we need to recursively switch the root node until we find the element (work you way up the tree with `.parent`)
"""
tag_name: str
xpath: str
attributes: dict[str, str]
children: list[DOMBaseNode]
is_interactive: bool = False
is_top_element: bool = False
is_in_viewport: bool = False
shadow_root: bool = False
highlight_index: int | None = None
viewport_coordinates: CoordinateSet | None = None
page_coordinates: CoordinateSet | None = None
viewport_info: ViewportInfo | None = None
"""
### State injected by the browser context.
The idea is that the clickable elements are sometimes persistent from the previous page -> tells the model which objects are new/_how_ the state has changed
"""
is_new: bool | None = None
def __json__(self) -> dict:
return {
'tag_name': self.tag_name,
'xpath': self.xpath,
'attributes': self.attributes,
'is_visible': self.is_visible,
'is_interactive': self.is_interactive,
'is_top_element': self.is_top_element,
'is_in_viewport': self.is_in_viewport,
'shadow_root': self.shadow_root,
'highlight_index': self.highlight_index,
'viewport_coordinates': self.viewport_coordinates,
'page_coordinates': self.page_coordinates,
'children': [child.__json__() for child in self.children],
}
def __repr__(self) -> str:
tag_str = f'<{self.tag_name}'
# Add attributes
for key, value in self.attributes.items():
tag_str += f' {key}="{value}"'
tag_str += '>'
# Add extra info
extras = []
if self.is_interactive:
extras.append('interactive')
if self.is_top_element:
extras.append('top')
if self.shadow_root:
extras.append('shadow-root')
if self.highlight_index is not None:
extras.append(f'highlight:{self.highlight_index}')
if self.is_in_viewport:
extras.append('in-viewport')
if extras:
tag_str += f' [{", ".join(extras)}]'
return tag_str
@cached_property
def hash(self) -> HashedDomElement:
from browser_use.dom.history_tree_processor.service import (
HistoryTreeProcessor,
)
return HistoryTreeProcessor._hash_dom_element(self)
def get_all_text_till_next_clickable_element(self, max_depth: int = -1) -> str:
text_parts = []
def collect_text(node: DOMBaseNode, current_depth: int) -> None:
if max_depth != -1 and current_depth > max_depth:
return
# Skip this branch if we hit a highlighted element (except for the current node)
if isinstance(node, DOMElementNode) and node != self and node.highlight_index is not None:
return
if isinstance(node, DOMTextNode):
text_parts.append(node.text)
elif isinstance(node, DOMElementNode):
for child in node.children:
collect_text(child, current_depth + 1)
collect_text(self, 0)
return '\n'.join(text_parts).strip()
@time_execution_sync('--clickable_elements_to_string')
def clickable_elements_to_string(self, include_attributes: list[str] | None = None) -> str:
"""Convert the processed DOM content to HTML."""
formatted_text = []
def process_node(node: DOMBaseNode, depth: int) -> None:
next_depth = int(depth)
depth_str = depth * '\t'
if isinstance(node, DOMElementNode):
# Add element with highlight_index
if node.highlight_index is not None:
next_depth += 1
text = node.get_all_text_till_next_clickable_element()
attributes_html_str = ''
if include_attributes:
attributes_to_include = {
key: str(value) for key, value in node.attributes.items() if key in include_attributes
}
# Easy LLM optimizations
# if tag == role attribute, don't include it
if node.tag_name == attributes_to_include.get('role'):
del attributes_to_include['role']
# if aria-label == text of the node, don't include it
if (
attributes_to_include.get('aria-label')
and attributes_to_include.get('aria-label', '').strip() == text.strip()
):
del attributes_to_include['aria-label']
# if placeholder == text of the node, don't include it
if (
attributes_to_include.get('placeholder')
and attributes_to_include.get('placeholder', '').strip() == text.strip()
):
del attributes_to_include['placeholder']
if attributes_to_include:
# Format as key1='value1' key2='value2'
attributes_html_str = ' '.join(f"{key}='{value}'" for key, value in attributes_to_include.items())
# Build the line
if node.is_new:
highlight_indicator = f'*[{node.highlight_index}]*'
else:
highlight_indicator = f'[{node.highlight_index}]'
line = f'{depth_str}{highlight_indicator}<{node.tag_name}'
if attributes_html_str:
line += f' {attributes_html_str}'
if text:
# Add space before >text only if there were NO attributes added before
if not attributes_html_str:
line += ' '
line += f'>{text}'
# Add space before /> only if neither attributes NOR text were added
elif not attributes_html_str:
line += ' '
line += ' />' # 1 token
formatted_text.append(line)
# Process children regardless
for child in node.children:
process_node(child, next_depth)
elif isinstance(node, DOMTextNode):
# Add text only if it doesn't have a highlighted parent
if (
not node.has_parent_with_highlight_index()
and node.parent
and node.parent.is_visible
and node.parent.is_top_element
): # and node.is_parent_top_element()
formatted_text.append(f'{depth_str}{node.text}')
process_node(self, 0)
return '\n'.join(formatted_text)
def get_file_upload_element(self, check_siblings: bool = True) -> Optional['DOMElementNode']:
# Check if current element is a file input
if self.tag_name == 'input' and self.attributes.get('type') == 'file':
return self
# Check children
for child in self.children:
if isinstance(child, DOMElementNode):
result = child.get_file_upload_element(check_siblings=False)
if result:
return result
# Check siblings only for the initial call
if check_siblings and self.parent:
for sibling in self.parent.children:
if sibling is not self and isinstance(sibling, DOMElementNode):
result = sibling.get_file_upload_element(check_siblings=False)
if result:
return result
return None
SelectorMap = dict[int, DOMElementNode]
@dataclass
class DOMState:
element_tree: DOMElementNode
selector_map: SelectorMap

View file

@ -0,0 +1,5 @@
class LLMException(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(f'Error {status_code}: {message}')

View file

@ -0,0 +1,132 @@
import logging
import os
import sys
from dotenv import load_dotenv
load_dotenv()
def addLoggingLevel(levelName, levelNum, methodName=None):
"""
Comprehensively adds a new logging level to the `logging` module and the
currently configured logging class.
`levelName` becomes an attribute of the `logging` module with the value
`levelNum`. `methodName` becomes a convenience method for both `logging`
itself and the class returned by `logging.getLoggerClass()` (usually just
`logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
used.
To avoid accidental clobberings of existing attributes, this method will
raise an `AttributeError` if the level name is already an attribute of the
`logging` module or if the method name is already present
Example
-------
>>> addLoggingLevel('TRACE', logging.DEBUG - 5)
>>> logging.getLogger(__name__).setLevel('TRACE')
>>> logging.getLogger(__name__).trace('that worked')
>>> logging.trace('so did this')
>>> logging.TRACE
5
"""
if not methodName:
methodName = levelName.lower()
if hasattr(logging, levelName):
raise AttributeError(f'{levelName} already defined in logging module')
if hasattr(logging, methodName):
raise AttributeError(f'{methodName} already defined in logging module')
if hasattr(logging.getLoggerClass(), methodName):
raise AttributeError(f'{methodName} already defined in logger class')
# This method was inspired by the answers to Stack Overflow post
# http://stackoverflow.com/q/2183233/2988730, especially
# http://stackoverflow.com/a/13638084/2988730
def logForLevel(self, message, *args, **kwargs):
if self.isEnabledFor(levelNum):
self._log(levelNum, message, args, **kwargs)
def logToRoot(message, *args, **kwargs):
logging.log(levelNum, message, *args, **kwargs)
logging.addLevelName(levelNum, levelName)
setattr(logging, levelName, levelNum)
setattr(logging.getLoggerClass(), methodName, logForLevel)
setattr(logging, methodName, logToRoot)
def setup_logging():
# Try to add RESULT level, but ignore if it already exists
try:
addLoggingLevel('RESULT', 35) # This allows ERROR, FATAL and CRITICAL
except AttributeError:
pass # Level already exists, which is fine
log_type = os.getenv('BROWSER_USE_LOGGING_LEVEL', 'info').lower()
# Check if handlers are already set up
if logging.getLogger().hasHandlers():
return
# Clear existing handlers
root = logging.getLogger()
root.handlers = []
class BrowserUseFormatter(logging.Formatter):
def format(self, record):
if isinstance(record.name, str) and record.name.startswith('browser_use.'):
record.name = record.name.split('.')[-2]
return super().format(record)
# Setup single handler for all loggers
console = logging.StreamHandler(sys.stdout)
# adittional setLevel here to filter logs
if log_type == 'result':
console.setLevel('RESULT')
console.setFormatter(BrowserUseFormatter('%(message)s'))
else:
console.setFormatter(BrowserUseFormatter('%(levelname)-8s [%(name)s] %(message)s'))
# Configure root logger only
root.addHandler(console)
# switch cases for log_type
if log_type == 'result':
root.setLevel('RESULT') # string usage to avoid syntax error
elif log_type == 'debug':
root.setLevel(logging.DEBUG)
else:
root.setLevel(logging.INFO)
# Configure browser_use logger
browser_use_logger = logging.getLogger('browser_use')
browser_use_logger.propagate = False # Don't propagate to root logger
browser_use_logger.addHandler(console)
browser_use_logger.setLevel(root.level) # Set same level as root logger
logger = logging.getLogger('browser_use')
# logger.info('BrowserUse logging setup complete with level %s', log_type)
# Silence third-party loggers
for logger in [
'WDM',
'httpx',
'selenium',
'playwright',
'urllib3',
'asyncio',
'langchain',
'openai',
'httpcore',
'charset_normalizer',
'anthropic._base_client',
'PIL.PngImagePlugin',
'trafilatura.htmlprocessing',
'trafilatura',
]:
third_party = logging.getLogger(logger)
third_party.setLevel(logging.ERROR)
third_party.propagate = False

View file

@ -0,0 +1,124 @@
import logging
import os
import uuid
from pathlib import Path
from dotenv import load_dotenv
from posthog import Posthog
from browser_use.telemetry.views import BaseTelemetryEvent
from browser_use.utils import singleton
load_dotenv()
logger = logging.getLogger(__name__)
POSTHOG_EVENT_SETTINGS = {
'process_person_profile': True,
}
def xdg_cache_home() -> Path:
default = Path.home() / '.cache'
env_var = os.getenv('XDG_CACHE_HOME')
if env_var and (path := Path(env_var)).is_absolute():
return path
return default
@singleton
class ProductTelemetry:
"""
Service for capturing anonymized telemetry data.
If the environment variable `ANONYMIZED_TELEMETRY=False`, anonymized telemetry will be disabled.
"""
USER_ID_PATH = str(xdg_cache_home() / 'browser_use' / 'telemetry_user_id')
PROJECT_API_KEY = 'phc_F8JMNjW1i2KbGUTaW1unnDdLSPCoyc52SGRU0JecaUh'
HOST = 'https://eu.i.posthog.com'
UNKNOWN_USER_ID = 'UNKNOWN'
_curr_user_id = None
def __init__(self) -> None:
telemetry_disabled = os.getenv('ANONYMIZED_TELEMETRY', 'true').lower() == 'false'
self.debug_logging = os.getenv('BROWSER_USE_LOGGING_LEVEL', 'info').lower() == 'debug'
if telemetry_disabled:
self._posthog_client = None
else:
logger.info(
'Anonymized telemetry enabled. See https://docs.browser-use.com/development/telemetry for more information.'
)
self._posthog_client = Posthog(
project_api_key=self.PROJECT_API_KEY,
host=self.HOST,
disable_geoip=False,
enable_exception_autocapture=True,
)
# Silence posthog's logging
if not self.debug_logging:
posthog_logger = logging.getLogger('posthog')
posthog_logger.disabled = True
if self._posthog_client is None:
logger.debug('Telemetry disabled')
def capture(self, event: BaseTelemetryEvent) -> None:
if self._posthog_client is None:
return
if self.debug_logging:
logger.debug(f'Telemetry event: {event.name} {event.properties}')
self._direct_capture(event)
def _direct_capture(self, event: BaseTelemetryEvent) -> None:
"""
Should not be thread blocking because posthog magically handles it
"""
if self._posthog_client is None:
return
try:
self._posthog_client.capture(
self.user_id,
event.name,
{**event.properties, **POSTHOG_EVENT_SETTINGS},
)
except Exception as e:
logger.error(f'Failed to send telemetry event {event.name}: {e}')
def flush(self) -> None:
if self._posthog_client:
try:
self._posthog_client.flush()
logger.debug('PostHog client telemetry queue flushed.')
except Exception as e:
logger.error(f'Failed to flush PostHog client: {e}')
else:
logger.debug('PostHog client not available, skipping flush.')
@property
def user_id(self) -> str:
if self._curr_user_id:
return self._curr_user_id
# File access may fail due to permissions or other reasons. We don't want to
# crash so we catch all exceptions.
try:
if not os.path.exists(self.USER_ID_PATH):
os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
with open(self.USER_ID_PATH, 'w') as f:
new_user_id = str(uuid.uuid4())
f.write(new_user_id)
self._curr_user_id = new_user_id
else:
with open(self.USER_ID_PATH) as f:
self._curr_user_id = f.read()
except Exception:
self._curr_user_id = 'UNKNOWN_USER_ID'
return self._curr_user_id

View file

@ -0,0 +1,56 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass
from typing import Any
@dataclass
class BaseTelemetryEvent(ABC):
@property
@abstractmethod
def name(self) -> str:
pass
@property
def properties(self) -> dict[str, Any]:
return {k: v for k, v in asdict(self).items() if k != 'name'}
@dataclass
class RegisteredFunction:
name: str
params: dict[str, Any]
@dataclass
class ControllerRegisteredFunctionsTelemetryEvent(BaseTelemetryEvent):
registered_functions: list[RegisteredFunction]
name: str = 'controller_registered_functions'
@dataclass
class AgentTelemetryEvent(BaseTelemetryEvent):
# start details
task: str
model: str
model_provider: str
planner_llm: str | None
max_steps: int
max_actions_per_step: int
use_vision: bool
use_validation: bool
version: str
source: str
# step details
action_errors: Sequence[str | None]
action_history: Sequence[list[dict] | None]
urls_visited: Sequence[str | None]
# end details
steps: int
total_input_tokens: int
total_duration_seconds: float
success: bool | None
final_result_response: str | None
error_message: str | None
name: str = 'agent_event'

View file

@ -0,0 +1,345 @@
import asyncio
import logging
import os
import platform
import signal
import time
from collections.abc import Callable, Coroutine
from functools import wraps
from sys import stderr
from typing import Any, ParamSpec, TypeVar
logger = logging.getLogger(__name__)
# Global flag to prevent duplicate exit messages
_exiting = False
# Define generic type variables for return type and parameters
R = TypeVar('R')
P = ParamSpec('P')
class SignalHandler:
"""
A modular and reusable signal handling system for managing SIGINT (Ctrl+C), SIGTERM,
and other signals in asyncio applications.
This class provides:
- Configurable signal handling for SIGINT and SIGTERM
- Support for custom pause/resume callbacks
- Management of event loop state across signals
- Standardized handling of first and second Ctrl+C presses
- Cross-platform compatibility (with simplified behavior on Windows)
"""
def __init__(
self,
loop: asyncio.AbstractEventLoop | None = None,
pause_callback: Callable[[], None] | None = None,
resume_callback: Callable[[], None] | None = None,
custom_exit_callback: Callable[[], None] | None = None,
exit_on_second_int: bool = True,
interruptible_task_patterns: list[str] = None,
):
"""
Initialize the signal handler.
Args:
loop: The asyncio event loop to use. Defaults to current event loop.
pause_callback: Function to call when system is paused (first Ctrl+C)
resume_callback: Function to call when system is resumed
custom_exit_callback: Function to call on exit (second Ctrl+C or SIGTERM)
exit_on_second_int: Whether to exit on second SIGINT (Ctrl+C)
interruptible_task_patterns: List of patterns to match task names that should be
canceled on first Ctrl+C (default: ['step', 'multi_act', 'get_next_action'])
"""
self.loop = loop or asyncio.get_event_loop()
self.pause_callback = pause_callback
self.resume_callback = resume_callback
self.custom_exit_callback = custom_exit_callback
self.exit_on_second_int = exit_on_second_int
self.interruptible_task_patterns = interruptible_task_patterns or ['step', 'multi_act', 'get_next_action']
self.is_windows = platform.system() == 'Windows'
# Initialize loop state attributes
self._initialize_loop_state()
# Store original signal handlers to restore them later if needed
self.original_sigint_handler = None
self.original_sigterm_handler = None
def _initialize_loop_state(self) -> None:
"""Initialize loop state attributes used for signal handling."""
setattr(self.loop, 'ctrl_c_pressed', False)
setattr(self.loop, 'waiting_for_input', False)
def register(self) -> None:
"""Register signal handlers for SIGINT and SIGTERM."""
try:
if self.is_windows:
# On Windows, use simple signal handling with immediate exit on Ctrl+C
def windows_handler(sig, frame):
print('\n\n🛑 Got Ctrl+C. Exiting immediately on Windows...\n', file=stderr)
# Run the custom exit callback if provided
if self.custom_exit_callback:
self.custom_exit_callback()
os._exit(0)
self.original_sigint_handler = signal.signal(signal.SIGINT, windows_handler)
else:
# On Unix-like systems, use asyncio's signal handling for smoother experience
self.original_sigint_handler = self.loop.add_signal_handler(signal.SIGINT, lambda: self.sigint_handler())
self.original_sigterm_handler = self.loop.add_signal_handler(signal.SIGTERM, lambda: self.sigterm_handler())
except Exception:
# there are situations where signal handlers are not supported, e.g.
# - when running in a thread other than the main thread
# - some operating systems
# - inside jupyter notebooks
pass
def unregister(self) -> None:
"""Unregister signal handlers and restore original handlers if possible."""
try:
if self.is_windows:
# On Windows, just restore the original SIGINT handler
if self.original_sigint_handler:
signal.signal(signal.SIGINT, self.original_sigint_handler)
else:
# On Unix-like systems, use asyncio's signal handler removal
self.loop.remove_signal_handler(signal.SIGINT)
self.loop.remove_signal_handler(signal.SIGTERM)
# Restore original handlers if available
if self.original_sigint_handler:
signal.signal(signal.SIGINT, self.original_sigint_handler)
if self.original_sigterm_handler:
signal.signal(signal.SIGTERM, self.original_sigterm_handler)
except Exception as e:
logger.warning(f'Error while unregistering signal handlers: {e}')
def _handle_second_ctrl_c(self) -> None:
"""
Handle a second Ctrl+C press by performing cleanup and exiting.
This is shared logic used by both sigint_handler and wait_for_resume.
"""
global _exiting
if not _exiting:
_exiting = True
# Call custom exit callback if provided
if self.custom_exit_callback:
try:
self.custom_exit_callback()
except Exception as e:
logger.error(f'Error in exit callback: {e}')
# Force immediate exit - more reliable than sys.exit()
print('\n\n🛑 Got second Ctrl+C. Exiting immediately...\n', file=stderr)
# Reset terminal to a clean state by sending multiple escape sequences
# Order matters for terminal resets - we try different approaches
# Reset terminal modes for both stdout and stderr
print('\033[?25h', end='', flush=True, file=stderr) # Show cursor
print('\033[?25h', end='', flush=True) # Show cursor
# Reset text attributes and terminal modes
print('\033[0m', end='', flush=True, file=stderr) # Reset text attributes
print('\033[0m', end='', flush=True) # Reset text attributes
# Disable special input modes that may cause arrow keys to output control chars
print('\033[?1l', end='', flush=True, file=stderr) # Reset cursor keys to normal mode
print('\033[?1l', end='', flush=True) # Reset cursor keys to normal mode
# Disable bracketed paste mode
print('\033[?2004l', end='', flush=True, file=stderr)
print('\033[?2004l', end='', flush=True)
# Carriage return helps ensure a clean line
print('\r', end='', flush=True, file=stderr)
print('\r', end='', flush=True)
os._exit(0)
def sigint_handler(self) -> None:
"""
SIGINT (Ctrl+C) handler.
First Ctrl+C: Cancel current step and pause.
Second Ctrl+C: Exit immediately if exit_on_second_int is True.
"""
global _exiting
if _exiting:
# Already exiting, force exit immediately
os._exit(0)
if getattr(self.loop, 'ctrl_c_pressed', False):
# If we're in the waiting for input state, let the pause method handle it
if getattr(self.loop, 'waiting_for_input', False):
return
# Second Ctrl+C - exit immediately if configured to do so
if self.exit_on_second_int:
self._handle_second_ctrl_c()
# Mark that Ctrl+C was pressed
self.loop.ctrl_c_pressed = True
# Cancel current tasks that should be interruptible - this is crucial for immediate pausing
self._cancel_interruptible_tasks()
# Call pause callback if provided - this sets the paused flag
if self.pause_callback:
try:
self.pause_callback()
except Exception as e:
logger.error(f'Error in pause callback: {e}')
# Log pause message after pause_callback is called (not before)
print('----------------------------------------------------------------------', file=stderr)
def sigterm_handler(self) -> None:
"""
SIGTERM handler.
Always exits the program completely.
"""
global _exiting
if not _exiting:
_exiting = True
print('\n\n🛑 SIGTERM received. Exiting immediately...\n\n', file=stderr)
# Call custom exit callback if provided
if self.custom_exit_callback:
self.custom_exit_callback()
os._exit(0)
def _cancel_interruptible_tasks(self) -> None:
"""Cancel current tasks that should be interruptible."""
current_task = asyncio.current_task(self.loop)
for task in asyncio.all_tasks(self.loop):
if task != current_task and not task.done():
task_name = task.get_name() if hasattr(task, 'get_name') else str(task)
# Cancel tasks that match certain patterns
if any(pattern in task_name for pattern in self.interruptible_task_patterns):
logger.debug(f'Cancelling task: {task_name}')
task.cancel()
# Add exception handler to silence "Task exception was never retrieved" warnings
task.add_done_callback(lambda t: t.exception() if t.cancelled() else None)
# Also cancel the current task if it's interruptible
if current_task and not current_task.done():
task_name = current_task.get_name() if hasattr(current_task, 'get_name') else str(current_task)
if any(pattern in task_name for pattern in self.interruptible_task_patterns):
logger.debug(f'Cancelling current task: {task_name}')
current_task.cancel()
def wait_for_resume(self) -> None:
"""
Wait for user input to resume or exit.
This method should be called after handling the first Ctrl+C.
It temporarily restores default signal handling to allow catching
a second Ctrl+C directly.
"""
# Set flag to indicate we're waiting for input
setattr(self.loop, 'waiting_for_input', True)
# Temporarily restore default signal handling for SIGINT
# This ensures KeyboardInterrupt will be raised during input()
original_handler = signal.getsignal(signal.SIGINT)
try:
signal.signal(signal.SIGINT, signal.default_int_handler)
except ValueError:
# we are running in a thread other than the main thread
# or signal handlers are not supported for some other reason
pass
green = '\x1b[32;1m'
red = '\x1b[31m'
blink = '\033[33;5m'
unblink = '\033[0m'
reset = '\x1b[0m'
try: # escape code is to blink the ...
print(
f'➡️ Press {green}[Enter]{reset} to resume or {red}[Ctrl+C]{reset} again to exit{blink}...{unblink} ',
end='',
flush=True,
file=stderr,
)
input() # This will raise KeyboardInterrupt on Ctrl+C
# Call resume callback if provided
if self.resume_callback:
self.resume_callback()
except KeyboardInterrupt:
# Use the shared method to handle second Ctrl+C
self._handle_second_ctrl_c()
finally:
try:
# Restore our signal handler
signal.signal(signal.SIGINT, original_handler)
setattr(self.loop, 'waiting_for_input', False)
except Exception:
pass
def reset(self) -> None:
"""Reset state after resuming."""
# Clear the flags
if hasattr(self.loop, 'ctrl_c_pressed'):
self.loop.ctrl_c_pressed = False
if hasattr(self.loop, 'waiting_for_input'):
self.loop.waiting_for_input = False
def time_execution_sync(additional_text: str = '') -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
logger.debug(f'{additional_text} Execution time: {execution_time:.2f} seconds')
return result
return wrapper
return decorator
def time_execution_async(
additional_text: str = '',
) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
logger.debug(f'{additional_text} Execution time: {execution_time:.2f} seconds')
return result
return wrapper
return decorator
def singleton(cls):
instance = [None]
def wrapper(*args, **kwargs):
if instance[0] is None:
instance[0] = cls(*args, **kwargs)
return instance[0]
return wrapper
def check_env_variables(keys: list[str], any_or_all=all) -> bool:
"""Check if all required environment variables are set"""
return any_or_all(os.getenv(key, '').strip() for key in keys)