From 567dfd76b331b3bcb3a81946377ee51f2c48da70 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Moise=CC=81s=20Guimara=CC=83es?= <moises@wolfssl.com>
Date: Tue, 6 Dec 2016 12:11:02 -0300
Subject: [PATCH] adds initial code for SSLSocket

---
 wrapper/python/wolfssl/wolfssl/__init__.py | 361 ++++++++++++++++++++-
 1 file changed, 356 insertions(+), 5 deletions(-)

diff --git a/wrapper/python/wolfssl/wolfssl/__init__.py b/wrapper/python/wolfssl/wolfssl/__init__.py
index c3f9b3caf..72f8502e2 100644
--- a/wrapper/python/wolfssl/wolfssl/__init__.py
+++ b/wrapper/python/wolfssl/wolfssl/__init__.py
@@ -154,7 +154,11 @@ class SSLContext(object):
         private key in.
         """
 
-        if certfile:
+        if password is not None:
+            raise NotImplementedError("password callback support not "
+                                      "implemented yet")
+
+        if certfile is not None:
             ret = _lib.wolfSSL_CTX_use_certificate_chain_file(
                 self.native_object, t2b(certfile))
             if ret != _SSL_SUCCESS:
@@ -162,7 +166,7 @@ class SSLContext(object):
         else:
             raise TypeError("certfile should be a valid filesystem path")
 
-        if keyfile:
+        if keyfile is not None:
             ret = _lib.wolfSSL_CTX_use_PrivateKey_file(
                 self.native_object, t2b(keyfile), _SSL_FILETYPE_PEM)
             if ret != _SSL_SUCCESS:
@@ -185,7 +189,7 @@ class SSLContext(object):
         if cafile is None and capath is None and cadata is None:
             raise TypeError("cafile, capath and cadata cannot be all omitted")
 
-        if cafile or capath:
+        if cafile is not None or capath is not None:
             ret = _lib.wolfSSL_CTX_load_verify_locations(
                 self.native_object,
                 t2b(cafile) if cafile else _ffi.NULL,
@@ -194,13 +198,14 @@ class SSLContext(object):
             if ret != _SSL_SUCCESS:
                 raise SSLError("Unnable to load verify locations. Err %d" % ret)
 
-        if cadata:
+        if cadata is not None:
             ret = _lib.wolfSSL_CTX_load_verify_buffer(
                 self.native_object, t2b(cadata), len(cadata), _SSL_FILETYPE_PEM)
 
             if ret != _SSL_SUCCESS:
                 raise SSLError("Unnable to load verify locations. Err %d" % ret)
 
+
 class SSLSocket(socket):
     """
     This class implements a subtype of socket.socket that wraps the
@@ -215,7 +220,353 @@ class SSLSocket(socket):
                  sock_type=SOCK_STREAM, proto=0, fileno=None,
                  suppress_ragged_eofs=True, ciphers=None,
                  _context=None):
-        pass
+
+        # set options
+        self.do_handshake_on_connect = do_handshake_on_connect
+        self.suppress_ragged_eofs = suppress_ragged_eofs
+        self.server_side = server_side
+
+        # set context
+        if _context:
+            self._context = _context
+        else:
+            if server_side and not certfile:
+                raise ValueError("certfile must be specified for server-side "
+                                 "operations")
+
+            if keyfile and not certfile:
+                raise ValueError("certfile must be specified")
+
+            if certfile and not keyfile:
+                keyfile = certfile
+
+            self._context = SSLContext(ssl_version, server_side)
+            self._context.verify_mode = cert_reqs
+            if ca_certs:
+                self._context.load_verify_locations(ca_certs)
+            if certfile:
+                self._context.load_cert_chain(certfile, keyfile)
+            if ciphers:
+                self._context.set_ciphers(ciphers)
+
+            self.keyfile = keyfile
+            self.certfile = certfile
+            self.cert_reqs = cert_reqs
+            self.ssl_version = ssl_version
+            self.ca_certs = ca_certs
+            self.ciphers = ciphers
+
+        # preparing socket
+        if sock is not None:
+            # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
+            # mixed in.
+            if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
+                raise NotImplementedError("only stream sockets are supported")
+
+            socket.__init__(self,
+                            family=sock.family,
+                            sock_type=sock.type,
+                            proto=sock.proto,
+                            fileno=sock.fileno())
+            self.settimeout(sock.gettimeout())
+            sock.detach()
+
+        elif fileno is not None:
+            socket.__init__(self, fileno=fileno)
+
+        else:
+            socket.__init__(self, family=family, sock_type=sock_type,
+                            proto=proto)
+
+        # See if we are connected
+        try:
+            self.getpeername()
+        except OSError as exception:
+            if exception.errno != errno.ENOTCONN:
+                raise
+            connected = False
+        else:
+            connected = True
+
+        self._closed = False
+        self.native_object = _ffi.NULL
+        self._connected = connected
+
+        if connected:
+            # create the SSL object
+            try:
+                self.native_object = \
+                    _lib.wolfSSL_new(self.context.native_object)
+                if self.native_object == _ffi.NULL:
+                    raise MemoryError("Unnable to allocate ssl object")
+
+                ret = _lib.wolfSSL_set_fd(self.native_object, self.fileno)
+                if ret != _SSL_SUCCESS:
+                    raise ValueError("Unnable to set fd to ssl object")
+
+                if do_handshake_on_connect:
+                    self.do_handshake()
+            except (OSError, ValueError):
+                self.close()
+                raise
+
+
+    @property
+    def context(self):
+        """
+        Returns the context used by this object.
+        """
+        return self._context
+
+
+    def dup(self):
+        raise NotImplementedError("Can't dup() %s instances" %
+                                  self.__class__.__name__)
+
+
+    def _check_connected(self):
+        if not self._connected:
+            # getpeername() will raise ENOTCONN if the socket is really
+            # not connected; note that we can be connected even without
+            # _connected being set, e.g. if connect() first returned
+            # EAGAIN.
+            self.getpeername()
+
+
+    def write(self, data):
+        """
+        Write DATA to the underlying SSL channel. Returns
+        number of bytes of DATA actually transmitted.
+        """
+
+        if self.native_object == _ffi.NULL:
+            raise ValueError("Write on closed or unwrapped SSL socket")
+
+        data = t2b(data)
+
+        return _lib.wolfSSL_write(self.native_object, data, len(data))
+
+
+    def send(self, data, flags=0):
+        if self.native_object != _ffi.NULL:
+            if flags != 0:
+                raise ValueError(
+                    "non-zero flags not allowed in calls to send() on %s" %
+                    self.__class__)
+            return self.write(data)
+        else:
+            return socket.send(self, data, flags)
+
+
+    def sendto(self, data, flags_or_addr, addr=None):
+        if self.native_object != _ffi.NULL:
+            raise ValueError("sendto not allowed on instances of %s" %
+                             self.__class__)
+        elif addr is None:
+            return socket.sendto(self, data, flags_or_addr)
+        else:
+            return socket.sendto(self, data, flags_or_addr, addr)
+
+
+    def sendmsg(self, *args, **kwargs):
+        # Ensure programs don't send data unencrypted if they try to
+        # use this method.
+        raise NotImplementedError("sendmsg not allowed on instances of %s" %
+                                  self.__class__)
+
+
+    def sendall(self, data, flags=0):
+        if self.native_object != _ffi.NULL:
+            if flags != 0:
+                raise ValueError(
+                    "non-zero flags not allowed in calls to sendall() on %s" %
+                    self.__class__)
+
+            amount = len(data)
+            count = 0
+            while count < amount:
+                sent = self.send(data[count:])
+                count += sent
+            return amount
+        else:
+            return socket.sendall(self, data, flags)
+
+
+    def sendfile(self, file, offset=0, count=None):
+        """
+        Send a file, possibly by using os.sendfile() if this is a
+        clear-text socket. Return the total number of bytes sent.
+        """
+        # Ensure programs don't send unencrypted files if they try to
+        # use this method.
+        raise NotImplementedError("sendfile not allowed on instances of %s" %
+                                  self.__class__)
+
+
+    def read(self, length=1024, buffer=None):
+        """
+        Read up to LEN bytes and return them.
+        Return zero-length string on EOF.
+        """
+
+        if self.native_object == _ffi.NULL:
+            raise ValueError("Read on closed or unwrapped SSL socket")
+        
+        data = t2b("\0" * length)
+        length = _lib.WolfSSL_read(self.native_object, data, length)
+
+        if buffer is not None:
+            buffer.write(data, length)
+            return length
+        else:
+            raise MemoryError("")
+
+            return self._sslobj.read(len, buffer)
+        except SSLError as exception:
+            if exception.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
+                if buffer is not None:
+                    return 0
+                else:
+                    return b''
+            else:
+                raise
+
+
+    def recv(self, buflen=1024, flags=0):
+        self._checkClosed()
+        if self._sslobj:
+            if flags != 0:
+                raise ValueError(
+                    "non-zero flags not allowed in calls to recv() on %s" %
+                    self.__class__)
+            return self.read(buflen)
+        else:
+            return socket.recv(self, buflen, flags)
+
+
+    def recv_into(self, buffer, nbytes=None, flags=0):
+        self._checkClosed()
+        if buffer and (nbytes is None):
+            nbytes = len(buffer)
+        elif nbytes is None:
+            nbytes = 1024
+        if self._sslobj:
+            if flags != 0:
+                raise ValueError(
+                    "non-zero flags not allowed in calls to recv_into() on %s"
+                    % self.__class__)
+            return self.read(nbytes, buffer)
+        else:
+            return socket.recv_into(self, buffer, nbytes, flags)
+
+
+    def recvfrom(self, buflen=1024, flags=0):
+        self._checkClosed()
+        if self._sslobj:
+            raise ValueError("recvfrom not allowed on instances of %s" %
+                             self.__class__)
+        else:
+            return socket.recvfrom(self, buflen, flags)
+
+
+    def recvfrom_into(self, buffer, nbytes=None, flags=0):
+        self._checkClosed()
+        if self._sslobj:
+            raise ValueError("recvfrom_into not allowed on instances of %s" %
+                             self.__class__)
+        else:
+            return socket.recvfrom_into(self, buffer, nbytes, flags)
+
+
+    def recvmsg(self, *args, **kwargs):
+        raise NotImplementedError("recvmsg not allowed on instances of %s" %
+                                  self.__class__)
+
+
+    def recvmsg_into(self, *args, **kwargs):
+        raise NotImplementedError("recvmsg_into not allowed on instances of "
+                                  "%s" % self.__class__)
+
+
+    def shutdown(self, how):
+        self._checkClosed()
+        self._sslobj = None
+        socket.shutdown(self, how)
+
+
+    def unwrap(self):
+        if self._sslobj:
+            s = self._sslobj.unwrap()
+            self._sslobj = None
+            return s
+        else:
+            raise ValueError("No SSL wrapper around " + str(self))
+
+    def _real_close(self):
+        self._sslobj = None
+        socket._real_close(self)
+
+    def do_handshake(self, block=False):
+        """Perform a TLS/SSL handshake."""
+        self._check_connected()
+        timeout = self.gettimeout()
+        try:
+            if timeout == 0.0 and block:
+                self.settimeout(None)
+            self._sslobj.do_handshake()
+        finally:
+            self.settimeout(timeout)
+
+
+    def _real_connect(self, addr, connect_ex):
+        if self.server_side:
+            raise ValueError("can't connect in server-side mode")
+        # Here we assume that the socket is client-side, and not
+        # connected at the time of the call.  We connect it, then wrap it.
+        if self._connected:
+            raise ValueError("attempt to connect already-connected SSLSocket!")
+        sslobj = self.context._wrap_socket(self, False, self.server_hostname)
+        self._sslobj = SSLObject(sslobj, owner=self)
+        try:
+            if connect_ex:
+                rc = socket.connect_ex(self, addr)
+            else:
+                rc = None
+                socket.connect(self, addr)
+            if not rc:
+                self._connected = True
+                if self.do_handshake_on_connect:
+                    self.do_handshake()
+            return rc
+        except (OSError, ValueError):
+            self._sslobj = None
+            raise
+
+
+    def connect(self, addr):
+        """Connects to remote ADDR, and then wraps the connection in
+        an SSL channel."""
+        self._real_connect(addr, False)
+
+
+    def connect_ex(self, addr):
+        """Connects to remote ADDR, and then wraps the connection in
+        an SSL channel."""
+        return self._real_connect(addr, True)
+
+
+    def accept(self):
+        """Accepts a new connection from a remote client, and returns
+        a tuple containing that new connection wrapped with a server-side
+        SSL channel, and the address of the remote client."""
+
+        newsock, addr = socket.accept(self)
+        newsock = self.context.wrap_socket(
+            newsock,
+            do_handshake_on_connect=self.do_handshake_on_connect,
+            suppress_ragged_eofs=self.suppress_ragged_eofs,
+            server_side=True)
+        return newsock, addr
 
 
 def wrap_socket(sock, keyfile=None, certfile=None, server_side=False,