Bug 1411281 - Unmarshal all responses in WPT WebDriver client draft
authorAndreas Tolfsen <ato@sny.no>
Wed, 25 Oct 2017 14:18:26 +0100
changeset 688668 66a80dd6524b9bb393fbfe00082a8f72b660cf5a
parent 688667 afb95a8f64b674745f6346080a95d5dee3539c29
child 688669 e606001490ea131046e9bbd45371d29f1f0f7252
push id86818
push userbmo:ato@sny.no
push dateMon, 30 Oct 2017 13:42:30 +0000
bugs1411281
milestone58.0a1
Bug 1411281 - Unmarshal all responses in WPT WebDriver client The WPT WebDriver client currently only unmarshals responses for some commands (notably execute_script, execute_async_script, and find.css). For the client API we want to unmarshal all response bodies automatically. This patch moves all JSON serialisation/deserialisation to a new webdriver.protocol package so that it is not scattered around the client API. It introduces specialisations of JSONEncoder and JSONDecoder that allows web element references to be recognised and converted to complex webdriver.Element objects. This change means it is no longer necessary for callers to invoke webdriver.Session._element to convert the response to a web element as this will be done automatically on any request- and response body to webdriver.Sesson.send_command. An important thing to note is that HTTPWireProtocol.send does not follow this behaviour by default. That is because session.transport.send is used throughout WebDriver tests in WPT as a way to get the raw JSON body without having to set up session state manually. MozReview-Commit-ID: 5UyDAe43Hgf
testing/web-platform/tests/tools/webdriver/webdriver/client.py
testing/web-platform/tests/tools/webdriver/webdriver/protocol.py
testing/web-platform/tests/tools/webdriver/webdriver/transport.py
--- a/testing/web-platform/tests/tools/webdriver/webdriver/client.py
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/client.py
@@ -1,12 +1,12 @@
-import json
 import urlparse
 
 import error
+import protocol
 import transport
 
 from mozlog import get_default_logger
 
 logger = get_default_logger()
 
 
 def command(func):
@@ -140,17 +140,17 @@ class ActionSequence(object):
         action = {
             "type": "pointerMove",
             "x": x,
             "y": y
         }
         if duration is not None:
             action["duration"] = duration
         if origin is not None:
-            action["origin"] = origin if isinstance(origin, basestring) else origin.json()
+            action["origin"] = origin
         self._actions.append(action)
         return self
 
     def pointer_up(self, button=0):
         """Queue a pointerUp action for `button`.
 
         :param button: Pointer button to perform action with.
                        Default: 0, which represents main device button.
@@ -292,28 +292,19 @@ class Find(object):
         self.session = session
 
     @command
     def css(self, selector, all=True):
         return self._find_element("css selector", selector, all)
 
     def _find_element(self, strategy, selector, all):
         route = "elements" if all else "element"
-
         body = {"using": strategy,
                 "value": selector}
-
-        data = self.session.send_session_command("POST", route, body)
-
-        if all:
-            rv = [self.session._element(item) for item in data]
-        else:
-            rv = self.session._element(data)
-
-        return rv
+        return self.session.send_session_command("POST", route, body)
 
 
 class Cookies(object):
     def __init__(self, session):
         self.session = session
 
     def __getitem__(self, name):
         self.session.send_session_command("GET", "cookie/%s" % name, {})
@@ -427,30 +418,34 @@ class Session(object):
         :param uri: "Command part" of the HTTP request URL,
             e.g. `window/rect`.
         :param body: Optional body of the HTTP request.
 
         :return: `None` if the HTTP response body was empty, otherwise
             the `value` field returned after parsing the response
             body as JSON.
 
+        :raises ValueError: If the response body does not contain a
+            `value` key.
         :raises error.WebDriverException: If the remote end returns
             an error.
         """
-        response = self.transport.send(method, url, body)
+        response = self.transport.send(
+            method, url, body,
+            encoder=protocol.Encoder, decoder=protocol.Decoder,
+            session=self)
 
         if response.status != 200:
             raise error.from_response(response)
 
         if "value" in response.body:
             value = response.body["value"]
         else:
-            raise error.UnknownErrorException(
-                "Expected 'value' key in response body:\n"
-                "%s" % json.dumps(response.body))
+            raise ValueError("Expected 'value' key in response body:\n"
+                "%s" % response)
 
         return value
 
     def send_session_command(self, method, uri, body=None):
         """
         Send a command to an established session and validate its success.
 
         :param method: HTTP method to use in request.
@@ -515,45 +510,33 @@ class Session(object):
         return self.send_session_command("POST", "window", body=body)
 
     def switch_frame(self, frame):
         if frame == "parent":
             url = "frame/parent"
             body = None
         else:
             url = "frame"
-            if isinstance(frame, Element):
-                body = {"id": frame.json()}
-            else:
-                body = {"id": frame}
+            body = {"id": frame}
 
         return self.send_session_command("POST", url, body)
 
     @command
     def close(self):
         return self.send_session_command("DELETE", "window")
 
     @property
     @command
     def handles(self):
         return self.send_session_command("GET", "window/handles")
 
     @property
     @command
     def active_element(self):
-        data = self.send_session_command("GET", "element/active")
-        if data is not None:
-            return self._element(data)
-
-    def _element(self, data):
-        elem_id = data[Element.identifier]
-        assert elem_id
-        if elem_id in self._element_cache:
-            return self._element_cache[elem_id]
-        return Element(elem_id, self)
+        return self.send_session_command("GET", "element/active")
 
     @command
     def cookies(self, name=None):
         if name is None:
             url = "cookie"
         else:
             url = "cookie/%s" % name
         return self.send_session_command("GET", url, {})
@@ -632,30 +615,33 @@ class Element(object):
 
         assert id not in self.session._element_cache
         self.session._element_cache[self.id] = self
 
     def __eq__(self, other):
         return isinstance(other, Element) and self.id == other.id \
                 and self.session == other.session
 
+    @classmethod
+    def from_json(cls, json, session):
+        assert Element.identifier in json
+        uuid = json[Element.identifier]
+        if uuid in session._element_cache:
+            return session._element_cache[uuid]
+        return cls(uuid, session)
+
     def send_element_command(self, method, uri, body=None):
         url = "element/%s/%s" % (self.id, uri)
         return self.session.send_session_command(method, url, body)
 
-    def json(self):
-        return {Element.identifier: self.id}
-
     @command
     def find_element(self, strategy, selector):
         body = {"using": strategy,
                 "value": selector}
-
-        elem = self.send_element_command("POST", "element", body)
-        return self.session._element(elem)
+        return self.send_element_command("POST", "element", body)
 
     @command
     def click(self):
         self.send_element_command("POST", "click", {})
 
     @command
     def tap(self):
         self.send_element_command("POST", "tap", {})
new file mode 100644
--- /dev/null
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/protocol.py
@@ -0,0 +1,35 @@
+import json
+
+import webdriver
+
+
+"""WebDriver wire protocol codecs."""
+
+
+class Encoder(json.JSONEncoder):
+    def __init__(self, *args, **kwargs):
+        kwargs.pop("session")
+        super(Encoder, self).__init__(*args, **kwargs)
+
+    def default(self, obj):
+        if isinstance(obj, (list, tuple)):
+            return [self.default(x) for x in obj]
+        elif isinstance(obj, webdriver.Element):
+            return {webdriver.Element.identifier: obj.id}
+        return super(ProtocolEncoder, self).default(obj)
+
+
+class Decoder(json.JSONDecoder):
+    def __init__(self, *args, **kwargs):
+        self.session = kwargs.pop("session")
+        super(Decoder, self).__init__(
+            object_hook=self.object_hook, *args, **kwargs)
+
+    def object_hook(self, payload):
+        if isinstance(payload, (list, tuple)):
+            return [self.object_hook(x) for x in payload]
+        elif isinstance(payload, dict) and webdriver.Element.identifier in payload:
+            return webdriver.Element.from_json(payload, self.session)
+        elif isinstance(payload, dict):
+            return {k: self.object_hook(v) for k, v in payload.iteritems()}
+        return payload
--- a/testing/web-platform/tests/tools/webdriver/webdriver/transport.py
+++ b/testing/web-platform/tests/tools/webdriver/webdriver/transport.py
@@ -1,70 +1,77 @@
 import httplib
 import json
 import urlparse
 
 import error
 
 
+"""Implements HTTP transport for the WebDriver wire protocol."""
+
+
 class Response(object):
     """
     Describes an HTTP response received from a remote end whose
     body has been read and parsed as appropriate.
     """
 
     def __init__(self, status, body):
         self.status = status
         self.body = body
 
     def __repr__(self):
-        cls_name = self.__class__.__name__
         if self.error:
             return "<%s status=%s error=%s>" % (cls_name, self.status, repr(self.error))
-        return "<% status=%s body=%s>" % (cls_name, self.status, self.body)
+        return "<% status=%s body=%s>" % (cls_name, self.status, json.dumps(self.body))
+
+    def __str__(self):
+        return json.dumps(self.body, indent=2)
 
     @property
     def error(self):
         if self.status != 200:
             return error.from_response(self)
         return None
 
     @classmethod
-    def from_http_response(cls, http_response):
-        status = http_response.status
-        body = http_response.read()
+    def from_http(cls, http_response, decoder=json.JSONDecoder, **kwargs):
+        try:
+            body = json.load(http_response, cls=decoder, **kwargs)
+        except ValueError:
+            raise ValueError("Failed to decode response body as JSON:\n"
+                "%s" % json.dumps(body, indent=2))
 
-        # SpecID: dfn-send-a-response
-        #
-        # > 3. Set the response's header with name and value with the following
-        # >    values:
-        # >
-        # >    "Content-Type"
-        # >       "application/json; charset=utf-8"
-        # >    "cache-control"
-        # >       "no-cache"
-
-        if body:
-            try:
-                body = json.loads(body)
-            except:
-                raise error.UnknownErrorException("Failed to decode body as json:\n%s" % body)
-
-        return cls(status, body)
-
-
-class ToJsonEncoder(json.JSONEncoder):
-    def default(self, obj):
-        return getattr(obj.__class__, "json", json.JSONEncoder().default)(obj)
+        return cls(http_response.status, body)
 
 
 class HTTPWireProtocol(object):
     """
     Transports messages (commands and responses) over the WebDriver
     wire protocol.
+
+    Complex objects, such as ``webdriver.Element``, are by default
+    not marshaled to enable use of `session.transport.send` in WPT tests::
+
+        session = webdriver.Session("127.0.0.1", 4444)
+        response = transport.send("GET", "element/active", None)
+        print response.body["value"]
+        # => {u'element-6066-11e4-a52e-4f735466cecf': u'<uuid>'}
+
+    Automatic marshaling is provided by ``webdriver.protocol.Encoder``
+    and ``webdriver.protocol.Decoder``, which can be passed in to
+    ``HTTPWireProtocol.send`` along with a reference to the current
+    ``webdriver.Session``::
+
+        session = webdriver.Session("127.0.0.1", 4444)
+        response = transport.send("GET", "element/active", None,
+            encoder=protocol.Encoder, decoder=protocol.Decoder,
+            session=session)
+        print response.body["value"]
+        # => webdriver.Element
     """
 
     def __init__(self, host, port, url_prefix="/", timeout=None):
         """
         Construct interface for communicating with the remote server.
 
         :param url: URL of remote WebDriver server.
         :param wait: Duration to wait for remote to appear.
@@ -73,48 +80,79 @@ class HTTPWireProtocol(object):
         self.port = port
         self.url_prefix = url_prefix
 
         self._timeout = timeout
 
     def url(self, suffix):
         return urlparse.urljoin(self.url_prefix, suffix)
 
-    def send(self, method, uri, body=None, headers=None):
+    def send(self,
+             method,
+             uri,
+             body=None,
+             headers=None,
+             encoder=json.JSONEncoder,
+             decoder=json.JSONDecoder,
+             **codec_kwargs):
         """
         Send a command to the remote.
 
+        The request `body` must be JSON serialisable unless a
+        custom `encoder` has been provided.  This means complex
+        objects such as ``webdriver.Element`` are not automatically
+        made into JSON.  This behaviour is, however, provided by
+        ``webdriver.protocol.Encoder``, should you want it.
+
+        Similarly, the response body is returned au natural
+        as plain JSON unless a `decoder` that converts web
+        element references to ``webdriver.Element`` is provided.
+        Use ``webdriver.protocol.Decoder`` to achieve this behaviour.
+
         :param method: `GET`, `POST`, or `DELETE`.
         :param uri: Relative endpoint of the requests URL path.
         :param body: Body of the request.  Defaults to an empty
             dictionary if ``method`` is `POST`.
-        :param headers: Additional headers to include in the request.
+        :param headers: Additional dictionary of headers to include
+            in the request.
+        :param encoder: JSON encoder class, which defaults to
+            ``json.JSONEncoder`` unless specified.
+        :param decoder: JSON decoder class, which defaults to
+            ``json.JSONDecoder`` unless specified.
+        :param codec_kwargs: Surplus arguments passed on to `encoder`
+            and `decoder` on construction.
 
-        :return: Instance of ``wdclient.Response`` describing the
-            HTTP response received from the remote end.
+        :return: Instance of ``webdriver.transport.Response``
+            describing the HTTP response received from the remote end.
+
+        :raises ValueError: If `body` or the response body are not
+            JSON serialisable.
         """
         if body is None and method == "POST":
             body = {}
 
-        if isinstance(body, dict):
-            body = json.dumps(body, cls=ToJsonEncoder)
-
-        if isinstance(body, unicode):
-            body = body.encode("utf-8")
+        try:
+            payload = json.dumps(body, cls=encoder, **codec_kwargs)
+        except ValueError:
+            raise ValueError("Failed to encode request body as JSON:\n"
+                "%s" % json.dumps(body, indent=2))
+        if isinstance(payload, unicode):
+            payload = body.encode("utf-8")
 
         if headers is None:
             headers = {}
 
         url = self.url(uri)
 
-        kwargs = {}
+        conn_kwargs = {}
         if self._timeout is not None:
-            kwargs["timeout"] = self._timeout
+            conn_kwargs["timeout"] = self._timeout
 
         conn = httplib.HTTPConnection(
-            self.host, self.port, strict=True, **kwargs)
-        conn.request(method, url, body, headers)
+            self.host, self.port, strict=True, **conn_kwargs)
+        conn.request(method, url, payload, headers)
 
         try:
             response = conn.getresponse()
-            return Response.from_http_response(response)
+            return Response.from_http(
+                response, decoder=decoder, **codec_kwargs)
         finally:
             conn.close()