[Add] browser-use and main.py

This commit is contained in:
tv0924@icloud.com 2025-05-18 21:57:54 +09:00
commit 96914d44ac
221 changed files with 30952 additions and 1 deletions

View file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
import hashlib
from browser_use.dom.views import DOMElementNode
class ClickableElementProcessor:
@staticmethod
def get_clickable_elements_hashes(dom_element: DOMElementNode) -> set[str]:
"""Get all clickable elements in the DOM tree"""
clickable_elements = ClickableElementProcessor.get_clickable_elements(dom_element)
return {ClickableElementProcessor.hash_dom_element(element) for element in clickable_elements}
@staticmethod
def get_clickable_elements(dom_element: DOMElementNode) -> list[DOMElementNode]:
"""Get all clickable elements in the DOM tree"""
clickable_elements = list()
for child in dom_element.children:
if isinstance(child, DOMElementNode):
if child.highlight_index:
clickable_elements.append(child)
clickable_elements.extend(ClickableElementProcessor.get_clickable_elements(child))
return list(clickable_elements)
@staticmethod
def hash_dom_element(dom_element: DOMElementNode) -> str:
parent_branch_path = ClickableElementProcessor._get_parent_branch_path(dom_element)
branch_path_hash = ClickableElementProcessor._parent_branch_path_hash(parent_branch_path)
attributes_hash = ClickableElementProcessor._attributes_hash(dom_element.attributes)
xpath_hash = ClickableElementProcessor._xpath_hash(dom_element.xpath)
# text_hash = DomTreeProcessor._text_hash(dom_element)
return ClickableElementProcessor._hash_string(f'{branch_path_hash}-{attributes_hash}-{xpath_hash}')
@staticmethod
def _get_parent_branch_path(dom_element: DOMElementNode) -> list[str]:
parents: list[DOMElementNode] = []
current_element: DOMElementNode = dom_element
while current_element.parent is not None:
parents.append(current_element)
current_element = current_element.parent
parents.reverse()
return [parent.tag_name for parent in parents]
@staticmethod
def _parent_branch_path_hash(parent_branch_path: list[str]) -> str:
parent_branch_path_string = '/'.join(parent_branch_path)
return hashlib.sha256(parent_branch_path_string.encode()).hexdigest()
@staticmethod
def _attributes_hash(attributes: dict[str, str]) -> str:
attributes_string = ''.join(f'{key}={value}' for key, value in attributes.items())
return ClickableElementProcessor._hash_string(attributes_string)
@staticmethod
def _xpath_hash(xpath: str) -> str:
return ClickableElementProcessor._hash_string(xpath)
@staticmethod
def _text_hash(dom_element: DOMElementNode) -> str:
""" """
text_string = dom_element.get_all_text_till_next_clickable_element()
return ClickableElementProcessor._hash_string(text_string)
@staticmethod
def _hash_string(string: str) -> str:
return hashlib.sha256(string.encode()).hexdigest()

View file

@ -0,0 +1,106 @@
import hashlib
from browser_use.dom.history_tree_processor.view import DOMHistoryElement, HashedDomElement
from browser_use.dom.views import DOMElementNode
class HistoryTreeProcessor:
""" "
Operations on the DOM elements
@dev be careful - text nodes can change even if elements stay the same
"""
@staticmethod
def convert_dom_element_to_history_element(dom_element: DOMElementNode) -> DOMHistoryElement:
from browser_use.browser.context import BrowserContext
parent_branch_path = HistoryTreeProcessor._get_parent_branch_path(dom_element)
css_selector = BrowserContext._enhanced_css_selector_for_element(dom_element)
return DOMHistoryElement(
dom_element.tag_name,
dom_element.xpath,
dom_element.highlight_index,
parent_branch_path,
dom_element.attributes,
dom_element.shadow_root,
css_selector=css_selector,
page_coordinates=dom_element.page_coordinates,
viewport_coordinates=dom_element.viewport_coordinates,
viewport_info=dom_element.viewport_info,
)
@staticmethod
def find_history_element_in_tree(dom_history_element: DOMHistoryElement, tree: DOMElementNode) -> DOMElementNode | None:
hashed_dom_history_element = HistoryTreeProcessor._hash_dom_history_element(dom_history_element)
def process_node(node: DOMElementNode):
if node.highlight_index is not None:
hashed_node = HistoryTreeProcessor._hash_dom_element(node)
if hashed_node == hashed_dom_history_element:
return node
for child in node.children:
if isinstance(child, DOMElementNode):
result = process_node(child)
if result is not None:
return result
return None
return process_node(tree)
@staticmethod
def compare_history_element_and_dom_element(dom_history_element: DOMHistoryElement, dom_element: DOMElementNode) -> bool:
hashed_dom_history_element = HistoryTreeProcessor._hash_dom_history_element(dom_history_element)
hashed_dom_element = HistoryTreeProcessor._hash_dom_element(dom_element)
return hashed_dom_history_element == hashed_dom_element
@staticmethod
def _hash_dom_history_element(dom_history_element: DOMHistoryElement) -> HashedDomElement:
branch_path_hash = HistoryTreeProcessor._parent_branch_path_hash(dom_history_element.entire_parent_branch_path)
attributes_hash = HistoryTreeProcessor._attributes_hash(dom_history_element.attributes)
xpath_hash = HistoryTreeProcessor._xpath_hash(dom_history_element.xpath)
return HashedDomElement(branch_path_hash, attributes_hash, xpath_hash)
@staticmethod
def _hash_dom_element(dom_element: DOMElementNode) -> HashedDomElement:
parent_branch_path = HistoryTreeProcessor._get_parent_branch_path(dom_element)
branch_path_hash = HistoryTreeProcessor._parent_branch_path_hash(parent_branch_path)
attributes_hash = HistoryTreeProcessor._attributes_hash(dom_element.attributes)
xpath_hash = HistoryTreeProcessor._xpath_hash(dom_element.xpath)
# text_hash = DomTreeProcessor._text_hash(dom_element)
return HashedDomElement(branch_path_hash, attributes_hash, xpath_hash)
@staticmethod
def _get_parent_branch_path(dom_element: DOMElementNode) -> list[str]:
parents: list[DOMElementNode] = []
current_element: DOMElementNode = dom_element
while current_element.parent is not None:
parents.append(current_element)
current_element = current_element.parent
parents.reverse()
return [parent.tag_name for parent in parents]
@staticmethod
def _parent_branch_path_hash(parent_branch_path: list[str]) -> str:
parent_branch_path_string = '/'.join(parent_branch_path)
return hashlib.sha256(parent_branch_path_string.encode()).hexdigest()
@staticmethod
def _attributes_hash(attributes: dict[str, str]) -> str:
attributes_string = ''.join(f'{key}={value}' for key, value in attributes.items())
return hashlib.sha256(attributes_string.encode()).hexdigest()
@staticmethod
def _xpath_hash(xpath: str) -> str:
return hashlib.sha256(xpath.encode()).hexdigest()
@staticmethod
def _text_hash(dom_element: DOMElementNode) -> str:
""" """
text_string = dom_element.get_all_text_till_next_clickable_element()
return hashlib.sha256(text_string.encode()).hexdigest()

View file

@ -0,0 +1,69 @@
from dataclasses import dataclass
from pydantic import BaseModel
@dataclass
class HashedDomElement:
"""
Hash of the dom element to be used as a unique identifier
"""
branch_path_hash: str
attributes_hash: str
xpath_hash: str
# text_hash: str
class Coordinates(BaseModel):
x: int
y: int
class CoordinateSet(BaseModel):
top_left: Coordinates
top_right: Coordinates
bottom_left: Coordinates
bottom_right: Coordinates
center: Coordinates
width: int
height: int
class ViewportInfo(BaseModel):
scroll_x: int
scroll_y: int
width: int
height: int
@dataclass
class DOMHistoryElement:
tag_name: str
xpath: str
highlight_index: int | None
entire_parent_branch_path: list[str]
attributes: dict[str, str]
shadow_root: bool = False
css_selector: str | None = None
page_coordinates: CoordinateSet | None = None
viewport_coordinates: CoordinateSet | None = None
viewport_info: ViewportInfo | None = None
def to_dict(self) -> dict:
page_coordinates = self.page_coordinates.model_dump() if self.page_coordinates else None
viewport_coordinates = self.viewport_coordinates.model_dump() if self.viewport_coordinates else None
viewport_info = self.viewport_info.model_dump() if self.viewport_info else None
return {
'tag_name': self.tag_name,
'xpath': self.xpath,
'highlight_index': self.highlight_index,
'entire_parent_branch_path': self.entire_parent_branch_path,
'attributes': self.attributes,
'shadow_root': self.shadow_root,
'css_selector': self.css_selector,
'page_coordinates': page_coordinates,
'viewport_coordinates': viewport_coordinates,
'viewport_info': viewport_info,
}

View file

@ -0,0 +1,203 @@
import json
import logging
from dataclasses import dataclass
from importlib import resources
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from playwright.async_api import Page
from browser_use.dom.views import (
DOMBaseNode,
DOMElementNode,
DOMState,
DOMTextNode,
SelectorMap,
)
from browser_use.utils import time_execution_async
logger = logging.getLogger(__name__)
@dataclass
class ViewportInfo:
width: int
height: int
class DomService:
def __init__(self, page: 'Page'):
self.page = page
self.xpath_cache = {}
self.js_code = resources.files('browser_use.dom').joinpath('buildDomTree.js').read_text()
# region - Clickable elements
@time_execution_async('--get_clickable_elements')
async def get_clickable_elements(
self,
highlight_elements: bool = True,
focus_element: int = -1,
viewport_expansion: int = 0,
) -> DOMState:
element_tree, selector_map = await self._build_dom_tree(highlight_elements, focus_element, viewport_expansion)
return DOMState(element_tree=element_tree, selector_map=selector_map)
@time_execution_async('--get_cross_origin_iframes')
async def get_cross_origin_iframes(self) -> list[str]:
# invisible cross-origin iframes are used for ads and tracking, dont open those
hidden_frame_urls = await self.page.locator('iframe').filter(visible=False).evaluate_all('e => e.map(e => e.src)')
is_ad_url = lambda url: any(
domain in urlparse(url).netloc for domain in ('doubleclick.net', 'adroll.com', 'googletagmanager.com')
)
return [
frame.url
for frame in self.page.frames
if urlparse(frame.url).netloc # exclude data:urls and about:blank
and urlparse(frame.url).netloc != urlparse(self.page.url).netloc # exclude same-origin iframes
and frame.url not in hidden_frame_urls # exclude hidden frames
and not is_ad_url(frame.url) # exclude most common ad network tracker frame URLs
]
@time_execution_async('--build_dom_tree')
async def _build_dom_tree(
self,
highlight_elements: bool,
focus_element: int,
viewport_expansion: int,
) -> tuple[DOMElementNode, SelectorMap]:
if await self.page.evaluate('1+1') != 2:
raise ValueError('The page cannot evaluate javascript code properly')
if self.page.url == 'about:blank':
# short-circuit if the page is a new empty tab for speed, no need to inject buildDomTree.js
return (
DOMElementNode(
tag_name='body',
xpath='',
attributes={},
children=[],
is_visible=False,
parent=None,
),
{},
)
# NOTE: We execute JS code in the browser to extract important DOM information.
# The returned hash map contains information about the DOM tree and the
# relationship between the DOM elements.
debug_mode = logger.getEffectiveLevel() == logging.DEBUG
args = {
'doHighlightElements': highlight_elements,
'focusHighlightIndex': focus_element,
'viewportExpansion': viewport_expansion,
'debugMode': debug_mode,
}
try:
eval_page: dict = await self.page.evaluate(self.js_code, args)
except Exception as e:
logger.error('Error evaluating JavaScript: %s', e)
raise
# Only log performance metrics in debug mode
if debug_mode and 'perfMetrics' in eval_page:
logger.debug(
'DOM Tree Building Performance Metrics for: %s\n%s',
self.page.url,
json.dumps(eval_page['perfMetrics'], indent=2),
)
return await self._construct_dom_tree(eval_page)
@time_execution_async('--construct_dom_tree')
async def _construct_dom_tree(
self,
eval_page: dict,
) -> tuple[DOMElementNode, SelectorMap]:
js_node_map = eval_page['map']
js_root_id = eval_page['rootId']
selector_map = {}
node_map = {}
for id, node_data in js_node_map.items():
node, children_ids = self._parse_node(node_data)
if node is None:
continue
node_map[id] = node
if isinstance(node, DOMElementNode) and node.highlight_index is not None:
selector_map[node.highlight_index] = node
# NOTE: We know that we are building the tree bottom up
# and all children are already processed.
if isinstance(node, DOMElementNode):
for child_id in children_ids:
if child_id not in node_map:
continue
child_node = node_map[child_id]
child_node.parent = node
node.children.append(child_node)
html_to_dict = node_map[str(js_root_id)]
del node_map
del js_node_map
del js_root_id
if html_to_dict is None or not isinstance(html_to_dict, DOMElementNode):
raise ValueError('Failed to parse HTML to dictionary')
return html_to_dict, selector_map
def _parse_node(
self,
node_data: dict,
) -> tuple[DOMBaseNode | None, list[int]]:
if not node_data:
return None, []
# Process text nodes immediately
if node_data.get('type') == 'TEXT_NODE':
text_node = DOMTextNode(
text=node_data['text'],
is_visible=node_data['isVisible'],
parent=None,
)
return text_node, []
# Process coordinates if they exist for element nodes
viewport_info = None
if 'viewport' in node_data:
viewport_info = ViewportInfo(
width=node_data['viewport']['width'],
height=node_data['viewport']['height'],
)
element_node = DOMElementNode(
tag_name=node_data['tagName'],
xpath=node_data['xpath'],
attributes=node_data.get('attributes', {}),
children=[],
is_visible=node_data.get('isVisible', False),
is_interactive=node_data.get('isInteractive', False),
is_top_element=node_data.get('isTopElement', False),
is_in_viewport=node_data.get('isInViewport', False),
highlight_index=node_data.get('highlightIndex'),
shadow_root=node_data.get('shadowRoot', False),
parent=None,
viewport_info=viewport_info,
)
children_ids = node_data.get('children', [])
return element_node, children_ids

View file

@ -0,0 +1,123 @@
import asyncio
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from browser_use.browser.browser import Browser, BrowserConfig
from browser_use.browser.context import BrowserContext
async def analyze_page_structure(url: str):
"""Analyze and print the structure of a webpage with enhanced debugging"""
browser = Browser(
config=BrowserConfig(
headless=False, # Set to True if you don't need to see the browser
)
)
context = BrowserContext(browser=browser)
try:
async with context as ctx:
# Navigate to the URL
page = await ctx.get_current_page()
await page.goto(url)
await page.wait_for_load_state('networkidle')
# Get viewport dimensions
viewport_info = await page.evaluate("""() => {
return {
viewport: {
width: window.innerWidth,
height: window.innerHeight,
scrollX: window.scrollX,
scrollY: window.scrollY
}
}
}""")
print('\nViewport Information:')
print(f'Width: {viewport_info["viewport"]["width"]}')
print(f'Height: {viewport_info["viewport"]["height"]}')
print(f'ScrollX: {viewport_info["viewport"]["scrollX"]}')
print(f'ScrollY: {viewport_info["viewport"]["scrollY"]}')
# Enhanced debug information for cookie consent and fixed position elements
debug_info = await page.evaluate("""() => {
function getElementInfo(element) {
const rect = element.getBoundingClientRect();
const style = window.getComputedStyle(element);
return {
tag: element.tagName.toLowerCase(),
id: element.id,
className: element.className,
position: style.position,
rect: {
top: rect.top,
right: rect.right,
bottom: rect.bottom,
left: rect.left,
width: rect.width,
height: rect.height
},
isFixed: style.position === 'fixed',
isSticky: style.position === 'sticky',
zIndex: style.zIndex,
visibility: style.visibility,
display: style.display,
opacity: style.opacity
};
}
// Find cookie-related elements
const cookieElements = Array.from(document.querySelectorAll('[id*="cookie"], [id*="consent"], [class*="cookie"], [class*="consent"]'));
const fixedElements = Array.from(document.querySelectorAll('*')).filter(el => {
const style = window.getComputedStyle(el);
return style.position === 'fixed' || style.position === 'sticky';
});
return {
cookieElements: cookieElements.map(getElementInfo),
fixedElements: fixedElements.map(getElementInfo)
};
}""")
print('\nCookie-related Elements:')
for elem in debug_info['cookieElements']:
print(f'\nElement: {elem["tag"]}#{elem["id"]} .{elem["className"]}')
print(f'Position: {elem["position"]}')
print(f'Rect: {elem["rect"]}')
print(f'Z-Index: {elem["zIndex"]}')
print(f'Visibility: {elem["visibility"]}')
print(f'Display: {elem["display"]}')
print(f'Opacity: {elem["opacity"]}')
print('\nFixed/Sticky Position Elements:')
for elem in debug_info['fixedElements']:
print(f'\nElement: {elem["tag"]}#{elem["id"]} .{elem["className"]}')
print(f'Position: {elem["position"]}')
print(f'Rect: {elem["rect"]}')
print(f'Z-Index: {elem["zIndex"]}')
print(f'\nPage Structure for {url}:\n')
structure = await ctx.get_page_structure()
print(structure)
input('Press Enter to close the browser...')
finally:
await browser.close()
if __name__ == '__main__':
# You can modify this URL to analyze different pages
urls = [
'https://www.mlb.com/yankees/stats/',
'https://immobilienscout24.de',
'https://www.zeiss.com/career/en/job-search.html?page=1',
'https://www.zeiss.com/career/en/job-search.html?page=1',
'https://reddit.com',
]
for url in urls:
asyncio.run(analyze_page_structure(url))

View file

@ -0,0 +1,182 @@
import asyncio
import os
import anyio
from langchain_openai import ChatOpenAI
from browser_use.agent.prompts import AgentMessagePrompt
from browser_use.browser.browser import Browser, BrowserConfig
from browser_use.browser.context import BrowserContext, BrowserContextConfig
from browser_use.dom.service import DomService
def count_string_tokens(string: str, model: str) -> tuple[int, float]:
"""Count the number of tokens in a string using a specified model."""
def get_price_per_token(model: str) -> float:
"""Get the price per token for a specified model.
@todo: move to utils, use a package or sth
"""
prices = {
'gpt-4o': 2.5 / 1e6,
'gpt-4o-mini': 0.15 / 1e6,
}
return prices[model]
llm = ChatOpenAI(model=model)
token_count = llm.get_num_tokens(string)
price = token_count * get_price_per_token(model)
return token_count, price
TIMEOUT = 60
DEFAULT_INCLUDE_ATTRIBUTES = [
'id',
'title',
'type',
'name',
'role',
'aria-label',
'placeholder',
'value',
'alt',
'aria-expanded',
'data-date-format',
]
async def test_focus_vs_all_elements():
config = BrowserContextConfig(
# cookies_file='cookies3.json',
disable_security=True,
wait_for_network_idle_page_load_time=1,
)
browser = Browser(
config=BrowserConfig(
# browser_binary_path='/Applications/Google Chrome.app/Contents/MacOS/Google Chrome',
)
)
context = BrowserContext(browser=browser, config=config)
websites = [
'https://demos.telerik.com/kendo-react-ui/treeview/overview/basic/func?theme=default-ocean-blue-a11y',
'https://www.ycombinator.com/companies',
'https://kayak.com/flights',
# 'https://en.wikipedia.org/wiki/Humanist_Party_of_Ontario',
# 'https://www.google.com/travel/flights?tfs=CBwQARoJagcIARIDTEpVGglyBwgBEgNMSlVAAUgBcAGCAQsI____________AZgBAQ&tfu=KgIIAw&hl=en-US&gl=US',
# # 'https://www.concur.com/?&cookie_preferences=cpra',
# 'https://immobilienscout24.de',
'https://docs.google.com/spreadsheets/d/1INaIcfpYXlMRWO__de61SHFCaqt1lfHlcvtXZPItlpI/edit',
'https://www.zeiss.com/career/en/job-search.html?page=1',
'https://www.mlb.com/yankees/stats/',
'https://www.amazon.com/s?k=laptop&s=review-rank&crid=1RZCEJ289EUSI&qid=1740202453&sprefix=laptop%2Caps%2C166&ref=sr_st_review-rank&ds=v1%3A4EnYKXVQA7DIE41qCvRZoNB4qN92Jlztd3BPsTFXmxU',
'https://reddit.com',
'https://codepen.io/geheimschriftstift/pen/mPLvQz',
'https://www.google.com/search?q=google+hi&oq=google+hi&gs_lcrp=EgZjaHJvbWUyBggAEEUYOTIGCAEQRRhA0gEIMjI2NmowajSoAgCwAgE&sourceid=chrome&ie=UTF-8',
'https://google.com',
'https://amazon.com',
'https://github.com',
]
async with context as context:
page = await context.get_current_page()
dom_service = DomService(page)
for website in websites:
# sleep 2
await page.goto(website)
asyncio.sleep(1)
last_clicked_index = None # Track the index for text input
while True:
try:
print(f'\n{"=" * 50}\nTesting {website}\n{"=" * 50}')
# Get/refresh the state (includes removing old highlights)
print('\nGetting page state...')
all_elements_state = await context.get_state(True)
selector_map = all_elements_state.selector_map
total_elements = len(selector_map.keys())
print(f'Total number of elements: {total_elements}')
# print(all_elements_state.element_tree.clickable_elements_to_string())
prompt = AgentMessagePrompt(
state=all_elements_state,
result=None,
include_attributes=DEFAULT_INCLUDE_ATTRIBUTES,
step_info=None,
)
# print(prompt.get_user_message(use_vision=False).content)
# Write the user message to a file for analysis
user_message = prompt.get_user_message(use_vision=False).content
os.makedirs('./tmp', exist_ok=True)
async with await anyio.open_file('./tmp/user_message.txt', 'w', encoding='utf-8') as f:
await f.write(user_message)
token_count, price = count_string_tokens(user_message, model='gpt-4o')
print(f'Prompt token count: {token_count}, price: {round(price, 4)} USD')
print('User message written to ./tmp/user_message.txt')
# also save all_elements_state.element_tree.clickable_elements_to_string() to a file
# with open('./tmp/clickable_elements.json', 'w', encoding='utf-8') as f:
# f.write(json.dumps(all_elements_state.element_tree.__json__(), indent=2))
# print('Clickable elements written to ./tmp/clickable_elements.json')
answer = input("Enter element index to click, 'index,text' to input, or 'q' to quit: ")
if answer.lower() == 'q':
break
try:
if ',' in answer:
# Input text format: index,text
parts = answer.split(',', 1)
if len(parts) == 2:
try:
target_index = int(parts[0].strip())
text_to_input = parts[1]
if target_index in selector_map:
element_node = selector_map[target_index]
print(
f"Inputting text '{text_to_input}' into element {target_index}: {element_node.tag_name}"
)
await context._input_text_element_node(element_node, text_to_input)
print('Input successful.')
else:
print(f'Invalid index: {target_index}')
except ValueError:
print(f'Invalid index format: {parts[0]}')
else:
print("Invalid input format. Use 'index,text'.")
else:
# Click element format: index
try:
clicked_index = int(answer)
if clicked_index in selector_map:
element_node = selector_map[clicked_index]
print(f'Clicking element {clicked_index}: {element_node.tag_name}')
await context._click_element_node(element_node)
print('Click successful.')
else:
print(f'Invalid index: {clicked_index}')
except ValueError:
print(f"Invalid input: '{answer}'. Enter an index, 'index,text', or 'q'.")
except Exception as action_e:
print(f'Action failed: {action_e}')
# No explicit highlight removal here, get_state handles it at the start of the loop
except Exception as e:
print(f'Error in loop: {e}')
# Optionally add a small delay before retrying
await asyncio.sleep(1)
if __name__ == '__main__':
asyncio.run(test_focus_vs_all_elements())
# asyncio.run(test_process_html_file()) # Commented out the other test

View file

@ -0,0 +1,43 @@
import asyncio
import json
import os
import time
import anyio
from browser_use.browser.browser import Browser, BrowserConfig
async def test_process_dom():
browser = Browser(config=BrowserConfig(headless=False))
async with await browser.new_context() as context:
page = await context.get_current_page()
await page.goto('https://kayak.com/flights')
# await page.goto('https://google.com/flights')
# await page.goto('https://immobilienscout24.de')
# await page.goto('https://seleniumbase.io/w3schools/iframes')
await asyncio.sleep(3)
async with await anyio.open_file('browser_use/dom/buildDomTree.js', 'r') as f:
js_code = await f.read()
start = time.time()
dom_tree = await page.evaluate(js_code)
end = time.time()
# print(dom_tree)
print(f'Time: {end - start:.2f}s')
os.makedirs('./tmp', exist_ok=True)
async with await anyio.open_file('./tmp/dom.json', 'w') as f:
await f.write(json.dumps(dom_tree, indent=1))
# both of these work for immobilienscout24.de
# await page.click('.sc-dcJsrY.ezjNCe')
# await page.click(
# 'div > div:nth-of-type(2) > div > div:nth-of-type(2) > div > div:nth-of-type(2) > div > div > div > button:nth-of-type(2)'
# )
input('Press Enter to continue...')

View file

@ -0,0 +1,265 @@
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Optional
from browser_use.dom.history_tree_processor.view import CoordinateSet, HashedDomElement, ViewportInfo
from browser_use.utils import time_execution_sync
# Avoid circular import issues
if TYPE_CHECKING:
from .views import DOMElementNode
@dataclass(frozen=False)
class DOMBaseNode:
is_visible: bool
# Use None as default and set parent later to avoid circular reference issues
parent: Optional['DOMElementNode']
def __json__(self) -> dict:
raise NotImplementedError('DOMBaseNode is an abstract class')
@dataclass(frozen=False)
class DOMTextNode(DOMBaseNode):
text: str
type: str = 'TEXT_NODE'
def has_parent_with_highlight_index(self) -> bool:
current = self.parent
while current is not None:
# stop if the element has a highlight index (will be handled separately)
if current.highlight_index is not None:
return True
current = current.parent
return False
def is_parent_in_viewport(self) -> bool:
if self.parent is None:
return False
return self.parent.is_in_viewport
def is_parent_top_element(self) -> bool:
if self.parent is None:
return False
return self.parent.is_top_element
def __json__(self) -> dict:
return {
'text': self.text,
'type': self.type,
}
@dataclass(frozen=False)
class DOMElementNode(DOMBaseNode):
"""
xpath: the xpath of the element from the last root node (shadow root or iframe OR document if no shadow root or iframe).
To properly reference the element we need to recursively switch the root node until we find the element (work you way up the tree with `.parent`)
"""
tag_name: str
xpath: str
attributes: dict[str, str]
children: list[DOMBaseNode]
is_interactive: bool = False
is_top_element: bool = False
is_in_viewport: bool = False
shadow_root: bool = False
highlight_index: int | None = None
viewport_coordinates: CoordinateSet | None = None
page_coordinates: CoordinateSet | None = None
viewport_info: ViewportInfo | None = None
"""
### State injected by the browser context.
The idea is that the clickable elements are sometimes persistent from the previous page -> tells the model which objects are new/_how_ the state has changed
"""
is_new: bool | None = None
def __json__(self) -> dict:
return {
'tag_name': self.tag_name,
'xpath': self.xpath,
'attributes': self.attributes,
'is_visible': self.is_visible,
'is_interactive': self.is_interactive,
'is_top_element': self.is_top_element,
'is_in_viewport': self.is_in_viewport,
'shadow_root': self.shadow_root,
'highlight_index': self.highlight_index,
'viewport_coordinates': self.viewport_coordinates,
'page_coordinates': self.page_coordinates,
'children': [child.__json__() for child in self.children],
}
def __repr__(self) -> str:
tag_str = f'<{self.tag_name}'
# Add attributes
for key, value in self.attributes.items():
tag_str += f' {key}="{value}"'
tag_str += '>'
# Add extra info
extras = []
if self.is_interactive:
extras.append('interactive')
if self.is_top_element:
extras.append('top')
if self.shadow_root:
extras.append('shadow-root')
if self.highlight_index is not None:
extras.append(f'highlight:{self.highlight_index}')
if self.is_in_viewport:
extras.append('in-viewport')
if extras:
tag_str += f' [{", ".join(extras)}]'
return tag_str
@cached_property
def hash(self) -> HashedDomElement:
from browser_use.dom.history_tree_processor.service import (
HistoryTreeProcessor,
)
return HistoryTreeProcessor._hash_dom_element(self)
def get_all_text_till_next_clickable_element(self, max_depth: int = -1) -> str:
text_parts = []
def collect_text(node: DOMBaseNode, current_depth: int) -> None:
if max_depth != -1 and current_depth > max_depth:
return
# Skip this branch if we hit a highlighted element (except for the current node)
if isinstance(node, DOMElementNode) and node != self and node.highlight_index is not None:
return
if isinstance(node, DOMTextNode):
text_parts.append(node.text)
elif isinstance(node, DOMElementNode):
for child in node.children:
collect_text(child, current_depth + 1)
collect_text(self, 0)
return '\n'.join(text_parts).strip()
@time_execution_sync('--clickable_elements_to_string')
def clickable_elements_to_string(self, include_attributes: list[str] | None = None) -> str:
"""Convert the processed DOM content to HTML."""
formatted_text = []
def process_node(node: DOMBaseNode, depth: int) -> None:
next_depth = int(depth)
depth_str = depth * '\t'
if isinstance(node, DOMElementNode):
# Add element with highlight_index
if node.highlight_index is not None:
next_depth += 1
text = node.get_all_text_till_next_clickable_element()
attributes_html_str = ''
if include_attributes:
attributes_to_include = {
key: str(value) for key, value in node.attributes.items() if key in include_attributes
}
# Easy LLM optimizations
# if tag == role attribute, don't include it
if node.tag_name == attributes_to_include.get('role'):
del attributes_to_include['role']
# if aria-label == text of the node, don't include it
if (
attributes_to_include.get('aria-label')
and attributes_to_include.get('aria-label', '').strip() == text.strip()
):
del attributes_to_include['aria-label']
# if placeholder == text of the node, don't include it
if (
attributes_to_include.get('placeholder')
and attributes_to_include.get('placeholder', '').strip() == text.strip()
):
del attributes_to_include['placeholder']
if attributes_to_include:
# Format as key1='value1' key2='value2'
attributes_html_str = ' '.join(f"{key}='{value}'" for key, value in attributes_to_include.items())
# Build the line
if node.is_new:
highlight_indicator = f'*[{node.highlight_index}]*'
else:
highlight_indicator = f'[{node.highlight_index}]'
line = f'{depth_str}{highlight_indicator}<{node.tag_name}'
if attributes_html_str:
line += f' {attributes_html_str}'
if text:
# Add space before >text only if there were NO attributes added before
if not attributes_html_str:
line += ' '
line += f'>{text}'
# Add space before /> only if neither attributes NOR text were added
elif not attributes_html_str:
line += ' '
line += ' />' # 1 token
formatted_text.append(line)
# Process children regardless
for child in node.children:
process_node(child, next_depth)
elif isinstance(node, DOMTextNode):
# Add text only if it doesn't have a highlighted parent
if (
not node.has_parent_with_highlight_index()
and node.parent
and node.parent.is_visible
and node.parent.is_top_element
): # and node.is_parent_top_element()
formatted_text.append(f'{depth_str}{node.text}')
process_node(self, 0)
return '\n'.join(formatted_text)
def get_file_upload_element(self, check_siblings: bool = True) -> Optional['DOMElementNode']:
# Check if current element is a file input
if self.tag_name == 'input' and self.attributes.get('type') == 'file':
return self
# Check children
for child in self.children:
if isinstance(child, DOMElementNode):
result = child.get_file_upload_element(check_siblings=False)
if result:
return result
# Check siblings only for the initial call
if check_siblings and self.parent:
for sibling in self.parent.children:
if sibling is not self and isinstance(sibling, DOMElementNode):
result = sibling.get_file_upload_element(check_siblings=False)
if result:
return result
return None
SelectorMap = dict[int, DOMElementNode]
@dataclass
class DOMState:
element_tree: DOMElementNode
selector_map: SelectorMap