1
0
mirror of https://github.com/janeczku/calibre-web synced 2024-11-01 07:36:20 +00:00
calibre-web/cps/cw_advocate/connection.py

202 lines
7.0 KiB
Python
Raw Normal View History

#
# Copyright 2015 Jordan Milne
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Source: https://github.com/JordanMilne/Advocate
import ipaddress
import socket
from socket import timeout as SocketTimeout
from urllib3.connection import HTTPSConnection, HTTPConnection
from urllib3.exceptions import ConnectTimeoutError
from urllib3.util.connection import _set_socket_options
from urllib3.util.connection import create_connection as old_create_connection
from . import addrvalidator
from .exceptions import UnacceptableAddressException
def advocate_getaddrinfo(host, port, get_canonname=False):
addrinfo = socket.getaddrinfo(
host,
port,
0,
socket.SOCK_STREAM,
0,
# We need what the DNS client sees the hostname as, correctly handles
# IDNs and tricky things like `private.foocorp.org\x00.google.com`.
# All IDNs will be converted to punycode.
socket.AI_CANONNAME if get_canonname else 0,
)
return fix_addrinfo(addrinfo)
def fix_addrinfo(records):
"""
Propagate the canonname across records and parse IPs
I'm not sure if this is just the behaviour of `getaddrinfo` on Linux, but
it seems like only the first record in the set has the canonname field
populated.
"""
def fix_record(record, canonname):
sa = record[4]
sa = (ipaddress.ip_address(sa[0]),) + sa[1:]
return record[0], record[1], record[2], canonname, sa
canonname = None
if records:
# Apparently the canonical name is only included in the first record?
# Add it to all of them.
assert(len(records[0]) == 5)
canonname = records[0][3]
return tuple(fix_record(x, canonname) for x in records)
# Lifted from requests' urllib3, which in turn lifted it from `socket.py`. Oy!
def validating_create_connection(address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None, socket_options=None,
validator=None):
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
# We can skip asking for the canon name if we're not doing hostname-based
# blacklisting.
need_canonname = False
if validator.hostname_blacklist:
need_canonname = True
# We check both the non-canonical and canonical hostnames so we can
# catch both of these:
# CNAME from nonblacklisted.com -> blacklisted.com
# CNAME from blacklisted.com -> nonblacklisted.com
if not validator.is_hostname_allowed(host):
raise UnacceptableAddressException(host)
err = None
addrinfo = advocate_getaddrinfo(host, port, get_canonname=need_canonname)
if addrinfo:
if validator.autodetect_local_addresses:
local_addresses = addrvalidator.determine_local_addresses()
else:
local_addresses = []
for res in addrinfo:
# Are we allowed to connect with this result?
if not validator.is_addrinfo_allowed(
res,
_local_addresses=local_addresses,
):
continue
af, socktype, proto, canonname, sa = res
# Unparse the validated IP
sa = (sa[0].exploded,) + sa[1:]
sock = None
try:
sock = socket.socket(af, socktype, proto)
# If provided, set socket level options before connecting.
# This is the only addition urllib3 makes to this function.
_set_socket_options(sock, socket_options)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
sock = None
if err is None:
# If we got here, none of the results were acceptable
err = UnacceptableAddressException(address)
if err is not None:
raise err
else:
raise socket.error("getaddrinfo returns an empty list")
# TODO: Is there a better way to add this to multiple classes with different
# base classes? I tried a mixin, but it used the base method instead.
def _validating_new_conn(self):
""" Establish a socket connection and set nodelay settings on it.
:return: New socket connection.
"""
extra_kw = {}
if self.source_address:
extra_kw['source_address'] = self.source_address
if self.socket_options:
extra_kw['socket_options'] = self.socket_options
try:
# Hack around HTTPretty's patched sockets
# TODO: some better method of hacking around it that checks if we
# _would have_ connected to a private addr?
conn_func = validating_create_connection
if socket.getaddrinfo.__module__.startswith("httpretty"):
conn_func = old_create_connection
else:
extra_kw["validator"] = self._validator
conn = conn_func(
(self.host, self.port),
self.timeout,
**extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self, "Connection to %s timed out. (connect timeout=%s)" %
(self.host, self.timeout))
return conn
# Don't silently break if the private API changes across urllib3 versions
assert(hasattr(HTTPConnection, '_new_conn'))
assert(hasattr(HTTPSConnection, '_new_conn'))
class ValidatingHTTPConnection(HTTPConnection):
_new_conn = _validating_new_conn
def __init__(self, *args, **kwargs):
self._validator = kwargs.pop("validator")
HTTPConnection.__init__(self, *args, **kwargs)
class ValidatingHTTPSConnection(HTTPSConnection):
_new_conn = _validating_new_conn
def __init__(self, *args, **kwargs):
self._validator = kwargs.pop("validator")
HTTPSConnection.__init__(self, *args, **kwargs)