Ausgabe der neuen DB Einträge

This commit is contained in:
hubobel 2022-01-02 21:50:48 +01:00
parent bad48e1627
commit cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions

View file

@ -0,0 +1,7 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web}.
"""

View file

@ -0,0 +1,103 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
General helpers for L{twisted.web} unit tests.
"""
from __future__ import division, absolute_import
from twisted.internet.defer import succeed
from twisted.web import server
from twisted.trial.unittest import TestCase
from twisted.python.failure import Failure
from twisted.web._flatten import flattenString
from twisted.web.error import FlattenerError
def _render(resource, request):
result = resource.render(request)
if isinstance(result, bytes):
request.write(result)
request.finish()
return succeed(None)
elif result is server.NOT_DONE_YET:
if request.finished:
return succeed(None)
else:
return request.notifyFinish()
else:
raise ValueError("Unexpected return value: %r" % (result,))
class FlattenTestCase(TestCase):
"""
A test case that assists with testing L{twisted.web._flatten}.
"""
def assertFlattensTo(self, root, target):
"""
Assert that a root element, when flattened, is equal to a string.
"""
d = flattenString(None, root)
d.addCallback(lambda s: self.assertEqual(s, target))
return d
def assertFlattensImmediately(self, root, target):
"""
Assert that a root element, when flattened, is equal to a string, and
performs no asynchronus Deferred anything.
This version is more convenient in tests which wish to make multiple
assertions about flattening, since it can be called multiple times
without having to add multiple callbacks.
@return: the result of rendering L{root}, which should be equivalent to
L{target}.
@rtype: L{bytes}
"""
results = []
it = self.assertFlattensTo(root, target)
it.addBoth(results.append)
# Do our best to clean it up if something goes wrong.
self.addCleanup(it.cancel)
if not results:
self.fail("Rendering did not complete immediately.")
result = results[0]
if isinstance(result, Failure):
result.raiseException()
return results[0]
def assertFlatteningRaises(self, root, exn):
"""
Assert flattening a root element raises a particular exception.
"""
d = self.assertFailure(self.assertFlattensTo(root, b''), FlattenerError)
d.addCallback(lambda exc: self.assertIsInstance(exc._exception, exn))
return d
def assertIsFilesystemTemporary(case, fileObj):
"""
Assert that C{fileObj} is a temporary file on the filesystem.
@param case: A C{TestCase} instance to use to make the assertion.
@raise: C{case.failureException} if C{fileObj} is not a temporary file on
the filesystem.
"""
# The tempfile API used to create content returns an instance of a
# different type depending on what platform we're running on. The point
# here is to verify that the request body is in a file that's on the
# filesystem. Having a fileno method that returns an int is a somewhat
# close approximation of this. -exarkun
case.assertIsInstance(fileObj.fileno(), int)
__all__ = ["_render", "FlattenTestCase", "assertIsFilesystemTemporary"]

View file

@ -0,0 +1,168 @@
"""
Helpers for URI and method injection tests.
@see: U{CVE-2019-12387}
"""
import string
UNPRINTABLE_ASCII = (
frozenset(range(0, 128)) -
frozenset(bytearray(string.printable, 'ascii'))
)
NONASCII = frozenset(range(128, 256))
class MethodInjectionTestsMixin(object):
"""
A mixin that runs HTTP method injection tests. Define
L{MethodInjectionTestsMixin.attemptRequestWithMaliciousMethod} in
a L{twisted.trial.unittest.SynchronousTestCase} subclass to test
how HTTP client code behaves when presented with malicious HTTP
methods.
@see: U{CVE-2019-12387}
"""
def attemptRequestWithMaliciousMethod(self, method):
"""
Attempt to send a request with the given method. This should
synchronously raise a L{ValueError} if either is invalid.
@param method: the method (e.g. C{GET\x00})
@param uri: the URI
@type method:
"""
raise NotImplementedError()
def test_methodWithCLRFRejected(self):
"""
Issuing a request with a method that contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
method = b"GET\r\nX-Injected-Header: value"
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
def test_methodWithUnprintableASCIIRejected(self):
"""
Issuing a request with a method that contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
method = b"GET%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
def test_methodWithNonASCIIRejected(self):
"""
Issuing a request with a method that contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
method = b"GET%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
class URIInjectionTestsMixin(object):
"""
A mixin that runs HTTP URI injection tests. Define
L{MethodInjectionTestsMixin.attemptRequestWithMaliciousURI} in a
L{twisted.trial.unittest.SynchronousTestCase} subclass to test how
HTTP client code behaves when presented with malicious HTTP
URIs.
"""
def attemptRequestWithMaliciousURI(self, method):
"""
Attempt to send a request with the given URI. This should
synchronously raise a L{ValueError} if either is invalid.
@param uri: the URI.
@type method:
"""
raise NotImplementedError()
def test_hostWithCRLFRejected(self):
"""
Issuing a request with a URI whose host contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
uri = b"http://twisted\r\n.invalid/path"
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_hostWithWithUnprintableASCIIRejected(self):
"""
Issuing a request with a URI whose host contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
uri = b"http://twisted%s.invalid/OK" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_hostWithNonASCIIRejected(self):
"""
Issuing a request with a URI whose host contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
uri = b"http://twisted%s.invalid/OK" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithCRLFRejected(self):
"""
Issuing a request with a URI whose path contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
uri = b"http://twisted.invalid/\r\npath"
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithWithUnprintableASCIIRejected(self):
"""
Issuing a request with a URI whose path contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
uri = b"http://twisted.invalid/OK%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithNonASCIIRejected(self):
"""
Issuing a request with a URI whose path contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
uri = b"http://twisted.invalid/OK%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")

View file

@ -0,0 +1,486 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Helpers related to HTTP requests, used by tests.
"""
from __future__ import division, absolute_import
__all__ = ['DummyChannel', 'DummyRequest']
from io import BytesIO
from zope.interface import implementer, verify
from twisted.python.compat import intToBytes
from twisted.python.deprecate import deprecated
from incremental import Version
from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import ISSLTransport, IAddress
from twisted.trial import unittest
from twisted.web.http_headers import Headers
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET, Session, Site
from twisted.web._responses import FOUND
textLinearWhitespaceComponents = [
u"Foo%sbar" % (lw,) for lw in
[u'\r', u'\n', u'\r\n']
]
sanitizedText = "Foo bar"
bytesLinearWhitespaceComponents = [
component.encode('ascii') for component in
textLinearWhitespaceComponents
]
sanitizedBytes = sanitizedText.encode('ascii')
@implementer(IAddress)
class NullAddress(object):
"""
A null implementation of L{IAddress}.
"""
class DummyChannel:
class TCP:
port = 80
disconnected = False
def __init__(self, peer=None):
if peer is None:
peer = IPv4Address("TCP", '192.168.1.1', 12344)
self._peer = peer
self.written = BytesIO()
self.producers = []
def getPeer(self):
return self._peer
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("Can only write bytes to a transport, not %r" % (data,))
self.written.write(data)
def writeSequence(self, iovec):
for data in iovec:
self.write(data)
def getHost(self):
return IPv4Address("TCP", '10.0.0.1', self.port)
def registerProducer(self, producer, streaming):
self.producers.append((producer, streaming))
def unregisterProducer(self):
pass
def loseConnection(self):
self.disconnected = True
@implementer(ISSLTransport)
class SSL(TCP):
pass
site = Site(Resource())
def __init__(self, peer=None):
self.transport = self.TCP(peer)
def requestDone(self, request):
pass
def writeHeaders(self, version, code, reason, headers):
response_line = version + b" " + code + b" " + reason + b"\r\n"
headerSequence = [response_line]
headerSequence.extend(
name + b': ' + value + b"\r\n" for name, value in headers
)
headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)
def getPeer(self):
return self.transport.getPeer()
def getHost(self):
return self.transport.getHost()
def registerProducer(self, producer, streaming):
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def write(self, data):
self.transport.write(data)
def writeSequence(self, iovec):
self.transport.writeSequence(iovec)
def loseConnection(self):
self.transport.loseConnection()
def endRequest(self):
pass
def isSecure(self):
return isinstance(self.transport, self.SSL)
class DummyRequest(object):
"""
Represents a dummy or fake request. See L{twisted.web.server.Request}.
@ivar _finishedDeferreds: L{None} or a C{list} of L{Deferreds} which will
be called back with L{None} when C{finish} is called or which will be
errbacked if C{processingFailed} is called.
@type requestheaders: C{Headers}
@ivar requestheaders: A Headers instance that stores values for all request
headers.
@type responseHeaders: C{Headers}
@ivar responseHeaders: A Headers instance that stores values for all
response headers.
@type responseCode: C{int}
@ivar responseCode: The response code which was passed to
C{setResponseCode}.
@type written: C{list} of C{bytes}
@ivar written: The bytes which have been written to the request.
"""
uri = b'http://dummy/'
method = b'GET'
client = None
def registerProducer(self, prod, s):
"""
Call an L{IPullProducer}'s C{resumeProducing} method in a
loop until it unregisters itself.
@param prod: The producer.
@type prod: L{IPullProducer}
@param s: Whether or not the producer is streaming.
"""
# XXX: Handle IPushProducers
self.go = 1
while self.go:
prod.resumeProducing()
def unregisterProducer(self):
self.go = 0
def __init__(self, postpath, session=None, client=None):
self.sitepath = []
self.written = []
self.finished = 0
self.postpath = postpath
self.prepath = []
self.session = None
self.protoSession = session or Session(0, self)
self.args = {}
self.requestHeaders = Headers()
self.responseHeaders = Headers()
self.responseCode = None
self._finishedDeferreds = []
self._serverName = b"dummy"
self.clientproto = b"HTTP/1.0"
def getAllHeaders(self):
"""
Return dictionary mapping the names of all received headers to the last
value received for each.
Since this method does not return all header information,
C{self.requestHeaders.getAllRawHeaders()} may be preferred.
NOTE: This function is a direct copy of
C{twisted.web.http.Request.getAllRawHeaders}.
"""
headers = {}
for k, v in self.requestHeaders.getAllRawHeaders():
headers[k.lower()] = v[-1]
return headers
def getHeader(self, name):
"""
Retrieve the value of a request header.
@type name: C{bytes}
@param name: The name of the request header for which to retrieve the
value. Header names are compared case-insensitively.
@rtype: C{bytes} or L{None}
@return: The value of the specified request header.
"""
return self.requestHeaders.getRawHeaders(name.lower(), [None])[0]
def setHeader(self, name, value):
"""TODO: make this assert on write() if the header is content-length
"""
self.responseHeaders.addRawHeader(name, value)
def getSession(self):
if self.session:
return self.session
assert not self.written, "Session cannot be requested after data has been written."
self.session = self.protoSession
return self.session
def render(self, resource):
"""
Render the given resource as a response to this request.
This implementation only handles a few of the most common behaviors of
resources. It can handle a render method that returns a string or
C{NOT_DONE_YET}. It doesn't know anything about the semantics of
request methods (eg HEAD) nor how to set any particular headers.
Basically, it's largely broken, but sufficient for some tests at least.
It should B{not} be expanded to do all the same stuff L{Request} does.
Instead, L{DummyRequest} should be phased out and L{Request} (or some
other real code factored in a different way) used.
"""
result = resource.render(self)
if result is NOT_DONE_YET:
return
self.write(result)
self.finish()
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("write() only accepts bytes")
self.written.append(data)
def notifyFinish(self):
"""
Return a L{Deferred} which is called back with L{None} when the request
is finished. This will probably only work if you haven't called
C{finish} yet.
"""
finished = Deferred()
self._finishedDeferreds.append(finished)
return finished
def finish(self):
"""
Record that the request is finished and callback and L{Deferred}s
waiting for notification of this.
"""
self.finished = self.finished + 1
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.callback(None)
def processingFailed(self, reason):
"""
Errback and L{Deferreds} waiting for finish notification.
"""
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.errback(reason)
def addArg(self, name, value):
self.args[name] = [value]
def setResponseCode(self, code, message=None):
"""
Set the HTTP status response code, but takes care that this is called
before any data is written.
"""
assert not self.written, "Response code cannot be set after data has been written: %s." % "@@@@".join(self.written)
self.responseCode = code
self.responseMessage = message
def setLastModified(self, when):
assert not self.written, "Last-Modified cannot be set after data has been written: %s." % "@@@@".join(self.written)
def setETag(self, tag):
assert not self.written, "ETag cannot be set after data has been written: %s." % "@@@@".join(self.written)
def getClientIP(self):
"""
Return the IPv4 address of the client which made this request, if there
is one, otherwise L{None}.
"""
if isinstance(self.client, (IPv4Address, IPv6Address)):
return self.client.host
return None
def getClientAddress(self):
"""
Return the L{IAddress} of the client that made this request.
@return: an address.
@rtype: an L{IAddress} provider.
"""
if self.client is None:
return NullAddress()
return self.client
def getRequestHostname(self):
"""
Get a dummy hostname associated to the HTTP request.
@rtype: C{bytes}
@returns: a dummy hostname
"""
return self._serverName
def getHost(self):
"""
Get a dummy transport's host.
@rtype: C{IPv4Address}
@returns: a dummy transport's host
"""
return IPv4Address('TCP', '127.0.0.1', 80)
def setHost(self, host, port, ssl=0):
"""
Change the host and port the request thinks it's using.
@type host: C{bytes}
@param host: The value to which to change the host header.
@type ssl: C{bool}
@param ssl: A flag which, if C{True}, indicates that the request is
considered secure (if C{True}, L{isSecure} will return C{True}).
"""
self._forceSSL = ssl # set first so isSecure will work
if self.isSecure():
default = 443
else:
default = 80
if port == default:
hostHeader = host
else:
hostHeader = host + b":" + intToBytes(port)
self.requestHeaders.addRawHeader(b"host", hostHeader)
def redirect(self, url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
self.setResponseCode(FOUND)
self.setHeader(b"location", url)
DummyRequest.getClientIP = deprecated(
Version('Twisted', 18, 4, 0),
replacement="getClientAddress",
)(DummyRequest.getClientIP)
class DummyRequestTests(unittest.SynchronousTestCase):
"""
Tests for L{DummyRequest}.
"""
def test_getClientIPDeprecated(self):
"""
L{DummyRequest.getClientIP} is deprecated in favor of
L{DummyRequest.getClientAddress}
"""
request = DummyRequest([])
request.getClientIP()
warnings = self.flushWarnings(
offendingFunctions=[self.test_getClientIPDeprecated])
self.assertEqual(1, len(warnings))
[warning] = warnings
self.assertEqual(warning.get("category"), DeprecationWarning)
self.assertEqual(
warning.get("message"),
("twisted.web.test.requesthelper.DummyRequest.getClientIP "
"was deprecated in Twisted 18.4.0; "
"please use getClientAddress instead"),
)
def test_getClientIPSupportsIPv6(self):
"""
L{DummyRequest.getClientIP} supports IPv6 addresses, just like
L{twisted.web.http.Request.getClientIP}.
"""
request = DummyRequest([])
client = IPv6Address("TCP", "::1", 12345)
request.client = client
self.assertEqual("::1", request.getClientIP())
def test_getClientAddressWithoutClient(self):
"""
L{DummyRequest.getClientAddress} returns an L{IAddress}
provider no C{client} has been set.
"""
request = DummyRequest([])
null = request.getClientAddress()
verify.verifyObject(IAddress, null)
def test_getClientAddress(self):
"""
L{DummyRequest.getClientAddress} returns the C{client}.
"""
request = DummyRequest([])
client = IPv4Address("TCP", "127.0.0.1", 12345)
request.client = client
address = request.getClientAddress()
self.assertIs(address, client)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,462 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.twcgi}.
"""
import sys
import os
import json
from io import BytesIO
from twisted.trial import unittest
from twisted.internet import address, reactor, interfaces, error
from twisted.python import util, failure, log
from twisted.web.http import NOT_FOUND, INTERNAL_SERVER_ERROR
from twisted.web import client, twcgi, server, resource, http_headers
from twisted.web.test._util import _render
from twisted.web.test.test_web import DummyRequest
DUMMY_CGI = '''\
print("Header: OK")
print("")
print("cgi output")
'''
DUAL_HEADER_CGI = '''\
print("Header: spam")
print("Header: eggs")
print("")
print("cgi output")
'''
BROKEN_HEADER_CGI = '''\
print("XYZ")
print("")
print("cgi output")
'''
SPECIAL_HEADER_CGI = '''\
print("Server: monkeys")
print("Date: last year")
print("")
print("cgi output")
'''
READINPUT_CGI = '''\
# This is an example of a correctly-written CGI script which reads a body
# from stdin, which only reads env['CONTENT_LENGTH'] bytes.
import os, sys
body_length = int(os.environ.get('CONTENT_LENGTH',0))
indata = sys.stdin.read(body_length)
print("Header: OK")
print("")
print("readinput ok")
'''
READALLINPUT_CGI = '''\
# This is an example of the typical (incorrect) CGI script which expects
# the server to close stdin when the body of the request is complete.
# A correct CGI should only read env['CONTENT_LENGTH'] bytes.
import sys
indata = sys.stdin.read()
print("Header: OK")
print("")
print("readallinput ok")
'''
NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI = '''\
print("content-type: text/cgi-duplicate-test")
print("")
print("cgi output")
'''
HEADER_OUTPUT_CGI = '''\
import json
import os
print("")
print("")
vals = {x:y for x,y in os.environ.items() if x.startswith("HTTP_")}
print(json.dumps(vals))
'''
class PythonScript(twcgi.FilteredScript):
filter = sys.executable
class CGITests(unittest.TestCase):
"""
Tests for L{twcgi.FilteredScript}.
"""
if not interfaces.IReactorProcess.providedBy(reactor):
skip = "CGI tests require a functional reactor.spawnProcess()"
def startServer(self, cgi):
root = resource.Resource()
cgipath = util.sibpath(__file__, cgi)
root.putChild(b"cgi", PythonScript(cgipath))
site = server.Site(root)
self.p = reactor.listenTCP(0, site)
return self.p.getHost().port
def tearDown(self):
if getattr(self, 'p', None):
return self.p.stopListening()
def writeCGI(self, source):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(source)
return cgiFilename
def test_CGI(self):
cgiFilename = self.writeCGI(DUMMY_CGI)
portnum = self.startServer(cgiFilename)
url = 'http://localhost:%d/cgi' % (portnum,)
url = url.encode("ascii")
d = client.Agent(reactor).request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._testCGI_1)
return d
def _testCGI_1(self, res):
self.assertEqual(res, b"cgi output" + os.linesep.encode("ascii"))
def test_protectedServerAndDate(self):
"""
If the CGI script emits a I{Server} or I{Date} header, these are
ignored.
"""
cgiFilename = self.writeCGI(SPECIAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertNotIn('monkeys',
response.headers.getRawHeaders('server'))
self.assertNotIn('last year',
response.headers.getRawHeaders('date'))
d.addCallback(checkResponse)
return d
def test_noDuplicateContentTypeHeaders(self):
"""
If the CGI script emits a I{content-type} header, make sure that the
server doesn't add an additional (duplicate) one, as per ticket 4786.
"""
cgiFilename = self.writeCGI(NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertEqual(
response.headers.getRawHeaders('content-type'),
['text/cgi-duplicate-test'])
return response
d.addCallback(checkResponse)
return d
def test_noProxyPassthrough(self):
"""
The CGI script is never called with the Proxy header passed through.
"""
cgiFilename = self.writeCGI(HEADER_OUTPUT_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
headers = http_headers.Headers({b"Proxy": [b"foo"],
b"X-Innocent-Header": [b"bar"]})
d = agent.request(b"GET", url, headers=headers)
def checkResponse(response):
headers = json.loads(response.decode("ascii"))
self.assertEqual(
set(headers.keys()),
{"HTTP_HOST", "HTTP_CONNECTION", "HTTP_X_INNOCENT_HEADER"})
d.addCallback(client.readBody)
d.addCallback(checkResponse)
return d
def test_duplicateHeaderCGI(self):
"""
If a CGI script emits two instances of the same header, both are sent
in the response.
"""
cgiFilename = self.writeCGI(DUAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertEqual(
response.headers.getRawHeaders('header'), ['spam', 'eggs'])
d.addCallback(checkResponse)
return d
def test_malformedHeaderCGI(self):
"""
Check for the error message in the duplicated header
"""
cgiFilename = self.writeCGI(BROKEN_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
loggedMessages = []
def addMessage(eventDict):
loggedMessages.append(log.textFromEventDict(eventDict))
log.addObserver(addMessage)
self.addCleanup(log.removeObserver, addMessage)
def checkResponse(ignored):
self.assertIn("ignoring malformed CGI header: " + repr(b'XYZ'),
loggedMessages)
d.addCallback(checkResponse)
return d
def test_ReadEmptyInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(READINPUT_CGI)
portnum = self.startServer(cgiFilename)
agent = client.Agent(reactor)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadEmptyInput_1)
return d
test_ReadEmptyInput.timeout = 5
def _test_ReadEmptyInput_1(self, res):
expected = "readinput ok{}".format(os.linesep)
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_ReadInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(READINPUT_CGI)
portnum = self.startServer(cgiFilename)
agent = client.Agent(reactor)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = agent.request(
uri=url,
method=b"POST",
bodyProducer=client.FileBodyProducer(
BytesIO(b"Here is your stdin")),
)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadInput_1)
return d
test_ReadInput.timeout = 5
def _test_ReadInput_1(self, res):
expected = "readinput ok{}".format(os.linesep)
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_ReadAllInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(READALLINPUT_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = client.Agent(reactor).request(
uri=url,
method=b"POST",
bodyProducer=client.FileBodyProducer(
BytesIO(b"Here is your stdin")),
)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadAllInput_1)
return d
test_ReadAllInput.timeout = 5
def _test_ReadAllInput_1(self, res):
expected = "readallinput ok{}".format(os.linesep)
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_useReactorArgument(self):
"""
L{twcgi.FilteredScript.runProcess} uses the reactor passed as an
argument to the constructor.
"""
class FakeReactor:
"""
A fake reactor recording whether spawnProcess is called.
"""
called = False
def spawnProcess(self, *args, **kwargs):
"""
Set the C{called} flag to C{True} if C{spawnProcess} is called.
@param args: Positional arguments.
@param kwargs: Keyword arguments.
"""
self.called = True
fakeReactor = FakeReactor()
request = DummyRequest(['a', 'b'])
request.client = address.IPv4Address('TCP', '127.0.0.1', 12345)
resource = twcgi.FilteredScript("dummy-file", reactor=fakeReactor)
_render(resource, request)
self.assertTrue(fakeReactor.called)
class CGIScriptTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIScript}.
"""
def test_pathInfo(self):
"""
L{twcgi.CGIScript.render} sets the process environment
I{PATH_INFO} from the request path.
"""
class FakeReactor:
"""
A fake reactor recording the environment passed to spawnProcess.
"""
def spawnProcess(self, process, filename, args, env, wdir):
"""
Store the C{env} L{dict} to an instance attribute.
@param process: Ignored
@param filename: Ignored
@param args: Ignored
@param env: The environment L{dict} which will be stored
@param wdir: Ignored
"""
self.process_env = env
_reactor = FakeReactor()
resource = twcgi.CGIScript(self.mktemp(), reactor=_reactor)
request = DummyRequest(['a', 'b'])
request.client = address.IPv4Address('TCP', '127.0.0.1', 12345)
_render(resource, request)
self.assertEqual(_reactor.process_env["PATH_INFO"],
"/a/b")
class CGIDirectoryTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIDirectory}.
"""
def test_render(self):
"""
L{twcgi.CGIDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = twcgi.CGIDirectory(self.mktemp())
request = DummyRequest([''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{twcgi.CGIDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{twcgi.CGIDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = twcgi.CGIDirectory(path)
request = DummyRequest(['foo'])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
class CGIProcessProtocolTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIProcessProtocol}.
"""
def test_prematureEndOfHeaders(self):
"""
If the process communicating with L{CGIProcessProtocol} ends before
finishing writing out headers, the response has I{INTERNAL SERVER
ERROR} as its status code.
"""
request = DummyRequest([''])
protocol = twcgi.CGIProcessProtocol(request)
protocol.processEnded(failure.Failure(error.ProcessTerminated()))
self.assertEqual(request.responseCode, INTERNAL_SERVER_ERROR)
def discardBody(response):
"""
Discard the body of a HTTP response.
@param response: The response.
@return: The response.
"""
return client.readBody(response).addCallback(lambda _: response)

View file

@ -0,0 +1,45 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for various parts of L{twisted.web}.
"""
from zope.interface import implementer, verify
from twisted.internet import defer, interfaces
from twisted.trial import unittest
from twisted.web import client
@implementer(interfaces.IStreamClientEndpoint)
class DummyEndPoint(object):
"""An endpoint that does not connect anywhere"""
def __init__(self, someString):
self.someString = someString
def __repr__(self):
return 'DummyEndPoint({})'.format(self.someString)
def connect(self, factory):
return defer.succeed(dict(factory=factory))
class HTTPConnectionPoolTests(unittest.TestCase):
"""
Unit tests for L{client.HTTPConnectionPoolTest}.
"""
def test_implements(self):
"""L{DummyEndPoint}s implements L{interfaces.IStreamClientEndpoint}"""
ep = DummyEndPoint("something")
verify.verifyObject(interfaces.IStreamClientEndpoint, ep)
def test_repr(self):
"""connection L{repr()} includes endpoint's L{repr()}"""
pool = client.HTTPConnectionPool(reactor=None)
ep = DummyEndPoint("this_is_probably_unique")
d = pool.getConnection('someplace', ep)
result = self.successResultOf(d)
representation = repr(result)
self.assertIn(repr(ep), representation)

View file

@ -0,0 +1,527 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.distrib}.
"""
from os.path import abspath
from xml.dom.minidom import parseString
try:
import pwd
except ImportError:
pwd = None
from zope.interface.verify import verifyObject
from twisted.python import filepath, failure
from twisted.internet import reactor, defer
from twisted.trial import unittest
from twisted.spread import pb
from twisted.spread.banana import SIZE_LIMIT
from twisted.web import distrib, client, resource, static, server
from twisted.web.test.test_web import DummyRequest, DummyChannel
from twisted.web.test._util import _render
from twisted.test import proto_helpers
from twisted.web.http_headers import Headers
from twisted.logger import globalLogPublisher
class MySite(server.Site):
pass
class PBServerFactory(pb.PBServerFactory):
"""
A PB server factory which keeps track of the most recent protocol it
created.
@ivar proto: L{None} or the L{Broker} instance most recently returned
from C{buildProtocol}.
"""
proto = None
def buildProtocol(self, addr):
self.proto = pb.PBServerFactory.buildProtocol(self, addr)
return self.proto
class ArbitraryError(Exception):
"""
An exception for this test.
"""
class DistribTests(unittest.TestCase):
port1 = None
port2 = None
sub = None
f1 = None
def tearDown(self):
"""
Clean up all the event sources left behind by either directly by
test methods or indirectly via some distrib API.
"""
dl = [defer.Deferred(), defer.Deferred()]
if self.f1 is not None and self.f1.proto is not None:
self.f1.proto.notifyOnDisconnect(lambda: dl[0].callback(None))
else:
dl[0].callback(None)
if self.sub is not None and self.sub.publisher is not None:
self.sub.publisher.broker.notifyOnDisconnect(
lambda: dl[1].callback(None))
self.sub.publisher.broker.transport.loseConnection()
else:
dl[1].callback(None)
if self.port1 is not None:
dl.append(self.port1.stopListening())
if self.port2 is not None:
dl.append(self.port2.stopListening())
return defer.gatherResults(dl)
def testDistrib(self):
# site1 is the publisher
r1 = resource.Resource()
r1.putChild(b"there", static.Data(b"root", "text/plain"))
site1 = server.Site(r1)
self.f1 = PBServerFactory(distrib.ResourcePublisher(site1))
self.port1 = reactor.listenTCP(0, self.f1)
self.sub = distrib.ResourceSubscription("127.0.0.1",
self.port1.getHost().port)
r2 = resource.Resource()
r2.putChild(b"here", self.sub)
f2 = MySite(r2)
self.port2 = reactor.listenTCP(0, f2)
agent = client.Agent(reactor)
url = "http://127.0.0.1:{}/here/there".format(
self.port2.getHost().port)
url = url.encode("ascii")
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self.assertEqual, b'root')
return d
def _setupDistribServer(self, child):
"""
Set up a resource on a distrib site using L{ResourcePublisher}.
@param child: The resource to publish using distrib.
@return: A tuple consisting of the host and port on which to contact
the created site.
"""
distribRoot = resource.Resource()
distribRoot.putChild(b"child", child)
distribSite = server.Site(distribRoot)
self.f1 = distribFactory = PBServerFactory(
distrib.ResourcePublisher(distribSite))
distribPort = reactor.listenTCP(
0, distribFactory, interface="127.0.0.1")
self.addCleanup(distribPort.stopListening)
addr = distribPort.getHost()
self.sub = mainRoot = distrib.ResourceSubscription(
addr.host, addr.port)
mainSite = server.Site(mainRoot)
mainPort = reactor.listenTCP(0, mainSite, interface="127.0.0.1")
self.addCleanup(mainPort.stopListening)
mainAddr = mainPort.getHost()
return mainPort, mainAddr
def _requestTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with the result of the request.
"""
mainPort, mainAddr = self._setupDistribServer(child)
agent = client.Agent(reactor)
url = "http://%s:%s/child" % (mainAddr.host, mainAddr.port)
url = url.encode("ascii")
d = agent.request(b"GET", url, **kwargs)
d.addCallback(client.readBody)
return d
def _requestAgentTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with a tuple consisting of a
L{twisted.test.proto_helpers.AccumulatingProtocol} containing the
body of the response and an L{IResponse} with the response itself.
"""
mainPort, mainAddr = self._setupDistribServer(child)
url = "http://{}:{}/child".format(mainAddr.host, mainAddr.port)
url = url.encode("ascii")
d = client.Agent(reactor).request(b"GET", url, **kwargs)
def cbCollectBody(response):
protocol = proto_helpers.AccumulatingProtocol()
response.deliverBody(protocol)
d = protocol.closedDeferred = defer.Deferred()
d.addCallback(lambda _: (protocol, response))
return d
d.addCallback(cbCollectBody)
return d
def test_requestHeaders(self):
"""
The request headers are available on the request object passed to a
distributed resource's C{render} method.
"""
requestHeaders = {}
logObserver = proto_helpers.EventLoggingObserver()
globalLogPublisher.addObserver(logObserver)
req = [None]
class ReportRequestHeaders(resource.Resource):
def render(self, request):
req[0] = request
requestHeaders.update(dict(
request.requestHeaders.getAllRawHeaders()))
return b""
def check_logs():
msgs = [e["log_format"] for e in logObserver]
self.assertIn('connected to publisher', msgs)
self.assertIn(
"could not connect to distributed web service: {msg}",
msgs
)
self.assertIn(req[0], msgs)
globalLogPublisher.removeObserver(logObserver)
request = self._requestTest(
ReportRequestHeaders(), headers=Headers({'foo': ['bar']}))
def cbRequested(result):
self.f1.proto.notifyOnDisconnect(check_logs)
self.assertEqual(requestHeaders[b'Foo'], [b'bar'])
request.addCallback(cbRequested)
return request
def test_requestResponseCode(self):
"""
The response code can be set by the request object passed to a
distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200)
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, b"")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, b"OK")
request.addCallback(cbRequested)
return request
def test_requestResponseCodeMessage(self):
"""
The response code and message can be set by the request object passed to
a distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200, b"some-message")
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, b"")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, b"some-message")
request.addCallback(cbRequested)
return request
def test_largeWrite(self):
"""
If a string longer than the Banana size limit is passed to the
L{distrib.Request} passed to the remote resource, it is broken into
smaller strings to be transported over the PB connection.
"""
class LargeWrite(resource.Resource):
def render(self, request):
request.write(b'x' * SIZE_LIMIT + b'y')
request.finish()
return server.NOT_DONE_YET
request = self._requestTest(LargeWrite())
request.addCallback(self.assertEqual, b'x' * SIZE_LIMIT + b'y')
return request
def test_largeReturn(self):
"""
Like L{test_largeWrite}, but for the case where C{render} returns a
long string rather than explicitly passing it to L{Request.write}.
"""
class LargeReturn(resource.Resource):
def render(self, request):
return b'x' * SIZE_LIMIT + b'y'
request = self._requestTest(LargeReturn())
request.addCallback(self.assertEqual, b'x' * SIZE_LIMIT + b'y')
return request
def test_connectionLost(self):
"""
If there is an error issuing the request to the remote publisher, an
error response is returned.
"""
# Using pb.Root as a publisher will cause request calls to fail with an
# error every time. Just what we want to test.
self.f1 = serverFactory = PBServerFactory(pb.Root())
self.port1 = serverPort = reactor.listenTCP(0, serverFactory)
self.sub = subscription = distrib.ResourceSubscription(
"127.0.0.1", serverPort.getHost().port)
request = DummyRequest([b''])
d = _render(subscription, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 500)
# This is the error we caused the request to fail with. It should
# have been logged.
errors = self.flushLoggedErrors(pb.NoSuchMethod)
self.assertEqual(len(errors), 1)
# The error page is rendered as HTML.
expected = [
b'',
b'<html>',
b' <head><title>500 - Server Connection Lost</title></head>',
b' <body>',
b' <h1>Server Connection Lost</h1>',
b' <p>Connection to distributed server lost:'
b'<pre>'
b'[Failure instance: Traceback from remote host -- '
b'twisted.spread.flavors.NoSuchMethod: '
b'No such method: remote_request',
b']</pre></p>',
b' </body>',
b'</html>',
b''
]
self.assertEqual([b'\n'.join(expected)], request.written)
d.addCallback(cbRendered)
return d
def test_logFailed(self):
"""
When a request fails, the string form of the failure is logged.
"""
logObserver = proto_helpers.EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
f = failure.Failure(ArbitraryError())
request = DummyRequest([b''])
issue = distrib.Issue(request)
issue.failed(f)
self.assertEquals(1, len(logObserver))
self.assertIn(
"Failure instance",
logObserver[0]["log_format"]
)
def test_requestFail(self):
"""
When L{twisted.web.distrib.Request}'s fail is called, the failure
is logged.
"""
logObserver = proto_helpers.EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
err = ArbitraryError()
f = failure.Failure(err)
req = distrib.Request(DummyChannel())
req.fail(f)
self.flushLoggedErrors(ArbitraryError)
self.assertEquals(1, len(logObserver))
self.assertIs(logObserver[0]["log_failure"], f)
class _PasswordDatabase:
def __init__(self, users):
self._users = users
def getpwall(self):
return iter(self._users)
def getpwnam(self, username):
for user in self._users:
if user[0] == username:
return user
raise KeyError()
class UserDirectoryTests(unittest.TestCase):
"""
Tests for L{UserDirectory}, a resource for listing all user resources
available on a system.
"""
def setUp(self):
self.alice = ('alice', 'x', 123, 456, 'Alice,,,', self.mktemp(), '/bin/sh')
self.bob = ('bob', 'x', 234, 567, 'Bob,,,', self.mktemp(), '/bin/sh')
self.database = _PasswordDatabase([self.alice, self.bob])
self.directory = distrib.UserDirectory(self.database)
def test_interface(self):
"""
L{UserDirectory} instances provide L{resource.IResource}.
"""
self.assertTrue(verifyObject(resource.IResource, self.directory))
def _404Test(self, name):
"""
Verify that requesting the C{name} child of C{self.directory} results
in a 404 response.
"""
request = DummyRequest([name])
result = self.directory.getChild(name, request)
d = _render(result, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 404)
d.addCallback(cbRendered)
return d
def test_getInvalidUser(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which does not correspond to any known
user.
"""
return self._404Test('carol')
def test_getUserWithoutResource(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which corresponds to a known user who has
neither a user directory nor a user distrib socket.
"""
return self._404Test('alice')
def test_getPublicHTMLChild(self):
"""
L{UserDirectory.getChild} returns a L{static.File} instance when passed
the name of a user with a home directory containing a I{public_html}
directory.
"""
home = filepath.FilePath(self.bob[-2])
public_html = home.child('public_html')
public_html.makedirs()
request = DummyRequest(['bob'])
result = self.directory.getChild('bob', request)
self.assertIsInstance(result, static.File)
self.assertEqual(result.path, public_html.path)
def test_getDistribChild(self):
"""
L{UserDirectory.getChild} returns a L{ResourceSubscription} instance
when passed the name of a user suffixed with C{".twistd"} who has a
home directory containing a I{.twistd-web-pb} socket.
"""
home = filepath.FilePath(self.bob[-2])
home.makedirs()
web = home.child('.twistd-web-pb')
request = DummyRequest(['bob'])
result = self.directory.getChild('bob.twistd', request)
self.assertIsInstance(result, distrib.ResourceSubscription)
self.assertEqual(result.host, 'unix')
self.assertEqual(abspath(result.port), web.path)
def test_invalidMethod(self):
"""
L{UserDirectory.render} raises L{UnsupportedMethod} in response to a
non-I{GET} request.
"""
request = DummyRequest([''])
request.method = 'POST'
self.assertRaises(
server.UnsupportedMethod, self.directory.render, request)
def test_render(self):
"""
L{UserDirectory} renders a list of links to available user content
in response to a I{GET} request.
"""
public_html = filepath.FilePath(self.alice[-2]).child('public_html')
public_html.makedirs()
web = filepath.FilePath(self.bob[-2])
web.makedirs()
# This really only works if it's a unix socket, but the implementation
# doesn't currently check for that. It probably should someday, and
# then skip users with non-sockets.
web.child('.twistd-web-pb').setContent(b"")
request = DummyRequest([''])
result = _render(self.directory, request)
def cbRendered(ignored):
document = parseString(b''.join(request.written))
# Each user should have an li with a link to their page.
[alice, bob] = document.getElementsByTagName('li')
self.assertEqual(alice.firstChild.tagName, 'a')
self.assertEqual(alice.firstChild.getAttribute('href'), 'alice/')
self.assertEqual(alice.firstChild.firstChild.data, 'Alice (file)')
self.assertEqual(bob.firstChild.tagName, 'a')
self.assertEqual(bob.firstChild.getAttribute('href'), 'bob.twistd/')
self.assertEqual(bob.firstChild.firstChild.data, 'Bob (twistd)')
result.addCallback(cbRendered)
return result
def test_passwordDatabase(self):
"""
If L{UserDirectory} is instantiated with no arguments, it uses the
L{pwd} module as its password database.
"""
directory = distrib.UserDirectory()
self.assertIdentical(directory._pwd, pwd)
if pwd is None:
test_passwordDatabase.skip = "pwd module required"

View file

@ -0,0 +1,305 @@
# -*- test-case-name: twisted.web.test.test_domhelpers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Specific tests for (some of) the methods in L{twisted.web.domhelpers}.
"""
from xml.dom import minidom
from twisted.python.compat import unicode
from twisted.trial.unittest import TestCase
from twisted.web import domhelpers, microdom
class DOMHelpersTestsMixin:
"""
A mixin for L{TestCase} subclasses which defines test methods for
domhelpers functionality based on a DOM creation function provided by a
subclass.
"""
dom = None
def test_getElementsByTagName(self):
doc1 = self.dom.parseString('<foo/>')
actual = domhelpers.getElementsByTagName(doc1, 'foo')[0].nodeName
expected = 'foo'
self.assertEqual(actual, expected)
el1 = doc1.documentElement
actual = domhelpers.getElementsByTagName(el1, 'foo')[0].nodeName
self.assertEqual(actual, expected)
doc2_xml = '<a><foo in="a"/><b><foo in="b"/></b><c><foo in="c"/></c><foo in="d"/><foo in="ef"/><g><foo in="g"/><h><foo in="h"/></h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
tag_list = domhelpers.getElementsByTagName(doc2, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
expected = 'abcdefgh'
self.assertEqual(actual, expected)
el2 = doc2.documentElement
tag_list = domhelpers.getElementsByTagName(el2, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
self.assertEqual(actual, expected)
doc3_xml = '''
<a><foo in="a"/>
<b><foo in="b"/>
<d><foo in="d"/>
<g><foo in="g"/></g>
<h><foo in="h"/></h>
</d>
<e><foo in="e"/>
<i><foo in="i"/></i>
</e>
</b>
<c><foo in="c"/>
<f><foo in="f"/>
<j><foo in="j"/></j>
</f>
</c>
</a>'''
doc3 = self.dom.parseString(doc3_xml)
tag_list = domhelpers.getElementsByTagName(doc3, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
expected = 'abdgheicfj'
self.assertEqual(actual, expected)
el3 = doc3.documentElement
tag_list = domhelpers.getElementsByTagName(el3, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
self.assertEqual(actual, expected)
doc4_xml = '<foo><bar></bar><baz><foo/></baz></foo>'
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.getElementsByTagName(doc4, 'foo')
root = doc4.documentElement
expected = [root, root.childNodes[-1].childNodes[0]]
self.assertEqual(actual, expected)
actual = domhelpers.getElementsByTagName(root, 'foo')
self.assertEqual(actual, expected)
def test_gatherTextNodes(self):
doc1 = self.dom.parseString('<a>foo</a>')
actual = domhelpers.gatherTextNodes(doc1)
expected = 'foo'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc1.documentElement)
self.assertEqual(actual, expected)
doc2_xml = '<a>a<b>b</b><c>c</c>def<g>g<h>h</h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
actual = domhelpers.gatherTextNodes(doc2)
expected = 'abcdefgh'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc2.documentElement)
self.assertEqual(actual, expected)
doc3_xml = ('<a>a<b>b<d>d<g>g</g><h>h</h></d><e>e<i>i</i></e></b>' +
'<c>c<f>f<j>j</j></f></c></a>')
doc3 = self.dom.parseString(doc3_xml)
actual = domhelpers.gatherTextNodes(doc3)
expected = 'abdgheicfj'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc3.documentElement)
self.assertEqual(actual, expected)
def test_clearNode(self):
doc1 = self.dom.parseString('<a><b><c><d/></c></b></a>')
a_node = doc1.documentElement
domhelpers.clearNode(a_node)
self.assertEqual(
a_node.toxml(),
self.dom.Element('a').toxml())
doc2 = self.dom.parseString('<a><b><c><d/></c></b></a>')
b_node = doc2.documentElement.childNodes[0]
domhelpers.clearNode(b_node)
actual = doc2.documentElement.toxml()
expected = self.dom.Element('a')
expected.appendChild(self.dom.Element('b'))
self.assertEqual(actual, expected.toxml())
def test_get(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
doc = self.dom.Document()
node = domhelpers.get(doc1, "foo")
actual = node.toxml()
expected = doc.createElement('c')
expected.setAttribute('class', 'foo')
self.assertEqual(actual, expected.toxml())
node = domhelpers.get(doc1, "bar")
actual = node.toxml()
expected = doc.createElement('b')
expected.setAttribute('id', 'bar')
self.assertEqual(actual, expected.toxml())
self.assertRaises(domhelpers.NodeLookupError,
domhelpers.get,
doc1,
"pzork")
def test_getIfExists(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
doc = self.dom.Document()
node = domhelpers.getIfExists(doc1, "foo")
actual = node.toxml()
expected = doc.createElement('c')
expected.setAttribute('class', 'foo')
self.assertEqual(actual, expected.toxml())
node = domhelpers.getIfExists(doc1, "pzork")
self.assertIdentical(node, None)
def test_getAndClear(self):
doc1 = self.dom.parseString('<a><b id="foo"><c></c></b></a>')
doc = self.dom.Document()
node = domhelpers.getAndClear(doc1, "foo")
actual = node.toxml()
expected = doc.createElement('b')
expected.setAttribute('id', 'foo')
self.assertEqual(actual, expected.toxml())
def test_locateNodes(self):
doc1 = self.dom.parseString('<a><b foo="olive"><c foo="olive"/></b><d foo="poopy"/></a>')
doc = self.dom.Document()
node_list = domhelpers.locateNodes(
doc1.childNodes, 'foo', 'olive', noNesting=1)
actual = ''.join([node.toxml() for node in node_list])
expected = doc.createElement('b')
expected.setAttribute('foo', 'olive')
c = doc.createElement('c')
c.setAttribute('foo', 'olive')
expected.appendChild(c)
self.assertEqual(actual, expected.toxml())
node_list = domhelpers.locateNodes(
doc1.childNodes, 'foo', 'olive', noNesting=0)
actual = ''.join([node.toxml() for node in node_list])
self.assertEqual(actual, expected.toxml() + c.toxml())
def test_getParents(self):
doc1 = self.dom.parseString('<a><b><c><d/></c><e/></b><f/></a>')
node_list = domhelpers.getParents(
doc1.childNodes[0].childNodes[0].childNodes[0])
actual = ''.join([node.tagName for node in node_list
if hasattr(node, 'tagName')])
self.assertEqual(actual, 'cba')
def test_findElementsWithAttribute(self):
doc1 = self.dom.parseString('<a foo="1"><b foo="2"/><c foo="1"/><d/></a>')
node_list = domhelpers.findElementsWithAttribute(doc1, 'foo')
actual = ''.join([node.tagName for node in node_list])
self.assertEqual(actual, 'abc')
node_list = domhelpers.findElementsWithAttribute(doc1, 'foo', '1')
actual = ''.join([node.tagName for node in node_list])
self.assertEqual(actual, 'ac')
def test_findNodesNamed(self):
doc1 = self.dom.parseString('<doc><foo/><bar/><foo>a</foo></doc>')
node_list = domhelpers.findNodesNamed(doc1, 'foo')
actual = len(node_list)
self.assertEqual(actual, 2)
def test_escape(self):
j = 'this string " contains many & characters> xml< won\'t like'
expected = 'this string &quot; contains many &amp; characters&gt; xml&lt; won\'t like'
self.assertEqual(domhelpers.escape(j), expected)
def test_unescape(self):
j = 'this string &quot; has &&amp; entities &gt; &lt; and some characters xml won\'t like<'
expected = 'this string " has && entities > < and some characters xml won\'t like<'
self.assertEqual(domhelpers.unescape(j), expected)
def test_getNodeText(self):
"""
L{getNodeText} returns the concatenation of all the text data at or
beneath the node passed to it.
"""
node = self.dom.parseString('<foo><bar>baz</bar><bar>quux</bar></foo>')
self.assertEqual(domhelpers.getNodeText(node), "bazquux")
class MicroDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = microdom
def test_gatherTextNodesDropsWhitespace(self):
"""
Microdom discards whitespace-only text nodes, so L{gatherTextNodes}
returns only the text from nodes which had non-whitespace characters.
"""
doc4_xml = '''<html>
<head>
</head>
<body>
stuff
</body>
</html>
'''
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.gatherTextNodes(doc4)
expected = '\n stuff\n '
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc4.documentElement)
self.assertEqual(actual, expected)
def test_textEntitiesNotDecoded(self):
"""
Microdom does not decode entities in text nodes.
"""
doc5_xml = '<x>Souffl&amp;</x>'
doc5 = self.dom.parseString(doc5_xml)
actual = domhelpers.gatherTextNodes(doc5)
expected = 'Souffl&amp;'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
class MiniDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = minidom
def test_textEntitiesDecoded(self):
"""
Minidom does decode entities in text nodes.
"""
doc5_xml = '<x>Souffl&amp;</x>'
doc5 = self.dom.parseString(doc5_xml)
actual = domhelpers.gatherTextNodes(doc5)
expected = 'Souffl&'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
def test_getNodeUnicodeText(self):
"""
L{domhelpers.getNodeText} returns a C{unicode} string when text
nodes are represented in the DOM with unicode, whether or not there
are non-ASCII characters present.
"""
node = self.dom.parseString("<foo>bar</foo>")
text = domhelpers.getNodeText(node)
self.assertEqual(text, u"bar")
self.assertIsInstance(text, unicode)
node = self.dom.parseString(u"<foo>\N{SNOWMAN}</foo>".encode('utf-8'))
text = domhelpers.getNodeText(node)
self.assertEqual(text, u"\N{SNOWMAN}")
self.assertIsInstance(text, unicode)

View file

@ -0,0 +1,481 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP errors.
"""
from __future__ import division, absolute_import
import re
import sys
import traceback
from twisted.trial import unittest
from twisted.python.compat import nativeString, _PY3
from twisted.web import error
from twisted.web.template import Tag
class CodeToMessageTests(unittest.TestCase):
"""
L{_codeToMessages} inverts L{_responses.RESPONSES}
"""
def test_validCode(self):
m = error._codeToMessage(b"302")
self.assertEqual(m, b"Found")
def test_invalidCode(self):
m = error._codeToMessage(b"987")
self.assertEqual(m, None)
def test_nonintegerCode(self):
m = error._codeToMessage(b"InvalidCode")
self.assertEqual(m, None)
class ErrorTests(unittest.TestCase):
"""
Tests for how L{Error} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{Error} constructor and the
C{code} argument is a valid HTTP status code, C{code} is mapped to a
descriptive string to which C{message} is assigned.
"""
e = error.Error(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatus(self):
"""
If no C{message} argument is passed to the L{Error} constructor and
C{code} isn't a valid HTTP status code, C{message} stays L{None}.
"""
e = error.Error(b"InvalidCode")
self.assertEqual(e.message, None)
def test_messageExists(self):
"""
If a C{message} argument is passed to the L{Error} constructor, the
C{message} isn't affected by the value of C{status}.
"""
e = error.Error(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
def test_str(self):
"""
C{str()} on an L{Error} returns the code and message it was
instantiated with.
"""
# Bytestring status
e = error.Error(b"200", b"OK")
self.assertEqual(str(e), "200 OK")
# int status
e = error.Error(200, b"OK")
self.assertEqual(str(e), "200 OK")
class PageRedirectTests(unittest.TestCase):
"""
Tests for how L{PageRedirect} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and the C{code} argument is a valid HTTP status code, C{code} is mapped
to a descriptive string to which C{message} is assigned.
"""
e = error.PageRedirect(b"200", location=b"/foo")
self.assertEqual(e.message, b"OK to /foo")
def test_noMessageValidStatusNoLocation(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{location} is also empty and the C{code} argument is a valid HTTP
status code, C{code} is mapped to a descriptive string to which
C{message} is assigned without trying to include an empty location.
"""
e = error.PageRedirect(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatusLocationExists(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{code} isn't a valid HTTP status code, C{message} stays L{None}.
"""
e = error.PageRedirect(b"InvalidCode", location=b"/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self):
"""
If a C{message} argument is passed to the L{PageRedirect} constructor,
the C{message} isn't affected by the value of C{status}.
"""
e = error.PageRedirect(b"200", b"My own message", location=b"/foo")
self.assertEqual(e.message, b"My own message to /foo")
def test_messageExistsNoLocation(self):
"""
If a C{message} argument is passed to the L{PageRedirect} constructor
and no location is provided, C{message} doesn't try to include the
empty location.
"""
e = error.PageRedirect(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
class InfiniteRedirectionTests(unittest.TestCase):
"""
Tests for how L{InfiniteRedirection} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and the C{code} argument is a valid HTTP status code,
C{code} is mapped to a descriptive string to which C{message} is
assigned.
"""
e = error.InfiniteRedirection(b"200", location=b"/foo")
self.assertEqual(e.message, b"OK to /foo")
def test_noMessageValidStatusNoLocation(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{location} is also empty and the C{code} argument is a
valid HTTP status code, C{code} is mapped to a descriptive string to
which C{message} is assigned without trying to include an empty
location.
"""
e = error.InfiniteRedirection(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatusLocationExists(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{code} isn't a valid HTTP status code, C{message} stays
L{None}.
"""
e = error.InfiniteRedirection(b"InvalidCode", location=b"/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self):
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor, the C{message} isn't affected by the value of C{status}.
"""
e = error.InfiniteRedirection(b"200", b"My own message",
location=b"/foo")
self.assertEqual(e.message, b"My own message to /foo")
def test_messageExistsNoLocation(self):
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor and no location is provided, C{message} doesn't try to
include the empty location.
"""
e = error.InfiniteRedirection(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
class RedirectWithNoLocationTests(unittest.TestCase):
"""
L{RedirectWithNoLocation} is a subclass of L{Error} which sets
a custom message in the constructor.
"""
def test_validMessage(self):
"""
When C{code}, C{message}, and C{uri} are passed to the
L{RedirectWithNoLocation} constructor, the C{message} and C{uri}
attributes are set, respectively.
"""
e = error.RedirectWithNoLocation(b"302", b"REDIRECT",
b"https://example.com")
self.assertEqual(e.message, b"REDIRECT to https://example.com")
self.assertEqual(e.uri, b"https://example.com")
class MissingRenderMethodTests(unittest.TestCase):
"""
Tests for how L{MissingRenderMethod} exceptions are initialized and
displayed.
"""
def test_constructor(self):
"""
Given C{element} and C{renderName} arguments, the
L{MissingRenderMethod} constructor assigns the values to the
corresponding attributes.
"""
elt = object()
e = error.MissingRenderMethod(elt, 'renderThing')
self.assertIs(e.element, elt)
self.assertIs(e.renderName, 'renderThing')
def test_repr(self):
"""
A L{MissingRenderMethod} is represented using a custom string
containing the element's representation and the method name.
"""
elt = object()
e = error.MissingRenderMethod(elt, 'renderThing')
self.assertEqual(
repr(e),
("'MissingRenderMethod': "
"%r had no render method named 'renderThing'") % elt)
class MissingTemplateLoaderTests(unittest.TestCase):
"""
Tests for how L{MissingTemplateLoader} exceptions are initialized and
displayed.
"""
def test_constructor(self):
"""
Given an C{element} argument, the L{MissingTemplateLoader} constructor
assigns the value to the corresponding attribute.
"""
elt = object()
e = error.MissingTemplateLoader(elt)
self.assertIs(e.element, elt)
def test_repr(self):
"""
A L{MissingTemplateLoader} is represented using a custom string
containing the element's representation and the method name.
"""
elt = object()
e = error.MissingTemplateLoader(elt)
self.assertEqual(
repr(e),
"'MissingTemplateLoader': %r had no loader" % elt)
class FlattenerErrorTests(unittest.TestCase):
"""
Tests for L{FlattenerError}.
"""
def makeFlattenerError(self, roots=[]):
try:
raise RuntimeError("oh noes")
except Exception as e:
tb = traceback.extract_tb(sys.exc_info()[2])
return error.FlattenerError(e, roots, tb)
def fakeFormatRoot(self, obj):
return 'R(%s)' % obj
def test_constructor(self):
"""
Given C{exception}, C{roots}, and C{traceback} arguments, the
L{FlattenerError} constructor assigns the roots to the C{_roots}
attribute.
"""
e = self.makeFlattenerError(roots=['a', 'b'])
self.assertEqual(e._roots, ['a', 'b'])
def test_str(self):
"""
The string form of a L{FlattenerError} is identical to its
representation.
"""
e = self.makeFlattenerError()
self.assertEqual(str(e), repr(e))
def test_reprWithRootsAndWithTraceback(self):
"""
The representation of a L{FlattenerError} initialized with roots and a
traceback contains a formatted representation of those roots (using
C{_formatRoot}) and a formatted traceback.
"""
e = self.makeFlattenerError(['a', 'b'])
e._formatRoot = self.fakeFormatRoot
self.assertTrue(
re.match('Exception while flattening:\n'
' R\(a\)\n'
' R\(b\)\n'
' File "[^"]*", line [0-9]*, in makeFlattenerError\n'
' raise RuntimeError\("oh noes"\)\n'
'RuntimeError: oh noes\n$',
repr(e), re.M | re.S),
repr(e))
def test_reprWithoutRootsAndWithTraceback(self):
"""
The representation of a L{FlattenerError} initialized without roots but
with a traceback contains a formatted traceback but no roots.
"""
e = self.makeFlattenerError([])
self.assertTrue(
re.match('Exception while flattening:\n'
' File "[^"]*", line [0-9]*, in makeFlattenerError\n'
' raise RuntimeError\("oh noes"\)\n'
'RuntimeError: oh noes\n$',
repr(e), re.M | re.S),
repr(e))
def test_reprWithoutRootsAndWithoutTraceback(self):
"""
The representation of a L{FlattenerError} initialized without roots but
with a traceback contains a formatted traceback but no roots.
"""
e = error.FlattenerError(RuntimeError("oh noes"), [], None)
self.assertTrue(
re.match('Exception while flattening:\n'
'RuntimeError: oh noes\n$',
repr(e), re.M | re.S),
repr(e))
def test_formatRootShortUnicodeString(self):
"""
The C{_formatRoot} method formats a short unicode string using the
built-in repr.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(nativeString('abcd')), repr('abcd'))
def test_formatRootLongUnicodeString(self):
"""
The C{_formatRoot} method formats a long unicode string using the
built-in repr with an ellipsis.
"""
e = self.makeFlattenerError()
longString = nativeString('abcde-' * 20)
self.assertEqual(e._formatRoot(longString),
repr('abcde-abcde-abcde-ab<...>e-abcde-abcde-abcde-'))
def test_formatRootShortByteString(self):
"""
The C{_formatRoot} method formats a short byte string using the
built-in repr.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(b'abcd'), repr(b'abcd'))
def test_formatRootLongByteString(self):
"""
The C{_formatRoot} method formats a long byte string using the
built-in repr with an ellipsis.
"""
e = self.makeFlattenerError()
longString = b'abcde-' * 20
self.assertEqual(e._formatRoot(longString),
repr(b'abcde-abcde-abcde-ab<...>e-abcde-abcde-abcde-'))
def test_formatRootTagNoFilename(self):
"""
The C{_formatRoot} method formats a C{Tag} with no filename information
as 'Tag <tagName>'.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(Tag('a-tag')), 'Tag <a-tag>')
def test_formatRootTagWithFilename(self):
"""
The C{_formatRoot} method formats a C{Tag} with filename information
using the filename, line, column, and tag information
"""
e = self.makeFlattenerError()
t = Tag('a-tag', filename='tpl.py', lineNumber=10, columnNumber=20)
self.assertEqual(e._formatRoot(t),
'File "tpl.py", line 10, column 20, in "a-tag"')
def test_string(self):
"""
If a L{FlattenerError} is created with a string root, up to around 40
bytes from that string are included in the string representation of the
exception.
"""
self.assertEqual(
str(error.FlattenerError(RuntimeError("reason"),
['abc123xyz'], [])),
"Exception while flattening:\n"
" 'abc123xyz'\n"
"RuntimeError: reason\n")
self.assertEqual(
str(error.FlattenerError(
RuntimeError("reason"), ['0123456789' * 10], [])),
"Exception while flattening:\n"
" '01234567890123456789"
"<...>01234567890123456789'\n" # TODO: re-add 0
"RuntimeError: reason\n")
def test_unicode(self):
"""
If a L{FlattenerError} is created with a unicode root, up to around 40
characters from that string are included in the string representation
of the exception.
"""
# the response includes the output of repr(), which differs between
# Python 2 and 3
u = {'u': ''} if _PY3 else {'u': 'u'}
self.assertEqual(
str(error.FlattenerError(
RuntimeError("reason"), [u'abc\N{SNOWMAN}xyz'], [])),
"Exception while flattening:\n"
" %(u)s'abc\\u2603xyz'\n" # Codepoint for SNOWMAN
"RuntimeError: reason\n" % u)
self.assertEqual(
str(error.FlattenerError(
RuntimeError("reason"), [u'01234567\N{SNOWMAN}9' * 10],
[])),
"Exception while flattening:\n"
" %(u)s'01234567\\u2603901234567\\u26039"
"<...>01234567\\u2603901234567"
"\\u26039'\n"
"RuntimeError: reason\n" % u)
class UnsupportedMethodTests(unittest.SynchronousTestCase):
"""
Tests for L{UnsupportedMethod}.
"""
def test_str(self):
"""
The C{__str__} for L{UnsupportedMethod} makes it clear that what it
shows is a list of the supported methods, not the method that was
unsupported.
"""
b = "b" if _PY3 else ""
e = error.UnsupportedMethod([b"HEAD", b"PATCH"])
self.assertEqual(
str(e), "Expected one of [{b}'HEAD', {b}'PATCH']".format(b=b),
)

View file

@ -0,0 +1,537 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the flattening portion of L{twisted.web.template}, implemented in
L{twisted.web._flatten}.
"""
import sys
import traceback
from xml.etree.cElementTree import XML
from collections import OrderedDict
from zope.interface import implementer
from twisted.python.compat import _PY35PLUS
from twisted.trial.unittest import TestCase
from twisted.test.testutils import XMLAssertionMixin
from twisted.internet.defer import passthru, succeed, gatherResults
from twisted.web.iweb import IRenderable
from twisted.web.error import UnfilledSlot, UnsupportedType, FlattenerError
from twisted.web.template import tags, Tag, Comment, CDATA, CharRef, slot
from twisted.web.template import Element, renderer, TagLoader, flattenString
from twisted.web.test._util import FlattenTestCase
class SerializationTests(FlattenTestCase, XMLAssertionMixin):
"""
Tests for flattening various things.
"""
def test_nestedTags(self):
"""
Test that nested tags flatten correctly.
"""
return self.assertFlattensTo(
tags.html(tags.body('42'), hi='there'),
b'<html hi="there"><body>42</body></html>')
def test_serializeString(self):
"""
Test that strings will be flattened and escaped correctly.
"""
return gatherResults([
self.assertFlattensTo('one', b'one'),
self.assertFlattensTo('<abc&&>123', b'&lt;abc&amp;&amp;&gt;123'),
])
def test_serializeSelfClosingTags(self):
"""
The serialized form of a self-closing tag is C{'<tagName />'}.
"""
return self.assertFlattensTo(tags.img(), b'<img />')
def test_serializeAttribute(self):
"""
The serialized form of attribute I{a} with value I{b} is C{'a="b"'}.
"""
self.assertFlattensImmediately(tags.img(src='foo'),
b'<img src="foo" />')
def test_serializedMultipleAttributes(self):
"""
Multiple attributes are separated by a single space in their serialized
form.
"""
tag = tags.img()
tag.attributes = OrderedDict([("src", "foo"), ("name", "bar")])
self.assertFlattensImmediately(tag, b'<img src="foo" name="bar" />')
def checkAttributeSanitization(self, wrapData, wrapTag):
"""
Common implementation of L{test_serializedAttributeWithSanitization}
and L{test_serializedDeferredAttributeWithSanitization},
L{test_serializedAttributeWithTransparentTag}.
@param wrapData: A 1-argument callable that wraps around the
attribute's value so other tests can customize it.
@param wrapData: callable taking L{bytes} and returning something
flattenable
@param wrapTag: A 1-argument callable that wraps around the outer tag
so other tests can customize it.
@type wrapTag: callable taking L{Tag} and returning L{Tag}.
"""
self.assertFlattensImmediately(
wrapTag(tags.img(src=wrapData("<>&\""))),
b'<img src="&lt;&gt;&amp;&quot;" />')
def test_serializedAttributeWithSanitization(self):
"""
Attribute values containing C{"<"}, C{">"}, C{"&"}, or C{'"'} have
C{"&lt;"}, C{"&gt;"}, C{"&amp;"}, or C{"&quot;"} substituted for those
bytes in the serialized output.
"""
self.checkAttributeSanitization(passthru, passthru)
def test_serializedDeferredAttributeWithSanitization(self):
"""
Like L{test_serializedAttributeWithSanitization}, but when the contents
of the attribute are in a L{Deferred
<twisted.internet.defer.Deferred>}.
"""
self.checkAttributeSanitization(succeed, passthru)
def test_serializedAttributeWithSlotWithSanitization(self):
"""
Like L{test_serializedAttributeWithSanitization} but with a slot.
"""
toss = []
self.checkAttributeSanitization(
lambda value: toss.append(value) or slot("stuff"),
lambda tag: tag.fillSlots(stuff=toss.pop())
)
def test_serializedAttributeWithTransparentTag(self):
"""
Attribute values which are supplied via the value of a C{t:transparent}
tag have the same substitution rules to them as values supplied
directly.
"""
self.checkAttributeSanitization(tags.transparent, passthru)
def test_serializedAttributeWithTransparentTagWithRenderer(self):
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is rendered by a renderer on an element.
"""
class WithRenderer(Element):
def __init__(self, value, loader):
self.value = value
super(WithRenderer, self).__init__(loader)
@renderer
def stuff(self, request, tag):
return self.value
toss = []
self.checkAttributeSanitization(
lambda value: toss.append(value) or
tags.transparent(render="stuff"),
lambda tag: WithRenderer(toss.pop(), TagLoader(tag))
)
def test_serializedAttributeWithRenderable(self):
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is a provider of L{IRenderable} rather than a transparent
tag.
"""
@implementer(IRenderable)
class Arbitrary(object):
def __init__(self, value):
self.value = value
def render(self, request):
return self.value
self.checkAttributeSanitization(Arbitrary, passthru)
def checkTagAttributeSerialization(self, wrapTag):
"""
Common implementation of L{test_serializedAttributeWithTag} and
L{test_serializedAttributeWithDeferredTag}.
@param wrapTag: A 1-argument callable that wraps around the attribute's
value so other tests can customize it.
@param wrapTag: callable taking L{Tag} and returning something
flattenable
"""
innerTag = tags.a('<>&"')
outerTag = tags.img(src=wrapTag(innerTag))
outer = self.assertFlattensImmediately(
outerTag,
b'<img src="&lt;a&gt;&amp;lt;&amp;gt;&amp;amp;&quot;&lt;/a&gt;" />')
inner = self.assertFlattensImmediately(
innerTag, b'<a>&lt;&gt;&amp;"</a>')
# Since the above quoting is somewhat tricky, validate it by making sure
# that the main use-case for tag-within-attribute is supported here: if
# we serialize a tag, it is quoted *such that it can be parsed out again
# as a tag*.
self.assertXMLEqual(XML(outer).attrib['src'], inner)
def test_serializedAttributeWithTag(self):
"""
L{Tag} objects which are serialized within the context of an attribute
are serialized such that the text content of the attribute may be
parsed to retrieve the tag.
"""
self.checkTagAttributeSerialization(passthru)
def test_serializedAttributeWithDeferredTag(self):
"""
Like L{test_serializedAttributeWithTag}, but when the L{Tag} is in a
L{Deferred <twisted.internet.defer.Deferred>}.
"""
self.checkTagAttributeSerialization(succeed)
def test_serializedAttributeWithTagWithAttribute(self):
"""
Similar to L{test_serializedAttributeWithTag}, but for the additional
complexity where the tag which is the attribute value itself has an
attribute value which contains bytes which require substitution.
"""
flattened = self.assertFlattensImmediately(
tags.img(src=tags.a(href='<>&"')),
b'<img src="&lt;a href='
b'&quot;&amp;lt;&amp;gt;&amp;amp;&amp;quot;&quot;&gt;'
b'&lt;/a&gt;" />')
# As in checkTagAttributeSerialization, belt-and-suspenders:
self.assertXMLEqual(XML(flattened).attrib['src'],
b'<a href="&lt;&gt;&amp;&quot;"></a>')
def test_serializeComment(self):
"""
Test that comments are correctly flattened and escaped.
"""
return self.assertFlattensTo(Comment('foo bar'), b'<!--foo bar-->'),
def test_commentEscaping(self):
"""
The data in a L{Comment} is escaped and mangled in the flattened output
so that the result is a legal SGML and XML comment.
SGML comment syntax is complicated and hard to use. This rule is more
restrictive, and more compatible:
Comments start with <!-- and end with --> and never contain -- or >.
Also by XML syntax, a comment may not end with '-'.
@see: U{http://www.w3.org/TR/REC-xml/#sec-comments}
"""
def verifyComment(c):
self.assertTrue(
c.startswith(b'<!--'),
"%r does not start with the comment prefix" % (c,))
self.assertTrue(
c.endswith(b'-->'),
"%r does not end with the comment suffix" % (c,))
# If it is shorter than 7, then the prefix and suffix overlap
# illegally.
self.assertTrue(
len(c) >= 7,
"%r is too short to be a legal comment" % (c,))
content = c[4:-3]
self.assertNotIn(b'--', content)
self.assertNotIn(b'>', content)
if content:
self.assertNotEqual(content[-1], b'-')
results = []
for c in [
'',
'foo---bar',
'foo---bar-',
'foo>bar',
'foo-->bar',
'----------------',
]:
d = flattenString(None, Comment(c))
d.addCallback(verifyComment)
results.append(d)
return gatherResults(results)
def test_serializeCDATA(self):
"""
Test that CDATA is correctly flattened and escaped.
"""
return gatherResults([
self.assertFlattensTo(CDATA('foo bar'), b'<![CDATA[foo bar]]>'),
self.assertFlattensTo(
CDATA('foo ]]> bar'),
b'<![CDATA[foo ]]]]><![CDATA[> bar]]>'),
])
def test_serializeUnicode(self):
"""
Test that unicode is encoded correctly in the appropriate places, and
raises an error when it occurs in inappropriate place.
"""
snowman = u'\N{SNOWMAN}'
return gatherResults([
self.assertFlattensTo(snowman, b'\xe2\x98\x83'),
self.assertFlattensTo(tags.p(snowman), b'<p>\xe2\x98\x83</p>'),
self.assertFlattensTo(Comment(snowman), b'<!--\xe2\x98\x83-->'),
self.assertFlattensTo(CDATA(snowman), b'<![CDATA[\xe2\x98\x83]]>'),
self.assertFlatteningRaises(
Tag(snowman), UnicodeEncodeError),
self.assertFlatteningRaises(
Tag('p', attributes={snowman: ''}), UnicodeEncodeError),
])
def test_serializeCharRef(self):
"""
A character reference is flattened to a string using the I{&#NNNN;}
syntax.
"""
ref = CharRef(ord(u"\N{SNOWMAN}"))
return self.assertFlattensTo(ref, b"&#9731;")
def test_serializeDeferred(self):
"""
Test that a deferred is substituted with the current value in the
callback chain when flattened.
"""
return self.assertFlattensTo(succeed('two'), b'two')
def test_serializeSameDeferredTwice(self):
"""
Test that the same deferred can be flattened twice.
"""
d = succeed('three')
return gatherResults([
self.assertFlattensTo(d, b'three'),
self.assertFlattensTo(d, b'three'),
])
def test_serializeCoroutine(self):
"""
Test that a coroutine returning a value is substituted with the that
value when flattened.
"""
from textwrap import dedent
namespace = {}
exec(dedent(
"""
async def coro(x):
return x
"""
), namespace)
coro = namespace["coro"]
return self.assertFlattensTo(coro('four'), b'four')
if not _PY35PLUS:
test_serializeCoroutine.skip = (
"coroutines not available before Python 3.5"
)
def test_serializeCoroutineWithAwait(self):
"""
Test that a coroutine returning an awaited deferred value is
substituted with that value when flattened.
"""
from textwrap import dedent
namespace = dict(succeed=succeed)
exec(dedent(
"""
async def coro(x):
return await succeed(x)
"""
), namespace)
coro = namespace["coro"]
return self.assertFlattensTo(coro('four'), b'four')
if not _PY35PLUS:
test_serializeCoroutineWithAwait.skip = (
"coroutines not available before Python 3.5"
)
def test_serializeIRenderable(self):
"""
Test that flattening respects all of the IRenderable interface.
"""
@implementer(IRenderable)
class FakeElement(object):
def render(ign,ored):
return tags.p(
'hello, ',
tags.transparent(render='test'), ' - ',
tags.transparent(render='test'))
def lookupRenderMethod(ign, name):
self.assertEqual(name, 'test')
return lambda ign, node: node('world')
return gatherResults([
self.assertFlattensTo(FakeElement(), b'<p>hello, world - world</p>'),
])
def test_serializeSlots(self):
"""
Test that flattening a slot will use the slot value from the tag.
"""
t1 = tags.p(slot('test'))
t2 = t1.clone()
t2.fillSlots(test='hello, world')
return gatherResults([
self.assertFlatteningRaises(t1, UnfilledSlot),
self.assertFlattensTo(t2, b'<p>hello, world</p>'),
])
def test_serializeDeferredSlots(self):
"""
Test that a slot with a deferred as its value will be flattened using
the value from the deferred.
"""
t = tags.p(slot('test'))
t.fillSlots(test=succeed(tags.em('four>')))
return self.assertFlattensTo(t, b'<p><em>four&gt;</em></p>')
def test_unknownTypeRaises(self):
"""
Test that flattening an unknown type of thing raises an exception.
"""
return self.assertFlatteningRaises(None, UnsupportedType)
# Use the co_filename mechanism (instead of the __file__ mechanism) because
# it is the mechanism traceback formatting uses. The two do not necessarily
# agree with each other. This requires a code object compiled in this file.
# The easiest way to get a code object is with a new function. I'll use a
# lambda to avoid adding anything else to this namespace. The result will
# be a string which agrees with the one the traceback module will put into a
# traceback for frames associated with functions defined in this file.
HERE = (lambda: None).__code__.co_filename
class FlattenerErrorTests(TestCase):
"""
Tests for L{FlattenerError}.
"""
def test_renderable(self):
"""
If a L{FlattenerError} is created with an L{IRenderable} provider root,
the repr of that object is included in the string representation of the
exception.
"""
@implementer(IRenderable)
class Renderable(object):
def __repr__(self):
return "renderable repr"
self.assertEqual(
str(FlattenerError(
RuntimeError("reason"), [Renderable()], [])),
"Exception while flattening:\n"
" renderable repr\n"
"RuntimeError: reason\n")
def test_tag(self):
"""
If a L{FlattenerError} is created with a L{Tag} instance with source
location information, the source location is included in the string
representation of the exception.
"""
tag = Tag(
'div', filename='/foo/filename.xhtml', lineNumber=17, columnNumber=12)
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [tag], [])),
"Exception while flattening:\n"
" File \"/foo/filename.xhtml\", line 17, column 12, in \"div\"\n"
"RuntimeError: reason\n")
def test_tagWithoutLocation(self):
"""
If a L{FlattenerError} is created with a L{Tag} instance without source
location information, only the tagName is included in the string
representation of the exception.
"""
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [Tag('span')], [])),
"Exception while flattening:\n"
" Tag <span>\n"
"RuntimeError: reason\n")
def test_traceback(self):
"""
If a L{FlattenerError} is created with traceback frames, they are
included in the string representation of the exception.
"""
# Try to be realistic in creating the data passed in for the traceback
# frames.
def f():
g()
def g():
raise RuntimeError("reason")
try:
f()
except RuntimeError as e:
# Get the traceback, minus the info for *this* frame
tbinfo = traceback.extract_tb(sys.exc_info()[2])[1:]
exc = e
else:
self.fail("f() must raise RuntimeError")
self.assertEqual(
str(FlattenerError(exc, [], tbinfo)),
"Exception while flattening:\n"
" File \"%s\", line %d, in f\n"
" g()\n"
" File \"%s\", line %d, in g\n"
" raise RuntimeError(\"reason\")\n"
"RuntimeError: reason\n" % (
HERE, f.__code__.co_firstlineno + 1,
HERE, g.__code__.co_firstlineno + 1))

View file

@ -0,0 +1,43 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.web import html
class WebHtmlTests(unittest.TestCase):
"""
Unit tests for L{twisted.web.html}.
"""
def test_deprecation(self):
"""
Calls to L{twisted.web.html} members emit a deprecation warning.
"""
def assertDeprecationWarningOf(method):
"""
Check that a deprecation warning is present.
"""
warningsShown = self.flushWarnings([self.test_deprecation])
self.assertEqual(len(warningsShown), 1)
self.assertIdentical(
warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
'twisted.web.html.%s was deprecated in Twisted 15.3.0; '
'please use twisted.web.template instead' % (
method,),
)
html.PRE('')
assertDeprecationWarningOf('PRE')
html.UL([])
assertDeprecationWarningOf('UL')
html.linkList([])
assertDeprecationWarningOf('linkList')
html.output(lambda: None)
assertDeprecationWarningOf('output')

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,660 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.http_headers}.
"""
from __future__ import division, absolute_import
from twisted.trial.unittest import TestCase
from twisted.python.compat import _PY3, unicode
from twisted.web.http_headers import Headers
from twisted.web.test.requesthelper import (
bytesLinearWhitespaceComponents,
sanitizedBytes,
textLinearWhitespaceComponents,
)
def assertSanitized(testCase, components, expected):
"""
Assert that the components are sanitized to the expected value as
both a header name and value, across all of L{Header}'s setters
and getters.
@param testCase: A test case.
@param components: A sequence of values that contain linear
whitespace to use as header names and values; see
C{textLinearWhitespaceComponents} and
C{bytesLinearWhitespaceComponents}
@param expected: The expected sanitized form of the component for
both headers names and their values.
"""
for component in components:
headers = []
headers.append(Headers({component: [component]}))
added = Headers()
added.addRawHeader(component, component)
headers.append(added)
setHeader = Headers()
setHeader.setRawHeaders(component, [component])
headers.append(setHeader)
for header in headers:
testCase.assertEqual(list(header.getAllRawHeaders()),
[(expected, [expected])])
testCase.assertEqual(header.getRawHeaders(expected), [expected])
class BytesHeadersTests(TestCase):
"""
Tests for L{Headers}, using L{bytes} arguments for methods.
"""
def test_sanitizeLinearWhitespace(self):
"""
Linear whitespace in header names or values is replaced with a
single space.
"""
assertSanitized(self, bytesLinearWhitespaceComponents, sanitizedBytes)
def test_initializer(self):
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}.
"""
h = Headers({b'Foo': [b'bar']})
self.assertEqual(h.getRawHeaders(b'foo'), [b'bar'])
def test_setRawHeaders(self):
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of byte string values.
"""
rawValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders(b"test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertEqual(h.getRawHeaders(b"test"), rawValue)
def test_rawHeadersTypeChecking(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
"""
h = Headers()
self.assertRaises(TypeError, h.setRawHeaders, b'key', {b'Foo': b'bar'})
def test_addRawHeader(self):
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader(b"test", b"lemur")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
h.addRawHeader(b"test", b"panda")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self):
"""
L{Headers.getRawHeaders} returns L{None} if the header is not found and
no default is specified.
"""
self.assertIsNone(Headers().getRawHeaders(b"test"))
def test_getRawHeadersDefaultValue(self):
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders(b"test", default), default)
def test_getRawHeadersWithDefaultMatchingValue(self):
"""
If the object passed as the value list to L{Headers.setRawHeaders}
is later passed as a default to L{Headers.getRawHeaders}, the
result nevertheless contains encoded values.
"""
h = Headers()
default = [u"value"]
h.setRawHeaders(b"key", default)
self.assertIsInstance(h.getRawHeaders(b"key", default)[0], bytes)
self.assertEqual(h.getRawHeaders(b"key", default), [b"value"])
def test_getRawHeaders(self):
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test"), [b"lemur"])
def test_hasHeaderTrue(self):
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
def test_hasHeaderFalse(self):
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader(b"test"))
def test_removeHeader(self):
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders(b"foo", [b"lemur"])
self.assertTrue(h.hasHeader(b"foo"))
h.removeHeader(b"foo")
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders(b"bar", [b"panda"])
self.assertTrue(h.hasHeader(b"bar"))
h.removeHeader(b"Bar")
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self):
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader(b"test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_canonicalNameCaps(self):
"""
L{Headers._canonicalNameCaps} returns the canonical capitalization for
the given header.
"""
h = Headers()
self.assertEqual(h._canonicalNameCaps(b"test"), b"Test")
self.assertEqual(h._canonicalNameCaps(b"test-stuff"), b"Test-Stuff")
self.assertEqual(h._canonicalNameCaps(b"content-md5"), b"Content-MD5")
self.assertEqual(h._canonicalNameCaps(b"dnt"), b"DNT")
self.assertEqual(h._canonicalNameCaps(b"etag"), b"ETag")
self.assertEqual(h._canonicalNameCaps(b"p3p"), b"P3P")
self.assertEqual(h._canonicalNameCaps(b"te"), b"TE")
self.assertEqual(h._canonicalNameCaps(b"www-authenticate"),
b"WWW-Authenticate")
self.assertEqual(h._canonicalNameCaps(b"x-xss-protection"),
b"X-XSS-Protection")
def test_getAllRawHeaders(self):
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemurs"])
h.setRawHeaders(b"www-authenticate", [b"basic aksljdlk="])
allHeaders = set([(k, tuple(v)) for k, v in h.getAllRawHeaders()])
self.assertEqual(allHeaders,
set([(b"WWW-Authenticate", (b"basic aksljdlk=",)),
(b"Test", (b"lemurs",))]))
def test_headersComparison(self):
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders(b"foo", [b"panda"])
second = Headers()
second.setRawHeaders(b"foo", [b"panda"])
third = Headers()
third.setRawHeaders(b"foo", [b"lemur", b"panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
def test_otherComparison(self):
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, b"foo")
def test_repr(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%r: [%r, %r]})" % (foo, bar, baz))
def test_reprWithRawBytes(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains, not attempting to decode any raw bytes.
"""
# There's no such thing as undecodable latin-1, you'll just get
# some mojibake
foo = b"foo"
# But this is invalid UTF-8! So, any accidental decoding/encoding will
# throw an exception.
bar = b"bar\xe1"
baz = b"baz\xe1"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%r: [%r, %r]})" % (foo, bar, baz))
def test_subclassRepr(self):
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
"FunnyHeaders({%r: [%r, %r]})" % (foo, bar, baz))
def test_copy(self):
"""
L{Headers.copy} creates a new independent copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders(b'test', [b'foo'])
i = h.copy()
self.assertEqual(i.getRawHeaders(b'test'), [b'foo'])
h.addRawHeader(b'test', b'bar')
self.assertEqual(i.getRawHeaders(b'test'), [b'foo'])
i.addRawHeader(b'test', b'baz')
self.assertEqual(h.getRawHeaders(b'test'), [b'foo', b'bar'])
class UnicodeHeadersTests(TestCase):
"""
Tests for L{Headers}, using L{unicode} arguments for methods.
"""
def test_sanitizeLinearWhitespace(self):
"""
Linear whitespace in header names or values is replaced with a
single space.
"""
assertSanitized(self, textLinearWhitespaceComponents, sanitizedBytes)
def test_initializer(self):
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}. If a L{bytes} argument is given, it returns
L{bytes} values, and if a L{unicode} argument is given, it returns
L{unicode} values. Both are the same header value, just encoded or
decoded.
"""
h = Headers({u'Foo': [u'bar']})
self.assertEqual(h.getRawHeaders(b'foo'), [b'bar'])
self.assertEqual(h.getRawHeaders(u'foo'), [u'bar'])
def test_setRawHeaders(self):
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of strings, encoded.
"""
rawValue = [u"value1", u"value2"]
rawEncodedValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders("test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertTrue(h.hasHeader("test"))
self.assertTrue(h.hasHeader("Test"))
self.assertEqual(h.getRawHeaders("test"), rawValue)
self.assertEqual(h.getRawHeaders(b"test"), rawEncodedValue)
def test_nameNotEncodable(self):
"""
Passing L{unicode} to any function that takes a header name will encode
said header name as ISO-8859-1, and if it cannot be encoded, it will
raise a L{UnicodeDecodeError}.
"""
h = Headers()
# Only these two functions take names
with self.assertRaises(UnicodeEncodeError):
h.setRawHeaders(u"\u2603", [u"val"])
with self.assertRaises(UnicodeEncodeError):
h.hasHeader(u"\u2603")
def test_nameEncoding(self):
"""
Passing L{unicode} to any function that takes a header name will encode
said header name as ISO-8859-1.
"""
h = Headers()
# We set it using a Unicode string.
h.setRawHeaders(u"\u00E1", [b"foo"])
# It's encoded to the ISO-8859-1 value, which we can use to access it
self.assertTrue(h.hasHeader(b"\xe1"))
self.assertEqual(h.getRawHeaders(b"\xe1"), [b'foo'])
# We can still access it using the Unicode string..
self.assertTrue(h.hasHeader(u"\u00E1"))
def test_rawHeadersValueEncoding(self):
"""
Passing L{unicode} to L{Headers.setRawHeaders} will encode the name as
ISO-8859-1 and values as UTF-8.
"""
h = Headers()
h.setRawHeaders(u"\u00E1", [u"\u2603", b"foo"])
self.assertTrue(h.hasHeader(b"\xe1"))
self.assertEqual(h.getRawHeaders(b"\xe1"), [b'\xe2\x98\x83', b'foo'])
def test_rawHeadersTypeChecking(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
"""
h = Headers()
self.assertRaises(TypeError, h.setRawHeaders, u'key', {u'Foo': u'bar'})
def test_addRawHeader(self):
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader(u"test", u"lemur")
self.assertEqual(h.getRawHeaders(u"test"), [u"lemur"])
h.addRawHeader(u"test", u"panda")
self.assertEqual(h.getRawHeaders(u"test"), [u"lemur", u"panda"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self):
"""
L{Headers.getRawHeaders} returns L{None} if the header is not found and
no default is specified.
"""
self.assertIsNone(Headers().getRawHeaders(u"test"))
def test_getRawHeadersDefaultValue(self):
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders(u"test", default), default)
self.assertIdentical(h.getRawHeaders(u"test", None), None)
self.assertEqual(h.getRawHeaders(u"test", [None]), [None])
self.assertEqual(
h.getRawHeaders(u"test", [u"\N{SNOWMAN}"]),
[u"\N{SNOWMAN}"],
)
def test_getRawHeadersWithDefaultMatchingValue(self):
"""
If the object passed as the value list to L{Headers.setRawHeaders}
is later passed as a default to L{Headers.getRawHeaders}, the
result nevertheless contains decoded values.
"""
h = Headers()
default = [b"value"]
h.setRawHeaders(b"key", default)
self.assertIsInstance(h.getRawHeaders(u"key", default)[0], unicode)
self.assertEqual(h.getRawHeaders(u"key", default), [u"value"])
def test_getRawHeaders(self):
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders(u"test\u00E1", [u"lemur"])
self.assertEqual(h.getRawHeaders(u"test\u00E1"), [u"lemur"])
self.assertEqual(h.getRawHeaders(u"Test\u00E1"), [u"lemur"])
self.assertEqual(h.getRawHeaders(b"test\xe1"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test\xe1"), [b"lemur"])
def test_hasHeaderTrue(self):
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders(u"test\u00E1", [u"lemur"])
self.assertTrue(h.hasHeader(u"test\u00E1"))
self.assertTrue(h.hasHeader(u"Test\u00E1"))
self.assertTrue(h.hasHeader(b"test\xe1"))
self.assertTrue(h.hasHeader(b"Test\xe1"))
def test_hasHeaderFalse(self):
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader(u"test\u00E1"))
def test_removeHeader(self):
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders(u"foo", [u"lemur"])
self.assertTrue(h.hasHeader(u"foo"))
h.removeHeader(u"foo")
self.assertFalse(h.hasHeader(u"foo"))
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders(u"bar", [u"panda"])
self.assertTrue(h.hasHeader(u"bar"))
h.removeHeader(u"Bar")
self.assertFalse(h.hasHeader(u"bar"))
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self):
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader(u"test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_getAllRawHeaders(self):
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders(u"test\u00E1", [u"lemurs"])
h.setRawHeaders(u"www-authenticate", [u"basic aksljdlk="])
h.setRawHeaders(u"content-md5", [u"kjdfdfgdfgnsd"])
allHeaders = set([(k, tuple(v)) for k, v in h.getAllRawHeaders()])
self.assertEqual(allHeaders,
set([(b"WWW-Authenticate", (b"basic aksljdlk=",)),
(b"Content-MD5", (b"kjdfdfgdfgnsd",)),
(b"Test\xe1", (b"lemurs",))]))
def test_headersComparison(self):
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders(u"foo\u00E1", [u"panda"])
second = Headers()
second.setRawHeaders(u"foo\u00E1", [u"panda"])
third = Headers()
third.setRawHeaders(u"foo\u00E1", [u"lemur", u"panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
# Headers instantiated with bytes equivs are also the same
firstBytes = Headers()
firstBytes.setRawHeaders(b"foo\xe1", [b"panda"])
secondBytes = Headers()
secondBytes.setRawHeaders(b"foo\xe1", [b"panda"])
thirdBytes = Headers()
thirdBytes.setRawHeaders(b"foo\xe1", [b"lemur", u"panda"])
self.assertEqual(first, firstBytes)
self.assertEqual(second, secondBytes)
self.assertEqual(third, thirdBytes)
def test_otherComparison(self):
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, u"foo")
def test_repr(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains. This shows only reprs of bytes values, as
undecodable headers may cause an exception.
"""
foo = u"foo\u00E1"
bar = u"bar\u2603"
baz = u"baz"
fooEncoded = "'foo\\xe1'"
barEncoded = "'bar\\xe2\\x98\\x83'"
if _PY3:
fooEncoded = "b" + fooEncoded
barEncoded = "b" + barEncoded
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%s: [%s, %r]})" % (fooEncoded,
barEncoded,
baz.encode('utf8')))
def test_subclassRepr(self):
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = u"foo\u00E1"
bar = u"bar\u2603"
baz = u"baz"
fooEncoded = "'foo\\xe1'"
barEncoded = "'bar\\xe2\\x98\\x83'"
if _PY3:
fooEncoded = "b" + fooEncoded
barEncoded = "b" + barEncoded
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
"FunnyHeaders({%s: [%s, %r]})" % (fooEncoded,
barEncoded,
baz.encode('utf8')))
def test_copy(self):
"""
L{Headers.copy} creates a new independent copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders(u'test\u00E1', [u'foo\u2603'])
i = h.copy()
# The copy contains the same value as the original
self.assertEqual(i.getRawHeaders(u'test\u00E1'), [u'foo\u2603'])
self.assertEqual(i.getRawHeaders(b'test\xe1'), [b'foo\xe2\x98\x83'])
# Add a header to the original
h.addRawHeader(u'test\u00E1', u'bar')
# Verify that the copy has not changed
self.assertEqual(i.getRawHeaders(u'test\u00E1'), [u'foo\u2603'])
self.assertEqual(i.getRawHeaders(b'test\xe1'), [b'foo\xe2\x98\x83'])
# Add a header to the copy
i.addRawHeader(u'test\u00E1', b'baz')
# Verify that the orignal does not have it
self.assertEqual(
h.getRawHeaders(u'test\u00E1'), [u'foo\u2603', u'bar'])
self.assertEqual(
h.getRawHeaders(b'test\xe1'), [b'foo\xe2\x98\x83', b'bar'])

View file

@ -0,0 +1,677 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._auth}.
"""
from __future__ import division, absolute_import
import base64
from zope.interface import implementer
from zope.interface.verify import verifyObject
from twisted.trial import unittest
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone
from twisted.internet.address import IPv4Address
from twisted.cred import error, portal
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.checkers import ANONYMOUS, AllowAnonymousAccess
from twisted.cred.credentials import IUsernamePassword
from twisted.web.iweb import ICredentialFactory
from twisted.web.resource import IResource, Resource, getChildForRequest
from twisted.web._auth import basic, digest
from twisted.web._auth.wrapper import HTTPAuthSessionWrapper, UnauthorizedResource
from twisted.web._auth.basic import BasicCredentialFactory
from twisted.web.server import NOT_DONE_YET
from twisted.web.static import Data
from twisted.web.test.test_web import DummyRequest
from twisted.test.proto_helpers import EventLoggingObserver
from twisted.logger import globalLogPublisher
def b64encode(s):
return base64.b64encode(s).strip()
class BasicAuthTestsMixin:
"""
L{TestCase} mixin class which defines a number of tests for
L{basic.BasicCredentialFactory}. Because this mixin defines C{setUp}, it
must be inherited before L{TestCase}.
"""
def setUp(self):
self.request = self.makeRequest()
self.realm = b'foo'
self.username = b'dreid'
self.password = b'S3CuR1Ty'
self.credentialFactory = basic.BasicCredentialFactory(self.realm)
def makeRequest(self, method=b'GET', clientAddress=None):
"""
Create a request object to be passed to
L{basic.BasicCredentialFactory.decode} along with a response value.
Override this in a subclass.
"""
raise NotImplementedError("%r did not implement makeRequest" % (
self.__class__,))
def test_interface(self):
"""
L{BasicCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(
verifyObject(ICredentialFactory, self.credentialFactory))
def test_usernamePassword(self):
"""
L{basic.BasicCredentialFactory.decode} turns a base64-encoded response
into a L{UsernamePassword} object with a password which reflects the
one which was encoded in the response.
"""
response = b64encode(b''.join([self.username, b':', self.password]))
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(IUsernamePassword.providedBy(creds))
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + b'wrong'))
def test_incorrectPadding(self):
"""
L{basic.BasicCredentialFactory.decode} decodes a base64-encoded
response with incorrect padding.
"""
response = b64encode(b''.join([self.username, b':', self.password]))
response = response.strip(b'=')
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(verifyObject(IUsernamePassword, creds))
self.assertTrue(creds.checkPassword(self.password))
def test_invalidEncoding(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} if passed
a response which is not base64-encoded.
"""
response = b'x' # one byte cannot be valid base64 text
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode, response, self.makeRequest())
def test_invalidCredentials(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} when
passed a response which is not valid base64-encoded text.
"""
response = b64encode(b'123abc+/')
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode,
response, self.makeRequest())
class RequestMixin:
def makeRequest(self, method=b'GET', clientAddress=None):
"""
Create a L{DummyRequest} (change me to create a
L{twisted.web.http.Request} instead).
"""
if clientAddress is None:
clientAddress = IPv4Address("TCP", "localhost", 1234)
request = DummyRequest(b'/')
request.method = method
request.client = clientAddress
return request
class BasicAuthTests(RequestMixin, BasicAuthTestsMixin, unittest.TestCase):
"""
Basic authentication tests which use L{twisted.web.http.Request}.
"""
class DigestAuthTests(RequestMixin, unittest.TestCase):
"""
Digest authentication tests which use L{twisted.web.http.Request}.
"""
def setUp(self):
"""
Create a DigestCredentialFactory for testing
"""
self.realm = b"test realm"
self.algorithm = b"md5"
self.credentialFactory = digest.DigestCredentialFactory(
self.algorithm, self.realm)
self.request = self.makeRequest()
def test_decode(self):
"""
L{digest.DigestCredentialFactory.decode} calls the C{decode} method on
L{twisted.cred.digest.DigestCredentialFactory} with the HTTP method and
host of the request.
"""
host = b'169.254.0.1'
method = b'GET'
done = [False]
response = object()
def check(_response, _method, _host):
self.assertEqual(response, _response)
self.assertEqual(method, _method)
self.assertEqual(host, _host)
done[0] = True
self.patch(self.credentialFactory.digest, 'decode', check)
req = self.makeRequest(method, IPv4Address('TCP', host, 81))
self.credentialFactory.decode(response, req)
self.assertTrue(done[0])
def test_interface(self):
"""
L{DigestCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(
verifyObject(ICredentialFactory, self.credentialFactory))
def test_getChallenge(self):
"""
The challenge issued by L{DigestCredentialFactory.getChallenge} must
include C{'qop'}, C{'realm'}, C{'algorithm'}, C{'nonce'}, and
C{'opaque'} keys. The values for the C{'realm'} and C{'algorithm'}
keys must match the values supplied to the factory's initializer.
None of the values may have newlines in them.
"""
challenge = self.credentialFactory.getChallenge(self.request)
self.assertEqual(challenge['qop'], b'auth')
self.assertEqual(challenge['realm'], b'test realm')
self.assertEqual(challenge['algorithm'], b'md5')
self.assertIn('nonce', challenge)
self.assertIn('opaque', challenge)
for v in challenge.values():
self.assertNotIn(b'\n', v)
def test_getChallengeWithoutClientIP(self):
"""
L{DigestCredentialFactory.getChallenge} can issue a challenge even if
the L{Request} it is passed returns L{None} from C{getClientIP}.
"""
request = self.makeRequest(b'GET', None)
challenge = self.credentialFactory.getChallenge(request)
self.assertEqual(challenge['qop'], b'auth')
self.assertEqual(challenge['realm'], b'test realm')
self.assertEqual(challenge['algorithm'], b'md5')
self.assertIn('nonce', challenge)
self.assertIn('opaque', challenge)
class UnauthorizedResourceTests(RequestMixin, unittest.TestCase):
"""
Tests for L{UnauthorizedResource}.
"""
def test_getChildWithDefault(self):
"""
An L{UnauthorizedResource} is every child of itself.
"""
resource = UnauthorizedResource([])
self.assertIdentical(
resource.getChildWithDefault("foo", None), resource)
self.assertIdentical(
resource.getChildWithDefault("bar", None), resource)
def _unauthorizedRenderTest(self, request):
"""
Render L{UnauthorizedResource} for the given request object and verify
that the response code is I{Unauthorized} and that a I{WWW-Authenticate}
header is set in the response containing a challenge.
"""
resource = UnauthorizedResource([
BasicCredentialFactory('example.com')])
request.render(resource)
self.assertEqual(request.responseCode, 401)
self.assertEqual(
request.responseHeaders.getRawHeaders(b'www-authenticate'),
[b'basic realm="example.com"'])
def test_render(self):
"""
L{UnauthorizedResource} renders with a 401 response code and a
I{WWW-Authenticate} header and puts a simple unauthorized message
into the response body.
"""
request = self.makeRequest()
self._unauthorizedRenderTest(request)
self.assertEqual(b'Unauthorized', b''.join(request.written))
def test_renderHEAD(self):
"""
The rendering behavior of L{UnauthorizedResource} for a I{HEAD} request
is like its handling of a I{GET} request, but no response body is
written.
"""
request = self.makeRequest(method=b'HEAD')
self._unauthorizedRenderTest(request)
self.assertEqual(b'', b''.join(request.written))
def test_renderQuotesRealm(self):
"""
The realm value included in the I{WWW-Authenticate} header set in
the response when L{UnauthorizedResounrce} is rendered has quotes
and backslashes escaped.
"""
resource = UnauthorizedResource([
BasicCredentialFactory('example\\"foo')])
request = self.makeRequest()
request.render(resource)
self.assertEqual(
request.responseHeaders.getRawHeaders(b'www-authenticate'),
[b'basic realm="example\\\\\\"foo"'])
def test_renderQuotesDigest(self):
"""
The digest value included in the I{WWW-Authenticate} header
set in the response when L{UnauthorizedResource} is rendered
has quotes and backslashes escaped.
"""
resource = UnauthorizedResource([
digest.DigestCredentialFactory(b'md5', b'example\\"foo')])
request = self.makeRequest()
request.render(resource)
authHeader = request.responseHeaders.getRawHeaders(
b'www-authenticate'
)[0]
self.assertIn(b'realm="example\\\\\\"foo"', authHeader)
self.assertIn(b'hm="md5', authHeader)
implementer(portal.IRealm)
class Realm(object):
"""
A simple L{IRealm} implementation which gives out L{WebAvatar} for any
avatarId.
@type loggedIn: C{int}
@ivar loggedIn: The number of times C{requestAvatar} has been invoked for
L{IResource}.
@type loggedOut: C{int}
@ivar loggedOut: The number of times the logout callback has been invoked.
"""
def __init__(self, avatarFactory):
self.loggedOut = 0
self.loggedIn = 0
self.avatarFactory = avatarFactory
def requestAvatar(self, avatarId, mind, *interfaces):
if IResource in interfaces:
self.loggedIn += 1
return IResource, self.avatarFactory(avatarId), self.logout
raise NotImplementedError()
def logout(self):
self.loggedOut += 1
class HTTPAuthHeaderTests(unittest.TestCase):
"""
Tests for L{HTTPAuthSessionWrapper}.
"""
makeRequest = DummyRequest
def setUp(self):
"""
Create a realm, portal, and L{HTTPAuthSessionWrapper} to use in the tests.
"""
self.username = b'foo bar'
self.password = b'bar baz'
self.avatarContent = b"contents of the avatar resource itself"
self.childName = b"foo-child"
self.childContent = b"contents of the foo child of the avatar"
self.checker = InMemoryUsernamePasswordDatabaseDontUse()
self.checker.addUser(self.username, self.password)
self.avatar = Data(self.avatarContent, 'text/plain')
self.avatar.putChild(
self.childName, Data(self.childContent, 'text/plain'))
self.avatars = {self.username: self.avatar}
self.realm = Realm(self.avatars.get)
self.portal = portal.Portal(self.realm, [self.checker])
self.credentialFactories = []
self.wrapper = HTTPAuthSessionWrapper(
self.portal, self.credentialFactories)
def _authorizedBasicLogin(self, request):
"""
Add an I{basic authorization} header to the given request and then
dispatch it, starting from C{self.wrapper} and returning the resulting
L{IResource}.
"""
authorization = b64encode(self.username + b':' + self.password)
request.requestHeaders.addRawHeader(b'authorization',
b'Basic ' + authorization)
return getChildForRequest(self.wrapper, request)
def test_getChildWithDefault(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} instance when the request does
not have the required I{Authorization} headers.
"""
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def _invalidAuthorizationTest(self, response):
"""
Create a request with the given value as the value of an
I{Authorization} header and perform resource traversal with it,
starting at C{self.wrapper}. Assert that the result is a 401 response
code. Return a L{Deferred} which fires when this is all done.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b'authorization', response)
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChildWithDefaultUnauthorizedUser(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which does not exist.
"""
return self._invalidAuthorizationTest(
b'Basic ' + b64encode(b'foo:bar'))
def test_getChildWithDefaultUnauthorizedPassword(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which exists and the wrong
password.
"""
return self._invalidAuthorizationTest(
b'Basic ' + b64encode(self.username + b':bar'))
def test_getChildWithDefaultUnrecognizedScheme(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with an unrecognized scheme.
"""
return self._invalidAuthorizationTest(b'Quux foo bar baz')
def test_getChildWithDefaultAuthorized(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{IResource} which renders the L{IResource} avatar
retrieved from the portal when the request has a valid I{Authorization}
header.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.childContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_renderAuthorized(self):
"""
Resource traversal which terminates at an L{HTTPAuthSessionWrapper}
and includes correct authentication headers results in the
L{IResource} avatar (not one of its children) retrieved from the
portal being rendered.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
# Request it exactly, not any of its children.
request = self.makeRequest([])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.avatarContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChallengeCalledWithRequest(self):
"""
When L{HTTPAuthSessionWrapper} finds an L{ICredentialFactory} to issue
a challenge, it calls the C{getChallenge} method with the request as an
argument.
"""
@implementer(ICredentialFactory)
class DumbCredentialFactory(object):
scheme = b'dumb'
def __init__(self):
self.requests = []
def getChallenge(self, request):
self.requests.append(request)
return {}
factory = DumbCredentialFactory()
self.credentialFactories.append(factory)
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(factory.requests, [request])
d.addCallback(cbFinished)
request.render(child)
return d
def _logoutTest(self):
"""
Issue a request for an authentication-protected resource using valid
credentials and then return the C{DummyRequest} instance which was
used.
This is a helper for tests about the behavior of the logout
callback.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
class SlowerResource(Resource):
def render(self, request):
return NOT_DONE_YET
self.avatar.putChild(self.childName, SlowerResource())
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(self.realm.loggedOut, 0)
return request
def test_logout(self):
"""
The realm's logout callback is invoked after the resource is rendered.
"""
request = self._logoutTest()
request.finish()
self.assertEqual(self.realm.loggedOut, 1)
def test_logoutOnError(self):
"""
The realm's logout callback is also invoked if there is an error
generating the response (for example, if the client disconnects
early).
"""
request = self._logoutTest()
request.processingFailed(
Failure(ConnectionDone("Simulated disconnect")))
self.assertEqual(self.realm.loggedOut, 1)
def test_decodeRaises(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has a I{Basic
Authorization} header which cannot be decoded using base64.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b'authorization', b'Basic decode should fail')
child = getChildForRequest(self.wrapper, request)
self.assertIsInstance(child, UnauthorizedResource)
def test_selectParseResponse(self):
"""
L{HTTPAuthSessionWrapper._selectParseHeader} returns a two-tuple giving
the L{ICredentialFactory} to use to parse the header and a string
containing the portion of the header which remains to be parsed.
"""
basicAuthorization = b'Basic abcdef123456'
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(None, None))
factory = BasicCredentialFactory('example.com')
self.credentialFactories.append(factory)
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(factory, b'abcdef123456'))
def test_unexpectedDecodeError(self):
"""
Any unexpected exception raised by the credential factory's C{decode}
method results in a 500 response code and causes the exception to be
logged.
"""
logObserver = EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
class UnexpectedException(Exception):
pass
class BadFactory(object):
scheme = b'bad'
def getChallenge(self, client):
return {}
def decode(self, response, request):
raise UnexpectedException()
self.credentialFactories.append(BadFactory())
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b'authorization', b'Bad abc')
child = getChildForRequest(self.wrapper, request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEquals(1, len(logObserver))
self.assertIsInstance(
logObserver[0]["log_failure"].value,
UnexpectedException
)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_unexpectedLoginError(self):
"""
Any unexpected failure from L{Portal.login} results in a 500 response
code and causes the failure to be logged.
"""
logObserver = EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
class UnexpectedException(Exception):
pass
class BrokenChecker(object):
credentialInterfaces = (IUsernamePassword,)
def requestAvatarId(self, credentials):
raise UnexpectedException()
self.portal.registerChecker(BrokenChecker())
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEquals(1, len(logObserver))
self.assertIsInstance(
logObserver[0]["log_failure"].value,
UnexpectedException
)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_anonymousAccess(self):
"""
Anonymous requests are allowed if a L{Portal} has an anonymous checker
registered.
"""
unprotectedContents = b"contents of the unprotected child resource"
self.avatars[ANONYMOUS] = Resource()
self.avatars[ANONYMOUS].putChild(
self.childName, Data(unprotectedContents, 'text/plain'))
self.portal.registerChecker(AllowAnonymousAccess())
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [unprotectedContents])
d.addCallback(cbFinished)
request.render(child)
return d

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,573 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test for L{twisted.web.proxy}.
"""
from twisted.trial.unittest import TestCase
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from twisted.web.server import Site
from twisted.web.proxy import ReverseProxyResource, ProxyClientFactory
from twisted.web.proxy import ProxyClient, ProxyRequest, ReverseProxyRequest
from twisted.web.test.test_web import DummyRequest
class ReverseProxyResourceTests(TestCase):
"""
Tests for L{ReverseProxyResource}.
"""
def _testRender(self, uri, expectedURI):
"""
Check that a request pointing at C{uri} produce a new proxy connection,
with the path of this request pointing at C{expectedURI}.
"""
root = Resource()
reactor = MemoryReactor()
resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path", reactor)
root.putChild(b'index', resource)
site = Site(root)
transport = StringTransportWithDisconnection()
channel = site.buildProtocol(None)
channel.makeConnection(transport)
# Clear the timeout if the tests failed
self.addCleanup(channel.connectionLost, None)
channel.dataReceived(b"GET " +
uri +
b" HTTP/1.1\r\nAccept: text/html\r\n\r\n")
[(host, port, factory, _timeout, _bind_addr)] = reactor.tcpClients
# Check that one connection has been created, to the good host/port
self.assertEqual(host, u"127.0.0.1")
self.assertEqual(port, 1234)
# Check the factory passed to the connect, and its given path
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.headers[b"host"], b"127.0.0.1:1234")
def test_render(self):
"""
Test that L{ReverseProxyResource.render} initiates a connection to the
given server with a L{ProxyClientFactory} as parameter.
"""
return self._testRender(b"/index", b"/path")
def test_render_subpage(self):
"""
Test that L{ReverseProxyResource.render} will instantiate a child
resource that will initiate a connection to the given server
requesting the apropiate url subpath.
"""
return self._testRender(b"/index/page1", b"/path/page1")
def test_renderWithQuery(self):
"""
Test that L{ReverseProxyResource.render} passes query parameters to the
created factory.
"""
return self._testRender(b"/index?foo=bar", b"/path?foo=bar")
def test_getChild(self):
"""
The L{ReverseProxyResource.getChild} method should return a resource
instance with the same class as the originating resource, forward
port, host, and reactor values, and update the path value with the
value passed.
"""
reactor = MemoryReactor()
resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path", reactor)
child = resource.getChild(b'foo', None)
# The child should keep the same class
self.assertIsInstance(child, ReverseProxyResource)
self.assertEqual(child.path, b"/path/foo")
self.assertEqual(child.port, 1234)
self.assertEqual(child.host, u"127.0.0.1")
self.assertIdentical(child.reactor, resource.reactor)
def test_getChildWithSpecial(self):
"""
The L{ReverseProxyResource} return by C{getChild} has a path which has
already been quoted.
"""
resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path")
child = resource.getChild(b' /%', None)
self.assertEqual(child.path, b"/path/%20%2F%25")
class DummyChannel(object):
"""
A dummy HTTP channel, that does nothing but holds a transport and saves
connection lost.
@ivar transport: the transport used by the client.
@ivar lostReason: the reason saved at connection lost.
"""
def __init__(self, transport):
"""
Hold a reference to the transport.
"""
self.transport = transport
self.lostReason = None
def connectionLost(self, reason):
"""
Keep track of the connection lost reason.
"""
self.lostReason = reason
def getPeer(self):
"""
Get peer information from the transport.
"""
return self.transport.getPeer()
def getHost(self):
"""
Get host information from the transport.
"""
return self.transport.getHost()
class ProxyClientTests(TestCase):
"""
Tests for L{ProxyClient}.
"""
def _parseOutHeaders(self, content):
"""
Parse the headers out of some web content.
@param content: Bytes received from a web server.
@return: A tuple of (requestLine, headers, body). C{headers} is a dict
of headers, C{requestLine} is the first line (e.g. "POST /foo ...")
and C{body} is whatever is left.
"""
headers, body = content.split(b'\r\n\r\n')
headers = headers.split(b'\r\n')
requestLine = headers.pop(0)
return (
requestLine, dict(header.split(b': ') for header in headers), body)
def makeRequest(self, path):
"""
Make a dummy request object for the URL path.
@param path: A URL path, beginning with a slash.
@return: A L{DummyRequest}.
"""
return DummyRequest(path)
def makeProxyClient(self, request, method=b"GET", headers=None,
requestBody=b""):
"""
Make a L{ProxyClient} object used for testing.
@param request: The request to use.
@param method: The HTTP method to use, GET by default.
@param headers: The HTTP headers to use expressed as a dict. If not
provided, defaults to {'accept': 'text/html'}.
@param requestBody: The body of the request. Defaults to the empty
string.
@return: A L{ProxyClient}
"""
if headers is None:
headers = {b"accept": b"text/html"}
path = b'/' + request.postpath
return ProxyClient(
method, path, b'HTTP/1.0', headers, requestBody, request)
def connectProxy(self, proxyClient):
"""
Connect a proxy client to a L{StringTransportWithDisconnection}.
@param proxyClient: A L{ProxyClient}.
@return: The L{StringTransportWithDisconnection}.
"""
clientTransport = StringTransportWithDisconnection()
clientTransport.protocol = proxyClient
proxyClient.makeConnection(clientTransport)
return clientTransport
def assertForwardsHeaders(self, proxyClient, requestLine, headers):
"""
Assert that C{proxyClient} sends C{headers} when it connects.
@param proxyClient: A L{ProxyClient}.
@param requestLine: The request line we expect to be sent.
@param headers: A dict of headers we expect to be sent.
@return: If the assertion is successful, return the request body as
bytes.
"""
self.connectProxy(proxyClient)
requestContent = proxyClient.transport.value()
receivedLine, receivedHeaders, body = self._parseOutHeaders(
requestContent)
self.assertEqual(receivedLine, requestLine)
self.assertEqual(receivedHeaders, headers)
return body
def makeResponseBytes(self, code, message, headers, body):
lines = [b"HTTP/1.0 " + str(code).encode('ascii') + b' ' + message]
for header, values in headers:
for value in values:
lines.append(header + b': ' + value)
lines.extend([b'', body])
return b'\r\n'.join(lines)
def assertForwardsResponse(self, request, code, message, headers, body):
"""
Assert that C{request} has forwarded a response from the server.
@param request: A L{DummyRequest}.
@param code: The expected HTTP response code.
@param message: The expected HTTP message.
@param headers: The expected HTTP headers.
@param body: The expected response body.
"""
self.assertEqual(request.responseCode, code)
self.assertEqual(request.responseMessage, message)
receivedHeaders = list(request.responseHeaders.getAllRawHeaders())
receivedHeaders.sort()
expectedHeaders = headers[:]
expectedHeaders.sort()
self.assertEqual(receivedHeaders, expectedHeaders)
self.assertEqual(b''.join(request.written), body)
def _testDataForward(self, code, message, headers, body, method=b"GET",
requestBody=b"", loseConnection=True):
"""
Build a fake proxy connection, and send C{data} over it, checking that
it's forwarded to the originating request.
"""
request = self.makeRequest(b'foo')
client = self.makeProxyClient(
request, method, {b'accept': b'text/html'}, requestBody)
receivedBody = self.assertForwardsHeaders(
client, method + b' /foo HTTP/1.0',
{b'connection': b'close', b'accept': b'text/html'})
self.assertEqual(receivedBody, requestBody)
# Fake an answer
client.dataReceived(
self.makeResponseBytes(code, message, headers, body))
# Check that the response data has been forwarded back to the original
# requester.
self.assertForwardsResponse(request, code, message, headers, body)
# Check that when the response is done, the request is finished.
if loseConnection:
client.transport.loseConnection()
# Even if we didn't call loseConnection, the transport should be
# disconnected. This lets us not rely on the server to close our
# sockets for us.
self.assertFalse(client.transport.connected)
self.assertEqual(request.finished, 1)
def test_forward(self):
"""
When connected to the server, L{ProxyClient} should send the saved
request, with modifications of the headers, and then forward the result
to the parent request.
"""
return self._testDataForward(
200, b"OK", [(b"Foo", [b"bar", b"baz"])], b"Some data\r\n")
def test_postData(self):
"""
Try to post content in the request, and check that the proxy client
forward the body of the request.
"""
return self._testDataForward(
200, b"OK", [(b"Foo", [b"bar"])], b"Some data\r\n", b"POST", b"Some content")
def test_statusWithMessage(self):
"""
If the response contains a status with a message, it should be
forwarded to the parent request with all the information.
"""
return self._testDataForward(
404, b"Not Found", [], b"")
def test_contentLength(self):
"""
If the response contains a I{Content-Length} header, the inbound
request object should still only have C{finish} called on it once.
"""
data = b"foo bar baz"
return self._testDataForward(
200,
b"OK",
[(b"Content-Length", [str(len(data)).encode('ascii')])],
data)
def test_losesConnection(self):
"""
If the response contains a I{Content-Length} header, the outgoing
connection is closed when all response body data has been received.
"""
data = b"foo bar baz"
return self._testDataForward(
200,
b"OK",
[(b"Content-Length", [str(len(data)).encode('ascii')])],
data,
loseConnection=False)
def test_headersCleanups(self):
"""
The headers given at initialization should be modified:
B{proxy-connection} should be removed if present, and B{connection}
should be added.
"""
client = ProxyClient(b'GET', b'/foo', b'HTTP/1.0',
{b"accept": b"text/html", b"proxy-connection": b"foo"}, b'', None)
self.assertEqual(client.headers,
{b"accept": b"text/html", b"connection": b"close"})
def test_keepaliveNotForwarded(self):
"""
The proxy doesn't really know what to do with keepalive things from
the remote server, so we stomp over any keepalive header we get from
the client.
"""
headers = {
b"accept": b"text/html",
b'keep-alive': b'300',
b'connection': b'keep-alive',
}
expectedHeaders = headers.copy()
expectedHeaders[b'connection'] = b'close'
del expectedHeaders[b'keep-alive']
client = ProxyClient(b'GET', b'/foo', b'HTTP/1.0', headers, b'', None)
self.assertForwardsHeaders(
client, b'GET /foo HTTP/1.0', expectedHeaders)
def test_defaultHeadersOverridden(self):
"""
L{server.Request} within the proxy sets certain response headers by
default. When we get these headers back from the remote server, the
defaults are overridden rather than simply appended.
"""
request = self.makeRequest(b'foo')
request.responseHeaders.setRawHeaders(b'server', [b'old-bar'])
request.responseHeaders.setRawHeaders(b'date', [b'old-baz'])
request.responseHeaders.setRawHeaders(b'content-type', [b"old/qux"])
client = self.makeProxyClient(request, headers={b'accept': b'text/html'})
self.connectProxy(client)
headers = {
b'Server': [b'bar'],
b'Date': [b'2010-01-01'],
b'Content-Type': [b'application/x-baz'],
}
client.dataReceived(
self.makeResponseBytes(200, b"OK", headers.items(), b''))
self.assertForwardsResponse(
request, 200, b'OK', list(headers.items()), b'')
class ProxyClientFactoryTests(TestCase):
"""
Tests for L{ProxyClientFactory}.
"""
def test_connectionFailed(self):
"""
Check that L{ProxyClientFactory.clientConnectionFailed} produces
a B{501} response to the parent request.
"""
request = DummyRequest([b'foo'])
factory = ProxyClientFactory(b'GET', b'/foo', b'HTTP/1.0',
{b"accept": b"text/html"}, '', request)
factory.clientConnectionFailed(None, None)
self.assertEqual(request.responseCode, 501)
self.assertEqual(request.responseMessage, b"Gateway error")
self.assertEqual(
list(request.responseHeaders.getAllRawHeaders()),
[(b"Content-Type", [b"text/html"])])
self.assertEqual(
b''.join(request.written),
b"<H1>Could not connect</H1>")
self.assertEqual(request.finished, 1)
def test_buildProtocol(self):
"""
L{ProxyClientFactory.buildProtocol} should produce a L{ProxyClient}
with the same values of attributes (with updates on the headers).
"""
factory = ProxyClientFactory(b'GET', b'/foo', b'HTTP/1.0',
{b"accept": b"text/html"}, b'Some data',
None)
proto = factory.buildProtocol(None)
self.assertIsInstance(proto, ProxyClient)
self.assertEqual(proto.command, b'GET')
self.assertEqual(proto.rest, b'/foo')
self.assertEqual(proto.data, b'Some data')
self.assertEqual(proto.headers,
{b"accept": b"text/html", b"connection": b"close"})
class ProxyRequestTests(TestCase):
"""
Tests for L{ProxyRequest}.
"""
def _testProcess(self, uri, expectedURI, method=b"GET", data=b""):
"""
Build a request pointing at C{uri}, and check that a proxied request
is created, pointing a C{expectedURI}.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(len(data))
request.handleContentChunk(data)
request.requestReceived(method, b'http://example.com' + uri,
b'HTTP/1.0')
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], u"example.com")
self.assertEqual(reactor.tcpClients[0][1], 80)
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.command, method)
self.assertEqual(factory.version, b'HTTP/1.0')
self.assertEqual(factory.headers, {b'host': b'example.com'})
self.assertEqual(factory.data, data)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.father, request)
def test_process(self):
"""
L{ProxyRequest.process} should create a connection to the given server,
with a L{ProxyClientFactory} as connection factory, with the correct
parameters:
- forward comment, version and data values
- update headers with the B{host} value
- remove the host from the URL
- pass the request as parent request
"""
return self._testProcess(b"/foo/bar", b"/foo/bar")
def test_processWithoutTrailingSlash(self):
"""
If the incoming request doesn't contain a slash,
L{ProxyRequest.process} should add one when instantiating
L{ProxyClientFactory}.
"""
return self._testProcess(b"", b"/")
def test_processWithData(self):
"""
L{ProxyRequest.process} should be able to retrieve request body and
to forward it.
"""
return self._testProcess(
b"/foo/bar", b"/foo/bar", b"POST", b"Some content")
def test_processWithPort(self):
"""
Check that L{ProxyRequest.process} correctly parse port in the incoming
URL, and create an outgoing connection with this port.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(0)
request.requestReceived(b'GET', b'http://example.com:1234/foo/bar',
b'HTTP/1.0')
# That should create one connection, with the port parsed from the URL
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], u"example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
class DummyFactory(object):
"""
A simple holder for C{host} and C{port} information.
"""
def __init__(self, host, port):
self.host = host
self.port = port
class ReverseProxyRequestTests(TestCase):
"""
Tests for L{ReverseProxyRequest}.
"""
def test_process(self):
"""
L{ReverseProxyRequest.process} should create a connection to its
factory host/port, using a L{ProxyClientFactory} instantiated with the
correct parameters, and particularly set the B{host} header to the
factory host.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ReverseProxyRequest(channel, False, reactor)
request.factory = DummyFactory(u"example.com", 1234)
request.gotLength(0)
request.requestReceived(b'GET', b'/foo/bar', b'HTTP/1.0')
# Check that one connection has been created, to the good host/port
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], u"example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
# Check the factory passed to the connect, and its headers
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.headers, {b'host': b'example.com'})

View file

@ -0,0 +1,289 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.resource}.
"""
from twisted.trial.unittest import TestCase
from twisted.python.compat import _PY3
from twisted.web.error import UnsupportedMethod
from twisted.web.resource import (
NOT_FOUND, FORBIDDEN, Resource, ErrorPage, NoResource, ForbiddenResource,
getChildForRequest)
from twisted.web.http_headers import Headers
from twisted.web.test.requesthelper import DummyRequest
class ErrorPageTests(TestCase):
"""
Tests for L{ErrorPage}, L{NoResource}, and L{ForbiddenResource}.
"""
errorPage = ErrorPage
noResource = NoResource
forbiddenResource = ForbiddenResource
def test_getChild(self):
"""
The C{getChild} method of L{ErrorPage} returns the L{ErrorPage} it is
called on.
"""
page = self.errorPage(321, "foo", "bar")
self.assertIdentical(page.getChild(b"name", object()), page)
def _pageRenderingTest(self, page, code, brief, detail):
request = DummyRequest([b''])
template = (
u"\n"
u"<html>\n"
u" <head><title>%s - %s</title></head>\n"
u" <body>\n"
u" <h1>%s</h1>\n"
u" <p>%s</p>\n"
u" </body>\n"
u"</html>\n")
expected = template % (code, brief, brief, detail)
self.assertEqual(
page.render(request), expected.encode('utf-8'))
self.assertEqual(request.responseCode, code)
self.assertEqual(
request.responseHeaders,
Headers({b'content-type': [b'text/html; charset=utf-8']}))
def test_errorPageRendering(self):
"""
L{ErrorPage.render} returns a C{bytes} describing the error defined by
the response code and message passed to L{ErrorPage.__init__}. It also
uses that response code to set the response code on the L{Request}
passed in.
"""
code = 321
brief = "brief description text"
detail = "much longer text might go here"
page = self.errorPage(code, brief, detail)
self._pageRenderingTest(page, code, brief, detail)
def test_noResourceRendering(self):
"""
L{NoResource} sets the HTTP I{NOT FOUND} code.
"""
detail = "long message"
page = self.noResource(detail)
self._pageRenderingTest(page, NOT_FOUND, "No Such Resource", detail)
def test_forbiddenResourceRendering(self):
"""
L{ForbiddenResource} sets the HTTP I{FORBIDDEN} code.
"""
detail = "longer message"
page = self.forbiddenResource(detail)
self._pageRenderingTest(page, FORBIDDEN, "Forbidden Resource", detail)
class DynamicChild(Resource):
"""
A L{Resource} to be created on the fly by L{DynamicChildren}.
"""
def __init__(self, path, request):
Resource.__init__(self)
self.path = path
self.request = request
class DynamicChildren(Resource):
"""
A L{Resource} with dynamic children.
"""
def getChild(self, path, request):
return DynamicChild(path, request)
class BytesReturnedRenderable(Resource):
"""
A L{Resource} with minimal capabilities to render a response.
"""
def __init__(self, response):
"""
@param response: A C{bytes} object giving the value to return from
C{render_GET}.
"""
Resource.__init__(self)
self._response = response
def render_GET(self, request):
"""
Render a response to a I{GET} request by returning a short byte string
to be written by the server.
"""
return self._response
class ImplicitAllowedMethods(Resource):
"""
A L{Resource} which implicitly defines its allowed methods by defining
renderers to handle them.
"""
def render_GET(self, request):
pass
def render_PUT(self, request):
pass
class ResourceTests(TestCase):
"""
Tests for L{Resource}.
"""
def test_staticChildren(self):
"""
L{Resource.putChild} adds a I{static} child to the resource. That child
is returned from any call to L{Resource.getChildWithDefault} for the
child's path.
"""
resource = Resource()
child = Resource()
sibling = Resource()
resource.putChild(b"foo", child)
resource.putChild(b"bar", sibling)
self.assertIdentical(
child, resource.getChildWithDefault(b"foo", DummyRequest([])))
def test_dynamicChildren(self):
"""
L{Resource.getChildWithDefault} delegates to L{Resource.getChild} when
the requested path is not associated with any static child.
"""
path = b"foo"
request = DummyRequest([])
resource = DynamicChildren()
child = resource.getChildWithDefault(path, request)
self.assertIsInstance(child, DynamicChild)
self.assertEqual(child.path, path)
self.assertIdentical(child.request, request)
def test_staticChildPathType(self):
"""
Test that passing the wrong type to putChild results in a warning,
and a failure in Python 3
"""
resource = Resource()
child = Resource()
sibling = Resource()
resource.putChild(u"foo", child)
warnings = self.flushWarnings([self.test_staticChildPathType])
self.assertEqual(len(warnings), 1)
self.assertIn("Path segment must be bytes",
warnings[0]['message'])
if _PY3:
# We expect an error here because u"foo" != b"foo" on Py3k
self.assertIsInstance(
resource.getChildWithDefault(b"foo", DummyRequest([])),
ErrorPage)
resource.putChild(None, sibling)
warnings = self.flushWarnings([self.test_staticChildPathType])
self.assertEqual(len(warnings), 1)
self.assertIn("Path segment must be bytes",
warnings[0]['message'])
def test_defaultHEAD(self):
"""
When not otherwise overridden, L{Resource.render} treats a I{HEAD}
request as if it were a I{GET} request.
"""
expected = b"insert response here"
request = DummyRequest([])
request.method = b'HEAD'
resource = BytesReturnedRenderable(expected)
self.assertEqual(expected, resource.render(request))
def test_explicitAllowedMethods(self):
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to the value of the
C{allowedMethods} attribute of the L{Resource}, if it has one.
"""
expected = [b'GET', b'HEAD', b'PUT']
resource = Resource()
resource.allowedMethods = expected
request = DummyRequest([])
request.method = b'FICTIONAL'
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(set(expected), set(exc.allowedMethods))
def test_implicitAllowedMethods(self):
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to a list of the
methods supported by the L{Resource}, as determined by the
I{render_}-prefixed methods which it defines, if C{allowedMethods} is
not explicitly defined by the L{Resource}.
"""
expected = set([b'GET', b'HEAD', b'PUT'])
resource = ImplicitAllowedMethods()
request = DummyRequest([])
request.method = b'FICTIONAL'
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(expected, set(exc.allowedMethods))
class GetChildForRequestTests(TestCase):
"""
Tests for L{getChildForRequest}.
"""
def test_exhaustedPostPath(self):
"""
L{getChildForRequest} returns whatever resource has been reached by the
time the request's C{postpath} is empty.
"""
request = DummyRequest([])
resource = Resource()
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_leafResource(self):
"""
L{getChildForRequest} returns the first resource it encounters with a
C{isLeaf} attribute set to C{True}.
"""
request = DummyRequest([b"foo", b"bar"])
resource = Resource()
resource.isLeaf = True
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_postPathToPrePath(self):
"""
As path segments from the request are traversed, they are taken from
C{postpath} and put into C{prepath}.
"""
request = DummyRequest([b"foo", b"bar"])
root = Resource()
child = Resource()
child.isLeaf = True
root.putChild(b"foo", child)
self.assertIdentical(child, getChildForRequest(root, request))
self.assertEqual(request.prepath, [b"foo"])
self.assertEqual(request.postpath, [b"bar"])

View file

@ -0,0 +1,115 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.script}.
"""
import os
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.web.http import NOT_FOUND
from twisted.web.script import ResourceScriptDirectory, PythonScript
from twisted.web.test._util import _render
from twisted.web.test.requesthelper import DummyRequest
class ResourceScriptDirectoryTests(TestCase):
"""
Tests for L{ResourceScriptDirectory}.
"""
def test_renderNotFound(self):
"""
L{ResourceScriptDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = ResourceScriptDirectory(self.mktemp())
request = DummyRequest([b''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{ResourceScriptDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = ResourceScriptDirectory(path)
request = DummyRequest([b'foo'])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_render(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders a
response with the HTTP 200 status code and the content of the rpy's
C{request} global.
"""
tmp = FilePath(self.mktemp())
tmp.makedirs()
tmp.child("test.rpy").setContent(b"""
from twisted.web.resource import Resource
class TestResource(Resource):
isLeaf = True
def render_GET(self, request):
return b'ok'
resource = TestResource()""")
resource = ResourceScriptDirectory(tmp._asBytesPath())
request = DummyRequest([b''])
child = resource.getChild(b"test.rpy", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(b"".join(request.written), b"ok")
d.addCallback(cbRendered)
return d
class PythonScriptTests(TestCase):
"""
Tests for L{PythonScript}.
"""
def test_notFoundRender(self):
"""
If the source file a L{PythonScript} is initialized with doesn't exist,
L{PythonScript.render} sets the HTTP response code to I{NOT FOUND}.
"""
resource = PythonScript(self.mktemp(), None)
request = DummyRequest([b''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_renderException(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders a
response with the HTTP 200 status code and the content of the rpy's
C{request} global.
"""
tmp = FilePath(self.mktemp())
tmp.makedirs()
child = tmp.child("test.epy")
child.setContent(b'raise Exception("nooo")')
resource = PythonScript(child._asBytesPath(), None)
request = DummyRequest([b''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertIn(b"nooo", b"".join(request.written))
d.addCallback(cbRendered)
return d

View file

@ -0,0 +1,166 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._stan} portion of the L{twisted.web.template}
implementation.
"""
from __future__ import absolute_import, division
from twisted.web.template import Comment, CDATA, CharRef, Tag
from twisted.trial.unittest import TestCase
from twisted.python.compat import _PY3
def proto(*a, **kw):
"""
Produce a new tag for testing.
"""
return Tag('hello')(*a, **kw)
class TagTests(TestCase):
"""
Tests for L{Tag}.
"""
def test_fillSlots(self):
"""
L{Tag.fillSlots} returns self.
"""
tag = proto()
self.assertIdentical(tag, tag.fillSlots(test='test'))
def test_cloneShallow(self):
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. If the shallow flag is C{False}, that's where it
stops.
"""
innerList = ["inner list"]
tag = proto("How are you", innerList,
hello="world", render="aSampleMethod")
tag.fillSlots(foo='bar')
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone(deep=False)
self.assertEqual(clone.attributes['hello'], 'world')
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertEqual(clone.children, ["How are you", innerList])
self.assertNotIdentical(clone.children, tag.children)
self.assertIdentical(clone.children[1], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_cloneDeep(self):
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. In its normal operating mode (where the deep flag is
C{True}, as is the default), it will clone all sub-lists and sub-tags.
"""
innerTag = proto("inner")
innerList = ["inner list"]
tag = proto("How are you", innerTag, innerList,
hello="world", render="aSampleMethod")
tag.fillSlots(foo='bar')
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone()
self.assertEqual(clone.attributes['hello'], 'world')
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertNotIdentical(clone.children, tag.children)
# sanity check
self.assertIdentical(tag.children[1], innerTag)
# clone should have sub-clone
self.assertNotIdentical(clone.children[1], innerTag)
# sanity check
self.assertIdentical(tag.children[2], innerList)
# clone should have sub-clone
self.assertNotIdentical(clone.children[2], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_clear(self):
"""
L{Tag.clear} removes all children from a tag, but leaves its attributes
in place.
"""
tag = proto("these are", "children", "cool", andSoIs='this-attribute')
tag.clear()
self.assertEqual(tag.children, [])
self.assertEqual(tag.attributes, {'andSoIs': 'this-attribute'})
def test_suffix(self):
"""
L{Tag.__call__} accepts Python keywords with a suffixed underscore as
the DOM attribute of that literal suffix.
"""
proto = Tag('div')
tag = proto()
tag(class_='a')
self.assertEqual(tag.attributes, {'class': 'a'})
def test_commentReprPy2(self):
"""
L{Comment.__repr__} returns a value which makes it easy to see what's
in the comment.
"""
self.assertEqual(repr(Comment(u"hello there")),
"Comment(u'hello there')")
def test_cdataReprPy2(self):
"""
L{CDATA.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(CDATA(u"test data")),
"CDATA(u'test data')")
def test_commentReprPy3(self):
"""
L{Comment.__repr__} returns a value which makes it easy to see what's
in the comment.
"""
self.assertEqual(repr(Comment(u"hello there")),
"Comment('hello there')")
def test_cdataReprPy3(self):
"""
L{CDATA.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(CDATA(u"test data")),
"CDATA('test data')")
if not _PY3:
test_commentReprPy3.skip = "Only relevant on Python 3."
test_cdataReprPy3.skip = "Only relevant on Python 3."
else:
test_commentReprPy2.skip = "Only relevant on Python 2."
test_cdataReprPy2.skip = "Only relevant on Python 2."
def test_charrefRepr(self):
"""
L{CharRef.__repr__} returns a value which makes it easy to see what
character is referred to.
"""
snowman = ord(u"\N{SNOWMAN}")
self.assertEqual(repr(CharRef(snowman)), "CharRef(9731)")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,346 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.tap}.
"""
from __future__ import absolute_import, division
import os
import stat
from twisted.internet import reactor, endpoints
from twisted.internet.interfaces import IReactorUNIX
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.python.threadpool import ThreadPool
from twisted.python.usage import UsageError
from twisted.spread.pb import PBServerFactory
from twisted.trial.unittest import TestCase
from twisted.web import demo
from twisted.web.distrib import ResourcePublisher, UserDirectory
from twisted.web.script import PythonScript
from twisted.web.server import Site
from twisted.web.static import Data, File
from twisted.web.tap import Options, makeService
from twisted.web.tap import makePersonalServerFactory, _AddHeadersResource
from twisted.web.test.requesthelper import DummyRequest
from twisted.web.twcgi import CGIScript
from twisted.web.wsgi import WSGIResource
application = object()
class ServiceTests(TestCase):
"""
Tests for the service creation APIs in L{twisted.web.tap}.
"""
def _pathOption(self):
"""
Helper for the I{--path} tests which creates a directory and creates
an L{Options} object which uses that directory as its static
filesystem root.
@return: A two-tuple of a L{FilePath} referring to the directory and
the value associated with the C{'root'} key in the L{Options}
instance after parsing a I{--path} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
options = Options()
options.parseOptions(['--path', path.path])
root = options['root']
return path, root
def test_path(self):
"""
The I{--path} option causes L{Options} to create a root resource
which serves responses from the specified path.
"""
path, root = self._pathOption()
self.assertIsInstance(root, File)
self.assertEqual(root.path, path.path)
def test_pathServer(self):
"""
The I{--path} option to L{makeService} causes it to return a service
which will listen on the server address given by the I{--port} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
port = self.mktemp()
options = Options()
options.parseOptions(['--port', 'unix:' + port, '--path', path.path])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertIsInstance(service.services[0].factory.resource, File)
self.assertEqual(service.services[0].factory.resource.path, path.path)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
if not IReactorUNIX.providedBy(reactor):
test_pathServer.skip = (
"The reactor does not support UNIX domain sockets")
def test_cgiProcessor(self):
"""
The I{--path} option creates a root resource which serves a
L{CGIScript} instance for any child with the C{".cgi"} extension.
"""
path, root = self._pathOption()
path.child("foo.cgi").setContent(b"")
self.assertIsInstance(root.getChild("foo.cgi", None), CGIScript)
def test_epyProcessor(self):
"""
The I{--path} option creates a root resource which serves a
L{PythonScript} instance for any child with the C{".epy"} extension.
"""
path, root = self._pathOption()
path.child("foo.epy").setContent(b"")
self.assertIsInstance(root.getChild("foo.epy", None), PythonScript)
def test_rpyProcessor(self):
"""
The I{--path} option creates a root resource which serves the
C{resource} global defined by the Python source in any child with
the C{".rpy"} extension.
"""
path, root = self._pathOption()
path.child("foo.rpy").setContent(
b"from twisted.web.static import Data\n"
b"resource = Data('content', 'major/minor')\n")
child = root.getChild("foo.rpy", None)
self.assertIsInstance(child, Data)
self.assertEqual(child.data, 'content')
self.assertEqual(child.type, 'major/minor')
def test_makePersonalServerFactory(self):
"""
L{makePersonalServerFactory} returns a PB server factory which has
as its root object a L{ResourcePublisher}.
"""
# The fact that this pile of objects can actually be used somehow is
# verified by twisted.web.test.test_distrib.
site = Site(Data(b"foo bar", "text/plain"))
serverFactory = makePersonalServerFactory(site)
self.assertIsInstance(serverFactory, PBServerFactory)
self.assertIsInstance(serverFactory.root, ResourcePublisher)
self.assertIdentical(serverFactory.root.site, site)
def test_personalServer(self):
"""
The I{--personal} option to L{makeService} causes it to return a
service which will listen on the server address given by the I{--port}
option.
"""
port = self.mktemp()
options = Options()
options.parseOptions(['--port', 'unix:' + port, '--personal'])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
if not IReactorUNIX.providedBy(reactor):
test_personalServer.skip = (
"The reactor does not support UNIX domain sockets")
def test_defaultPersonalPath(self):
"""
If the I{--port} option not specified but the I{--personal} option is,
L{Options} defaults the port to C{UserDirectory.userSocketName} in the
user's home directory.
"""
options = Options()
options.parseOptions(['--personal'])
path = os.path.expanduser(
os.path.join('~', UserDirectory.userSocketName))
self.assertEqual(options['ports'][0],
'unix:{}'.format(path))
if not IReactorUNIX.providedBy(reactor):
test_defaultPersonalPath.skip = (
"The reactor does not support UNIX domain sockets")
def test_defaultPort(self):
"""
If the I{--port} option is not specified, L{Options} defaults the port
to C{8080}.
"""
options = Options()
options.parseOptions([])
self.assertEqual(
endpoints._parseServer(options['ports'][0], None)[:2],
('TCP', (8080, None)))
def test_twoPorts(self):
"""
If the I{--http} option is given twice, there are two listeners
"""
options = Options()
options.parseOptions(['--listen', 'tcp:8001', '--listen', 'tcp:8002'])
self.assertIn('8001', options['ports'][0])
self.assertIn('8002', options['ports'][1])
def test_wsgi(self):
"""
The I{--wsgi} option takes the fully-qualifed Python name of a WSGI
application object and creates a L{WSGIResource} at the root which
serves that application.
"""
options = Options()
options.parseOptions(['--wsgi', __name__ + '.application'])
root = options['root']
self.assertTrue(root, WSGIResource)
self.assertIdentical(root._reactor, reactor)
self.assertTrue(isinstance(root._threadpool, ThreadPool))
self.assertIdentical(root._application, application)
# The threadpool should start and stop with the reactor.
self.assertFalse(root._threadpool.started)
reactor.fireSystemEvent('startup')
self.assertTrue(root._threadpool.started)
self.assertFalse(root._threadpool.joined)
reactor.fireSystemEvent('shutdown')
self.assertTrue(root._threadpool.joined)
def test_invalidApplication(self):
"""
If I{--wsgi} is given an invalid name, L{Options.parseOptions}
raises L{UsageError}.
"""
options = Options()
for name in [__name__ + '.nosuchthing', 'foo.']:
exc = self.assertRaises(
UsageError, options.parseOptions, ['--wsgi', name])
self.assertEqual(str(exc),
"No such WSGI application: %r" % (name,))
def test_HTTPSFailureOnMissingSSL(self):
"""
An L{UsageError} is raised when C{https} is requested but there is no
support for SSL.
"""
options = Options()
exception = self.assertRaises(
UsageError, options.parseOptions, ['--https=443'])
self.assertEqual('SSL support not installed', exception.args[0])
if requireModule('OpenSSL.SSL') is not None:
test_HTTPSFailureOnMissingSSL.skip = 'SSL module is available.'
def test_HTTPSAcceptedOnAvailableSSL(self):
"""
When SSL support is present, it accepts the --https option.
"""
options = Options()
options.parseOptions(['--https=443'])
self.assertIn('ssl', options['ports'][0])
self.assertIn('443', options['ports'][0])
if requireModule('OpenSSL.SSL') is None:
test_HTTPSAcceptedOnAvailableSSL.skip = 'SSL module is not available.'
def test_add_header_parsing(self):
"""
When --add-header is specific, the value is parsed.
"""
options = Options()
options.parseOptions(
['--add-header', 'K1: V1', '--add-header', 'K2: V2']
)
self.assertEqual(options['extraHeaders'], [('K1', 'V1'), ('K2', 'V2')])
def test_add_header_resource(self):
"""
When --add-header is specified, the resource is a composition that adds
headers.
"""
options = Options()
options.parseOptions(
['--add-header', 'K1: V1', '--add-header', 'K2: V2']
)
service = makeService(options)
resource = service.services[0].factory.resource
self.assertIsInstance(resource, _AddHeadersResource)
self.assertEqual(resource._headers, [('K1', 'V1'), ('K2', 'V2')])
self.assertIsInstance(resource._originalResource, demo.Test)
def test_noTracebacksDeprecation(self):
"""
Passing --notracebacks is deprecated.
"""
options = Options()
options.parseOptions(["--notracebacks"])
makeService(options)
warnings = self.flushWarnings([self.test_noTracebacksDeprecation])
self.assertEqual(warnings[0]['category'], DeprecationWarning)
self.assertEqual(
warnings[0]['message'],
"--notracebacks was deprecated in Twisted 19.7.0"
)
self.assertEqual(len(warnings), 1)
def test_displayTracebacks(self):
"""
Passing --display-tracebacks will enable traceback rendering on the
generated Site.
"""
options = Options()
options.parseOptions(["--display-tracebacks"])
service = makeService(options)
self.assertTrue(service.services[0].factory.displayTracebacks)
def test_displayTracebacksNotGiven(self):
"""
Not passing --display-tracebacks will leave traceback rendering on the
generated Site off.
"""
options = Options()
options.parseOptions([])
service = makeService(options)
self.assertFalse(service.services[0].factory.displayTracebacks)
class AddHeadersResourceTests(TestCase):
def test_getChildWithDefault(self):
"""
When getChildWithDefault is invoked, it adds the headers to the
response.
"""
resource = _AddHeadersResource(
demo.Test(), [("K1", "V1"), ("K2", "V2"), ("K1", "V3")])
request = DummyRequest([])
resource.getChildWithDefault("", request)
self.assertEqual(
request.responseHeaders.getRawHeaders("K1"), ["V1", "V3"])
self.assertEqual(request.responseHeaders.getRawHeaders("K2"), ["V2"])

View file

@ -0,0 +1,827 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.template}
"""
from __future__ import division, absolute_import
from zope.interface.verify import verifyObject
from twisted.internet.defer import succeed, gatherResults
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.trial.util import suppress as SUPPRESS
from twisted.web.template import (
Element, TagLoader, renderer, tags, XMLFile, XMLString)
from twisted.web.iweb import ITemplateLoader
from twisted.web.error import (FlattenerError, MissingTemplateLoader,
MissingRenderMethod)
from twisted.web.template import renderElement
from twisted.web._element import UnexposedMethodError
from twisted.web.test._util import FlattenTestCase
from twisted.web.test.test_web import DummyRequest
from twisted.web.server import NOT_DONE_YET
from twisted.python.compat import NativeStringIO as StringIO
from twisted.logger import globalLogPublisher
from twisted.test.proto_helpers import EventLoggingObserver
_xmlFileSuppress = SUPPRESS(category=DeprecationWarning,
message="Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.")
class TagFactoryTests(TestCase):
"""
Tests for L{_TagFactory} through the publicly-exposed L{tags} object.
"""
def test_lookupTag(self):
"""
HTML tags can be retrieved through C{tags}.
"""
tag = tags.a
self.assertEqual(tag.tagName, "a")
def test_lookupHTML5Tag(self):
"""
Twisted supports the latest and greatest HTML tags from the HTML5
specification.
"""
tag = tags.video
self.assertEqual(tag.tagName, "video")
def test_lookupTransparentTag(self):
"""
To support transparent inclusion in templates, there is a special tag,
the transparent tag, which has no name of its own but is accessed
through the "transparent" attribute.
"""
tag = tags.transparent
self.assertEqual(tag.tagName, "")
def test_lookupInvalidTag(self):
"""
Invalid tags which are not part of HTML cause AttributeErrors when
accessed through C{tags}.
"""
self.assertRaises(AttributeError, getattr, tags, "invalid")
def test_lookupXMP(self):
"""
As a special case, the <xmp> tag is simply not available through
C{tags} or any other part of the templating machinery.
"""
self.assertRaises(AttributeError, getattr, tags, "xmp")
class ElementTests(TestCase):
"""
Tests for the awesome new L{Element} class.
"""
def test_missingTemplateLoader(self):
"""
L{Element.render} raises L{MissingTemplateLoader} if the C{loader}
attribute is L{None}.
"""
element = Element()
err = self.assertRaises(MissingTemplateLoader, element.render, None)
self.assertIdentical(err.element, element)
def test_missingTemplateLoaderRepr(self):
"""
A L{MissingTemplateLoader} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self):
return 'Pretty Repr Element'
self.assertIn('Pretty Repr Element',
repr(MissingTemplateLoader(PrettyReprElement())))
def test_missingRendererMethod(self):
"""
When called with the name which is not associated with a render method,
L{Element.lookupRenderMethod} raises L{MissingRenderMethod}.
"""
element = Element()
err = self.assertRaises(
MissingRenderMethod, element.lookupRenderMethod, "foo")
self.assertIdentical(err.element, element)
self.assertEqual(err.renderName, "foo")
def test_missingRenderMethodRepr(self):
"""
A L{MissingRenderMethod} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self):
return 'Pretty Repr Element'
s = repr(MissingRenderMethod(PrettyReprElement(),
'expectedMethod'))
self.assertIn('Pretty Repr Element', s)
self.assertIn('expectedMethod', s)
def test_definedRenderer(self):
"""
When called with the name of a defined render method,
L{Element.lookupRenderMethod} returns that render method.
"""
class ElementWithRenderMethod(Element):
@renderer
def foo(self, request, tag):
return "bar"
foo = ElementWithRenderMethod().lookupRenderMethod("foo")
self.assertEqual(foo(None, None), "bar")
def test_render(self):
"""
L{Element.render} loads a document from the C{loader} attribute and
returns it.
"""
class TemplateLoader(object):
def load(self):
return "result"
class StubElement(Element):
loader = TemplateLoader()
element = StubElement()
self.assertEqual(element.render(None), "result")
def test_misuseRenderer(self):
"""
If the L{renderer} decorator is called without any arguments, it will
raise a comprehensible exception.
"""
te = self.assertRaises(TypeError, renderer)
self.assertEqual(str(te),
"expose() takes at least 1 argument (0 given)")
def test_renderGetDirectlyError(self):
"""
Called directly, without a default, L{renderer.get} raises
L{UnexposedMethodError} when it cannot find a renderer.
"""
self.assertRaises(UnexposedMethodError, renderer.get, None,
"notARenderer")
class XMLFileReprTests(TestCase):
"""
Tests for L{twisted.web.template.XMLFile}'s C{__repr__}.
"""
def test_filePath(self):
"""
An L{XMLFile} with a L{FilePath} returns a useful repr().
"""
path = FilePath("/tmp/fake.xml")
self.assertEqual('<XMLFile of %r>' % (path,), repr(XMLFile(path)))
def test_filename(self):
"""
An L{XMLFile} with a filename returns a useful repr().
"""
fname = "/tmp/fake.xml"
self.assertEqual('<XMLFile of %r>' % (fname,), repr(XMLFile(fname)))
test_filename.suppress = [_xmlFileSuppress]
def test_file(self):
"""
An L{XMLFile} with a file object returns a useful repr().
"""
fobj = StringIO("not xml")
self.assertEqual('<XMLFile of %r>' % (fobj,), repr(XMLFile(fobj)))
test_file.suppress = [_xmlFileSuppress]
class XMLLoaderTestsMixin(object):
"""
@ivar templateString: Simple template to use to exercise the loaders.
@ivar deprecatedUse: C{True} if this use of L{XMLFile} is deprecated and
should emit a C{DeprecationWarning}.
"""
loaderFactory = None
templateString = '<p>Hello, world.</p>'
def test_load(self):
"""
Verify that the loader returns a tag with the correct children.
"""
loader = self.loaderFactory()
tag, = loader.load()
warnings = self.flushWarnings(offendingFunctions=[self.loaderFactory])
if self.deprecatedUse:
self.assertEqual(len(warnings), 1)
self.assertEqual(warnings[0]['category'], DeprecationWarning)
self.assertEqual(
warnings[0]['message'],
"Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.")
else:
self.assertEqual(len(warnings), 0)
self.assertEqual(tag.tagName, 'p')
self.assertEqual(tag.children, [u'Hello, world.'])
def test_loadTwice(self):
"""
If {load()} can be called on a loader twice the result should be the
same.
"""
loader = self.loaderFactory()
tags1 = loader.load()
tags2 = loader.load()
self.assertEqual(tags1, tags2)
test_loadTwice.suppress = [_xmlFileSuppress]
class XMLStringLoaderTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLString}
"""
deprecatedUse = False
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with C{self.templateString}.
"""
return XMLString(self.templateString)
class XMLFileWithFilePathTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s L{FilePath} support.
"""
deprecatedUse = False
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a L{FilePath} pointing to a
file that contains C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString.encode("utf8"))
return XMLFile(fp)
class XMLFileWithFileTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated file object support.
"""
deprecatedUse = True
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a file object that contains
C{self.templateString}.
"""
return XMLFile(StringIO(self.templateString))
class XMLFileWithFilenameTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated filename support.
"""
deprecatedUse = True
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a filename that points to a
file containing C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString.encode('utf8'))
return XMLFile(fp.path)
class FlattenIntegrationTests(FlattenTestCase):
"""
Tests for integration between L{Element} and
L{twisted.web._flatten.flatten}.
"""
def test_roundTrip(self):
"""
Given a series of parsable XML strings, verify that
L{twisted.web._flatten.flatten} will flatten the L{Element} back to the
input when sent on a round trip.
"""
fragments = [
b"<p>Hello, world.</p>",
b"<p><!-- hello, world --></p>",
b"<p><![CDATA[Hello, world.]]></p>",
b'<test1 xmlns:test2="urn:test2">'
b'<test2:test3></test2:test3></test1>',
b'<test1 xmlns="urn:test2"><test3></test3></test1>',
b'<p>\xe2\x98\x83</p>',
]
deferreds = [
self.assertFlattensTo(Element(loader=XMLString(xml)), xml)
for xml in fragments]
return gatherResults(deferreds)
def test_entityConversion(self):
"""
When flattening an HTML entity, it should flatten out to the utf-8
representation if possible.
"""
element = Element(loader=XMLString('<p>&#9731;</p>'))
return self.assertFlattensTo(element, b'<p>\xe2\x98\x83</p>')
def test_missingTemplateLoader(self):
"""
Rendering an Element without a loader attribute raises the appropriate
exception.
"""
return self.assertFlatteningRaises(Element(), MissingTemplateLoader)
def test_missingRenderMethod(self):
"""
Flattening an L{Element} with a C{loader} which has a tag with a render
directive fails with L{FlattenerError} if there is no available render
method to satisfy that directive.
"""
element = Element(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="unknownMethod" />
"""))
return self.assertFlatteningRaises(element, MissingRenderMethod)
def test_transparentRendering(self):
"""
A C{transparent} element should be eliminated from the DOM and rendered as
only its children.
"""
element = Element(loader=XMLString(
'<t:transparent '
'xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'Hello, world.'
'</t:transparent>'
))
return self.assertFlattensTo(element, b"Hello, world.")
def test_attrRendering(self):
"""
An Element with an attr tag renders the vaule of its attr tag as an
attribute of its containing tag.
"""
element = Element(loader=XMLString(
'<a xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:attr name="href">http://example.com</t:attr>'
'Hello, world.'
'</a>'
))
return self.assertFlattensTo(element,
b'<a href="http://example.com">Hello, world.</a>')
def test_errorToplevelAttr(self):
"""
A template with a toplevel C{attr} tag will not load; it will raise
L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
name='something'
>hello</t:attr>
""")
def test_errorUnnamedAttr(self):
"""
A template with an C{attr} tag with no C{name} attribute will not load;
it will raise L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<html><t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
>hello</t:attr></html>""")
def test_lenientPrefixBehavior(self):
"""
If the parser sees a prefix it doesn't recognize on an attribute, it
will pass it on through to serialization.
"""
theInput = (
'<hello:world hello:sample="testing" '
'xmlns:hello="http://made-up.example.com/ns/not-real">'
'This is a made-up tag.</hello:world>')
element = Element(loader=XMLString(theInput))
self.assertFlattensTo(element, theInput.encode('utf8'))
def test_deferredRendering(self):
"""
An Element with a render method which returns a Deferred will render
correctly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return succeed("Hello, world.")
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""))
return self.assertFlattensTo(element, b"Hello, world.")
def test_loaderClassAttribute(self):
"""
If there is a non-None loader attribute on the class of an Element
instance but none on the instance itself, the class attribute is used.
"""
class SubElement(Element):
loader = XMLString("<p>Hello, world.</p>")
return self.assertFlattensTo(SubElement(), b"<p>Hello, world.</p>")
def test_directiveRendering(self):
"""
An Element with a valid render directive has that directive invoked and
the result added to the output.
"""
renders = []
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
renders.append((self, request))
return tag("Hello, world.")
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""))
return self.assertFlattensTo(element, b"<p>Hello, world.</p>")
def test_directiveRenderingOmittingTag(self):
"""
An Element with a render method which omits the containing tag
successfully removes that tag from the output.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return "Hello, world."
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""))
return self.assertFlattensTo(element, b"Hello, world.")
def test_elementContainingStaticElement(self):
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return tag(Element(
loader=XMLString("<em>Hello, world.</em>")))
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""))
return self.assertFlattensTo(element, b"<p><em>Hello, world.</em></p>")
def test_elementUsingSlots(self):
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return tag.fillSlots(test2='world.')
element = RenderfulElement(loader=XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"'
' t:render="renderMethod">'
'<t:slot name="test1" default="Hello, " />'
'<t:slot name="test2" />'
'</p>'
))
return self.assertFlattensTo(element, b"<p>Hello, world.</p>")
def test_elementContainingDynamicElement(self):
"""
Directives in the document factory of an Element returned from a render
method of another Element are satisfied from the correct object: the
"inner" Element.
"""
class OuterElement(Element):
@renderer
def outerMethod(self, request, tag):
return tag(InnerElement(loader=XMLString("""
<t:ignored
xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="innerMethod" />
""")))
class InnerElement(Element):
@renderer
def innerMethod(self, request, tag):
return "Hello, world."
element = OuterElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="outerMethod" />
"""))
return self.assertFlattensTo(element, b"<p>Hello, world.</p>")
def test_sameLoaderTwice(self):
"""
Rendering the output of a loader, or even the same element, should
return different output each time.
"""
sharedLoader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:transparent t:render="classCounter" /> '
'<t:transparent t:render="instanceCounter" />'
'</p>')
class DestructiveElement(Element):
count = 0
instanceCount = 0
loader = sharedLoader
@renderer
def classCounter(self, request, tag):
DestructiveElement.count += 1
return tag(str(DestructiveElement.count))
@renderer
def instanceCounter(self, request, tag):
self.instanceCount += 1
return tag(str(self.instanceCount))
e1 = DestructiveElement()
e2 = DestructiveElement()
self.assertFlattensImmediately(e1, b"<p>1 1</p>")
self.assertFlattensImmediately(e1, b"<p>2 2</p>")
self.assertFlattensImmediately(e2, b"<p>3 1</p>")
class TagLoaderTests(FlattenTestCase):
"""
Tests for L{TagLoader}.
"""
def setUp(self):
self.loader = TagLoader(tags.i('test'))
def test_interface(self):
"""
An instance of L{TagLoader} provides L{ITemplateLoader}.
"""
self.assertTrue(verifyObject(ITemplateLoader, self.loader))
def test_loadsList(self):
"""
L{TagLoader.load} returns a list, per L{ITemplateLoader}.
"""
self.assertIsInstance(self.loader.load(), list)
def test_flatten(self):
"""
L{TagLoader} can be used in an L{Element}, and flattens as the tag used
to construct the L{TagLoader} would flatten.
"""
e = Element(self.loader)
self.assertFlattensImmediately(e, b'<i>test</i>')
class TestElement(Element):
"""
An L{Element} that can be rendered successfully.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'Hello, world.'
'</p>')
class TestFailureElement(Element):
"""
An L{Element} that can be used in place of L{FailureElement} to verify
that L{renderElement} can render failures properly.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'I failed.'
'</p>')
def __init__(self, failure, loader=None):
self.failure = failure
class FailingElement(Element):
"""
An element that raises an exception when rendered.
"""
def render(self, request):
a = 42
b = 0
return a // b
class FakeSite(object):
"""
A minimal L{Site} object that we can use to test displayTracebacks
"""
displayTracebacks = False
class RenderElementTests(TestCase):
"""
Test L{renderElement}
"""
def setUp(self):
"""
Set up a common L{DummyRequest} and L{FakeSite}.
"""
self.request = DummyRequest([""])
self.request.site = FakeSite()
def test_simpleRender(self):
"""
L{renderElement} returns NOT_DONE_YET and eventually
writes the rendered L{Element} to the request before finishing the
request.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
b"".join(self.request.written),
b"<!DOCTYPE html>\n"
b"<p>Hello, world.</p>")
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailure(self):
"""
L{renderElement} handles failures by writing a minimal
error message to the request and finishing it.
"""
element = FailingElement()
d = self.request.notifyFinish()
def check(_):
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
b"".join(self.request.written),
(b'<!DOCTYPE html>\n'
b'<div style="font-size:800%;'
b'background-color:#FFF;'
b'color:#F00'
b'">An error occurred while rendering the response.</div>'))
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailureWithTraceback(self):
"""
L{renderElement} will render a traceback when rendering of
the element fails and our site is configured to display tracebacks.
"""
logObserver = EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
self.request.site.displayTracebacks = True
element = FailingElement()
d = self.request.notifyFinish()
def check(_):
self.assertEquals(1, len(logObserver))
f = logObserver[0]["log_failure"]
self.assertIsInstance(f.value, FlattenerError)
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
b"".join(self.request.written),
b"<!DOCTYPE html>\n<p>I failed.</p>")
self.assertTrue(self.request.finished)
d.addCallback(check)
renderElement(self.request, element, _failElement=TestFailureElement)
return d
def test_nonDefaultDoctype(self):
"""
L{renderElement} will write the doctype string specified by the
doctype keyword argument.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
b"".join(self.request.written),
(b'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
b' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">\n'
b'<p>Hello, world.</p>'))
d.addCallback(check)
renderElement(
self.request,
element,
doctype=(
b'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
b' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">'))
return d
def test_noneDoctype(self):
"""
L{renderElement} will not write out a doctype if the doctype keyword
argument is L{None}.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
b"".join(self.request.written),
b'<p>Hello, world.</p>')
d.addCallback(check)
renderElement(self.request, element, doctype=None)
return d

View file

@ -0,0 +1,366 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.util}.
"""
from __future__ import absolute_import, division
import gc
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase, TestCase
from twisted.internet import defer
from twisted.python.compat import _PY3, intToBytes, networkString
from twisted.web import resource, util
from twisted.web.error import FlattenerError
from twisted.web.http import FOUND
from twisted.web.server import Request
from twisted.web.template import TagLoader, flattenString, tags
from twisted.web.test.requesthelper import DummyChannel, DummyRequest
from twisted.web.util import DeferredResource
from twisted.web.util import _SourceFragmentElement, _FrameElement
from twisted.web.util import _StackElement, FailureElement, formatFailure
from twisted.web.util import redirectTo, _SourceLineElement
class RedirectToTests(TestCase):
"""
Tests for L{redirectTo}.
"""
def test_headersAndCode(self):
"""
L{redirectTo} will set the C{Location} and C{Content-Type} headers on
its request, and set the response code to C{FOUND}, so the browser will
be redirected.
"""
request = Request(DummyChannel(), True)
request.method = b'GET'
targetURL = b"http://target.example.com/4321"
redirectTo(targetURL, request)
self.assertEqual(request.code, FOUND)
self.assertEqual(
request.responseHeaders.getRawHeaders(b'location'), [targetURL])
self.assertEqual(
request.responseHeaders.getRawHeaders(b'content-type'),
[b'text/html; charset=utf-8'])
def test_redirectToUnicodeURL(self) :
"""
L{redirectTo} will raise TypeError if unicode object is passed in URL
"""
request = Request(DummyChannel(), True)
request.method = b'GET'
targetURL = u'http://target.example.com/4321'
self.assertRaises(TypeError, redirectTo, targetURL, request)
class FailureElementTests(TestCase):
"""
Tests for L{FailureElement} and related helpers which can render a
L{Failure} as an HTML string.
"""
def setUp(self):
"""
Create a L{Failure} which can be used by the rendering tests.
"""
def lineNumberProbeAlsoBroken():
message = "This is a problem"
raise Exception(message)
# Figure out the line number from which the exception will be raised.
self.base = lineNumberProbeAlsoBroken.__code__.co_firstlineno + 1
try:
lineNumberProbeAlsoBroken()
except:
self.failure = Failure(captureVars=True)
self.frame = self.failure.frames[-1]
def test_sourceLineElement(self):
"""
L{_SourceLineElement} renders a source line and line number.
"""
element = _SourceLineElement(
TagLoader(tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"))),
50, " print 'hello'")
d = flattenString(None, element)
expected = (
u"<div><span>50</span><span>"
u" \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}print 'hello'</span></div>")
d.addCallback(
self.assertEqual, expected.encode('utf-8'))
return d
def test_sourceFragmentElement(self):
"""
L{_SourceFragmentElement} renders source lines at and around the line
number indicated by a frame object.
"""
element = _SourceFragmentElement(
TagLoader(tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"),
render="sourceLines")),
self.frame)
source = [
u' \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}message = '
u'"This is a problem"',
u' \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}raise Exception(message)',
u'# Figure out the line number from which the exception will be '
u'raised.',
]
d = flattenString(None, element)
if _PY3:
stringToCheckFor = ''.join([
'<div class="snippet%sLine"><span>%d</span><span>%s</span>'
'</div>' % (
["", "Highlight"][lineNumber == 1],
self.base + lineNumber,
(u" \N{NO-BREAK SPACE}" * 4 + sourceLine))
for (lineNumber, sourceLine)
in enumerate(source)]).encode("utf8")
else:
stringToCheckFor = ''.join([
'<div class="snippet%sLine"><span>%d</span><span>%s</span>'
'</div>' % (
["", "Highlight"][lineNumber == 1],
self.base + lineNumber,
(u" \N{NO-BREAK SPACE}" * 4 + sourceLine).encode('utf8'))
for (lineNumber, sourceLine)
in enumerate(source)])
d.addCallback(self.assertEqual, stringToCheckFor)
return d
def test_frameElementFilename(self):
"""
The I{filename} renderer of L{_FrameElement} renders the filename
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="filename")),
self.frame)
d = flattenString(None, element)
d.addCallback(
# __file__ differs depending on whether an up-to-date .pyc file
# already existed.
self.assertEqual,
b"<span>" + networkString(__file__.rstrip('c')) + b"</span>")
return d
def test_frameElementLineNumber(self):
"""
The I{lineNumber} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="lineNumber")),
self.frame)
d = flattenString(None, element)
d.addCallback(
self.assertEqual, b"<span>" + intToBytes(self.base + 1) + b"</span>")
return d
def test_frameElementFunction(self):
"""
The I{function} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="function")),
self.frame)
d = flattenString(None, element)
d.addCallback(
self.assertEqual, b"<span>lineNumberProbeAlsoBroken</span>")
return d
def test_frameElementSource(self):
"""
The I{source} renderer of L{_FrameElement} renders the source code near
the source filename/line number associated with the frame object used to
initialize the L{_FrameElement}.
"""
element = _FrameElement(None, self.frame)
renderer = element.lookupRenderMethod("source")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _SourceFragmentElement)
self.assertIdentical(result.frame, self.frame)
self.assertEqual([tag], result.loader.load())
def test_stackElement(self):
"""
The I{frames} renderer of L{_StackElement} renders each stack frame in
the list of frames used to initialize the L{_StackElement}.
"""
element = _StackElement(None, self.failure.frames[:2])
renderer = element.lookupRenderMethod("frames")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, list)
self.assertIsInstance(result[0], _FrameElement)
self.assertIdentical(result[0].frame, self.failure.frames[0])
self.assertIsInstance(result[1], _FrameElement)
self.assertIdentical(result[1].frame, self.failure.frames[1])
# They must not share the same tag object.
self.assertNotEqual(result[0].loader.load(), result[1].loader.load())
self.assertEqual(2, len(result))
def test_failureElementTraceback(self):
"""
The I{traceback} renderer of L{FailureElement} renders the failure's
stack frames using L{_StackElement}.
"""
element = FailureElement(self.failure)
renderer = element.lookupRenderMethod("traceback")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _StackElement)
self.assertIdentical(result.stackFrames, self.failure.frames)
self.assertEqual([tag], result.loader.load())
def test_failureElementType(self):
"""
The I{type} renderer of L{FailureElement} renders the failure's
exception type.
"""
element = FailureElement(
self.failure, TagLoader(tags.span(render="type")))
d = flattenString(None, element)
if _PY3:
exc = b"builtins.Exception"
else:
exc = b"exceptions.Exception"
d.addCallback(
self.assertEqual, b"<span>" + exc + b"</span>")
return d
def test_failureElementValue(self):
"""
The I{value} renderer of L{FailureElement} renders the value's exception
value.
"""
element = FailureElement(
self.failure, TagLoader(tags.span(render="value")))
d = flattenString(None, element)
d.addCallback(
self.assertEqual, b'<span>This is a problem</span>')
return d
class FormatFailureTests(TestCase):
"""
Tests for L{twisted.web.util.formatFailure} which returns an HTML string
representing the L{Failure} instance passed to it.
"""
def test_flattenerError(self):
"""
If there is an error flattening the L{Failure} instance,
L{formatFailure} raises L{FlattenerError}.
"""
self.assertRaises(FlattenerError, formatFailure, object())
def test_returnsBytes(self):
"""
The return value of L{formatFailure} is a C{str} instance (not a
C{unicode} instance) with numeric character references for any non-ASCII
characters meant to appear in the output.
"""
try:
raise Exception("Fake bug")
except:
result = formatFailure(Failure())
self.assertIsInstance(result, bytes)
if _PY3:
self.assertTrue(all(ch < 128 for ch in result))
else:
self.assertTrue(all(ord(ch) < 128 for ch in result))
# Indentation happens to rely on NO-BREAK SPACE
self.assertIn(b"&#160;", result)
class SDResource(resource.Resource):
def __init__(self,default):
self.default = default
def getChildWithDefault(self, name, request):
d = defer.succeed(self.default)
resource = util.DeferredResource(d)
return resource.getChildWithDefault(name, request)
class DeferredResourceTests(SynchronousTestCase):
"""
Tests for L{DeferredResource}.
"""
def testDeferredResource(self):
r = resource.Resource()
r.isLeaf = 1
s = SDResource(r)
d = DummyRequest(['foo', 'bar', 'baz'])
resource.getChildForRequest(s, d)
self.assertEqual(d.postpath, ['bar', 'baz'])
def test_render(self):
"""
L{DeferredResource} uses the request object's C{render} method to
render the resource which is the result of the L{Deferred} being
handled.
"""
rendered = []
request = DummyRequest([])
request.render = rendered.append
result = resource.Resource()
deferredResource = DeferredResource(defer.succeed(result))
deferredResource.render(request)
self.assertEqual(rendered, [result])
def test_renderNoFailure(self):
"""
If the L{Deferred} fails, L{DeferredResource} reports the failure via
C{processingFailed}, and does not cause an unhandled error to be
logged.
"""
request = DummyRequest([])
d = request.notifyFinish()
failure = Failure(RuntimeError())
deferredResource = DeferredResource(defer.fail(failure))
deferredResource.render(request)
self.assertEqual(self.failureResultOf(d), failure)
del deferredResource
gc.collect()
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(errors, [])

View file

@ -0,0 +1,200 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.vhost}.
"""
from __future__ import absolute_import, division
from twisted.internet.defer import gatherResults
from twisted.trial.unittest import TestCase
from twisted.web.http import NOT_FOUND
from twisted.web.resource import NoResource
from twisted.web.static import Data
from twisted.web.server import Site
from twisted.web.vhost import (_HostResource,
NameVirtualHost,
VHostMonsterResource)
from twisted.web.test.test_web import DummyRequest
from twisted.web.test._util import _render
class HostResourceTests(TestCase):
"""
Tests for L{_HostResource}.
"""
def test_getChild(self):
"""
L{_HostResource.getChild} returns the proper I{Resource} for the vhost
embedded in the URL. Verify that returning the proper I{Resource}
required changing the I{Host} in the header.
"""
bazroot = Data(b'root data', "")
bazuri = Data(b'uri data', "")
baztest = Data(b'test data', "")
bazuri.putChild(b'test', baztest)
bazroot.putChild(b'uri', bazuri)
hr = _HostResource()
root = NameVirtualHost()
root.default = Data(b'default data', "")
root.addHost(b'baz.com', bazroot)
request = DummyRequest([b'uri', b'test'])
request.prepath = [b'bar', b'http', b'baz.com']
request.site = Site(root)
request.isSecure = lambda: False
request.host = b''
step = hr.getChild(b'baz.com', request) # Consumes rest of path
self.assertIsInstance(step, Data)
request = DummyRequest([b'uri', b'test'])
step = root.getChild(b'uri', request)
self.assertIsInstance(step, NoResource)
class NameVirtualHostTests(TestCase):
"""
Tests for L{NameVirtualHost}.
"""
def test_renderWithoutHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not L{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data(b"correct result", "")
request = DummyRequest([''])
self.assertEqual(
virtualHostResource.render(request), b"correct result")
def test_renderWithoutHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is L{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([''])
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_renderWithHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the resource
which is the value in the instance's C{host} dictionary corresponding
to the key indicated by the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.addHost(b'example.org', Data(b"winner", ""))
request = DummyRequest([b''])
request.requestHeaders.addRawHeader(b'host', b'example.org')
d = _render(virtualHostResource, request)
def cbRendered(ignored, request):
self.assertEqual(b''.join(request.written), b"winner")
d.addCallback(cbRendered, request)
# The port portion of the Host header should not be considered.
requestWithPort = DummyRequest([b''])
requestWithPort.requestHeaders.addRawHeader(b'host', b'example.org:8000')
dWithPort = _render(virtualHostResource, requestWithPort)
def cbRendered(ignored, requestWithPort):
self.assertEqual(b''.join(requestWithPort.written), b"winner")
dWithPort.addCallback(cbRendered, requestWithPort)
return gatherResults([d, dWithPort])
def test_renderWithUnknownHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not L{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data(b"correct data", "")
request = DummyRequest([b''])
request.requestHeaders.addRawHeader(b'host', b'example.com')
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(b''.join(request.written), b"correct data")
d.addCallback(cbRendered)
return d
def test_renderWithUnknownHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is L{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([''])
request.requestHeaders.addRawHeader(b'host', b'example.com')
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_getChild(self):
"""
L{NameVirtualHost.getChild} returns correct I{Resource} based off
the header and modifies I{Request} to ensure proper prepath and
postpath are set.
"""
virtualHostResource = NameVirtualHost()
leafResource = Data(b"leaf data", "")
leafResource.isLeaf = True
normResource = Data(b"norm data", "")
virtualHostResource.addHost(b'leaf.example.org', leafResource)
virtualHostResource.addHost(b'norm.example.org', normResource)
request = DummyRequest([])
request.requestHeaders.addRawHeader(b'host', b'norm.example.org')
request.prepath = [b'']
self.assertIsInstance(virtualHostResource.getChild(b'', request),
NoResource)
self.assertEqual(request.prepath, [b''])
self.assertEqual(request.postpath, [])
request = DummyRequest([])
request.requestHeaders.addRawHeader(b'host', b'leaf.example.org')
request.prepath = [b'']
self.assertIsInstance(virtualHostResource.getChild(b'', request),
Data)
self.assertEqual(request.prepath, [])
self.assertEqual(request.postpath, [b''])
class VHostMonsterResourceTests(TestCase):
"""
Tests for L{VHostMonsterResource}.
"""
def test_getChild(self):
"""
L{VHostMonsterResource.getChild} returns I{_HostResource} and modifies
I{Request} with correct L{Request.isSecure}.
"""
vhm = VHostMonsterResource()
request = DummyRequest([])
self.assertIsInstance(vhm.getChild(b'http', request), _HostResource)
self.assertFalse(request.isSecure())
request = DummyRequest([])
self.assertIsInstance(vhm.getChild(b'https', request), _HostResource)
self.assertTrue(request.isSecure())

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,28 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The L{_response} module contains constants for all standard HTTP codes, along
with a mapping to the corresponding phrases.
"""
from __future__ import division, absolute_import
import string
from twisted.trial import unittest
from twisted.web import _responses
class ResponseTests(unittest.TestCase):
def test_constants(self):
"""
All constants besides C{RESPONSES} defined in L{_response} are
integers and are keys in C{RESPONSES}.
"""
for sym in dir(_responses):
if sym == 'RESPONSES':
continue
if all((c == '_' or c in string.ascii_uppercase) for c in sym):
val = getattr(_responses, sym)
self.assertIsInstance(val, int)
self.assertIn(val, _responses.RESPONSES)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,915 @@
# -*- test-case-name: twisted.web.test.test_xmlrpc -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for XML-RPC support in L{twisted.web.xmlrpc}.
"""
from __future__ import division, absolute_import
from twisted.python.compat import nativeString, networkString, NativeStringIO
from io import BytesIO
import datetime
from twisted.trial import unittest
from twisted.web import xmlrpc
from twisted.web.xmlrpc import XMLRPC, payloadTemplate, addIntrospection
from twisted.web.xmlrpc import _QueryFactory, withRequest, xmlrpclib
from twisted.web import server, client, http, static
from twisted.internet import reactor, defer
from twisted.internet.error import ConnectionDone
from twisted.python import failure
from twisted.python.reflect import namedModule
from twisted.test.proto_helpers import MemoryReactor, EventLoggingObserver
from twisted.web.test.test_web import DummyRequest
from twisted.logger import (globalLogPublisher, FilteringLogObserver,
LogLevelFilterPredicate, LogLevel)
try:
namedModule('twisted.internet.ssl')
except ImportError:
sslSkip = "OpenSSL not present"
else:
sslSkip = None
class AsyncXMLRPCTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of Deferreds.
"""
def setUp(self):
self.request = DummyRequest([''])
self.request.method = 'POST'
self.request.content = NativeStringIO(
payloadTemplate % ('async', xmlrpclib.dumps(())))
result = self.result = defer.Deferred()
class AsyncResource(XMLRPC):
def xmlrpc_async(self):
return result
self.resource = AsyncResource()
def test_deferredResponse(self):
"""
If an L{XMLRPC} C{xmlrpc_*} method returns a L{defer.Deferred}, the
response to the request is the result of that L{defer.Deferred}.
"""
self.resource.render(self.request)
self.assertEqual(self.request.written, [])
self.result.callback("result")
resp = xmlrpclib.loads(b"".join(self.request.written))
self.assertEqual(resp, (('result',), None))
self.assertEqual(self.request.finished, 1)
def test_interruptedDeferredResponse(self):
"""
While waiting for the L{Deferred} returned by an L{XMLRPC} C{xmlrpc_*}
method to fire, the connection the request was issued over may close.
If this happens, neither C{write} nor C{finish} is called on the
request.
"""
self.resource.render(self.request)
self.request.processingFailed(
failure.Failure(ConnectionDone("Simulated")))
self.result.callback("result")
self.assertEqual(self.request.written, [])
self.assertEqual(self.request.finished, 0)
class TestRuntimeError(RuntimeError):
pass
class TestValueError(ValueError):
pass
class Test(XMLRPC):
# If you add xmlrpc_ methods to this class, go change test_listMethods
# below.
FAILURE = 666
NOT_FOUND = 23
SESSION_EXPIRED = 42
def xmlrpc_echo(self, arg):
return arg
# the doc string is part of the test
def xmlrpc_add(self, a, b):
"""
This function add two numbers.
"""
return a + b
xmlrpc_add.signature = [['int', 'int', 'int'],
['double', 'double', 'double']]
# the doc string is part of the test
def xmlrpc_pair(self, string, num):
"""
This function puts the two arguments in an array.
"""
return [string, num]
xmlrpc_pair.signature = [['array', 'string', 'int']]
# the doc string is part of the test
def xmlrpc_defer(self, x):
"""Help for defer."""
return defer.succeed(x)
def xmlrpc_deferFail(self):
return defer.fail(TestValueError())
# don't add a doc string, it's part of the test
def xmlrpc_fail(self):
raise TestRuntimeError
def xmlrpc_fault(self):
return xmlrpc.Fault(12, "hello")
def xmlrpc_deferFault(self):
return defer.fail(xmlrpc.Fault(17, "hi"))
def xmlrpc_snowman(self, payload):
"""
Used to test that we can pass Unicode.
"""
snowman = u"\u2603"
if snowman != payload:
return xmlrpc.Fault(13, "Payload not unicode snowman")
return snowman
def xmlrpc_complex(self):
return {"a": ["b", "c", 12, []], "D": "foo"}
def xmlrpc_dict(self, map, key):
return map[key]
xmlrpc_dict.help = 'Help for dict.'
@withRequest
def xmlrpc_withRequest(self, request, other):
"""
A method decorated with L{withRequest} which can be called by
a test to verify that the request object really is passed as
an argument.
"""
return (
# as a proof that request is a request
request.method +
# plus proof other arguments are still passed along
' ' + other)
def lookupProcedure(self, procedurePath):
try:
return XMLRPC.lookupProcedure(self, procedurePath)
except xmlrpc.NoSuchFunction:
if procedurePath.startswith("SESSION"):
raise xmlrpc.Fault(self.SESSION_EXPIRED,
"Session non-existent/expired.")
else:
raise
class TestLookupProcedure(XMLRPC):
"""
This is a resource which customizes procedure lookup to be used by the tests
of support for this customization.
"""
def echo(self, x):
return x
def lookupProcedure(self, procedureName):
"""
Lookup a procedure from a fixed set of choices, either I{echo} or
I{system.listeMethods}.
"""
if procedureName == 'echo':
return self.echo
raise xmlrpc.NoSuchFunction(
self.NOT_FOUND, 'procedure %s not found' % (procedureName,))
class TestListProcedures(XMLRPC):
"""
This is a resource which customizes procedure enumeration to be used by the
tests of support for this customization.
"""
def listProcedures(self):
"""
Return a list of a single method this resource will claim to support.
"""
return ['foo']
class TestAuthHeader(Test):
"""
This is used to get the header info so that we can test
authentication.
"""
def __init__(self):
Test.__init__(self)
self.request = None
def render(self, request):
self.request = request
return Test.render(self, request)
def xmlrpc_authinfo(self):
return self.request.getUser(), self.request.getPassword()
class TestQueryProtocol(xmlrpc.QueryProtocol):
"""
QueryProtocol for tests that saves headers received and sent,
inside the factory.
"""
def connectionMade(self):
self.factory.transport = self.transport
xmlrpc.QueryProtocol.connectionMade(self)
def handleHeader(self, key, val):
self.factory.headers[key.lower()] = val
def sendHeader(self, key, val):
"""
Keep sent headers so we can inspect them later.
"""
self.factory.sent_headers[key.lower()] = val
xmlrpc.QueryProtocol.sendHeader(self, key, val)
class TestQueryFactory(xmlrpc._QueryFactory):
"""
QueryFactory using L{TestQueryProtocol} for saving headers.
"""
protocol = TestQueryProtocol
def __init__(self, *args, **kwargs):
self.headers = {}
self.sent_headers = {}
xmlrpc._QueryFactory.__init__(self, *args, **kwargs)
class TestQueryFactoryCancel(xmlrpc._QueryFactory):
"""
QueryFactory that saves a reference to the
L{twisted.internet.interfaces.IConnector} to test connection lost.
"""
def startedConnecting(self, connector):
self.connector = connector
class XMLRPCTests(unittest.TestCase):
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(Test()),
interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def tearDown(self):
self.factories = []
return self.p.stopListening()
def queryFactory(self, *args, **kwargs):
"""
Specific queryFactory for proxy that uses our custom
L{TestQueryFactory}, and save factories.
"""
factory = TestQueryFactory(*args, **kwargs)
self.factories.append(factory)
return factory
def proxy(self, factory=None):
"""
Return a new xmlrpc.Proxy for the test site created in
setUp(), using the given factory as the queryFactory, or
self.queryFactory if no factory is provided.
"""
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % self.port))
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
def test_results(self):
inputOutput = [
("add", (2, 3), 5),
("defer", ("a",), "a"),
("dict", ({"a": 1}, "a"), 1),
("pair", ("a", 1), ["a", 1]),
("snowman", (u"\u2603"), u"\u2603"),
("complex", (), {"a": ["b", "c", 12, []], "D": "foo"})]
dl = []
for meth, args, outp in inputOutput:
d = self.proxy().callRemote(meth, *args)
d.addCallback(self.assertEqual, outp)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_headers(self):
"""
Verify that headers sent from the client side and the ones we
get back from the server side are correct.
"""
d = self.proxy().callRemote("snowman", u"\u2603")
def check_server_headers(ing):
self.assertEqual(
self.factories[0].headers[b'content-type'],
b'text/xml; charset=utf-8')
self.assertEqual(
self.factories[0].headers[b'content-length'], b'129')
def check_client_headers(ign):
self.assertEqual(
self.factories[0].sent_headers[b'user-agent'],
b'Twisted/XMLRPClib')
self.assertEqual(
self.factories[0].sent_headers[b'content-type'],
b'text/xml; charset=utf-8')
self.assertEqual(
self.factories[0].sent_headers[b'content-length'], b'155')
d.addCallback(check_server_headers)
d.addCallback(check_client_headers)
return d
def test_errors(self):
"""
Verify that for each way a method exposed via XML-RPC can fail, the
correct 'Content-type' header is set in the response and that the
client-side Deferred is errbacked with an appropriate C{Fault}
instance.
"""
logObserver = EventLoggingObserver()
filtered = FilteringLogObserver(
logObserver,
[LogLevelFilterPredicate(defaultLogLevel=LogLevel.critical)]
)
globalLogPublisher.addObserver(filtered)
self.addCleanup(lambda: globalLogPublisher.removeObserver(filtered))
dl = []
for code, methodName in [(666, "fail"), (666, "deferFail"),
(12, "fault"), (23, "noSuchMethod"),
(17, "deferFault"), (42, "SESSION_TEST")]:
d = self.proxy().callRemote(methodName)
d = self.assertFailure(d, xmlrpc.Fault)
d.addCallback(lambda exc, code=code:
self.assertEqual(exc.faultCode, code))
dl.append(d)
d = defer.DeferredList(dl, fireOnOneErrback=True)
def cb(ign):
for factory in self.factories:
self.assertEqual(factory.headers[b'content-type'],
b'text/xml; charset=utf-8')
self.assertEquals(2, len(logObserver))
f1 = logObserver[0]["log_failure"].value
f2 = logObserver[1]["log_failure"].value
if isinstance(f1, TestValueError):
self.assertIsInstance(f2, TestRuntimeError)
else:
self.assertIsInstance(f1, TestRuntimeError)
self.assertIsInstance(f2, TestValueError)
self.flushLoggedErrors(TestRuntimeError, TestValueError)
d.addCallback(cb)
return d
def test_cancel(self):
"""
A deferred from the Proxy can be cancelled, disconnecting
the L{twisted.internet.interfaces.IConnector}.
"""
def factory(*args, **kw):
factory.f = TestQueryFactoryCancel(*args, **kw)
return factory.f
d = self.proxy(factory).callRemote('add', 2, 3)
self.assertNotEqual(factory.f.connector.state, "disconnected")
d.cancel()
self.assertEqual(factory.f.connector.state, "disconnected")
d = self.assertFailure(d, defer.CancelledError)
return d
def test_errorGet(self):
"""
A classic GET on the xml server should return a NOT_ALLOWED.
"""
agent = client.Agent(reactor)
d = agent.request(b"GET", networkString("http://127.0.0.1:%d/" % (self.port,)))
def checkResponse(response):
self.assertEqual(response.code, http.NOT_ALLOWED)
d.addCallback(checkResponse)
return d
def test_errorXMLContent(self):
"""
Test that an invalid XML input returns an L{xmlrpc.Fault}.
"""
agent = client.Agent(reactor)
d = agent.request(
uri=networkString("http://127.0.0.1:%d/" % (self.port,)),
method=b"POST",
bodyProducer=client.FileBodyProducer(BytesIO(b"foo")))
d.addCallback(client.readBody)
def cb(result):
self.assertRaises(xmlrpc.Fault, xmlrpclib.loads, result)
d.addCallback(cb)
return d
def test_datetimeRoundtrip(self):
"""
If an L{xmlrpclib.DateTime} is passed as an argument to an XML-RPC
call and then returned by the server unmodified, the result should
be equal to the original object.
"""
when = xmlrpclib.DateTime()
d = self.proxy().callRemote("echo", when)
d.addCallback(self.assertEqual, when)
return d
def test_doubleEncodingError(self):
"""
If it is not possible to encode a response to the request (for example,
because L{xmlrpclib.dumps} raises an exception when encoding a
L{Fault}) the exception which prevents the response from being
generated is logged and the request object is finished anyway.
"""
logObserver = EventLoggingObserver()
filtered = FilteringLogObserver(
logObserver,
[LogLevelFilterPredicate(defaultLogLevel=LogLevel.critical)]
)
globalLogPublisher.addObserver(filtered)
self.addCleanup(lambda: globalLogPublisher.removeObserver(filtered))
d = self.proxy().callRemote("echo", "")
# *Now* break xmlrpclib.dumps. Hopefully the client already used it.
def fakeDumps(*args, **kwargs):
raise RuntimeError("Cannot encode anything at all!")
self.patch(xmlrpclib, 'dumps', fakeDumps)
# It doesn't matter how it fails, so long as it does. Also, it happens
# to fail with an implementation detail exception right now, not
# something suitable as part of a public interface.
d = self.assertFailure(d, Exception)
def cbFailed(ignored):
# The fakeDumps exception should have been logged.
self.assertEquals(1, len(logObserver))
self.assertIsInstance(
logObserver[0]["log_failure"].value,
RuntimeError
)
self.assertEqual(len(self.flushLoggedErrors(RuntimeError)), 1)
d.addCallback(cbFailed)
return d
def test_closeConnectionAfterRequest(self):
"""
The connection to the web server is closed when the request is done.
"""
d = self.proxy().callRemote('echo', '')
def responseDone(ignored):
[factory] = self.factories
self.assertFalse(factory.transport.connected)
self.assertTrue(factory.transport.disconnected)
return d.addCallback(responseDone)
def test_tcpTimeout(self):
"""
For I{HTTP} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectTCP call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy(b"http://127.0.0.1:69", connectTimeout=2.0,
reactor=reactor)
proxy.callRemote("someMethod")
self.assertEqual(reactor.tcpClients[0][3], 2.0)
def test_sslTimeout(self):
"""
For I{HTTPS} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectSSL call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy(b"https://127.0.0.1:69", connectTimeout=3.0,
reactor=reactor)
proxy.callRemote("someMethod")
self.assertEqual(reactor.sslClients[0][4], 3.0)
test_sslTimeout.skip = sslSkip
class XMLRPCProxyWithoutSlashTests(XMLRPCTests):
"""
Test with proxy that doesn't add a slash.
"""
def proxy(self, factory=None):
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d" % self.port))
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
class XMLRPCPublicLookupProcedureTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of subclasses which override
C{lookupProcedure} and C{listProcedures}.
"""
def createServer(self, resource):
self.p = reactor.listenTCP(
0, server.Site(resource), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(
networkString('http://127.0.0.1:%d' % self.port))
def test_lookupProcedure(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to find
procedures that are not defined using a C{xmlrpc_}-prefixed method name.
"""
self.createServer(TestLookupProcedure())
what = "hello"
d = self.proxy.callRemote("echo", what)
d.addCallback(self.assertEqual, what)
return d
def test_errors(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to raise
L{NoSuchFunction} to indicate that a requested method is not available
to be called, signalling a fault to the XML-RPC client.
"""
self.createServer(TestLookupProcedure())
d = self.proxy.callRemote("xxxx", "hello")
d = self.assertFailure(d, xmlrpc.Fault)
return d
def test_listMethods(self):
"""
A subclass of L{XMLRPC} can override C{listProcedures} to define
Overriding listProcedures should prevent introspection from being
broken.
"""
resource = TestListProcedures()
addIntrospection(resource)
self.createServer(resource)
d = self.proxy.callRemote("system.listMethods")
def listed(procedures):
# The list will also include other introspection procedures added by
# addIntrospection. We just want to see "foo" from our customized
# listProcedures.
self.assertIn('foo', procedures)
d.addCallback(listed)
return d
class SerializationConfigMixin:
"""
Mixin which defines a couple tests which should pass when a particular flag
is passed to L{XMLRPC}.
These are not meant to be exhaustive serialization tests, since L{xmlrpclib}
does all of the actual serialization work. They are just meant to exercise
a few codepaths to make sure we are calling into xmlrpclib correctly.
@ivar flagName: A C{str} giving the name of the flag which must be passed to
L{XMLRPC} to allow the tests to pass. Subclasses should set this.
@ivar value: A value which the specified flag will allow the serialization
of. Subclasses should set this.
"""
def setUp(self):
"""
Create a new XML-RPC server with C{allowNone} set to C{True}.
"""
kwargs = {self.flagName: True}
self.p = reactor.listenTCP(
0, server.Site(Test(**kwargs)), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(
networkString("http://127.0.0.1:%d/" % (self.port,)), **kwargs)
def test_roundtripValue(self):
"""
C{self.value} can be round-tripped over an XMLRPC method call/response.
"""
d = self.proxy.callRemote('defer', self.value)
d.addCallback(self.assertEqual, self.value)
return d
def test_roundtripNestedValue(self):
"""
A C{dict} which contains C{self.value} can be round-tripped over an
XMLRPC method call/response.
"""
d = self.proxy.callRemote('defer', {'a': self.value})
d.addCallback(self.assertEqual, {'a': self.value})
return d
class XMLRPCAllowNoneTests(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing L{None} when the C{allowNone} flag is set.
"""
flagName = "allowNone"
value = None
class XMLRPCUseDateTimeTests(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing a C{datetime.datetime} instance when the C{useDateTime}
flag is set.
"""
flagName = "useDateTime"
value = datetime.datetime(2000, 12, 28, 3, 45, 59)
class XMLRPCAuthenticatedTests(XMLRPCTests):
"""
Test with authenticated proxy. We run this with the same input/output as
above.
"""
user = b"username"
password = b"asecret"
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(TestAuthHeader()),
interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_authInfoInURL(self):
url = "http://%s:%s@127.0.0.1:%d/" % (
nativeString(self.user), nativeString(self.password), self.port)
p = xmlrpc.Proxy(networkString(url))
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_explicitAuthInfo(self):
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % (
self.port,)), self.user, self.password)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_longPassword(self):
"""
C{QueryProtocol} uses the C{base64.b64encode} function to encode user
name and password in the I{Authorization} header, so that it doesn't
embed new lines when using long inputs.
"""
longPassword = self.password * 40
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % (
self.port,)), self.user, longPassword)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, longPassword])
return d
def test_explicitAuthInfoOverride(self):
p = xmlrpc.Proxy(networkString("http://wrong:info@127.0.0.1:%d/" % (
self.port,)), self.user, self.password)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
class XMLRPCIntrospectionTests(XMLRPCTests):
def setUp(self):
xmlrpc = Test()
addIntrospection(xmlrpc)
self.p = reactor.listenTCP(0, server.Site(xmlrpc),interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_listMethods(self):
def cbMethods(meths):
meths.sort()
self.assertEqual(
meths,
['add', 'complex', 'defer', 'deferFail',
'deferFault', 'dict', 'echo', 'fail', 'fault',
'pair', 'snowman', 'system.listMethods',
'system.methodHelp',
'system.methodSignature', 'withRequest'])
d = self.proxy().callRemote("system.listMethods")
d.addCallback(cbMethods)
return d
def test_methodHelp(self):
inputOutputs = [
("defer", "Help for defer."),
("fail", ""),
("dict", "Help for dict.")]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodHelp", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_methodSignature(self):
inputOutputs = [
("defer", ""),
("add", [['int', 'int', 'int'],
['double', 'double', 'double']]),
("pair", [['array', 'string', 'int']])]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodSignature", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
class XMLRPCClientErrorHandlingTests(unittest.TestCase):
"""
Test error handling on the xmlrpc client.
"""
def setUp(self):
self.resource = static.Data(
b"This text is not a valid XML-RPC response.",
b"text/plain")
self.resource.isLeaf = True
self.port = reactor.listenTCP(0, server.Site(self.resource),
interface='127.0.0.1')
def tearDown(self):
return self.port.stopListening()
def test_erroneousResponse(self):
"""
Test that calling the xmlrpc client on a static http server raises
an exception.
"""
proxy = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" %
(self.port.getHost().port,)))
return self.assertFailure(proxy.callRemote("someMethod"), ValueError)
class QueryFactoryParseResponseTests(unittest.TestCase):
"""
Test the behaviour of L{_QueryFactory.parseResponse}.
"""
def setUp(self):
# The _QueryFactory that we are testing. We don't care about any
# of the constructor parameters.
self.queryFactory = _QueryFactory(
path=None, host=None, method='POST', user=None, password=None,
allowNone=False, args=())
# An XML-RPC response that will parse without raising an error.
self.goodContents = xmlrpclib.dumps(('',))
# An 'XML-RPC response' that will raise a parsing error.
self.badContents = 'invalid xml'
# A dummy 'reason' to pass to clientConnectionLost. We don't care
# what it is.
self.reason = failure.Failure(ConnectionDone())
def test_parseResponseCallbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as a callback
of L{_QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addCallback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.goodContents)
return d
def test_parseResponseErrbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as an errback
of L{_QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.badContents)
return d
def test_badStatusErrbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as an errback
of L{_QueryFactory.badStatus}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.badStatus('status', 'message')
return d
def test_parseResponseWithoutData(self):
"""
Some server can send a response without any data:
L{_QueryFactory.parseResponse} should catch the error and call the
result errback.
"""
content = """
<methodResponse>
<params>
<param>
</param>
</params>
</methodResponse>"""
d = self.queryFactory.deferred
self.queryFactory.parseResponse(content)
return self.assertFailure(d, IndexError)
class XMLRPCWithRequestTests(unittest.TestCase):
def setUp(self):
self.resource = Test()
def test_withRequest(self):
"""
When an XML-RPC method is called and the implementation is
decorated with L{withRequest}, the request object is passed as
the first argument.
"""
request = DummyRequest('/RPC2')
request.method = "POST"
request.content = NativeStringIO(xmlrpclib.dumps(
("foo",), 'withRequest'))
def valid(n, request):
data = xmlrpclib.loads(request.written[0])
self.assertEqual(data, (('POST foo',), None))
d = request.notifyFinish().addCallback(valid, request)
self.resource.render_POST(request)
return d