|
- import json
- import re
- from collections import defaultdict
- from typing import Any, TypedDict, Union
-
- import numpy as np
- import numpy.typing as npt
- from gymnasium import spaces
- from playwright.sync_api import CDPSession, Page, ViewportSize
-
- from browser_env.constants import (
- ASCII_CHARSET,
- FREQ_UNICODE_CHARSET,
- IGNORED_ACTREE_PROPERTIES,
- UTTERANCE_MAX_LENGTH,
- )
-
- from .utils import (
- AccessibilityTree,
- AccessibilityTreeNode,
- BrowserConfig,
- BrowserInfo,
- DOMNode,
- DOMTree,
- Observation,
- png_bytes_to_numpy,
- )
-
- IN_VIEWPORT_RATIO_THRESHOLD = 0.6
-
-
- class ObservationProcessor:
- def process(self, page: Page, client: CDPSession) -> Observation:
- raise NotImplementedError
-
-
- class ObservationMetadata(TypedDict):
- obs_nodes_info: dict[str, Any]
-
-
- def create_empty_metadata() -> ObservationMetadata:
- return {
- "obs_nodes_info": {},
- }
-
-
- class TextObervationProcessor(ObservationProcessor):
- def __init__(
- self,
- observation_type: str,
- current_viewport_only: bool,
- viewport_size: ViewportSize,
- ):
- self.observation_type = observation_type
- self.current_viewport_only = current_viewport_only
- self.viewport_size = viewport_size
- self.observation_tag = "text"
- self.meta_data = (
- create_empty_metadata()
- ) # use the store meta data of this observation type
-
- def update_viewport_size(self, new_viewport_size: ViewportSize) -> None:
- self.viewport_size = new_viewport_size
-
- def fetch_browser_info(
- self,
- page: Page,
- client: CDPSession,
- ) -> BrowserInfo:
- # extract domtree
- tree = client.send(
- "DOMSnapshot.captureSnapshot",
- {
- "computedStyles": [],
- "includeDOMRects": True,
- "includePaintOrder": True,
- },
- )
-
- # calibrate the bounds, in some cases, the bounds are scaled somehow
- bounds = tree["documents"][0]["layout"]["bounds"]
- b = bounds[0]
- n = b[2] / self.viewport_size["width"]
- bounds = [[x / n for x in bound] for bound in bounds]
- tree["documents"][0]["layout"]["bounds"] = bounds
-
- # extract browser info
- win_top_bound = page.evaluate("window.pageYOffset")
- win_left_bound = page.evaluate("window.pageXOffset")
- win_width = page.evaluate("window.screen.width")
- win_height = page.evaluate("window.screen.height")
- win_right_bound = win_left_bound + win_width
- win_lower_bound = win_top_bound + win_height
- device_pixel_ratio = page.evaluate("window.devicePixelRatio")
- assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
-
- config: BrowserConfig = {
- "win_top_bound": win_top_bound,
- "win_left_bound": win_left_bound,
- "win_width": win_width,
- "win_height": win_height,
- "win_right_bound": win_right_bound,
- "win_lower_bound": win_lower_bound,
- "device_pixel_ratio": device_pixel_ratio,
- }
-
- # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
- info: BrowserInfo = {"DOMTree": tree, "config": config}
-
- return info
-
- @staticmethod
- def get_bounding_client_rect(
- client: CDPSession, backend_node_id: str
- ) -> dict[str, Any]:
- try:
- remote_object = client.send(
- "DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
- )
- remote_object_id = remote_object["object"]["objectId"]
- response = client.send(
- "Runtime.callFunctionOn",
- {
- "objectId": remote_object_id,
- "functionDeclaration": """
- function() {
- if (this.nodeType == 3) {
- var range = document.createRange();
- range.selectNode(this);
- var rect = range.getBoundingClientRect().toJSON();
- range.detach();
- return rect;
- } else {
- return this.getBoundingClientRect().toJSON();
- }
- }
- """,
- "returnByValue": True,
- },
- )
- return response
- except Exception as e:
- return {"result": {"subtype": "error"}}
-
- @staticmethod
- def get_element_in_viewport_ratio(
- elem_left_bound: float,
- elem_top_bound: float,
- width: float,
- height: float,
- config: BrowserConfig,
- ) -> float:
- elem_right_bound = elem_left_bound + width
- elem_lower_bound = elem_top_bound + height
-
- win_left_bound = 0
- win_right_bound = config["win_width"]
- win_top_bound = 0
- win_lower_bound = config["win_height"]
-
- # Compute the overlap in x and y axes
- overlap_width = max(
- 0,
- min(elem_right_bound, win_right_bound)
- - max(elem_left_bound, win_left_bound),
- )
- overlap_height = max(
- 0,
- min(elem_lower_bound, win_lower_bound)
- - max(elem_top_bound, win_top_bound),
- )
-
- # Compute the overlap area
- ratio = overlap_width * overlap_height / width * height
- return ratio
-
- def fetch_page_html(
- self,
- info: BrowserInfo,
- page: Page,
- client: CDPSession,
- current_viewport_only: bool,
- ) -> DOMTree:
- # adopted from [natbot](https://github.com/nat/natbot)
- tree = info["DOMTree"]
- strings = tree["strings"]
- document = tree["documents"][0]
- nodes = document["nodes"]
-
- # make a dom tree that is easier to navigate
- dom_tree: DOMTree = []
- graph = defaultdict(list)
- for node_idx in range(len(nodes["nodeName"])):
- cur_node: DOMNode = {
- "nodeId": "",
- "nodeType": "",
- "nodeName": "",
- "nodeValue": "",
- "attributes": "",
- "backendNodeId": "",
- "parentId": "",
- "childIds": [],
- "cursor": 0,
- "union_bound": None,
- }
-
- node_type_idx = nodes["nodeType"][node_idx]
- node_type = "generic"
- if node_type_idx >= 0 and node_type_idx < len(strings):
- node_type = strings[node_type_idx]
-
- node_name = strings[nodes["nodeName"][node_idx]]
-
- node_value_idx = nodes["nodeValue"][node_idx]
- node_value = ""
- if node_value_idx >= 0 and node_value_idx < len(strings):
- node_value = " ".join(strings[node_value_idx].split())
-
- node_attributes = [
- strings[i] for i in nodes["attributes"][node_idx]
- ]
- node_attributes_str = ""
- for i in range(0, len(node_attributes), 2):
- a = node_attributes[i]
- b = node_attributes[i + 1]
- b = " ".join(b.split())
- node_attributes_str += f'{a}="{b}" '
- node_attributes_str = node_attributes_str.strip()
-
- cur_node["nodeId"] = str(node_idx)
- cur_node["nodeType"] = node_type
- cur_node["nodeName"] = node_name
- cur_node["nodeValue"] = node_value
- cur_node["attributes"] = node_attributes_str
- cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
- cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
-
- if cur_node["parentId"] != "-1":
- graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
-
- # get the bound
- if cur_node["parentId"] == "-1":
- cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
- else:
- response = self.get_bounding_client_rect(
- client, cur_node["backendNodeId"]
- )
- if response.get("result", {}).get("subtype", "") == "error":
- cur_node["union_bound"] = None
- else:
- x = response["result"]["value"]["x"]
- y = response["result"]["value"]["y"]
- width = response["result"]["value"]["width"]
- height = response["result"]["value"]["height"]
- cur_node["union_bound"] = [x, y, width, height]
-
- dom_tree.append(cur_node)
-
- # add parent children index to the node
- for parent_id, child_ids in graph.items():
- dom_tree[int(parent_id)]["childIds"] = child_ids
-
- # remove the nodes that are not in the current viewport
- if current_viewport_only:
-
- def remove_node_in_graph(node: DOMNode) -> None:
- # update the node information in the accessibility tree
- node_id = node["nodeId"]
- parent_id = node["parentId"]
- child_ids = node["childIds"]
-
- # update the children of the parent node
- assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
- # remove the nodeid from parent
- index = dom_tree[int(parent_id)]["childIds"].index(node_id)
- dom_tree[int(parent_id)]["childIds"].pop(index)
-
- # Insert children_nodeids in the same location
- for child_id in child_ids:
- dom_tree[int(parent_id)]["childIds"].insert(
- index, child_id
- )
- index += 1
-
- # update children node's parent
- for child_id in child_ids:
- dom_tree[int(child_id)]["parentId"] = parent_id
- # mark as removed
- dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
-
- config = info["config"]
- for cursor, node in enumerate(dom_tree):
- if not node["union_bound"]:
- remove_node_in_graph(node)
- continue
-
- [x, y, width, height] = node["union_bound"]
-
- # invisible node
- if width == 0.0 or height == 0.0:
- remove_node_in_graph(node)
- continue
-
- in_viewport_ratio = self.get_element_in_viewport_ratio(
- elem_left_bound=float(x),
- elem_top_bound=float(y),
- width=float(width),
- height=float(height),
- config=config,
- )
-
- if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
- remove_node_in_graph(node)
-
- dom_tree = [
- node
- for node in dom_tree
- if node.get("parentId", "-1") != "[REMOVED]"
- ]
-
- return dom_tree
-
- @staticmethod
- def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
- """Parse the html tree into a string text"""
-
- obs_nodes_info = {}
- nodeid_to_cursor = {
- node["nodeId"]: idx for idx, node in enumerate(dom_tree)
- }
-
- def dfs(node_cursor: int, depth: int) -> str:
- tree_str = ""
- node = dom_tree[node_cursor]
- indent = "\t" * depth
- valid_node = True
- try:
- node_str = f"[{node_cursor}] <{node['nodeName']}"
- if node["attributes"]:
- node_str += f" {node['attributes']}"
- node_str += f"> {node['nodeValue']}"
- valid_node = bool(node["attributes"] or node["nodeValue"])
-
- if valid_node:
- obs_nodes_info[str(node_cursor)] = {
- "backend_id": node["backendNodeId"],
- "union_bound": node["union_bound"],
- "text": node_str,
- }
- tree_str += f"{indent}{node_str}\n"
-
- except Exception as e:
- valid_node = False
-
- for child_ids in node["childIds"]:
- child_cursor = nodeid_to_cursor[child_ids]
- child_depth = depth + 1 if valid_node else depth
- child_str = dfs(child_cursor, child_depth)
- tree_str += child_str
-
- return tree_str
-
- html = dfs(0, 0)
- return html, obs_nodes_info
-
- def fetch_page_accessibility_tree(
- self,
- info: BrowserInfo,
- client: CDPSession,
- current_viewport_only: bool,
- ) -> AccessibilityTree:
- accessibility_tree: AccessibilityTree = client.send(
- "Accessibility.getFullAXTree", {}
- )["nodes"]
-
- # a few nodes are repeated in the accessibility tree
- seen_ids = set()
- _accessibility_tree = []
- for node in accessibility_tree:
- if node["nodeId"] not in seen_ids:
- _accessibility_tree.append(node)
- seen_ids.add(node["nodeId"])
- accessibility_tree = _accessibility_tree
-
- nodeid_to_cursor = {}
- for cursor, node in enumerate(accessibility_tree):
- nodeid_to_cursor[node["nodeId"]] = cursor
- # usually because the node is not visible etc
- if "backendDOMNodeId" not in node:
- node["union_bound"] = None
- continue
- backend_node_id = str(node["backendDOMNodeId"])
- if node["role"]["value"] == "RootWebArea":
- # always inside the viewport
- node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
- else:
- response = self.get_bounding_client_rect(
- client, backend_node_id
- )
- if response.get("result", {}).get("subtype", "") == "error":
- node["union_bound"] = None
- else:
- x = response["result"]["value"]["x"]
- y = response["result"]["value"]["y"]
- width = response["result"]["value"]["width"]
- height = response["result"]["value"]["height"]
- node["union_bound"] = [x, y, width, height]
-
- # filter nodes that are not in the current viewport
- if current_viewport_only:
-
- def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
- # update the node information in the accessibility tree
- nodeid = node["nodeId"]
- node_cursor = nodeid_to_cursor[nodeid]
- parent_nodeid = node["parentId"]
- children_nodeids = node["childIds"]
- parent_cursor = nodeid_to_cursor[parent_nodeid]
- # update the children of the parent node
- assert (
- accessibility_tree[parent_cursor].get("parentId", "Root")
- is not None
- )
- # remove the nodeid from parent's childIds
- index = accessibility_tree[parent_cursor]["childIds"].index(
- nodeid
- )
- accessibility_tree[parent_cursor]["childIds"].pop(index)
- # Insert children_nodeids in the same location
- for child_nodeid in children_nodeids:
- accessibility_tree[parent_cursor]["childIds"].insert(
- index, child_nodeid
- )
- index += 1
- # update children node's parent
- for child_nodeid in children_nodeids:
- child_cursor = nodeid_to_cursor[child_nodeid]
- accessibility_tree[child_cursor][
- "parentId"
- ] = parent_nodeid
- # mark as removed
- accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
-
- config = info["config"]
- for node in accessibility_tree:
- if not node["union_bound"]:
- remove_node_in_graph(node)
- continue
-
- [x, y, width, height] = node["union_bound"]
-
- # invisible node
- if width == 0 or height == 0:
- remove_node_in_graph(node)
- continue
-
- in_viewport_ratio = self.get_element_in_viewport_ratio(
- elem_left_bound=float(x),
- elem_top_bound=float(y),
- width=float(width),
- height=float(height),
- config=config,
- )
-
- if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
- remove_node_in_graph(node)
-
- accessibility_tree = [
- node
- for node in accessibility_tree
- if node.get("parentId", "Root") != "[REMOVED]"
- ]
-
- return accessibility_tree
-
- @staticmethod
- def parse_accessibility_tree(
- accessibility_tree: AccessibilityTree,
- ) -> tuple[str, dict[str, Any]]:
- """Parse the accessibility tree into a string text"""
- node_id_to_idx = {}
- for idx, node in enumerate(accessibility_tree):
- node_id_to_idx[node["nodeId"]] = idx
-
- obs_nodes_info = {}
-
- def dfs(idx: int, obs_node_id: str, depth: int) -> str:
- tree_str = ""
- node = accessibility_tree[idx]
- indent = "\t" * depth
- valid_node = True
- try:
- role = node["role"]["value"]
- name = node["name"]["value"]
- node_str = f"[{obs_node_id}] {role} {repr(name)}"
- properties = []
- for property in node.get("properties", []):
- try:
- if property["name"] in IGNORED_ACTREE_PROPERTIES:
- continue
- properties.append(
- f'{property["name"]}: {property["value"]["value"]}'
- )
- except KeyError:
- pass
-
- if properties:
- node_str += " " + " ".join(properties)
-
- # check valid
- if not node_str.strip():
- valid_node = False
-
- # empty generic node
- if not name.strip():
- if not properties:
- if role in [
- "generic",
- "img",
- "list",
- "strong",
- "paragraph",
- "banner",
- "navigation",
- "Section",
- "LabelText",
- "Legend",
- "listitem",
- ]:
- valid_node = False
- elif role in ["listitem"]:
- valid_node = False
-
- if valid_node:
- tree_str += f"{indent}{node_str}"
- obs_nodes_info[obs_node_id] = {
- "backend_id": node["backendDOMNodeId"],
- "union_bound": node["union_bound"],
- "text": node_str,
- }
-
- except Exception as e:
- valid_node = False
-
- for _, child_node_id in enumerate(node["childIds"]):
- if child_node_id not in node_id_to_idx:
- continue
- # mark this to save some tokens
- child_depth = depth + 1 if valid_node else depth
- child_str = dfs(
- node_id_to_idx[child_node_id], child_node_id, child_depth
- )
- if child_str.strip():
- if tree_str.strip():
- tree_str += "\n"
- tree_str += child_str
-
- return tree_str
-
- tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
- return tree_str, obs_nodes_info
-
- @staticmethod
- def clean_accesibility_tree(tree_str: str) -> str:
- """further clean accesibility tree"""
- clean_lines: list[str] = []
- for line in tree_str.split("\n"):
- if "statictext" in line.lower():
- prev_lines = clean_lines[-3:]
- pattern = r"\[\d+\] StaticText '([^']+)'"
-
- match = re.search(pattern, line)
- if match:
- static_text = match.group(1)
- if all(
- static_text not in prev_line
- for prev_line in prev_lines
- ):
- clean_lines.append(line)
- else:
- clean_lines.append(line)
-
- return "\n".join(clean_lines)
-
- def process(self, page: Page, client: CDPSession) -> str:
- # get the tab info
- open_tabs = page.context.pages
- union_bound_to_id_map = {}
- try:
- tab_titles = [tab.title() for tab in open_tabs]
- current_tab_idx = open_tabs.index(page)
- for idx in range(len(open_tabs)):
- if idx == current_tab_idx:
- tab_titles[
- idx
- ] = f"Tab {idx} (current): {open_tabs[idx].title()}"
- else:
- tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
- tab_title_str = " | ".join(tab_titles)
- except Exception:
- tab_title_str = " | ".join(
- ["Tab {idx}" for idx in range(len(open_tabs))]
- )
-
- try:
- browser_info = self.fetch_browser_info(page, client)
- except Exception:
- page.wait_for_load_state("load", timeout=500)
- browser_info = self.fetch_browser_info(page, client)
-
- if self.observation_type == "html":
- dom_tree = self.fetch_page_html(
- browser_info,
- page,
- client,
- current_viewport_only=self.current_viewport_only,
- )
- content, obs_nodes_info = self.parse_html(dom_tree)
- self.obs_nodes_info = obs_nodes_info
- self.meta_data["obs_nodes_info"] = obs_nodes_info
- for id, properties in reversed(list(obs_nodes_info.items())):
- bound_key = tuple(properties['union_bound'])
- union_bound_to_id_map[bound_key] = id
-
- elif self.observation_type == "accessibility_tree":
- accessibility_tree = self.fetch_page_accessibility_tree(
- browser_info,
- client,
- current_viewport_only=self.current_viewport_only,
- )
- content, obs_nodes_info = self.parse_accessibility_tree(
- accessibility_tree
- )
- content = self.clean_accesibility_tree(content)
- self.obs_nodes_info = obs_nodes_info
- self.meta_data["obs_nodes_info"] = obs_nodes_info
- for id, properties in reversed(list(obs_nodes_info.items())):
- bound_key = tuple(properties['union_bound'])
- union_bound_to_id_map[bound_key] = id
- else:
- raise ValueError(
- f"Invalid observatrion type: {self.observation_type}"
- )
-
- self.browser_config = browser_info["config"]
- content = f"{tab_title_str}\n\n{content}"
- return content,union_bound_to_id_map
-
- def get_element_center(self, element_id: str) -> tuple[float, float]:
- node_info = self.obs_nodes_info[element_id]
- node_bound = node_info["union_bound"]
- x, y, width, height = node_bound
- center_x = x + width / 2
- center_y = y + height / 2
- return (
- center_x / self.viewport_size["width"],
- center_y / self.viewport_size["height"],
- )
-
-
- class ImageObservationProcessor(ObservationProcessor):
- def __init__(self, observation_type: str):
- self.observation_type = observation_type
- self.observation_tag = "image"
- self.meta_data = create_empty_metadata()
-
- def update_viewport_size(self, new_viewport_size: ViewportSize) -> None:
- self.viewport_size = new_viewport_size
-
- def process(self, page: Page, client: CDPSession) -> npt.NDArray[np.uint8]:
- try:
- screenshot = png_bytes_to_numpy(page.screenshot())
-
- except:
- page.wait_for_event("load")
- screenshot = png_bytes_to_numpy(page.screenshot())
-
- return screenshot
-
-
- class ObservationHandler:
- """Main entry point to access all observation processor"""
-
- def __init__(
- self,
- main_observation_type: str,
- text_observation_type: str,
- image_observation_type: str,
- current_viewport_only: bool,
- viewport_size: ViewportSize,
- ) -> None:
- self.main_observation_type = main_observation_type
- self.text_processor = TextObervationProcessor(
- text_observation_type, current_viewport_only, viewport_size
- )
- self.image_processor = ImageObservationProcessor(
- image_observation_type
- )
- self.viewport_size = viewport_size
-
- def update_viewport_size(self, new_viewport_size: ViewportSize) -> None:
- self.viewport_size = new_viewport_size
- # 确保所有观察处理器也更新了他们的视口大小
- self.text_processor.update_viewport_size(new_viewport_size)
- self.image_processor.update_viewport_size(new_viewport_size)
-
- def get_observation_space(self) -> spaces.Dict:
- text_space = spaces.Text(
- min_length=0,
- max_length=UTTERANCE_MAX_LENGTH,
- charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
- )
-
- image_space = spaces.Box(
- # Each position stores the RGB values. Note the swapped axes (height first).
- np.zeros(
- (self.viewport_size["height"], self.viewport_size["width"], 3),
- dtype=np.uint8,
- ),
- np.ones(
- (self.viewport_size["height"], self.viewport_size["width"], 3),
- dtype=np.uint8,
- )
- * 255.0,
- dtype=np.uint8,
- )
-
- return spaces.Dict({"text": text_space, "image": image_space})
-
- def get_observation(
- self, page: Page, client: CDPSession
- ) -> dict[str, Observation]:
- text_obs,union_bound_to_id_map = self.text_processor.process(page, client)
- image_obs = self.image_processor.process(page, client)
- return {"text": text_obs, "image": image_obs, "map" :union_bound_to_id_map}
-
- def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
- return {
- "text": self.text_processor.meta_data,
- "image": self.image_processor.meta_data,
- }
-
- @property
- def action_processor(self) -> ObservationProcessor:
- """Return the main processor that is associated with the action space"""
- if self.main_observation_type == "text":
- return self.text_processor
- elif self.main_observation_type == "image":
- return self.image_processor
- else:
- raise ValueError("Invalid main observation type")
|