|
- import json
- import re
- import time
- from collections import defaultdict
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Union
- import os
- import numpy as np
- import numpy.typing as npt
- from beartype import beartype
- from beartype.door import is_bearable
- from gymnasium import Env
- from gymnasium.spaces import Box, Text
- from playwright.sync_api import (
- CDPSession,
- Page,
- Playwright,
- ViewportSize,
- expect,
- sync_playwright,
- )
-
- from .actions import Action, execute_action, get_action_space
- from .processors import ObservationHandler, ObservationMetadata
- from .utils import (
- AccessibilityTree,
- DetachedPage,
- Observation,
- png_bytes_to_numpy,
- )
- from .label import label
-
- unmark = """
- function unmarkPage() {
- function clearDataIndexes() {
- var allElements = document.querySelectorAll("[data-markindex]");
- allElements.forEach(function(element) {
- element.removeAttribute("data-markindex");
- });
- }
- const markedElements = document.querySelectorAll("div[marked='true']");
- markedElements.forEach(element => {
- element.parentNode.removeChild(element);
- });
- clearDataIndexes();
- }"""
- js_code = """
- function markPage() {
- var labels = []
- var bodyRect = document.body.getBoundingClientRect();
- var items = Array.prototype.slice.call(
- document.querySelectorAll('*')
- ).map(function(element) {
- var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
- var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
-
- var rects = [...element.getClientRects()].filter(bb => {
- var center_x = bb.left + bb.width / 2;
- var center_y = bb.top + bb.height / 2;
- var elAtCenter = document.elementFromPoint(center_x, center_y);
-
- return elAtCenter === element || element.contains(elAtCenter)
- }).map(bb => {
- const rect = {
- left: Math.max(0, bb.left),
- top: Math.max(0, bb.top),
- right: Math.min(vw, bb.right),
- bottom: Math.min(vh, bb.bottom)
- };
- return {
- ...rect,
- width: rect.right - rect.left,
- height: rect.bottom - rect.top
- }
- });
-
- var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
-
- return {
- element: element,
- include:
- (element.tagName === "INPUT" || element.tagName === "TEXTAREA" || element.tagName === "SELECT" || element.tagName==="OPTION") ||
- (element.tagName === "BUTTON" || element.tagName === "A" || (element.onclick != null) || window.getComputedStyle(element).cursor == "pointer") ||
- (element.tagName === "IFRAME" || element.tagName === "VIDEO") || ((/^http:\/\/localhost:\d+(\/.*)?$/.test(element.title)) && element.tagName==="TR")
- ,
- area,
- rects,
- text: element.textContent.trim().replace(/\s{2,}/g, ' ')
- };
- }).filter(item =>
- item.include && (item.area >= 20)
- );
-
- // Only keep inner clickable items
- //items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
- var topLevelItems = items.filter(function(item) {
- // 检查该元素的父元素是否未包含在items中
- for (var i = 0; i < items.length; i++) {
- if (items[i].element.contains(item.element) && items[i] !== item) {
- return false;
- }
- }
- return true;
- });
- // Function to generate random colors
- function getRandomColor() {
- var letters = '0123456789ABCDEF';
- var color = '#';
- for (var i = 0; i < 6; i++) {
- color += letters[Math.floor(Math.random() * 16)];
- }
- return color;
- }
-
- // Lets create a floating border on top of these elements that will always be visible
- topLevelItems.forEach(function(item, index) {
- item.rects.forEach((bbox) => {
- newElement = document.createElement("div");
- var borderColor = getRandomColor();
- newElement.style.outline = `2px dashed ${borderColor}`;
- newElement.style.position = "fixed";
- newElement.setAttribute("marked", "true"); // 添加这行,用于标识元素
- newElement.setAttribute("data-markindex", index); // 添加index作为元素的data-index属性
- item.element.dataset.markindex = index; // 将这行替换为添加自定义数据属性的代码
- newElement.style.left = bbox.left + "px";
- newElement.style.top = bbox.top + "px";
- newElement.style.width = bbox.width + "px";
- newElement.style.height = bbox.height + "px";
- newElement.style.pointerEvents = "none";
- newElement.style.boxSizing = "border-box";
- newElement.style.zIndex = 2147483647;
- // newElement.style.background = `${borderColor}80`;
-
- // Add floating label at the corner
- var label = document.createElement("span");
- label.textContent = index;
- label.style.position = "absolute";
- label.style.top = "-19px";
- label.style.left = "0px";
- label.style.background = borderColor;
- label.style.color = "white";
- label.style.padding = "2px 4px";
- label.style.fontSize = "18px";
- label.style.borderRadius = "2px";
- newElement.appendChild(label);
-
- document.body.appendChild(newElement);
- labels.push(newElement);
-
- // item.element.setAttribute("-ai-label", label.textContent);
- });
- })
- function calculateUnionBound(rects) {
- let minX = Infinity, minY = Infinity, maxX = -Infinity, maxY = -Infinity;
- console.log(rects)
- rects.forEach((rect) => {
- minX = Math.min(minX, rect.left);
- minY = Math.min(minY, rect.top);
- maxX = Math.max(maxX, rect.left + rect.width);
- maxY = Math.max(maxY, rect.top + rect.height);
- });
-
- return {
- left: minX,
- top: minY,
- right: maxX,
- bottom: maxY,
- width: maxX - minX,
- height: maxY - minY
- };
- }
-
- let groups = new Map();
- topLevelItems.forEach((item, index) => {
- item.rects.forEach((rect) => {
- if (!groups.has(index)) {
- groups.set(index, []);
- }
- let text = item.text;
- if (item.element.tagName === "A") {
- if (item.element.title.trim()) {
- text = item.element.title; // 如果 A标签有title,使用title值
- } else if (item.element.className && !item.element.text.trim()) {
- text = item.element.className.split(' ')[0];
- }
- }
- else if (item.element.tagName === "SELECT") {
- text = Array.from(item.element.options)
- .map(option => 'option: ' + option.text)
- .join(' ');
- } else if (item.element.tagName === "INPUT" && !item.text.trim()){
- if(item.element.value.trim()){
- text = item.element.value
- }else if(item.element.placeholder.trim()) {
- text = item.element.placeholder
- }else{
- text = item.element.id
- }
- }
- else {
- text = item.text;
- }
- if(item.element.getAttribute('required') && item.element.tagName === "INPUT"){
- text = text + " required"
- }
- groups.get(index).push({
- index: index,
- text: text,
- left: rect.left,
- top: rect.top,
- width: rect.width,
- category: item.element.tagName, // 获取元素的类别
- height: rect.height,
- });
- });
- });
- let data = [];
-
- for (let rectGroup of groups.values()) {
- // find the corresponding topLevelItem for comparison
- const correspondingItem = topLevelItems.find(item => item.element.dataset.index == rectGroup[0].index);
- const unionBound = calculateUnionBound(rectGroup);
- data.push({
- index: rectGroup[0].index,
- text: rectGroup[0].text,
- category: rectGroup[0].category // 获取元素的类别
- });
- }
-
- return data;
- }
- """
- def update_viewport_size(self, new_viewport_size: ViewportSize) -> None:
- self.viewport_size = new_viewport_size
- # 确保所有观察处理器也更新了他们的视口大小
- self.text_processor.update_viewport_size(new_viewport_size)
- self.image_processor.update_viewport_size(new_viewport_size)
-
- # 在 TextObervationProcessor 和 ImageObservationProcessor 类中添加一个方法来更新viewport大小
- def update_viewport_size(self, new_viewport_size: ViewportSize) -> None:
- # 更新视口大小相关的其他属性如果有的话
- self.viewport_size = new_viewport_size
-
-
- @dataclass
- class PlaywrightScript:
- function: str # goto, get_by_role
- destination: str # https://www.google.com/, combobox
- name: str | None = None # Search, Avatar 2009
- operation: str | None = None # click, fill, press
- value: str | None = None # avatar movie, Enter
-
-
- def parse_action(action: str) -> PlaywrightScript:
- splitted = action.strip().split(" ")
- assert len(splitted) >= 2
- match splitted[:2]:
- case ["goto", url]:
- assert len(splitted) == 2
- return PlaywrightScript("goto", url)
- case ["get_by_role", destination]:
- assert len(splitted) >= 4
- match splitted[2:]:
- case [name, operation]:
- return PlaywrightScript(
- "get_by_role", destination, name, operation
- )
- case [name, operation, value]:
- return PlaywrightScript(
- "get_by_role", destination, name, operation, value
- )
- case _:
- raise ValueError("Invalid action")
- case _:
- raise ValueError(f"Invalid action {action}")
-
-
- class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
- """
- The goal of this environment is to produce a prototype of a browser environment.
- In the end, we want to support a fully configurable browser environment with wide
- range of action spaces and observation spaces, both structured and unstructured.
- But in this prototype, we just support action space specified by Playwright script,
- and observation space is the html content of the page.
- """
-
- @beartype
- def __init__(
- self,
- max_page_length: int = 8192,
- headless: bool = True,
- slow_mo: int = 0,
- observation_type: str = "html",
- current_viewport_only: bool = False,
- viewport_size: ViewportSize = {"width": 1280, "height": 720},
- save_trace_enabled: bool = False,
- sleep_after_execution: float = 0.0,
- ):
- # TODO: make Space[Action] = ActionSpace
- self.action_space = get_action_space() # type: ignore[assignment]
- self.headless = headless
- self.slow_mo = slow_mo
- self.current_viewport_only = current_viewport_only
- self.reset_finished = False
- self.viewport_size = viewport_size
- self.save_trace_enabled = save_trace_enabled
- self.sleep_after_execution = sleep_after_execution
-
- match observation_type:
- case "html" | "accessibility_tree":
- self.text_observation_type = observation_type
- self.image_observation_type = ""
- self.main_observation_type = "text"
- case "image":
- self.image_observation_type = observation_type
- self.text_observation_type = "" # type: ignore[assignment]
- self.main_observation_type = "image"
- case _:
- raise ValueError(
- f"Unsupported observation type: {observation_type}"
- )
-
- self.observation_handler = ObservationHandler(
- self.main_observation_type,
- self.text_observation_type,
- self.image_observation_type,
- self.current_viewport_only,
- self.viewport_size,
- )
-
- self.observation_space = (
- self.observation_handler.get_observation_space()
- )
-
- @beartype
- def setup(self, config_file: Path | None = None) -> None:
- self.context_manager = sync_playwright()
- self.playwright = self.context_manager.__enter__()
- self.browser = self.playwright.chromium.launch(
- headless=self.headless, slow_mo=self.slow_mo
- )
-
- if config_file:
- with open(config_file, "r") as f:
- instance_config = json.load(f)
- else:
- instance_config = {}
-
- storage_state = instance_config.get("storage_state", None)
- start_url = instance_config.get("start_url", None)
- geolocation = instance_config.get("geolocation", None)
-
- self.context = self.browser.new_context(
- viewport=self.viewport_size,
- storage_state=storage_state,
- geolocation=geolocation,
- device_scale_factor=1,
- )
- if self.save_trace_enabled:
- self.context.tracing.start(screenshots=True, snapshots=True)
- if start_url:
- start_urls = start_url.split(" |AND| ")
- for url in start_urls:
- page = self.context.new_page()
- client = page.context.new_cdp_session(
- page
- ) # talk to chrome devtools
- if self.text_observation_type == "accessibility_tree":
- client.send("Accessibility.enable")
- page.client = client # type: ignore # TODO[shuyanzh], fix this hackey client
- page.goto(url)
- # set the first page as the current page
- self.page = self.context.pages[0]
- self.page.bring_to_front()
- else:
- self.page = self.context.new_page()
- client = self.page.context.new_cdp_session(self.page)
- if self.text_observation_type == "accessibility_tree":
- client.send("Accessibility.enable")
- self.page.client = client # type: ignore
-
- def get_page_client(self, page: Page) -> CDPSession:
- return page.client # type: ignore
-
- def _get_obs(self) -> dict[str, Observation]:
- obs = self.observation_handler.get_observation(
- self.page, self.get_page_client(self.page)
- )
- return obs
-
- def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
- metadata = self.observation_handler.get_observation_metadata()
- return metadata
-
- @beartype
- def reset(
- self,
- *,
- seed: int | None = None,
- options: dict[str, str] | None = None
- ) -> tuple[dict[str, Observation], dict[str, Any]]:
- """
- Reset the environment.
- :param options: options for the environment. The current supported options are:
- - "storage_state": the storage state of the browser. It is a file path to a json file.
- """
- super().reset(seed=seed, options=options)
- if self.reset_finished:
- self.context_manager.__exit__()
-
- if options is not None and "config_file" in options:
- config_file = Path(options["config_file"])
- if config_file.exists():
- self.setup(config_file=config_file)
- else:
- raise ValueError(f"Config file {config_file} does not exist.")
- else:
- self.setup()
- self.reset_finished = True
- if self.sleep_after_execution > 0:
- time.sleep(self.sleep_after_execution)
- # scroll_height = self.page.eval_on_selector("body", "el => el.scrollHeight")
- # self.page.set_viewport_size({"width": 1920, "height": scroll_height})
- # self.observation_handler.update_viewport_size(self.page.viewport_size)
- step_number = options.get('step_number', None)
- config_file_name = options.get('config_file_name', None)
- dir_html = os.path.join(config_file_name, "html")
- dir_screenshot = os.path.join(config_file_name, "screenshot")
- os.makedirs(dir_html, exist_ok=True)
- os.makedirs(dir_screenshot, exist_ok=True)
-
- self.page.screenshot(path="screenshot_origin.png") # 保存屏幕截图
- self.page.screenshot(path=os.path.join(dir_screenshot, f"screenshot_origin_{step_number}.png"))
-
- page_html = self.page.content()
- with open(os.path.join(dir_html, f'step_{step_number}.html'), 'w', encoding='utf-8') as file:
- file.write(page_html)
-
- result = self.page.eval_on_selector("body", js_code) # 执行JavaScript代码
- time.sleep(2)
-
- self.page.screenshot(path="screenshot.png")
- self.page.screenshot(path=os.path.join(dir_screenshot, f"screenshot_mark_{step_number}.png"))
-
-
- # self.page.eval_on_selector("body", unmark)
- # time.sleep(1)
-
-
-
- observation = self._get_obs()
- observation_metadata = self._get_obs_metadata()
-
-
-
- labels_string = result
- # for item in labels:
- # bound_key = (item['left'], item['top'], item['width'], item['height'])
- # index = observation["map"].get(bound_key)
- # label_info = (f"bound_index:{item['index']} -Text: {item['text']} - tree Index: {index}\n")
- # labels_string += label_info
- info = {
- "page": DetachedPage(self.page.url, ""),
- "fail_error": "",
- "observation_metadata": observation_metadata,
- "bound2tree":labels_string
- }
- return (observation, info)
-
- def save_trace(self, trace_path: str | Path) -> None:
- if self.save_trace_enabled:
- self.context.tracing.stop(path=trace_path)
-
- def close(self) -> None:
- if self.reset_finished:
- self.context_manager.__exit__()
-
- def step(
- self, action: Action, step_number: int, config_file_name: str
- ) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
- if not self.reset_finished:
- raise RuntimeError("Call reset first before calling step.")
- success = False
- fail_error = ""
- try:
- self.page = execute_action(
- action,
- self.page,
- self.context,
- self.observation_handler.action_processor,
- )
- success = True
- except Exception as e:
- fail_error = str(e)
- # hard sleep TODO[shuyanzh] suboptimal, may need to check network
- if self.sleep_after_execution > 0:
- time.sleep(self.sleep_after_execution)
- # scroll_height = self.page.eval_on_selector("body", "el => el.scrollHeight")
- # self.page.set_viewport_size({"width": 1920, "height": scroll_height})
- # self.observation_handler.update_viewport_size(self.page.viewport_size)
-
- dir_html = os.path.join(config_file_name, "html")
- dir_screenshot = os.path.join(config_file_name, "screenshot")
- os.makedirs(dir_html, exist_ok=True)
- os.makedirs(dir_screenshot, exist_ok=True)
- page_html = self.page.content()
-
- self.page.eval_on_selector("body", unmark)
- time.sleep(2)
- self.page.screenshot(path="screenshot_origin.png") # 保存屏幕截图
- self.page.screenshot(path=os.path.join(dir_screenshot, f"screenshot_origin_{step_number}.png"))
-
- with open(os.path.join(dir_html, f'step_{step_number}.html'), 'w', encoding='utf-8') as file:
- file.write(page_html)
-
- result = labels = self.page.eval_on_selector("body", js_code) # 执行JavaScript代码
-
- time.sleep(1)
- print(self.page.url)
- self.page.screenshot(path="screenshot.png") # 保存屏幕截图
- self.page.screenshot(path=os.path.join(dir_screenshot, f"screenshot_mark_{step_number}.png"))
-
-
- observation = self._get_obs()
- observation_metadata = self._get_obs_metadata()
- labels_string = result
-
-
- info = {
- "page": DetachedPage(self.page.url, self.page.content()),
- "fail_error": fail_error,
- "observation_metadata": observation_metadata,
- "bound2tree":labels_string
- }
- msg = (
- observation,
- float(success), # reward
- False, # terminated
- False, # truncated
- info,
- )
- return msg
|